cog_task/action/core/
function.rs

1use crate::action::{Action, ActionSignal, Props, StatefulAction, DEFAULT, INFINITE};
2use crate::comm::{QWriter, Signal, SignalId};
3use crate::resource::{
4    Evaluator, Interpreter, IoManager, LoggerSignal, OptionalPath, OptionalString, ResourceAddr,
5    ResourceManager, ResourceValue,
6};
7use crate::server::{AsyncSignal, Config, State, SyncSignal};
8use eyre::{eyre, Context, Error, Result};
9use regex::Regex;
10use serde::{Deserialize, Serialize};
11use serde_cbor::Value;
12use std::collections::{BTreeMap, BTreeSet};
13use std::time::Instant;
14
15#[derive(Debug, Deserialize, Serialize)]
16#[serde(deny_unknown_fields)]
17pub struct Function {
18    #[serde(default)]
19    name: String,
20    #[serde(default)]
21    expr: OptionalString,
22    #[serde(default)]
23    src: OptionalPath,
24    #[serde(default)]
25    init_expr: OptionalString,
26    #[serde(default)]
27    init_src: OptionalPath,
28    #[serde(default)]
29    vars: BTreeMap<String, Value>,
30    #[serde(default)]
31    interpreter: Interpreter,
32    #[serde(default = "defaults::on_start")]
33    on_start: bool,
34    #[serde(default = "defaults::on_change")]
35    on_change: bool,
36    #[serde(default)]
37    once: bool,
38    #[serde(default)]
39    in_mapping: BTreeMap<SignalId, String>,
40    #[serde(default)]
41    in_update: SignalId,
42    #[serde(default)]
43    lo_response: SignalId,
44    #[serde(default)]
45    out_result: SignalId,
46}
47
48stateful!(Function {
49    name: String,
50    vars: BTreeMap<String, Value>,
51    evaluator: Evaluator,
52    on_start: bool,
53    on_change: bool,
54    once: bool,
55    in_mapping: BTreeMap<SignalId, String>,
56    in_update: SignalId,
57    lo_response: SignalId,
58    out_result: SignalId,
59});
60
61mod defaults {
62    pub fn on_start() -> bool {
63        true
64    }
65
66    pub fn on_change() -> bool {
67        true
68    }
69}
70
71impl Action for Function {
72    #[inline(always)]
73    fn init(self) -> Result<Box<dyn Action>, Error>
74    where
75        Self: 'static + Sized,
76    {
77        match (self.expr.is_some(), self.src.is_some()) {
78            (false, false) => Err(eyre!("`expr` and `src` cannot both be empty."))?,
79            (true, true) => Err(eyre!("Only one of `expr` and `src` should be set."))?,
80            _ => {}
81        };
82
83        if self.init_expr.is_some() && self.init_src.is_some() {
84            return Err(eyre!(
85                "Only one of `init_expr` and `init_src` should be set."
86            ));
87        }
88
89        let re = Regex::new(r"^[[:alpha:]][[:word:]]*$").unwrap();
90        for (_, var) in self.in_mapping.iter() {
91            if var.as_str() == "self" {
92                return Err(eyre!(
93                    "Reserved variable (\"self\") of Fn cannot be included in `in_mapping`."
94                ));
95            } else if !re.is_match(var) {
96                return Err(eyre!("Invalid variable name ({var}) in `in_mapping`."));
97            }
98        }
99
100        if self.out_result != 0
101            && (self.in_mapping.contains_key(&self.out_result) || self.in_update == self.out_result)
102        {
103            return Err(eyre!("Recursive expression not allowed."));
104        }
105
106        if self.in_update != 0 && self.in_mapping.contains_key(&self.in_update) {
107            return Err(eyre!("`in_update` cannot overlap with `in_mapping`."));
108        }
109
110        Ok(Box::new(self))
111    }
112
113    #[inline]
114    fn in_signals(&self) -> BTreeSet<SignalId> {
115        let mut signals: BTreeSet<_> = self.in_mapping.keys().cloned().collect();
116        signals.extend([self.in_update, self.lo_response]);
117        signals
118    }
119
120    #[inline]
121    fn out_signals(&self) -> BTreeSet<SignalId> {
122        BTreeSet::from([self.lo_response, self.out_result])
123    }
124
125    fn resources(&self, _config: &Config) -> Vec<ResourceAddr> {
126        let mut resources = vec![];
127        if let OptionalPath::Some(src) = &self.src {
128            resources.push(ResourceAddr::Text(src.clone()));
129        }
130        if let OptionalPath::Some(src) = &self.init_src {
131            resources.push(ResourceAddr::Text(src.clone()));
132        }
133        resources
134    }
135
136    fn stateful(
137        &self,
138        _io: &IoManager,
139        res: &ResourceManager,
140        config: &Config,
141        _sync_writer: &QWriter<SyncSignal>,
142        _async_writer: &QWriter<AsyncSignal>,
143    ) -> Result<Box<dyn StatefulAction>> {
144        let interpreter = self.interpreter.or(&config.interpreter());
145
146        let init = if let OptionalPath::Some(src) = &self.init_src {
147            match res.fetch(&ResourceAddr::Text(src.clone()))? {
148                ResourceValue::Text(expr) => (*expr).clone(),
149                _ => return Err(eyre!("Resource address and value types don't match.")),
150            }
151        } else if let OptionalString::Some(expr) = &self.init_expr {
152            expr.clone()
153        } else {
154            "".to_owned()
155        }
156        .trim()
157        .to_owned();
158
159        let expr = if let OptionalPath::Some(src) = &self.src {
160            match res.fetch(&ResourceAddr::Text(src.clone()))? {
161                ResourceValue::Text(expr) => (*expr).clone(),
162                _ => return Err(eyre!("Resource address and value types don't match.")),
163            }
164        } else if let OptionalString::Some(expr) = &self.expr {
165            expr.clone()
166        } else {
167            "".to_owned()
168        }
169        .trim()
170        .to_owned();
171
172        if expr.is_empty() {
173            return Err(eyre!("Fn expression cannot be empty."));
174        }
175
176        let mut vars = self.vars.clone();
177        vars.entry("self".to_owned()).or_insert(Value::Null);
178
179        for (_, var) in self.in_mapping.iter() {
180            if !vars.contains_key(var) {
181                return Err(eyre!("Undefined variable ({var}) in `in_mapping`."));
182            }
183        }
184
185        let evaluator = interpreter
186            .parse(&init, &expr, &mut vars)
187            .wrap_err("Failed to initialize function evaluator.")?;
188
189        Ok(Box::new(StatefulFunction {
190            done: false,
191            name: self.name.clone(),
192            vars,
193            evaluator,
194            on_start: self.on_start,
195            on_change: self.on_change,
196            once: self.once,
197            in_mapping: self.in_mapping.clone(),
198            in_update: self.in_update,
199            lo_response: self.lo_response,
200            out_result: self.out_result,
201        }))
202    }
203}
204
205impl StatefulAction for StatefulFunction {
206    impl_stateful!();
207
208    #[inline(always)]
209    fn props(&self) -> Props {
210        if self.once { DEFAULT } else { INFINITE }.into()
211    }
212
213    fn start(
214        &mut self,
215        sync_writer: &mut QWriter<SyncSignal>,
216        async_writer: &mut QWriter<AsyncSignal>,
217        state: &State,
218    ) -> Result<Signal> {
219        for (id, var) in self.in_mapping.iter() {
220            if let Some(entry) = self.vars.get_mut(var) {
221                if let Some(value) = state.get(id) {
222                    *entry = value.clone();
223                }
224            }
225        }
226
227        if self.on_start {
228            if self.once && self.lo_response == 0 {
229                self.done = true;
230                sync_writer.push(SyncSignal::UpdateGraph);
231            }
232
233            self.eval(sync_writer, async_writer)
234                .wrap_err("Failed to evaluate function.")
235        } else {
236            Ok(Signal::none())
237        }
238    }
239
240    fn update(
241        &mut self,
242        signal: &ActionSignal,
243        sync_writer: &mut QWriter<SyncSignal>,
244        async_writer: &mut QWriter<AsyncSignal>,
245        state: &State,
246    ) -> Result<Signal> {
247        let mut news: Vec<(SignalId, Value)> = vec![];
248        let mut changed = false;
249        let mut updated = false;
250        if let ActionSignal::StateChanged(_, signal) = signal {
251            for id in signal {
252                if let Some(var) = self.in_mapping.get(id) {
253                    if let Some(entry) = self.vars.get_mut(var) {
254                        *entry = state.get(id).unwrap().clone();
255                    }
256                    changed = true;
257                }
258
259                if *id == self.lo_response {
260                    let result = state.get(id).unwrap();
261                    self.vars.insert("self".to_owned(), result.clone());
262
263                    if !self.name.is_empty() {
264                        async_writer.push(LoggerSignal::Append(
265                            "function".to_owned(),
266                            (self.name.clone(), result.clone()),
267                        ));
268                    }
269
270                    if self.out_result > 0 {
271                        news.push((self.out_result, result.clone()));
272                    }
273
274                    if self.once {
275                        self.done = true;
276                        sync_writer.push(SyncSignal::UpdateGraph);
277                    }
278                }
279            }
280
281            if signal.contains(&self.in_update) {
282                updated = true;
283            }
284        }
285
286        if (changed && self.on_change) || updated {
287            news.extend(
288                self.eval(sync_writer, async_writer)
289                    .wrap_err("Failed to evaluate function.")?,
290            );
291        }
292
293        Ok(news.into())
294    }
295
296    fn debug(&self) -> Vec<(&str, String)> {
297        <dyn StatefulAction>::debug(self)
298            .into_iter()
299            .chain([("name", format!("{:?}", self.name))])
300            .collect()
301    }
302}
303
304impl StatefulFunction {
305    #[inline(always)]
306    fn eval(
307        &mut self,
308        sync_writer: &mut QWriter<SyncSignal>,
309        async_writer: &mut QWriter<AsyncSignal>,
310    ) -> Result<Signal> {
311        if self.lo_response > 0 {
312            self.eval_lazy(sync_writer)
313        } else {
314            self.eval_blocking(async_writer)
315        }
316    }
317
318    fn eval_blocking(&mut self, async_writer: &mut QWriter<AsyncSignal>) -> Result<Signal> {
319        let result = self.evaluator.eval(&mut self.vars)?;
320
321        self.vars.insert("self".to_owned(), result.clone());
322
323        if !self.name.is_empty() {
324            async_writer.push(LoggerSignal::Append(
325                "function".to_owned(),
326                (self.name.clone(), result.clone()),
327            ));
328        }
329
330        if self.out_result > 0 {
331            Ok(vec![(self.out_result, result)].into())
332        } else {
333            Ok(Signal::none())
334        }
335    }
336
337    fn eval_lazy(&mut self, sync_writer: &mut QWriter<SyncSignal>) -> Result<Signal> {
338        let loopback = {
339            let signal_id = self.lo_response;
340            let mut sync_writer = sync_writer.clone();
341
342            Box::new(move |value: Value| {
343                sync_writer.push(SyncSignal::Emit(
344                    Instant::now(),
345                    Signal::from(vec![(signal_id, value)]),
346                ));
347            })
348        };
349
350        let error = {
351            let mut sync_writer = sync_writer.clone();
352
353            Box::new(move |e: Error| {
354                sync_writer.push(SyncSignal::Error(e));
355            })
356        };
357
358        self.evaluator.eval_lazy(&mut self.vars, loopback, error)?;
359        Ok(Signal::none())
360    }
361}