cog_task/action/core/
branch.rs

1use crate::action::{Action, ActionSignal, Props, StatefulAction, DEFAULT, INFINITE, VISUAL};
2use crate::comm::{QWriter, Signal, SignalId};
3use crate::resource::{IoManager, ResourceAddr, ResourceManager};
4use crate::server::{AsyncSignal, Config, State, SyncSignal};
5use crate::util::f64_as_i64;
6use eframe::egui;
7use eyre::{eyre, Context, Result};
8use itertools::Itertools;
9use serde::{Deserialize, Serialize};
10use serde_cbor::Value;
11use std::collections::BTreeSet;
12
13#[derive(Debug, Deserialize, Serialize)]
14pub struct Branch {
15    #[serde(default)]
16    default: usize,
17    children: Vec<Box<dyn Action>>,
18    in_control: SignalId,
19}
20
21enum Decision {
22    Temporary(usize),
23    Final(usize),
24}
25
26stateful!(Branch {
27    children: Vec<Box<dyn StatefulAction>>,
28    in_control: SignalId,
29    decision: Decision,
30});
31
32impl Action for Branch {
33    #[inline]
34    fn in_signals(&self) -> BTreeSet<SignalId> {
35        let mut signals = BTreeSet::new();
36        signals.insert(self.in_control);
37        for c in self.children.iter() {
38            signals.extend(c.in_signals());
39        }
40        signals
41    }
42
43    #[inline]
44    fn out_signals(&self) -> BTreeSet<SignalId> {
45        let mut signals = BTreeSet::new();
46        for c in self.children.iter() {
47            signals.extend(c.out_signals());
48        }
49        signals
50    }
51
52    #[inline]
53    fn resources(&self, config: &Config) -> Vec<ResourceAddr> {
54        self.children
55            .iter()
56            .flat_map(|c| c.resources(config))
57            .unique()
58            .collect()
59    }
60
61    fn init(self) -> Result<Box<dyn Action>> {
62        if self.children.is_empty() {
63            Err(eyre!("Branch should have at least one child."))
64        } else if self.default >= self.children.len() {
65            Err(eyre!(
66                "Branch default value should be less than the number of its children."
67            ))
68        } else {
69            Ok(Box::new(self))
70        }
71    }
72
73    fn stateful(
74        &self,
75        io: &IoManager,
76        res: &ResourceManager,
77        config: &Config,
78        sync_writer: &QWriter<SyncSignal>,
79        async_writer: &QWriter<AsyncSignal>,
80    ) -> Result<Box<dyn StatefulAction>> {
81        let mut children = vec![];
82        for c in self.children.iter() {
83            children.push(c.stateful(io, res, config, sync_writer, async_writer)?);
84        }
85
86        Ok(Box::new(StatefulBranch {
87            done: false,
88            children,
89            in_control: self.in_control,
90            decision: Decision::Temporary(self.default),
91        }))
92    }
93}
94
95impl StatefulAction for StatefulBranch {
96    impl_stateful!();
97
98    #[inline]
99    fn props(&self) -> Props {
100        if let Decision::Final(i) = self.decision {
101            self.children[i].props()
102        } else {
103            self.children
104                .iter()
105                .fold(DEFAULT, |mut state, c| {
106                    let c = c.props();
107                    if c.visual() {
108                        state |= VISUAL;
109                    }
110                    if c.infinite() {
111                        state |= INFINITE;
112                    }
113                    state
114                })
115                .into()
116        }
117    }
118
119    #[inline]
120    fn start(
121        &mut self,
122        sync_writer: &mut QWriter<SyncSignal>,
123        async_writer: &mut QWriter<AsyncSignal>,
124        state: &State,
125    ) -> Result<Signal> {
126        let decision = if let Decision::Temporary(i) = self.decision {
127            i
128        } else {
129            return Err(eyre!("Tried to restart branch."));
130        };
131
132        let decision = match state.get(&self.in_control) {
133            Some(Value::Integer(i)) => {
134                if *i < self.children.len() as i128 {
135                    *i as usize
136                } else {
137                    return Err(eyre!("Branch request is out of bounds."));
138                }
139            }
140            Some(Value::Float(x)) => {
141                let x = f64_as_i64(*x).wrap_err("Non-integer number supplied to branch.")?;
142                if (0..self.children.len() as i64).contains(&x) {
143                    x as usize
144                } else {
145                    return Err(eyre!("Branch request is out of bounds."));
146                }
147            }
148            None => decision,
149            _ => return Err(eyre!("Branch control is in invalid state.")),
150        };
151
152        self.decision = Decision::Final(decision);
153        self.children[decision].start(sync_writer, async_writer, state)
154    }
155
156    #[inline]
157    fn update(
158        &mut self,
159        signal: &ActionSignal,
160        sync_writer: &mut QWriter<SyncSignal>,
161        async_writer: &mut QWriter<AsyncSignal>,
162        state: &State,
163    ) -> Result<Signal> {
164        if let Decision::Final(i) = self.decision {
165            let news = self.children[i].update(signal, sync_writer, async_writer, state)?;
166            if self.children[i].is_over()? {
167                self.done = true;
168            }
169            Ok(news)
170        } else {
171            Ok(Signal::none())
172        }
173    }
174
175    fn show(
176        &mut self,
177        ui: &mut egui::Ui,
178        sync_writer: &mut QWriter<SyncSignal>,
179        async_writer: &mut QWriter<AsyncSignal>,
180        state: &State,
181    ) -> Result<()> {
182        if let Decision::Final(i) = self.decision {
183            self.children[i].show(ui, sync_writer, async_writer, state)
184        } else {
185            Ok(())
186        }
187    }
188
189    #[inline]
190    fn stop(
191        &mut self,
192        sync_writer: &mut QWriter<SyncSignal>,
193        async_writer: &mut QWriter<AsyncSignal>,
194        state: &State,
195    ) -> Result<Signal> {
196        if let Decision::Final(i) = self.decision {
197            self.children[i].stop(sync_writer, async_writer, state)
198        } else {
199            Ok(Signal::none())
200        }
201    }
202}