cog_task/action/core/
branch.rs1use crate::action::{Action, ActionSignal, Props, StatefulAction, DEFAULT, INFINITE, VISUAL};
2use crate::comm::{QWriter, Signal, SignalId};
3use crate::resource::{IoManager, ResourceAddr, ResourceManager};
4use crate::server::{AsyncSignal, Config, State, SyncSignal};
5use crate::util::f64_as_i64;
6use eframe::egui;
7use eyre::{eyre, Context, Result};
8use itertools::Itertools;
9use serde::{Deserialize, Serialize};
10use serde_cbor::Value;
11use std::collections::BTreeSet;
12
13#[derive(Debug, Deserialize, Serialize)]
14pub struct Branch {
15 #[serde(default)]
16 default: usize,
17 children: Vec<Box<dyn Action>>,
18 in_control: SignalId,
19}
20
21enum Decision {
22 Temporary(usize),
23 Final(usize),
24}
25
26stateful!(Branch {
27 children: Vec<Box<dyn StatefulAction>>,
28 in_control: SignalId,
29 decision: Decision,
30});
31
32impl Action for Branch {
33 #[inline]
34 fn in_signals(&self) -> BTreeSet<SignalId> {
35 let mut signals = BTreeSet::new();
36 signals.insert(self.in_control);
37 for c in self.children.iter() {
38 signals.extend(c.in_signals());
39 }
40 signals
41 }
42
43 #[inline]
44 fn out_signals(&self) -> BTreeSet<SignalId> {
45 let mut signals = BTreeSet::new();
46 for c in self.children.iter() {
47 signals.extend(c.out_signals());
48 }
49 signals
50 }
51
52 #[inline]
53 fn resources(&self, config: &Config) -> Vec<ResourceAddr> {
54 self.children
55 .iter()
56 .flat_map(|c| c.resources(config))
57 .unique()
58 .collect()
59 }
60
61 fn init(self) -> Result<Box<dyn Action>> {
62 if self.children.is_empty() {
63 Err(eyre!("Branch should have at least one child."))
64 } else if self.default >= self.children.len() {
65 Err(eyre!(
66 "Branch default value should be less than the number of its children."
67 ))
68 } else {
69 Ok(Box::new(self))
70 }
71 }
72
73 fn stateful(
74 &self,
75 io: &IoManager,
76 res: &ResourceManager,
77 config: &Config,
78 sync_writer: &QWriter<SyncSignal>,
79 async_writer: &QWriter<AsyncSignal>,
80 ) -> Result<Box<dyn StatefulAction>> {
81 let mut children = vec![];
82 for c in self.children.iter() {
83 children.push(c.stateful(io, res, config, sync_writer, async_writer)?);
84 }
85
86 Ok(Box::new(StatefulBranch {
87 done: false,
88 children,
89 in_control: self.in_control,
90 decision: Decision::Temporary(self.default),
91 }))
92 }
93}
94
95impl StatefulAction for StatefulBranch {
96 impl_stateful!();
97
98 #[inline]
99 fn props(&self) -> Props {
100 if let Decision::Final(i) = self.decision {
101 self.children[i].props()
102 } else {
103 self.children
104 .iter()
105 .fold(DEFAULT, |mut state, c| {
106 let c = c.props();
107 if c.visual() {
108 state |= VISUAL;
109 }
110 if c.infinite() {
111 state |= INFINITE;
112 }
113 state
114 })
115 .into()
116 }
117 }
118
119 #[inline]
120 fn start(
121 &mut self,
122 sync_writer: &mut QWriter<SyncSignal>,
123 async_writer: &mut QWriter<AsyncSignal>,
124 state: &State,
125 ) -> Result<Signal> {
126 let decision = if let Decision::Temporary(i) = self.decision {
127 i
128 } else {
129 return Err(eyre!("Tried to restart branch."));
130 };
131
132 let decision = match state.get(&self.in_control) {
133 Some(Value::Integer(i)) => {
134 if *i < self.children.len() as i128 {
135 *i as usize
136 } else {
137 return Err(eyre!("Branch request is out of bounds."));
138 }
139 }
140 Some(Value::Float(x)) => {
141 let x = f64_as_i64(*x).wrap_err("Non-integer number supplied to branch.")?;
142 if (0..self.children.len() as i64).contains(&x) {
143 x as usize
144 } else {
145 return Err(eyre!("Branch request is out of bounds."));
146 }
147 }
148 None => decision,
149 _ => return Err(eyre!("Branch control is in invalid state.")),
150 };
151
152 self.decision = Decision::Final(decision);
153 self.children[decision].start(sync_writer, async_writer, state)
154 }
155
156 #[inline]
157 fn update(
158 &mut self,
159 signal: &ActionSignal,
160 sync_writer: &mut QWriter<SyncSignal>,
161 async_writer: &mut QWriter<AsyncSignal>,
162 state: &State,
163 ) -> Result<Signal> {
164 if let Decision::Final(i) = self.decision {
165 let news = self.children[i].update(signal, sync_writer, async_writer, state)?;
166 if self.children[i].is_over()? {
167 self.done = true;
168 }
169 Ok(news)
170 } else {
171 Ok(Signal::none())
172 }
173 }
174
175 fn show(
176 &mut self,
177 ui: &mut egui::Ui,
178 sync_writer: &mut QWriter<SyncSignal>,
179 async_writer: &mut QWriter<AsyncSignal>,
180 state: &State,
181 ) -> Result<()> {
182 if let Decision::Final(i) = self.decision {
183 self.children[i].show(ui, sync_writer, async_writer, state)
184 } else {
185 Ok(())
186 }
187 }
188
189 #[inline]
190 fn stop(
191 &mut self,
192 sync_writer: &mut QWriter<SyncSignal>,
193 async_writer: &mut QWriter<AsyncSignal>,
194 state: &State,
195 ) -> Result<Signal> {
196 if let Decision::Final(i) = self.decision {
197 self.children[i].stop(sync_writer, async_writer, state)
198 } else {
199 Ok(Signal::none())
200 }
201 }
202}