cog_task/action/core/
reaction.rs

1use crate::action::{Action, ActionSignal, Props, StatefulAction, INFINITE};
2use crate::comm::{QWriter, Signal, SignalId};
3use crate::resource::{IoManager, Key, LoggerSignal, ResourceManager};
4use crate::server::{AsyncSignal, Config, State, SyncSignal};
5use eyre::{eyre, Error, Result};
6use serde::{Deserialize, Serialize};
7use serde_cbor::Value;
8use std::collections::BTreeSet;
9use std::time::{Duration, Instant};
10
11#[derive(Debug, Deserialize, Serialize)]
12#[serde(deny_unknown_fields)]
13pub struct Reaction {
14    times: Vec<f32>,
15    #[serde(default = "defaults::group")]
16    group: String,
17    #[serde(default = "defaults::keys")]
18    keys: BTreeSet<Key>,
19    #[serde(default = "defaults::tol")]
20    tol: f32,
21    #[serde(default)]
22    out_rt: SignalId,
23    #[serde(default)]
24    out_accuracy: SignalId,
25    #[serde(default)]
26    out_mean_rt: SignalId,
27    #[serde(default)]
28    out_recall: SignalId,
29}
30
31stateful!(Reaction {
32    group: String,
33    keys: BTreeSet<Key>,
34    times: Vec<Duration>,
35    tol: Duration,
36    since: Instant,
37    next: Option<usize>,
38    reaction_correct: Vec<bool>,
39    reaction_times: Vec<f32>,
40    reaction_rts: Vec<f32>,
41    out_rt: SignalId,
42    out_accuracy: SignalId,
43    out_mean_rt: SignalId,
44    out_recall: SignalId,
45});
46
47mod defaults {
48    use crate::resource::Key;
49    use std::collections::BTreeSet;
50
51    #[inline(always)]
52    pub fn group() -> String {
53        "reaction".to_owned()
54    }
55
56    #[inline(always)]
57    pub fn keys() -> BTreeSet<Key> {
58        BTreeSet::new()
59    }
60
61    #[inline(always)]
62    pub fn tol() -> f32 {
63        2.0
64    }
65}
66
67impl Action for Reaction {
68    #[inline]
69    fn out_signals(&self) -> BTreeSet<SignalId> {
70        BTreeSet::from([
71            self.out_rt,
72            self.out_accuracy,
73            self.out_mean_rt,
74            self.out_recall,
75        ])
76    }
77
78    #[inline(always)]
79    fn init(mut self) -> Result<Box<dyn Action>, Error>
80    where
81        Self: 'static + Sized,
82    {
83        if self.group.is_empty() {
84            return Err(eyre!("Reaction `group` cannot be an empty string"));
85        }
86
87        self.times.sort_by(|a, b| a.partial_cmp(b).unwrap());
88        Ok(Box::new(self))
89    }
90
91    fn stateful(
92        &self,
93        _io: &IoManager,
94        _res: &ResourceManager,
95        _config: &Config,
96        _sync_writer: &QWriter<SyncSignal>,
97        _async_writer: &QWriter<AsyncSignal>,
98    ) -> Result<Box<dyn StatefulAction>> {
99        Ok(Box::new(StatefulReaction {
100            done: false,
101            group: self.group.clone(),
102            keys: self.keys.clone(),
103            times: self
104                .times
105                .iter()
106                .map(|t| Duration::from_secs_f32(*t))
107                .collect(),
108            tol: Duration::from_secs_f32(self.tol),
109            since: Instant::now(),
110            next: Some(0),
111            reaction_correct: vec![],
112            reaction_times: vec![],
113            reaction_rts: vec![],
114            out_rt: self.out_rt,
115            out_accuracy: self.out_accuracy,
116            out_recall: self.out_recall,
117            out_mean_rt: self.out_mean_rt,
118        }))
119    }
120}
121
122impl StatefulAction for StatefulReaction {
123    impl_stateful!();
124
125    #[inline(always)]
126    fn props(&self) -> Props {
127        INFINITE.into()
128    }
129
130    fn start(
131        &mut self,
132        _sync_writer: &mut QWriter<SyncSignal>,
133        async_writer: &mut QWriter<AsyncSignal>,
134        _state: &State,
135    ) -> Result<Signal> {
136        self.since = Instant::now();
137        async_writer.push(LoggerSignal::Append(
138            self.group.clone(),
139            ("event".to_owned(), Value::Text("start".to_owned())),
140        ));
141        Ok(Signal::none())
142    }
143
144    fn update(
145        &mut self,
146        signal: &ActionSignal,
147        sync_writer: &mut QWriter<SyncSignal>,
148        async_writer: &mut QWriter<AsyncSignal>,
149        _state: &State,
150    ) -> Result<Signal> {
151        let (time, keys) = match signal {
152            ActionSignal::KeyPress(t, k) => (t.duration_since(self.since), k),
153            _ => return Ok(Signal::none()),
154        };
155
156        if !self.keys.is_empty() && keys.is_disjoint(&self.keys) {
157            return Ok(Signal::none());
158        }
159
160        self.reaction_times.push(time.as_secs_f32());
161
162        let mut correct = false;
163        if self.next.is_none() {
164            self.reaction_correct.push(false);
165        } else {
166            while let Some(i) = self.next {
167                let target = self.times[i];
168                if time < target {
169                    self.reaction_correct.push(false);
170                    break;
171                } else if time <= target + self.tol {
172                    correct = true;
173                    let rt = (time - target).as_secs_f32();
174                    self.reaction_correct.push(true);
175                    self.reaction_rts.push(rt);
176                    if i < self.times.len() - 1 {
177                        self.next = Some(i + 1);
178                    } else {
179                        self.next = None;
180                    }
181                    if self.out_rt > 0 {
182                        sync_writer.push(SyncSignal::Emit(
183                            Instant::now(),
184                            vec![(self.out_rt, Value::Float(rt as f64))].into(),
185                        ))
186                    }
187                    break;
188                } else if i < self.times.len() - 1 {
189                    self.next = Some(i + 1);
190                } else {
191                    self.next = None;
192                    self.reaction_correct.push(false);
193                    break;
194                }
195            }
196        }
197
198        let entry = if correct {
199            (
200                "correct".to_string(),
201                Value::Array(vec![
202                    Value::Float(self.reaction_times[self.reaction_times.len() - 1] as f64),
203                    Value::Float(self.reaction_rts[self.reaction_rts.len() - 1] as f64),
204                ]),
205            )
206        } else {
207            (
208                "incorrect".to_string(),
209                Value::Array(vec![Value::Float(
210                    self.reaction_times[self.reaction_times.len() - 1] as f64,
211                )]),
212            )
213        };
214        async_writer.push(LoggerSignal::Append(self.group.clone(), entry));
215
216        Ok(Signal::none())
217    }
218
219    #[inline]
220    fn stop(
221        &mut self,
222        _sync_writer: &mut QWriter<SyncSignal>,
223        async_writer: &mut QWriter<AsyncSignal>,
224        _state: &State,
225    ) -> Result<Signal> {
226        let accuracy = self.accuracy();
227        let mean_rt = self.mean_rt();
228        let recall = self.recall();
229
230        async_writer.push(LoggerSignal::Extend(
231            self.group.clone(),
232            vec![
233                ("event".to_owned(), Value::Text("stop".to_owned())),
234                ("accuracy".to_owned(), Value::Float(accuracy)),
235                ("mean_rt".to_owned(), Value::Float(mean_rt)),
236                ("recall".to_owned(), Value::Float(recall)),
237            ],
238        ));
239
240        let mut news = vec![];
241        if self.out_accuracy > 0 {
242            news.push((self.out_accuracy, Value::Float(accuracy)))
243        }
244        if self.out_mean_rt > 0 {
245            news.push((self.out_mean_rt, Value::Float(mean_rt)))
246        }
247        if self.out_recall > 0 {
248            news.push((self.out_recall, Value::Float(recall)))
249        }
250        Ok(news.into())
251    }
252
253    fn debug(&self) -> Vec<(&str, String)> {
254        <dyn StatefulAction>::debug(self)
255            .into_iter()
256            .chain([("group", format!("{:?}", self.group))])
257            .collect()
258    }
259}
260
261impl StatefulReaction {
262    #[inline(always)]
263    fn accuracy(&self) -> f64 {
264        self.reaction_rts.len() as f64 / self.reaction_correct.len() as f64
265    }
266
267    #[inline(always)]
268    fn recall(&self) -> f64 {
269        self.reaction_rts.len() as f64 / self.times.len() as f64
270    }
271
272    #[inline(always)]
273    fn mean_rt(&self) -> f64 {
274        self.reaction_rts.iter().sum::<f32>() as f64 / self.reaction_rts.len() as f64
275    }
276}