logic_mesh/tokio_impl/engine/
single_threaded.rs

1// Copyright (c) 2022-2023, Radu Racariu.
2
3//!
4//! Single threaded engine implementation
5//!
6//! Spawn a local task for each block to be executed on the current thread.
7
8use std::{cell::Cell, cell::RefCell, collections::BTreeMap, rc::Rc};
9
10use anyhow::{anyhow, Result};
11use libhaystack::val::Value;
12
13use tokio::{
14    sync::mpsc::{self, Receiver, Sender},
15    task::LocalSet,
16};
17use uuid::Uuid;
18
19use super::message_dispatch::dispatch_message;
20use super::{block_pointer::BlockPropsPointer, schedule_block_on_engine};
21use crate::{
22    base::{
23        block::{
24            connect::{connect_input, connect_output, disconnect_block},
25            Block, BlockProps, BlockState,
26        },
27        engine::{
28            messages::{ChangeSource, EngineMessage, WatchMessage},
29            Engine,
30        },
31        program::data::{BlockData, LinkData},
32    },
33    blocks::registry::get_block,
34    tokio_impl::input::{Reader, Writer},
35};
36
37// The concrete trait for the block properties
38pub(super) trait BlockPropsType = BlockProps<Writer = Writer, Reader = Reader>;
39
40/// The concrete type for the engine messages
41pub type Messages = EngineMessage<Sender<WatchMessage>>;
42
43/// Creates single threaded execution environment for Blocks to be run on.
44///
45/// Each block would be executed inside a local task in the engine's local context.
46///
47pub struct SingleThreadedEngine {
48    /// Use to schedule task on the current thread
49    local: LocalSet,
50    /// Blocks registered with this engine, indexed by block id
51    block_props: BTreeMap<Uuid, Rc<Cell<BlockPropsPointer>>>,
52    /// Messaging channel used by external processes to control
53    /// and inspect this engines execution
54    sender: Sender<Messages>,
55    // Multi-producer single-consumer channel for receiving messages
56    receiver: Receiver<Messages>,
57    /// Senders used to reply to issued commands
58    /// Each sender would be associated to an external process
59    /// issuing commands to the engine.
60    pub(super) reply_senders: BTreeMap<uuid::Uuid, Sender<Messages>>,
61    /// Watchers for changes in block pins
62    pub(super) watchers: Rc<RefCell<BTreeMap<Uuid, Sender<WatchMessage>>>>,
63}
64
65impl Default for SingleThreadedEngine {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71impl Engine for SingleThreadedEngine {
72    type Writer = Writer;
73    type Reader = Reader;
74
75    type Channel = Sender<Messages>;
76
77    fn schedule<B: Block<Writer = Self::Writer, Reader = Self::Reader> + 'static>(
78        &mut self,
79        mut block: B,
80    ) {
81        let props = Rc::new(Cell::new(BlockPropsPointer::new(
82            &mut block as &mut dyn BlockPropsType,
83        )));
84        self.block_props.insert(*block.id(), props.clone());
85
86        let watchers = self.watchers.clone();
87
88        self.local.spawn_local(async move {
89            // Must do here also so we get the correct address
90            // of the moved block instance
91            props.set(BlockPropsPointer::new(
92                &mut block as &mut dyn BlockPropsType,
93            ));
94
95            // Tacks changes to block pins
96            let mut last_pin_values = BTreeMap::<String, Value>::new();
97
98            loop {
99                block.execute().await;
100
101                change_of_value_check(&watchers, &block, &mut last_pin_values);
102
103                if block.state() == BlockState::Terminated {
104                    break;
105                }
106            }
107        });
108    }
109
110    fn load_blocks_and_links(&mut self, blocks: &[BlockData], links: &[LinkData]) -> Result<()> {
111        blocks.iter().try_for_each(|block| -> Result<()> {
112            let id = Uuid::try_from(block.id.as_str()).ok();
113            if id.is_none() {
114                return Err(anyhow!("Invalid block id"));
115            }
116
117            let block = get_block(&block.name, Some(block.lib.clone()))
118                .ok_or_else(|| anyhow!("Block not found"))?;
119            schedule_block_on_engine(&block.desc, id, self)?;
120
121            Ok(())
122        })?;
123
124        links
125            .iter()
126            .try_for_each(|link| self.connect_blocks(link).map(|_| ()))
127    }
128
129    async fn run(&mut self) {
130        let mut is_paused = false;
131        loop {
132            let local_tasks = &self.local;
133            let mut engine_msg = None;
134
135            if !is_paused {
136                local_tasks
137                    .run_until(async {
138                        engine_msg = self.receiver.recv().await;
139                    })
140                    .await;
141            } else {
142                engine_msg = self.receiver.recv().await;
143            }
144
145            if let Some(message) = engine_msg {
146                if matches!(message, EngineMessage::Shutdown) {
147                    break;
148                } else if matches!(message, EngineMessage::Reset) {
149                    self.blocks_iter_mut().for_each(|block| {
150                        block.set_state(BlockState::Terminated);
151                    });
152
153                    self.block_props.clear();
154                    continue;
155                } else if matches!(message, EngineMessage::Pause) {
156                    is_paused = true;
157                    continue;
158                } else if matches!(message, EngineMessage::Resume) {
159                    is_paused = false;
160                    continue;
161                }
162
163                dispatch_message(self, message).await;
164            }
165        }
166    }
167
168    fn create_message_channel(
169        &mut self,
170        sender_id: uuid::Uuid,
171        sender_channel: Self::Channel,
172    ) -> Self::Channel {
173        self.reply_senders.insert(sender_id, sender_channel);
174
175        self.sender.clone()
176    }
177}
178
179impl SingleThreadedEngine {
180    /// Construct
181    pub fn new() -> Self {
182        // Create a multi-producer single-consumer channel with a buffer of 32 messages
183        let (sender, receiver) = mpsc::channel(32);
184
185        Self {
186            local: LocalSet::new(),
187            sender,
188            receiver,
189            block_props: BTreeMap::new(),
190            reply_senders: BTreeMap::new(),
191            watchers: Rc::default(),
192        }
193    }
194
195    /// Get a list of all the blocks that are currently
196    /// scheduled on this engine.
197    pub fn blocks(&self) -> Vec<&dyn BlockPropsType> {
198        self.blocks_iter_mut().map(|prop| &*prop).collect()
199    }
200
201    /// Get a list of all the blocks that are currently
202    /// scheduled on this engine.
203    pub fn blocks_mut(&self) -> Vec<&mut dyn BlockPropsType> {
204        self.blocks_iter_mut().collect()
205    }
206
207    pub(super) fn blocks_iter_mut(&self) -> impl Iterator<Item = &mut dyn BlockPropsType> {
208        self.block_props
209            .values()
210            .filter_map(|props| {
211                let props = props.get();
212                props.get()
213            })
214            .map(|prop| unsafe { &mut *prop })
215    }
216
217    pub(super) fn connect_blocks(&mut self, link_data: &LinkData) -> Result<LinkData> {
218        let (source_block_uuid, target_block_uuid) = (
219            Uuid::try_from(link_data.source_block_uuid.as_str())?,
220            Uuid::try_from(link_data.target_block_uuid.as_str())?,
221        );
222
223        let Some(source_block) = self.get_block_props_mut(&source_block_uuid) else {
224            return Err(anyhow!(
225                "Source block '{}' not found",
226                link_data.source_block_uuid
227            ));
228        };
229        let Some(target_block) = self.get_block_props_mut(&target_block_uuid) else {
230            return Err(anyhow!(
231                "Target block '{}' not found",
232                link_data.target_block_uuid
233            ));
234        };
235
236        let Some(target_input) = target_block.get_input_mut(&link_data.target_block_pin_name)
237        else {
238            return Err(anyhow!(
239                "Target input pin '{}' not found",
240                link_data.target_block_pin_name
241            ));
242        };
243
244        let link_id;
245        if let Some(source_input) = source_block.get_input_mut(&link_data.source_block_pin_name) {
246            link_id = connect_input(source_input, target_input).map_err(|err| anyhow!(err))?;
247
248            if let Some(val) = source_input.get_value() {
249                target_input
250                    .writer()
251                    .try_send(val.clone())
252                    .map_err(|err| anyhow!(err))?;
253            }
254            reset_connected_inputs(target_block, &link_data.target_block_pin_name)?;
255        } else if let Some(source_output) =
256            source_block.get_output_mut(&link_data.source_block_pin_name)
257        {
258            link_id = connect_output(source_output, target_input).map_err(|err| anyhow!(err))?;
259
260            // After connection, send the current value of the source output to the target input
261            if source_output.value().has_value() {
262                // Send the current value of the source output to the target input
263                target_input
264                    .writer()
265                    .try_send(source_output.value().clone())
266                    .map_err(|err| anyhow!(err))?;
267            }
268            reset_connected_inputs(target_block, &link_data.target_block_pin_name)?;
269        } else {
270            return Err(anyhow!(
271                "Source pin '{}' not found",
272                link_data.source_block_pin_name
273            ));
274        }
275        Ok(LinkData {
276            id: Some(link_id.to_string()),
277            ..link_data.clone()
278        })
279    }
280
281    pub(super) fn save_blocks_and_links(&mut self) -> Result<(Vec<BlockData>, Vec<LinkData>)> {
282        let blocks = self
283            .blocks_iter_mut()
284            .map(|block| BlockData {
285                id: block.id().to_string(),
286                name: block.name().to_string(),
287                dis: block.desc().dis.to_string(),
288                lib: block.desc().library.clone(),
289                category: block.desc().category.clone(),
290                ver: block.desc().ver.clone(),
291            })
292            .collect();
293
294        let mut links: Vec<LinkData> = Vec::new();
295        for block in self.blocks_iter_mut() {
296            for (pin_name, pin_links) in block.links() {
297                for link in pin_links {
298                    links.push(LinkData {
299                        id: Some(link.id().to_string()),
300                        source_block_pin_name: pin_name.to_string(),
301                        source_block_uuid: block.id().to_string(),
302                        target_block_pin_name: link.target_input().to_string(),
303                        target_block_uuid: link.target_block_id().to_string(),
304                    });
305                }
306            }
307        }
308
309        Ok((blocks, links))
310    }
311
312    pub(super) fn get_block_props_mut(
313        &self,
314        block_id: &Uuid,
315    ) -> Option<&mut (dyn BlockPropsType + 'static)> {
316        self.block_props.get(block_id).and_then(|ptr| {
317            let fat_ptr = (**ptr).get();
318            fat_ptr.get().map(|ptr| unsafe { &mut *ptr })
319        })
320    }
321
322    pub(super) fn add_block(
323        &mut self,
324        block_name: String,
325        block_id: Option<Uuid>,
326        lib: Option<String>,
327    ) -> Result<Uuid> {
328        let block =
329            get_block(block_name.as_str(), lib).ok_or_else(|| anyhow!("Block not found"))?;
330
331        schedule_block_on_engine(&block.desc, block_id, self)
332    }
333
334    pub(super) fn remove_block(&mut self, block_id: &Uuid) -> Result<Uuid> {
335        // Terminate the block
336        match self.get_block_props_mut(block_id) {
337            Some(block) => {
338                block.set_state(BlockState::Terminated);
339
340                disconnect_block(block, |id, name| {
341                    self.decrement_refresh_block_input(id, name)
342                });
343            }
344            None => return Err(anyhow!("Block not found")),
345        };
346
347        // Remove the block from any links
348        self.blocks_iter_mut().for_each(|block| {
349            if block.id() == block_id {
350                return;
351            }
352
353            let mut outs = block.outputs_mut();
354            outs.iter_mut().for_each(|output| {
355                output.remove_target_block_links(block_id);
356            });
357
358            let mut ins = block.inputs_mut();
359            ins.iter_mut().for_each(|input| {
360                input.remove_target_block_links(block_id);
361            });
362        });
363
364        // Remove the block from the block props
365        self.block_props.remove(block_id);
366
367        Ok(*block_id)
368    }
369
370    /// Decrements the connection count of the target block input.
371    /// Sends the current value of the input to itself (refresh current value.)
372    pub(crate) fn decrement_refresh_block_input(
373        &self,
374        block_id: &Uuid,
375        input_name: &str,
376    ) -> Option<usize> {
377        let target_block = self.get_block_props_mut(block_id);
378        target_block.and_then(|target_block| {
379            target_block.get_input_mut(input_name).map(|input| {
380                let cnt = input.decrement_conn();
381                let value = input.get_value().cloned();
382                input.writer().try_send(value.unwrap_or_default()).ok();
383
384                cnt
385            })
386        })
387    }
388}
389
390/// Implements the logic for checking if the watched block pins
391/// have changed, and if so, dispatches a message to the watch sender.
392fn change_of_value_check<B: Block + 'static>(
393    notification_channels: &Rc<RefCell<BTreeMap<Uuid, Sender<WatchMessage>>>>,
394    block: &B,
395    last_pin_values: &mut BTreeMap<String, Value>,
396) {
397    if notification_channels.borrow().is_empty() {
398        if !last_pin_values.is_empty() {
399            last_pin_values.clear();
400        }
401        return;
402    }
403
404    let mut changes = BTreeMap::<String, ChangeSource>::new();
405
406    block.outputs().iter().for_each(|output| {
407        let pin = output.desc().name.to_string();
408        let val = output.value();
409        if last_pin_values.get(&pin) != Some(val) {
410            changes.insert(pin.clone(), ChangeSource::Output(pin.clone(), val.clone()));
411            last_pin_values.insert(pin, val.clone());
412        }
413    });
414
415    block.inputs().iter().for_each(|input| {
416        let val = input.get_value();
417        if let Some(val) = val {
418            let pin = input.name().to_string();
419            if last_pin_values.get(&pin) != Some(val) {
420                changes.insert(pin.clone(), ChangeSource::Input(pin.clone(), val.clone()));
421                last_pin_values.insert(pin, val.clone());
422            }
423        }
424    });
425
426    if !changes.is_empty() {
427        for sender in notification_channels.borrow().values() {
428            let _ = sender.try_send(WatchMessage {
429                block_id: *block.id(),
430                changes: changes.clone(),
431            });
432        }
433    }
434}
435
436/// Implements the logic for resetting the target block inputs when a new link is created.
437/// This is needed because the target block would be monitoring the current set of inputs
438/// added before the link was created, and would not be aware of the newly created link.
439fn reset_connected_inputs(
440    target_block: &mut dyn BlockPropsType,
441    input_to_ignore: &str,
442) -> Result<()> {
443    // If the target block has other connected inputs, send the current value of one of
444    // them to itself (refresh current value.) in order to trigger the block to execute.
445    if let Some(a_connected_input) = target_block.inputs().iter().find_map(|input| {
446        if input.is_connected() && input.name() != input_to_ignore {
447            Some(input.name().to_string())
448        } else {
449            None
450        }
451    }) {
452        if let Some(target_input) = target_block.get_input_mut(&a_connected_input) {
453            if let Some(value) = target_input.get_value().cloned() {
454                target_input
455                    .writer()
456                    .try_send(value.clone())
457                    .map_err(|err| anyhow!(err))?;
458            }
459        }
460    }
461    Ok(())
462}
463
464#[cfg(not(target_arch = "wasm32"))]
465#[cfg(test)]
466mod test {
467    use std::{thread, time::Duration};
468
469    use crate::base;
470    use crate::blocks::{math::Add, misc::SineWave};
471    use base::block::{BlockConnect, BlockProps};
472    use base::engine::messages::EngineMessage::{InspectBlockReq, InspectBlockRes, Shutdown};
473
474    use crate::tokio_impl::engine::single_threaded::SingleThreadedEngine;
475    use base::engine::Engine;
476    use tokio::{runtime::Runtime, sync::mpsc, time::sleep};
477    use uuid::Uuid;
478
479    #[tokio::test(flavor = "current_thread")]
480    async fn engine_test() {
481        use crate::base::block::connect::connect_output;
482
483        let mut add1 = Add::new();
484        let add_uuid = *add1.id();
485
486        let mut sine1 = SineWave::new();
487
488        sine1.amplitude.val = Some(3.into());
489        sine1.freq.val = Some(200.into());
490        connect_output(&mut sine1.out, add1.inputs_mut()[0]).expect("Connected");
491
492        let mut sine2 = SineWave::new();
493        sine2.amplitude.val = Some(7.into());
494        sine2.freq.val = Some(400.into());
495
496        sine2
497            .connect_output("out", add1.inputs_mut()[1])
498            .expect("Connected");
499
500        let mut eng = SingleThreadedEngine::new();
501
502        let (sender, mut receiver) = mpsc::channel(32);
503        let channel_id = Uuid::new_v4();
504        let engine_sender = eng.create_message_channel(channel_id, sender.clone());
505
506        thread::spawn(move || {
507            let rt = Runtime::new().expect("RT");
508
509            let handle = rt.spawn(async move {
510                loop {
511                    sleep(Duration::from_millis(300)).await;
512
513                    let _ = engine_sender
514                        .send(InspectBlockReq(channel_id, add_uuid))
515                        .await;
516
517                    let res = receiver.recv().await;
518
519                    if let Some(InspectBlockRes(Ok(data))) = res {
520                        assert_eq!(data.id, add_uuid.to_string());
521                        assert_eq!(data.name, "Add");
522                        assert_eq!(data.inputs.len(), 16);
523                        assert_eq!(data.outputs.len(), 1);
524                    } else {
525                        assert!(false, "Failed to find block: {:?}", res)
526                    }
527
528                    let _ = engine_sender.send(Shutdown).await;
529                    break;
530                }
531            });
532
533            rt.block_on(async { handle.await })
534        });
535
536        eng.schedule(add1);
537        eng.schedule(sine1);
538        eng.schedule(sine2);
539
540        eng.run().await;
541    }
542}