logic_mesh/tokio_impl/engine/
single_threaded.rs1use std::{cell::Cell, cell::RefCell, collections::BTreeMap, rc::Rc};
9
10use anyhow::{anyhow, Result};
11use libhaystack::val::Value;
12
13use tokio::{
14 sync::mpsc::{self, Receiver, Sender},
15 task::LocalSet,
16};
17use uuid::Uuid;
18
19use super::message_dispatch::dispatch_message;
20use super::{block_pointer::BlockPropsPointer, schedule_block_on_engine};
21use crate::{
22 base::{
23 block::{
24 connect::{connect_input, connect_output, disconnect_block},
25 Block, BlockProps, BlockState,
26 },
27 engine::{
28 messages::{ChangeSource, EngineMessage, WatchMessage},
29 Engine,
30 },
31 program::data::{BlockData, LinkData},
32 },
33 blocks::registry::get_block,
34 tokio_impl::input::{Reader, Writer},
35};
36
37pub(super) trait BlockPropsType = BlockProps<Writer = Writer, Reader = Reader>;
39
40pub type Messages = EngineMessage<Sender<WatchMessage>>;
42
43pub struct SingleThreadedEngine {
48 local: LocalSet,
50 block_props: BTreeMap<Uuid, Rc<Cell<BlockPropsPointer>>>,
52 sender: Sender<Messages>,
55 receiver: Receiver<Messages>,
57 pub(super) reply_senders: BTreeMap<uuid::Uuid, Sender<Messages>>,
61 pub(super) watchers: Rc<RefCell<BTreeMap<Uuid, Sender<WatchMessage>>>>,
63}
64
65impl Default for SingleThreadedEngine {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71impl Engine for SingleThreadedEngine {
72 type Writer = Writer;
73 type Reader = Reader;
74
75 type Channel = Sender<Messages>;
76
77 fn schedule<B: Block<Writer = Self::Writer, Reader = Self::Reader> + 'static>(
78 &mut self,
79 mut block: B,
80 ) {
81 let props = Rc::new(Cell::new(BlockPropsPointer::new(
82 &mut block as &mut dyn BlockPropsType,
83 )));
84 self.block_props.insert(*block.id(), props.clone());
85
86 let watchers = self.watchers.clone();
87
88 self.local.spawn_local(async move {
89 props.set(BlockPropsPointer::new(
92 &mut block as &mut dyn BlockPropsType,
93 ));
94
95 let mut last_pin_values = BTreeMap::<String, Value>::new();
97
98 loop {
99 block.execute().await;
100
101 change_of_value_check(&watchers, &block, &mut last_pin_values);
102
103 if block.state() == BlockState::Terminated {
104 break;
105 }
106 }
107 });
108 }
109
110 fn load_blocks_and_links(&mut self, blocks: &[BlockData], links: &[LinkData]) -> Result<()> {
111 blocks.iter().try_for_each(|block| -> Result<()> {
112 let id = Uuid::try_from(block.id.as_str()).ok();
113 if id.is_none() {
114 return Err(anyhow!("Invalid block id"));
115 }
116
117 let block = get_block(&block.name, Some(block.lib.clone()))
118 .ok_or_else(|| anyhow!("Block not found"))?;
119 schedule_block_on_engine(&block.desc, id, self)?;
120
121 Ok(())
122 })?;
123
124 links
125 .iter()
126 .try_for_each(|link| self.connect_blocks(link).map(|_| ()))
127 }
128
129 async fn run(&mut self) {
130 let mut is_paused = false;
131 loop {
132 let local_tasks = &self.local;
133 let mut engine_msg = None;
134
135 if !is_paused {
136 local_tasks
137 .run_until(async {
138 engine_msg = self.receiver.recv().await;
139 })
140 .await;
141 } else {
142 engine_msg = self.receiver.recv().await;
143 }
144
145 if let Some(message) = engine_msg {
146 if matches!(message, EngineMessage::Shutdown) {
147 break;
148 } else if matches!(message, EngineMessage::Reset) {
149 self.blocks_iter_mut().for_each(|block| {
150 block.set_state(BlockState::Terminated);
151 });
152
153 self.block_props.clear();
154 continue;
155 } else if matches!(message, EngineMessage::Pause) {
156 is_paused = true;
157 continue;
158 } else if matches!(message, EngineMessage::Resume) {
159 is_paused = false;
160 continue;
161 }
162
163 dispatch_message(self, message).await;
164 }
165 }
166 }
167
168 fn create_message_channel(
169 &mut self,
170 sender_id: uuid::Uuid,
171 sender_channel: Self::Channel,
172 ) -> Self::Channel {
173 self.reply_senders.insert(sender_id, sender_channel);
174
175 self.sender.clone()
176 }
177}
178
179impl SingleThreadedEngine {
180 pub fn new() -> Self {
182 let (sender, receiver) = mpsc::channel(32);
184
185 Self {
186 local: LocalSet::new(),
187 sender,
188 receiver,
189 block_props: BTreeMap::new(),
190 reply_senders: BTreeMap::new(),
191 watchers: Rc::default(),
192 }
193 }
194
195 pub fn blocks(&self) -> Vec<&dyn BlockPropsType> {
198 self.blocks_iter_mut().map(|prop| &*prop).collect()
199 }
200
201 pub fn blocks_mut(&self) -> Vec<&mut dyn BlockPropsType> {
204 self.blocks_iter_mut().collect()
205 }
206
207 pub(super) fn blocks_iter_mut(&self) -> impl Iterator<Item = &mut dyn BlockPropsType> {
208 self.block_props
209 .values()
210 .filter_map(|props| {
211 let props = props.get();
212 props.get()
213 })
214 .map(|prop| unsafe { &mut *prop })
215 }
216
217 pub(super) fn connect_blocks(&mut self, link_data: &LinkData) -> Result<LinkData> {
218 let (source_block_uuid, target_block_uuid) = (
219 Uuid::try_from(link_data.source_block_uuid.as_str())?,
220 Uuid::try_from(link_data.target_block_uuid.as_str())?,
221 );
222
223 let Some(source_block) = self.get_block_props_mut(&source_block_uuid) else {
224 return Err(anyhow!(
225 "Source block '{}' not found",
226 link_data.source_block_uuid
227 ));
228 };
229 let Some(target_block) = self.get_block_props_mut(&target_block_uuid) else {
230 return Err(anyhow!(
231 "Target block '{}' not found",
232 link_data.target_block_uuid
233 ));
234 };
235
236 let Some(target_input) = target_block.get_input_mut(&link_data.target_block_pin_name)
237 else {
238 return Err(anyhow!(
239 "Target input pin '{}' not found",
240 link_data.target_block_pin_name
241 ));
242 };
243
244 let link_id;
245 if let Some(source_input) = source_block.get_input_mut(&link_data.source_block_pin_name) {
246 link_id = connect_input(source_input, target_input).map_err(|err| anyhow!(err))?;
247
248 if let Some(val) = source_input.get_value() {
249 target_input
250 .writer()
251 .try_send(val.clone())
252 .map_err(|err| anyhow!(err))?;
253 }
254 reset_connected_inputs(target_block, &link_data.target_block_pin_name)?;
255 } else if let Some(source_output) =
256 source_block.get_output_mut(&link_data.source_block_pin_name)
257 {
258 link_id = connect_output(source_output, target_input).map_err(|err| anyhow!(err))?;
259
260 if source_output.value().has_value() {
262 target_input
264 .writer()
265 .try_send(source_output.value().clone())
266 .map_err(|err| anyhow!(err))?;
267 }
268 reset_connected_inputs(target_block, &link_data.target_block_pin_name)?;
269 } else {
270 return Err(anyhow!(
271 "Source pin '{}' not found",
272 link_data.source_block_pin_name
273 ));
274 }
275 Ok(LinkData {
276 id: Some(link_id.to_string()),
277 ..link_data.clone()
278 })
279 }
280
281 pub(super) fn save_blocks_and_links(&mut self) -> Result<(Vec<BlockData>, Vec<LinkData>)> {
282 let blocks = self
283 .blocks_iter_mut()
284 .map(|block| BlockData {
285 id: block.id().to_string(),
286 name: block.name().to_string(),
287 dis: block.desc().dis.to_string(),
288 lib: block.desc().library.clone(),
289 category: block.desc().category.clone(),
290 ver: block.desc().ver.clone(),
291 })
292 .collect();
293
294 let mut links: Vec<LinkData> = Vec::new();
295 for block in self.blocks_iter_mut() {
296 for (pin_name, pin_links) in block.links() {
297 for link in pin_links {
298 links.push(LinkData {
299 id: Some(link.id().to_string()),
300 source_block_pin_name: pin_name.to_string(),
301 source_block_uuid: block.id().to_string(),
302 target_block_pin_name: link.target_input().to_string(),
303 target_block_uuid: link.target_block_id().to_string(),
304 });
305 }
306 }
307 }
308
309 Ok((blocks, links))
310 }
311
312 pub(super) fn get_block_props_mut(
313 &self,
314 block_id: &Uuid,
315 ) -> Option<&mut (dyn BlockPropsType + 'static)> {
316 self.block_props.get(block_id).and_then(|ptr| {
317 let fat_ptr = (**ptr).get();
318 fat_ptr.get().map(|ptr| unsafe { &mut *ptr })
319 })
320 }
321
322 pub(super) fn add_block(
323 &mut self,
324 block_name: String,
325 block_id: Option<Uuid>,
326 lib: Option<String>,
327 ) -> Result<Uuid> {
328 let block =
329 get_block(block_name.as_str(), lib).ok_or_else(|| anyhow!("Block not found"))?;
330
331 schedule_block_on_engine(&block.desc, block_id, self)
332 }
333
334 pub(super) fn remove_block(&mut self, block_id: &Uuid) -> Result<Uuid> {
335 match self.get_block_props_mut(block_id) {
337 Some(block) => {
338 block.set_state(BlockState::Terminated);
339
340 disconnect_block(block, |id, name| {
341 self.decrement_refresh_block_input(id, name)
342 });
343 }
344 None => return Err(anyhow!("Block not found")),
345 };
346
347 self.blocks_iter_mut().for_each(|block| {
349 if block.id() == block_id {
350 return;
351 }
352
353 let mut outs = block.outputs_mut();
354 outs.iter_mut().for_each(|output| {
355 output.remove_target_block_links(block_id);
356 });
357
358 let mut ins = block.inputs_mut();
359 ins.iter_mut().for_each(|input| {
360 input.remove_target_block_links(block_id);
361 });
362 });
363
364 self.block_props.remove(block_id);
366
367 Ok(*block_id)
368 }
369
370 pub(crate) fn decrement_refresh_block_input(
373 &self,
374 block_id: &Uuid,
375 input_name: &str,
376 ) -> Option<usize> {
377 let target_block = self.get_block_props_mut(block_id);
378 target_block.and_then(|target_block| {
379 target_block.get_input_mut(input_name).map(|input| {
380 let cnt = input.decrement_conn();
381 let value = input.get_value().cloned();
382 input.writer().try_send(value.unwrap_or_default()).ok();
383
384 cnt
385 })
386 })
387 }
388}
389
390fn change_of_value_check<B: Block + 'static>(
393 notification_channels: &Rc<RefCell<BTreeMap<Uuid, Sender<WatchMessage>>>>,
394 block: &B,
395 last_pin_values: &mut BTreeMap<String, Value>,
396) {
397 if notification_channels.borrow().is_empty() {
398 if !last_pin_values.is_empty() {
399 last_pin_values.clear();
400 }
401 return;
402 }
403
404 let mut changes = BTreeMap::<String, ChangeSource>::new();
405
406 block.outputs().iter().for_each(|output| {
407 let pin = output.desc().name.to_string();
408 let val = output.value();
409 if last_pin_values.get(&pin) != Some(val) {
410 changes.insert(pin.clone(), ChangeSource::Output(pin.clone(), val.clone()));
411 last_pin_values.insert(pin, val.clone());
412 }
413 });
414
415 block.inputs().iter().for_each(|input| {
416 let val = input.get_value();
417 if let Some(val) = val {
418 let pin = input.name().to_string();
419 if last_pin_values.get(&pin) != Some(val) {
420 changes.insert(pin.clone(), ChangeSource::Input(pin.clone(), val.clone()));
421 last_pin_values.insert(pin, val.clone());
422 }
423 }
424 });
425
426 if !changes.is_empty() {
427 for sender in notification_channels.borrow().values() {
428 let _ = sender.try_send(WatchMessage {
429 block_id: *block.id(),
430 changes: changes.clone(),
431 });
432 }
433 }
434}
435
436fn reset_connected_inputs(
440 target_block: &mut dyn BlockPropsType,
441 input_to_ignore: &str,
442) -> Result<()> {
443 if let Some(a_connected_input) = target_block.inputs().iter().find_map(|input| {
446 if input.is_connected() && input.name() != input_to_ignore {
447 Some(input.name().to_string())
448 } else {
449 None
450 }
451 }) {
452 if let Some(target_input) = target_block.get_input_mut(&a_connected_input) {
453 if let Some(value) = target_input.get_value().cloned() {
454 target_input
455 .writer()
456 .try_send(value.clone())
457 .map_err(|err| anyhow!(err))?;
458 }
459 }
460 }
461 Ok(())
462}
463
464#[cfg(not(target_arch = "wasm32"))]
465#[cfg(test)]
466mod test {
467 use std::{thread, time::Duration};
468
469 use crate::base;
470 use crate::blocks::{math::Add, misc::SineWave};
471 use base::block::{BlockConnect, BlockProps};
472 use base::engine::messages::EngineMessage::{InspectBlockReq, InspectBlockRes, Shutdown};
473
474 use crate::tokio_impl::engine::single_threaded::SingleThreadedEngine;
475 use base::engine::Engine;
476 use tokio::{runtime::Runtime, sync::mpsc, time::sleep};
477 use uuid::Uuid;
478
479 #[tokio::test(flavor = "current_thread")]
480 async fn engine_test() {
481 use crate::base::block::connect::connect_output;
482
483 let mut add1 = Add::new();
484 let add_uuid = *add1.id();
485
486 let mut sine1 = SineWave::new();
487
488 sine1.amplitude.val = Some(3.into());
489 sine1.freq.val = Some(200.into());
490 connect_output(&mut sine1.out, add1.inputs_mut()[0]).expect("Connected");
491
492 let mut sine2 = SineWave::new();
493 sine2.amplitude.val = Some(7.into());
494 sine2.freq.val = Some(400.into());
495
496 sine2
497 .connect_output("out", add1.inputs_mut()[1])
498 .expect("Connected");
499
500 let mut eng = SingleThreadedEngine::new();
501
502 let (sender, mut receiver) = mpsc::channel(32);
503 let channel_id = Uuid::new_v4();
504 let engine_sender = eng.create_message_channel(channel_id, sender.clone());
505
506 thread::spawn(move || {
507 let rt = Runtime::new().expect("RT");
508
509 let handle = rt.spawn(async move {
510 loop {
511 sleep(Duration::from_millis(300)).await;
512
513 let _ = engine_sender
514 .send(InspectBlockReq(channel_id, add_uuid))
515 .await;
516
517 let res = receiver.recv().await;
518
519 if let Some(InspectBlockRes(Ok(data))) = res {
520 assert_eq!(data.id, add_uuid.to_string());
521 assert_eq!(data.name, "Add");
522 assert_eq!(data.inputs.len(), 16);
523 assert_eq!(data.outputs.len(), 1);
524 } else {
525 assert!(false, "Failed to find block: {:?}", res)
526 }
527
528 let _ = engine_sender.send(Shutdown).await;
529 break;
530 }
531 });
532
533 rt.block_on(async { handle.await })
534 });
535
536 eng.schedule(add1);
537 eng.schedule(sine1);
538 eng.schedule(sine2);
539
540 eng.run().await;
541 }
542}