1use crate::action::{Action, ActionSignal, Props, StatefulAction, INFINITE};
2use crate::comm::{QWriter, Signal, SignalId};
3use crate::resource::{IoManager, Key, LoggerSignal, ResourceManager};
4use crate::server::{AsyncSignal, Config, State, SyncSignal};
5use eyre::{eyre, Error, Result};
6use serde::{Deserialize, Serialize};
7use serde_cbor::Value;
8use std::collections::BTreeSet;
9use std::time::{Duration, Instant};
10
11#[derive(Debug, Deserialize, Serialize)]
12#[serde(deny_unknown_fields)]
13pub struct Reaction {
14 times: Vec<f32>,
15 #[serde(default = "defaults::group")]
16 group: String,
17 #[serde(default = "defaults::keys")]
18 keys: BTreeSet<Key>,
19 #[serde(default = "defaults::tol")]
20 tol: f32,
21 #[serde(default)]
22 out_rt: SignalId,
23 #[serde(default)]
24 out_accuracy: SignalId,
25 #[serde(default)]
26 out_mean_rt: SignalId,
27 #[serde(default)]
28 out_recall: SignalId,
29}
30
31stateful!(Reaction {
32 group: String,
33 keys: BTreeSet<Key>,
34 times: Vec<Duration>,
35 tol: Duration,
36 since: Instant,
37 next: Option<usize>,
38 reaction_correct: Vec<bool>,
39 reaction_times: Vec<f32>,
40 reaction_rts: Vec<f32>,
41 out_rt: SignalId,
42 out_accuracy: SignalId,
43 out_mean_rt: SignalId,
44 out_recall: SignalId,
45});
46
47mod defaults {
48 use crate::resource::Key;
49 use std::collections::BTreeSet;
50
51 #[inline(always)]
52 pub fn group() -> String {
53 "reaction".to_owned()
54 }
55
56 #[inline(always)]
57 pub fn keys() -> BTreeSet<Key> {
58 BTreeSet::new()
59 }
60
61 #[inline(always)]
62 pub fn tol() -> f32 {
63 2.0
64 }
65}
66
67impl Action for Reaction {
68 #[inline]
69 fn out_signals(&self) -> BTreeSet<SignalId> {
70 BTreeSet::from([
71 self.out_rt,
72 self.out_accuracy,
73 self.out_mean_rt,
74 self.out_recall,
75 ])
76 }
77
78 #[inline(always)]
79 fn init(mut self) -> Result<Box<dyn Action>, Error>
80 where
81 Self: 'static + Sized,
82 {
83 if self.group.is_empty() {
84 return Err(eyre!("Reaction `group` cannot be an empty string"));
85 }
86
87 self.times.sort_by(|a, b| a.partial_cmp(b).unwrap());
88 Ok(Box::new(self))
89 }
90
91 fn stateful(
92 &self,
93 _io: &IoManager,
94 _res: &ResourceManager,
95 _config: &Config,
96 _sync_writer: &QWriter<SyncSignal>,
97 _async_writer: &QWriter<AsyncSignal>,
98 ) -> Result<Box<dyn StatefulAction>> {
99 Ok(Box::new(StatefulReaction {
100 done: false,
101 group: self.group.clone(),
102 keys: self.keys.clone(),
103 times: self
104 .times
105 .iter()
106 .map(|t| Duration::from_secs_f32(*t))
107 .collect(),
108 tol: Duration::from_secs_f32(self.tol),
109 since: Instant::now(),
110 next: Some(0),
111 reaction_correct: vec![],
112 reaction_times: vec![],
113 reaction_rts: vec![],
114 out_rt: self.out_rt,
115 out_accuracy: self.out_accuracy,
116 out_recall: self.out_recall,
117 out_mean_rt: self.out_mean_rt,
118 }))
119 }
120}
121
122impl StatefulAction for StatefulReaction {
123 impl_stateful!();
124
125 #[inline(always)]
126 fn props(&self) -> Props {
127 INFINITE.into()
128 }
129
130 fn start(
131 &mut self,
132 _sync_writer: &mut QWriter<SyncSignal>,
133 async_writer: &mut QWriter<AsyncSignal>,
134 _state: &State,
135 ) -> Result<Signal> {
136 self.since = Instant::now();
137 async_writer.push(LoggerSignal::Append(
138 self.group.clone(),
139 ("event".to_owned(), Value::Text("start".to_owned())),
140 ));
141 Ok(Signal::none())
142 }
143
144 fn update(
145 &mut self,
146 signal: &ActionSignal,
147 sync_writer: &mut QWriter<SyncSignal>,
148 async_writer: &mut QWriter<AsyncSignal>,
149 _state: &State,
150 ) -> Result<Signal> {
151 let (time, keys) = match signal {
152 ActionSignal::KeyPress(t, k) => (t.duration_since(self.since), k),
153 _ => return Ok(Signal::none()),
154 };
155
156 if !self.keys.is_empty() && keys.is_disjoint(&self.keys) {
157 return Ok(Signal::none());
158 }
159
160 self.reaction_times.push(time.as_secs_f32());
161
162 let mut correct = false;
163 if self.next.is_none() {
164 self.reaction_correct.push(false);
165 } else {
166 while let Some(i) = self.next {
167 let target = self.times[i];
168 if time < target {
169 self.reaction_correct.push(false);
170 break;
171 } else if time <= target + self.tol {
172 correct = true;
173 let rt = (time - target).as_secs_f32();
174 self.reaction_correct.push(true);
175 self.reaction_rts.push(rt);
176 if i < self.times.len() - 1 {
177 self.next = Some(i + 1);
178 } else {
179 self.next = None;
180 }
181 if self.out_rt > 0 {
182 sync_writer.push(SyncSignal::Emit(
183 Instant::now(),
184 vec![(self.out_rt, Value::Float(rt as f64))].into(),
185 ))
186 }
187 break;
188 } else if i < self.times.len() - 1 {
189 self.next = Some(i + 1);
190 } else {
191 self.next = None;
192 self.reaction_correct.push(false);
193 break;
194 }
195 }
196 }
197
198 let entry = if correct {
199 (
200 "correct".to_string(),
201 Value::Array(vec![
202 Value::Float(self.reaction_times[self.reaction_times.len() - 1] as f64),
203 Value::Float(self.reaction_rts[self.reaction_rts.len() - 1] as f64),
204 ]),
205 )
206 } else {
207 (
208 "incorrect".to_string(),
209 Value::Array(vec![Value::Float(
210 self.reaction_times[self.reaction_times.len() - 1] as f64,
211 )]),
212 )
213 };
214 async_writer.push(LoggerSignal::Append(self.group.clone(), entry));
215
216 Ok(Signal::none())
217 }
218
219 #[inline]
220 fn stop(
221 &mut self,
222 _sync_writer: &mut QWriter<SyncSignal>,
223 async_writer: &mut QWriter<AsyncSignal>,
224 _state: &State,
225 ) -> Result<Signal> {
226 let accuracy = self.accuracy();
227 let mean_rt = self.mean_rt();
228 let recall = self.recall();
229
230 async_writer.push(LoggerSignal::Extend(
231 self.group.clone(),
232 vec![
233 ("event".to_owned(), Value::Text("stop".to_owned())),
234 ("accuracy".to_owned(), Value::Float(accuracy)),
235 ("mean_rt".to_owned(), Value::Float(mean_rt)),
236 ("recall".to_owned(), Value::Float(recall)),
237 ],
238 ));
239
240 let mut news = vec![];
241 if self.out_accuracy > 0 {
242 news.push((self.out_accuracy, Value::Float(accuracy)))
243 }
244 if self.out_mean_rt > 0 {
245 news.push((self.out_mean_rt, Value::Float(mean_rt)))
246 }
247 if self.out_recall > 0 {
248 news.push((self.out_recall, Value::Float(recall)))
249 }
250 Ok(news.into())
251 }
252
253 fn debug(&self) -> Vec<(&str, String)> {
254 <dyn StatefulAction>::debug(self)
255 .into_iter()
256 .chain([("group", format!("{:?}", self.group))])
257 .collect()
258 }
259}
260
261impl StatefulReaction {
262 #[inline(always)]
263 fn accuracy(&self) -> f64 {
264 self.reaction_rts.len() as f64 / self.reaction_correct.len() as f64
265 }
266
267 #[inline(always)]
268 fn recall(&self) -> f64 {
269 self.reaction_rts.len() as f64 / self.times.len() as f64
270 }
271
272 #[inline(always)]
273 fn mean_rt(&self) -> f64 {
274 self.reaction_rts.iter().sum::<f32>() as f64 / self.reaction_rts.len() as f64
275 }
276}