1use core::f64;
5use std::{collections::HashSet, path::Path};
6
7use serde::{Deserialize, Serialize};
8use serde_json;
9
10pub type StateName = String;
11
12#[cfg(feature = "python")]
13use pyo3::prelude::*;
14
15use super::*;
16
17mod lookup;
18mod sequence;
19mod transition;
20
21pub use lookup::{InterpMethod, SequenceLookup};
22pub use sequence::Sequence;
23pub use transition::{ThreshOp, Timeout, Transition};
24
25#[derive(Default, Debug)]
26struct ExecutionState {
27 pub sequence_time_s: f64,
30
31 pub current_sequence: String,
33
34 pub input_index_map: BTreeMap<String, usize>,
38
39 pub dt_s: f64,
41
42 pub input_indices: Vec<usize>,
44
45 pub output_range: Range<usize>,
47}
48
49#[derive(Default, Debug, Serialize, Deserialize)]
51pub struct MachineCfg {
52 pub save_outputs: bool,
55
56 pub entry: String,
58
59 pub link_folder: Option<String>,
61
62 pub timeouts: BTreeMap<String, Timeout>,
64
65 pub transitions: BTreeMap<String, BTreeMap<String, Vec<Transition>>>,
67}
68
69#[derive(Serialize, Deserialize, Debug)]
75#[cfg_attr(feature = "python", pyclass)]
76pub struct SequenceMachine {
77 cfg: MachineCfg,
79
80 sequences: BTreeMap<String, Sequence>,
90
91 #[serde(skip)]
94 execution_state: ExecutionState,
95}
96
97impl Default for SequenceMachine {
98 fn default() -> Self {
99 Self {
100 cfg: MachineCfg {
101 entry: "Placeholder".into(),
102 ..Default::default()
103 },
104 sequences: BTreeMap::from([("Placeholder".into(), Sequence::default())]),
105 execution_state: ExecutionState::default(),
106 }
107 }
108}
109
110impl SequenceMachine {
111 pub fn new(
113 cfg: MachineCfg,
114 sequences: BTreeMap<String, Sequence>,
115 ) -> Result<Box<Self>, String> {
116 let input_indices = Vec::new();
119 let output_range = usize::MAX..usize::MAX;
120 let entry = cfg.entry.to_owned();
121
122 let machine = Self {
123 cfg,
124 sequences,
125 execution_state: ExecutionState {
126 sequence_time_s: f64::NAN,
127 current_sequence: entry,
128 input_index_map: BTreeMap::new(),
129 dt_s: f64::NAN,
130 input_indices,
131 output_range,
132 },
133 };
134
135 machine.validate()?;
136
137 Ok(Box::new(machine))
138 }
139
140 fn validate(&self) -> Result<(), String> {
142 for seq in self.sequences.values() {
144 seq.validate()?;
145 }
146
147 let seq_names: HashSet<String> = self.sequences.keys().cloned().collect();
149
150 let timeout_seq_names: HashSet<String> = self.cfg.timeouts.keys().cloned().collect();
152 if seq_names != timeout_seq_names {
153 return Err(format!(
154 "Timeouts do not match sequences. Sequence names that are present in sequences but not timeouts: `{:?}`. Sequence names that are present in timeouts but not sequences: {:?}",
155 seq_names.difference(&timeout_seq_names),
156 timeout_seq_names.difference(&seq_names)
157 ));
158 }
159
160 let transition_seq_names: HashSet<String> = self.cfg.timeouts.keys().cloned().collect();
162 if seq_names != transition_seq_names {
163 return Err(format!(
164 "Transitions do not match sequences. Sequence names that are present in sequences but not timeouts: `{:?}`. Sequence names that are present in transitions but not sequences: {:?}",
165 seq_names.difference(&transition_seq_names),
166 transition_seq_names.difference(&seq_names)
167 ));
168 }
169
170 for (seq, transitions) in self.cfg.transitions.iter() {
172 for target_sequence in transitions.keys() {
173 if !seq_names.contains(target_sequence) {
174 return Err(format!(
175 "Sequence `{seq}` has transition target sequence `{target_sequence}` which does not exist."
176 ));
177 }
178 }
179 }
180
181 Ok(())
182 }
183
184 #[cfg(feature = "python")]
185 fn add_transition(
186 &mut self,
187 source_sequence: String,
188 target_sequence: String,
189 transition: Transition,
190 ) -> Result<(), String> {
191 if !self.sequences.contains_key(&source_sequence) {
192 return Err(format!("Unknown source sequence: {source_sequence}"));
193 }
194 if !self.sequences.contains_key(&target_sequence) {
195 return Err(format!("Unknown target sequence: {target_sequence}"));
196 }
197
198 self.cfg
199 .transitions
200 .entry(source_sequence)
201 .or_insert_with(BTreeMap::new)
202 .entry(target_sequence)
203 .or_insert_with(Vec::new)
204 .push(transition);
205
206 Ok(())
207 }
208
209 fn current_sequence(&self) -> &Sequence {
211 &self.sequences[&self.execution_state.current_sequence]
212 }
213
214 fn entry_sequence(&self) -> &Sequence {
216 &self.sequences[&self.cfg.entry]
217 }
218
219 fn transition(&mut self, target_sequence: String) {
221 self.execution_state.current_sequence = target_sequence;
222 self.execution_state.sequence_time_s = self.current_sequence().get_start_time_s();
223 }
224
225 fn check_transitions(&mut self, sequence_time_s: f64, tape: &[f64]) -> Result<(), String> {
229 let sequence_name = &self.execution_state.current_sequence;
230
231 if sequence_time_s > self.current_sequence().get_end_time_s() {
233 return match &self.cfg.timeouts[sequence_name] {
234 Timeout::Transition(target_sequence) => {
235 self.transition(target_sequence.clone());
237 Ok(())
238 }
239 Timeout::Loop => {
240 self.transition(sequence_name.clone());
242 Ok(())
243 }
244 };
245 }
246
247 for (target_sequence, criteria) in self.cfg.transitions[sequence_name].iter() {
249 for criterion in criteria {
251 let should_transition = match criterion {
252 Transition::ConstantThresh(channel, op, thresh) => {
253 let i = self.execution_state.input_index_map[channel];
254 let v = tape[i];
255
256 op.eval(v, *thresh)
257 }
258 Transition::ChannelThresh(val_channel, op, thresh_channel) => {
259 let ival = self.execution_state.input_index_map[val_channel];
260 let ithresh = self.execution_state.input_index_map[thresh_channel];
261 let v = tape[ival];
262 let thresh = tape[ithresh];
263
264 op.eval(v, thresh)
265 }
266 Transition::LookupThresh(channel, op, lookup) => {
267 let i = self.execution_state.input_index_map[channel];
268 let v = tape[i];
269 let thresh = lookup.eval(sequence_time_s);
270
271 op.eval(v, thresh)
272 }
273 };
274
275 if should_transition {
278 self.transition(target_sequence.clone());
280 return Ok(());
281 }
282 }
283 }
284
285 Ok(())
287 }
288
289 pub fn load_folder(path: &dyn AsRef<Path>) -> Result<Box<Self>, String> {
293 let dir = std::fs::read_dir(path)
294 .map_err(|e| format!("Unable to read items in folder {:?}: {e}", path.as_ref()))?;
295
296 let mut csv_files = Vec::new();
297 let mut json_files = Vec::new();
298 for e in dir.flatten() {
299 let path = e.path();
300 if path.is_file() {
301 match path.extension() {
302 Some(ext) if ext.to_ascii_lowercase().to_str() == Some("csv") => {
303 csv_files.push(path)
304 }
305 Some(ext) if ext.to_ascii_lowercase().to_str() == Some("json") => {
306 json_files.push(path)
307 }
308 _ => {}
309 }
310 }
311 }
312
313 if json_files.is_empty() {
315 return Err("Did not find configuration json file".to_string());
316 }
317
318 if json_files.len() > 1 {
319 return Err(format!("Found multiple config json files: {json_files:?}"));
320 }
321
322 let json_file = &json_files[0];
324 let json_str = std::fs::read_to_string(json_file)
325 .map_err(|e| format!("Failed to read config json: {e}"))?;
326 let cfg: MachineCfg = serde_json::from_str(&json_str)
327 .map_err(|e| format!("Failed to parse config json: {e}"))?;
328
329 let mut sequences = BTreeMap::new();
331 for fp in csv_files {
332 let name = fp
334 .file_stem()
335 .ok_or_else(|| "Filename missing".to_string())?
336 .to_str()
337 .ok_or_else(|| "Filename is not valid unicode".to_string())?
338 .to_owned();
339 let seq: Sequence = Sequence::from_csv_file(&fp)?;
340 sequences.insert(name, seq);
341 }
342
343 Self::new(cfg, sequences)
344 }
345
346 pub fn save_folder(&self, path: &dyn AsRef<Path>) -> Result<(), String> {
348 let dir = path.as_ref();
349 std::fs::create_dir_all(dir)
350 .map_err(|e| format!("Unable to create folder {:?}: {e}", dir))?;
351
352 let cfg_path = dir.join("cfg.json");
353 let cfg_json = serde_json::to_string_pretty(&self.cfg)
354 .map_err(|e| format!("Failed to serialize config json: {e}"))?;
355 std::fs::write(&cfg_path, cfg_json)
356 .map_err(|e| format!("Failed to write config json: {e}"))?;
357
358 for (name, seq) in self.sequences.iter() {
359 let csv_path = dir.join(format!("{name}.csv"));
360 seq.to_csv(&csv_path)?;
361 }
362
363 Ok(())
364 }
365}
366
367#[typetag::serde]
368impl Calc for SequenceMachine {
369 fn init(
371 &mut self,
372 ctx: ControllerCtx,
373 input_indices: Vec<usize>,
374 output_range: Range<usize>,
375 ) -> Result<(), String> {
376 if let Some(rel_path) = &self.cfg.link_folder {
378 let folder = ctx.op_dir.join(rel_path);
379 *self = *Self::load_folder(&folder)
380 .map_err(|e| format!("Failed to load sequence machine from linked folder: {e}"))?;
381 }
382
383 self.terminate()?;
385
386 self.execution_state.input_indices = input_indices;
388 self.execution_state.output_range = output_range;
389 self.execution_state.dt_s = ctx.dt_ns as f64 / 1e9;
390
391 let entry_order: Vec<String> = self.current_sequence().data.keys().cloned().collect();
393 for s in self.sequences.values_mut() {
394 s.permute(&entry_order);
395 }
396
397 self.execution_state.input_index_map = BTreeMap::new();
400 for (i, name) in self
401 .execution_state
402 .input_indices
403 .iter()
404 .cloned()
405 .zip(self.get_input_names().iter())
406 {
407 self.execution_state.input_index_map.insert(name.clone(), i);
408 }
409
410 self.validate()
412 }
413
414 fn terminate(&mut self) -> Result<(), String> {
415 self.execution_state.input_indices.clear();
416 self.execution_state.output_range = usize::MAX..usize::MAX;
417 let start_time = self
418 .sequences
419 .get(&self.cfg.entry)
420 .ok_or_else(|| "Missing sequence".to_string())?
421 .get_start_time_s();
422 self.execution_state.sequence_time_s = start_time;
423 self.execution_state.current_sequence = self.cfg.entry.clone();
424 Ok(())
425 }
426
427 fn eval(&mut self, tape: &mut [f64]) -> Result<(), String> {
428 self.execution_state.sequence_time_s += self.execution_state.dt_s;
430 self.check_transitions(self.execution_state.sequence_time_s, tape)?;
432
433 self.current_sequence().eval(
435 self.execution_state.sequence_time_s,
436 self.execution_state.output_range.clone(),
437 tape,
438 );
439 Ok(())
440 }
441
442 fn get_input_map(&self) -> BTreeMap<CalcInputName, FieldName> {
445 let mut map = BTreeMap::new();
446
447 for transitions in self.cfg.transitions.values() {
448 for criteria in transitions.values() {
449 for criterion in criteria {
450 let names = criterion.get_input_names();
451 for name in names {
452 map.insert(name.clone(), name);
453 }
454 }
455 }
456 }
457
458 map
459 }
460
461 fn update_input_map(&mut self, _field: &str, _source: &str) -> Result<(), String> {
463 Err(
464 "SequenceMachine input map is derived from sequence transition criterion dependencies"
465 .to_string(),
466 )
467 }
468
469 fn get_input_names(&self) -> Vec<CalcInputName> {
471 self.get_input_map().keys().cloned().collect()
472 }
473
474 fn get_output_names(&self) -> Vec<CalcOutputName> {
476 let mut output_names = vec!["sequence_time_s".to_owned()];
477 self.entry_sequence()
478 .data
479 .keys()
480 .cloned()
481 .for_each(|n| output_names.push(n));
482 output_names
483 }
484
485 fn get_save_outputs(&self) -> bool {
487 self.cfg.save_outputs
488 }
489
490 fn set_save_outputs(&mut self, save_outputs: bool) {
492 self.cfg.save_outputs = save_outputs;
493 }
494
495 fn get_config(&self) -> BTreeMap<String, f64> {
497 BTreeMap::<String, f64>::new()
498 }
499
500 #[allow(unused)]
502 fn set_config(&mut self, cfg: &BTreeMap<String, f64>) -> Result<(), String> {
503 Err("No settable config fields".to_string())
504 }
505}
506
507#[cfg(feature = "python")]
508#[pymethods]
509impl SequenceMachine {
510 #[new]
511 fn py_new(entry: String) -> Self {
512 let mut cfg = MachineCfg::default();
513 cfg.save_outputs = true;
514 cfg.entry = entry;
515
516 Self {
517 cfg,
518 sequences: BTreeMap::new(),
519 execution_state: ExecutionState::default(),
520 }
521 }
522
523 fn to_json(&self) -> PyResult<String> {
525 let payload: &dyn Calc = self;
526 serde_json::to_string(payload)
527 .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))
528 }
529
530 #[staticmethod]
532 fn from_json(s: &str) -> PyResult<Self> {
533 serde_json::from_str::<Self>(s)
534 .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))
535 }
536
537 #[staticmethod]
538 #[pyo3(name = "load_folder")]
539 fn py_load_folder(path: &str) -> PyResult<Self> {
540 let path = Path::new(path);
541 Self::load_folder(&path)
542 .map(|machine| *machine)
543 .map_err(|e| pyo3::exceptions::PyValueError::new_err(e))
544 }
545
546 #[pyo3(name = "save_folder")]
547 fn py_save_folder(&self, path: &str) -> PyResult<()> {
548 let path = Path::new(path);
549 Self::save_folder(self, &path).map_err(|e| pyo3::exceptions::PyValueError::new_err(e))
550 }
551
552 fn get_entry(&self) -> PyResult<String> {
553 Ok(self.cfg.entry.clone())
554 }
555
556 fn set_entry(&mut self, entry: String) -> PyResult<()> {
557 if !self.sequences.is_empty() && !self.sequences.contains_key(&entry) {
558 return Err(pyo3::exceptions::PyKeyError::new_err(format!(
559 "Unknown entry sequence: {entry}"
560 )));
561 }
562 self.cfg.entry = entry;
563 Ok(())
564 }
565
566 fn get_link_folder(&self) -> PyResult<Option<String>> {
567 Ok(self.cfg.link_folder.clone())
568 }
569
570 fn set_link_folder(&mut self, link_folder: Option<String>) -> PyResult<()> {
571 self.cfg.link_folder = link_folder;
572 Ok(())
573 }
574
575 fn get_timeout(&self, sequence: String) -> PyResult<Option<String>> {
576 let timeout = self.cfg.timeouts.get(&sequence).ok_or_else(|| {
577 pyo3::exceptions::PyKeyError::new_err(format!("Unknown sequence: {sequence}"))
578 })?;
579
580 match timeout {
581 Timeout::Loop => Ok(None),
582 Timeout::Transition(target) => Ok(Some(target.clone())),
583 }
584 }
585
586 fn set_timeout(&mut self, sequence: String, target: Option<String>) -> PyResult<()> {
587 if !self.sequences.contains_key(&sequence) {
588 return Err(pyo3::exceptions::PyKeyError::new_err(format!(
589 "Unknown sequence: {sequence}"
590 )));
591 }
592
593 let timeout = match target {
594 Some(target_sequence) => {
595 if !self.sequences.contains_key(&target_sequence) {
596 return Err(pyo3::exceptions::PyKeyError::new_err(format!(
597 "Unknown target sequence: {target_sequence}"
598 )));
599 }
600 Timeout::Transition(target_sequence)
601 }
602 None => Timeout::Loop,
603 };
604
605 self.cfg.timeouts.insert(sequence, timeout);
606 Ok(())
607 }
608
609 fn add_sequence(
610 &mut self,
611 name: String,
612 tables: BTreeMap<String, (Vec<f64>, Vec<f64>, String)>,
613 timeout: Option<String>,
614 ) -> PyResult<()> {
615 if tables.is_empty() {
617 return Err(pyo3::exceptions::PyValueError::new_err(
618 "Sequence data is empty".to_string(),
619 ));
620 }
621
622 if self.sequences.contains_key(&name) {
624 return Err(pyo3::exceptions::PyKeyError::new_err(format!(
625 "Sequence already exists: {name}"
626 )));
627 }
628
629 let mut data = BTreeMap::new();
631 for (name, (time_s, vals, method)) in tables {
632 let method = InterpMethod::try_parse(&method).map_err(|e| {
634 pyo3::exceptions::PyValueError::new_err(format!(
635 "Output `{name}` has invalid interp method: {e}"
636 ))
637 })?;
638
639 let lookup = SequenceLookup::new(method, time_s, vals).map_err(|e| {
641 pyo3::exceptions::PyValueError::new_err(format!(
642 "Output `{name}` has invalid lookup data: {e}"
643 ))
644 })?;
645 data.insert(name, lookup);
646 }
647
648 if let Some(existing) = self.sequences.values().next() {
649 let expected: HashSet<String> = existing.data.keys().cloned().collect();
651 let provided: HashSet<String> = data.keys().cloned().collect();
652
653 if expected != provided {
655 let mut missing: Vec<String> = expected.difference(&provided).cloned().collect();
656 let mut extra: Vec<String> = provided.difference(&expected).cloned().collect();
657
658 missing.sort();
660 extra.sort();
661
662 return Err(pyo3::exceptions::PyValueError::new_err(format!(
663 "Sequence outputs must match existing sequences. Missing: {missing:?}. Extra: {extra:?}"
664 )));
665 }
666 }
667
668 let sequence = Sequence { data };
670 sequence.validate().map_err(|e| {
671 pyo3::exceptions::PyValueError::new_err(format!("Invalid Sequence: {e:?}"))
672 })?;
673
674 self.sequences.insert(name.clone(), sequence);
676 let timeout = match timeout {
677 Some(target_state) => Timeout::Transition(target_state),
678 None => Timeout::Loop,
679 };
680 self.cfg.timeouts.insert(name.clone(), timeout);
681 self.cfg
682 .transitions
683 .entry(name)
684 .or_insert_with(BTreeMap::new);
685
686 Ok(())
687 }
688
689 fn add_constant_thresh_transition(
691 &mut self,
692 source_target: (String, String),
693 channel: String,
694 op: (&str, f64),
695 threshold: f64,
696 ) -> PyResult<()> {
697 let (source_sequence, target_sequence) = source_target;
699 let op = ThreshOp::try_parse(op).map_err(|e| pyo3::exceptions::PyValueError::new_err(e))?;
700
701 let transition = Transition::ConstantThresh(channel, op, threshold);
703 self.add_transition(source_sequence, target_sequence, transition)
704 .map_err(|e| pyo3::exceptions::PyValueError::new_err(e))
705 }
706
707 fn add_channel_thresh_transition(
709 &mut self,
710 source_target: (String, String),
711 channel: String,
712 op: (&str, f64),
713 threshold_channel: String,
714 ) -> PyResult<()> {
715 let (source_sequence, target_sequence) = source_target;
717 let op = ThreshOp::try_parse(op).map_err(|e| pyo3::exceptions::PyValueError::new_err(e))?;
718
719 let transition = Transition::ChannelThresh(channel, op, threshold_channel);
721 self.add_transition(source_sequence, target_sequence, transition)
722 .map_err(|e| pyo3::exceptions::PyValueError::new_err(e))
723 }
724
725 fn add_lookup_thresh_transition(
727 &mut self,
728 source_target: (String, String),
729 channel: String,
730 op: (&str, f64),
731 threshold_lookup: (Vec<f64>, Vec<f64>, &str),
732 ) -> PyResult<()> {
733 let (source_sequence, target_sequence) = source_target;
735 let (time_s, vals, method) = threshold_lookup;
736 let op = ThreshOp::try_parse(op).map_err(|e| pyo3::exceptions::PyValueError::new_err(e))?;
737 let method = InterpMethod::try_parse(&method).map_err(|e| {
738 pyo3::exceptions::PyValueError::new_err(format!(
739 "Lookup has invalid interp method: {e}"
740 ))
741 })?;
742
743 let lookup = SequenceLookup::new(method, time_s, vals).map_err(|e| {
745 pyo3::exceptions::PyValueError::new_err(format!("Lookup has invalid data: {e}"))
746 })?;
747
748 let transition = Transition::LookupThresh(channel, op, lookup);
750 self.add_transition(source_sequence, target_sequence, transition)
751 .map_err(|e| pyo3::exceptions::PyValueError::new_err(e))
752 }
753}
754
755#[cfg(test)]
756mod tests {
757 use super::SequenceMachine;
758 use std::path::PathBuf;
759
760 #[test]
763 fn roundtrip_sequence_machine_folder() {
764 let root = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
765 let src_dir = root.join("examples").join("machine");
766 let tmp_dir = std::env::temp_dir().join("deimos_sequence_roundtrip");
767
768 let _ = std::fs::remove_dir_all(&tmp_dir);
769 std::fs::create_dir_all(&tmp_dir).unwrap();
770
771 let original = *SequenceMachine::load_folder(&src_dir).unwrap();
772 let original_json = serde_json::to_string_pretty(&original).unwrap();
773
774 original.save_folder(&tmp_dir).unwrap();
775 let roundtrip = *SequenceMachine::load_folder(&tmp_dir).unwrap();
776 let roundtrip_json = serde_json::to_string_pretty(&roundtrip).unwrap();
777
778 assert_eq!(original_json, roundtrip_json);
779 }
780}