1use bevy_ecs::{
19 change_detection::Mut,
20 prelude::{Commands, Entity, EntityRef, Query, World},
21 query::QueryEntityError,
22 system::{SystemParam, SystemState},
23};
24
25use std::{
26 any::TypeId,
27 collections::HashMap,
28 ops::RangeBounds,
29 sync::{Arc, Mutex, OnceLock},
30};
31
32use thiserror::Error as ThisError;
33
34use crate::{GateState, InputSlot, NotifyBufferUpdate, OperationError};
35
36mod any_buffer;
37pub use any_buffer::*;
38
39mod buffer_access_lifecycle;
40pub use buffer_access_lifecycle::BufferKeyLifecycle;
41pub(crate) use buffer_access_lifecycle::*;
42
43mod buffer_key_builder;
44pub use buffer_key_builder::*;
45
46mod buffer_gate;
47pub use buffer_gate::*;
48
49mod buffer_map;
50pub use buffer_map::*;
51
52mod buffer_storage;
53pub(crate) use buffer_storage::*;
54
55mod buffering;
56pub use buffering::*;
57
58mod bufferable;
59pub use bufferable::*;
60
61mod manage_buffer;
62pub use manage_buffer::*;
63
64#[cfg(feature = "diagram")]
65mod json_buffer;
66#[cfg(feature = "diagram")]
67pub use json_buffer::*;
68
69mod fetch_from_buffer;
70pub use fetch_from_buffer::*;
71
72pub struct Buffer<T> {
76 pub(crate) location: BufferLocation,
77 pub(crate) _ignore: std::marker::PhantomData<fn(T)>,
78}
79
80impl<T: 'static + Send + Sync> Buffer<T> {
81 pub fn join_by_cloning(self) -> CloneFromBuffer<T>
87 where
88 T: Clone,
89 {
90 CloneFromBuffer::new(self.location)
91 }
92
93 pub fn input_slot(self) -> InputSlot<T> {
95 InputSlot::new(self.scope(), self.id())
96 }
97
98 pub fn id(&self) -> Entity {
100 self.location.source
101 }
102
103 pub fn scope(&self) -> Entity {
105 self.location.scope
106 }
107
108 pub fn location(&self) -> BufferLocation {
110 self.location
111 }
112}
113
114impl<T> Clone for Buffer<T> {
115 fn clone(&self) -> Self {
116 *self
117 }
118}
119
120impl<T> Copy for Buffer<T> {}
121
122#[derive(Clone, Copy, Debug)]
126pub struct BufferLocation {
127 pub scope: Entity,
129 pub source: Entity,
131}
132
133#[derive(Clone)]
134pub struct CloneFromBuffer<T: Clone + Send + Sync + 'static> {
135 location: BufferLocation,
136 _ignore: std::marker::PhantomData<fn(T)>,
137}
138
139impl<T: Clone + Send + Sync + 'static> Copy for CloneFromBuffer<T> {}
140
141impl<T: Clone + Send + Sync + 'static> CloneFromBuffer<T> {
142 pub fn input_slot(self) -> InputSlot<T> {
144 InputSlot::new(self.scope(), self.id())
145 }
146
147 pub fn id(&self) -> Entity {
149 self.location.source
150 }
151
152 pub fn scope(&self) -> Entity {
154 self.location.scope
155 }
156
157 pub fn location(&self) -> BufferLocation {
159 self.location
160 }
161
162 #[must_use]
167 pub fn join_by_pulling(self) -> Buffer<T> {
168 Buffer {
169 location: self.location,
170 _ignore: Default::default(),
171 }
172 }
173
174 fn new(location: BufferLocation) -> Self {
175 Self::register_clone_for_join();
176 Self {
177 location,
178 _ignore: Default::default(),
179 }
180 }
181
182 pub fn register_clone_for_join() {
186 static REGISTER_CLONE: OnceLock<Mutex<HashMap<TypeId, ()>>> = OnceLock::new();
187 let register_clone = REGISTER_CLONE.get_or_init(|| Mutex::default());
188
189 let mut register_mut = register_clone.lock().unwrap();
192 register_mut.entry(TypeId::of::<T>()).or_insert_with(|| {
193 let interface = AnyBuffer::interface_for::<T>();
194 interface.register_cloning(
195 clone_for_any_join::<T>,
196 &(clone_for_join::<T> as FetchFromBufferFn<T>),
197 );
198 interface.register_buffer_downcast(
199 TypeId::of::<CloneFromBuffer<T>>(),
200 Box::new(|buffer: AnyBuffer| {
201 Ok(Box::new(CloneFromBuffer::<T>::new(buffer.location)))
202 }),
203 );
204 });
205 }
206}
207
208fn clone_for_any_join<T: 'static + Send + Sync + Clone>(
209 entity_ref: &EntityRef,
210 session: Entity,
211) -> Result<AnyMessageBox, OperationError> {
212 entity_ref
228 .clone_from_buffer::<T>(session)
229 .map(to_any_message)
230}
231
232impl<T: Clone + Send + Sync> From<CloneFromBuffer<T>> for Buffer<T> {
233 fn from(value: CloneFromBuffer<T>) -> Self {
234 Buffer {
235 location: value.location,
236 _ignore: Default::default(),
237 }
238 }
239}
240
241#[cfg_attr(
243 feature = "diagram",
244 derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema),
245 serde(rename_all = "snake_case")
246)]
247#[derive(Default, Clone, Copy, Debug)]
248pub struct BufferSettings {
249 retention: RetentionPolicy,
250}
251
252impl BufferSettings {
253 pub fn new(retention: RetentionPolicy) -> Self {
255 Self { retention }
256 }
257
258 pub fn keep_last(n: usize) -> Self {
260 Self::new(RetentionPolicy::KeepLast(n))
261 }
262
263 pub fn keep_first(n: usize) -> Self {
265 Self::new(RetentionPolicy::KeepFirst(n))
266 }
267
268 pub fn keep_all() -> Self {
270 Self::new(RetentionPolicy::KeepAll)
271 }
272
273 pub fn retention(&self) -> RetentionPolicy {
275 self.retention
276 }
277
278 pub fn retention_mut(&mut self) -> &mut RetentionPolicy {
280 &mut self.retention
281 }
282}
283
284#[cfg_attr(
291 feature = "diagram",
292 derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema),
293 serde(rename_all = "snake_case")
294)]
295#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)]
296pub enum RetentionPolicy {
297 KeepLast(usize),
300 KeepFirst(usize),
303 KeepAll,
305}
306
307impl Default for RetentionPolicy {
308 fn default() -> Self {
309 Self::KeepLast(1)
310 }
311}
312
313pub struct BufferKey<T> {
321 tag: BufferKeyTag,
322 _ignore: std::marker::PhantomData<fn(T)>,
323}
324
325impl<T> Clone for BufferKey<T> {
326 fn clone(&self) -> Self {
327 Self {
328 tag: self.tag.clone(),
329 _ignore: Default::default(),
330 }
331 }
332}
333
334impl<T> BufferKey<T> {
335 pub fn buffer(&self) -> Entity {
337 self.tag.buffer
338 }
339
340 pub fn session(&self) -> Entity {
342 self.tag.session
343 }
344
345 pub fn tag(&self) -> &BufferKeyTag {
346 &self.tag
347 }
348}
349
350impl<T: 'static + Send + Sync> BufferKeyLifecycle for BufferKey<T> {
351 type TargetBuffer = Buffer<T>;
352
353 fn create_key(buffer: &Self::TargetBuffer, builder: &BufferKeyBuilder) -> Self {
354 BufferKey {
355 tag: builder.make_tag(buffer.id()),
356 _ignore: Default::default(),
357 }
358 }
359
360 fn is_in_use(&self) -> bool {
361 self.tag.is_in_use()
362 }
363
364 fn deep_clone(&self) -> Self {
365 Self {
366 tag: self.tag.deep_clone(),
367 _ignore: Default::default(),
368 }
369 }
370}
371
372impl<T> std::fmt::Debug for BufferKey<T> {
373 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374 f.debug_struct("BufferKey")
375 .field("message_type_name", &std::any::type_name::<T>())
376 .field("tag", &self.tag)
377 .finish()
378 }
379}
380
381#[derive(Clone)]
384pub struct BufferKeyTag {
385 pub buffer: Entity,
386 pub session: Entity,
387 pub accessor: Entity,
388 pub lifecycle: Option<Arc<BufferAccessLifecycle>>,
389}
390
391impl BufferKeyTag {
392 pub fn is_in_use(&self) -> bool {
393 self.lifecycle.as_ref().is_some_and(|l| l.is_in_use())
394 }
395
396 pub fn deep_clone(&self) -> Self {
397 let mut deep = self.clone();
398 deep.lifecycle = self
399 .lifecycle
400 .as_ref()
401 .map(|l| Arc::new(l.as_ref().clone()));
402 deep
403 }
404}
405
406impl std::fmt::Debug for BufferKeyTag {
407 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
408 f.debug_struct("BufferKeyTag")
409 .field("buffer", &self.buffer)
410 .field("session", &self.session)
411 .field("accessor", &self.accessor)
412 .field("in_use", &self.is_in_use())
413 .finish()
414 }
415}
416
417#[derive(SystemParam)]
422pub struct BufferAccess<'w, 's, T>
423where
424 T: 'static + Send + Sync,
425{
426 query: Query<'w, 's, &'static BufferStorage<T>>,
427}
428
429impl<'w, 's, T: 'static + Send + Sync> BufferAccess<'w, 's, T> {
430 pub fn get<'a>(&'a self, key: &BufferKey<T>) -> Result<BufferView<'a, T>, QueryEntityError> {
431 let session = key.session();
432 self.query
433 .get(key.buffer())
434 .map(|storage| BufferView { storage, session })
435 }
436
437 pub fn get_newest<'a>(&'a self, key: &BufferKey<T>) -> Option<&'a T> {
438 self.get(key).ok().map(|view| view.newest()).flatten()
439 }
440}
441
442#[derive(SystemParam)]
447pub struct BufferAccessMut<'w, 's, T>
448where
449 T: 'static + Send + Sync,
450{
451 query: Query<'w, 's, &'static mut BufferStorage<T>>,
452 commands: Commands<'w, 's>,
453}
454
455impl<'w, 's, T> BufferAccessMut<'w, 's, T>
456where
457 T: 'static + Send + Sync,
458{
459 pub fn get<'a>(&'a self, key: &BufferKey<T>) -> Result<BufferView<'a, T>, QueryEntityError> {
460 let session = key.session();
461 self.query
462 .get(key.buffer())
463 .map(|storage| BufferView { storage, session })
464 }
465
466 pub fn get_newest<'a>(&'a self, key: &BufferKey<T>) -> Option<&'a T> {
467 self.get(key).ok().map(|view| view.newest()).flatten()
468 }
469
470 pub fn get_mut<'a>(
471 &'a mut self,
472 key: &BufferKey<T>,
473 ) -> Result<BufferMut<'w, 's, 'a, T>, QueryEntityError> {
474 let buffer = key.buffer();
475 let session = key.session();
476 let accessor = key.tag.accessor;
477 self.query
478 .get_mut(key.buffer())
479 .map(|storage| BufferMut::new(storage, buffer, session, accessor, &mut self.commands))
480 }
481}
482
483pub trait BufferWorldAccess {
485 fn buffer_view<T>(&self, key: &BufferKey<T>) -> Result<BufferView<'_, T>, BufferError>
490 where
491 T: 'static + Send + Sync;
492
493 fn buffer_gate_view(
495 &self,
496 key: impl Into<AnyBufferKey>,
497 ) -> Result<BufferGateView<'_>, BufferError>;
498
499 fn buffer_mut<T, U>(
504 &mut self,
505 key: &BufferKey<T>,
506 f: impl FnOnce(BufferMut<T>) -> U,
507 ) -> Result<U, BufferError>
508 where
509 T: 'static + Send + Sync;
510
511 fn buffer_gate_mut<U>(
516 &mut self,
517 key: impl Into<AnyBufferKey>,
518 f: impl FnOnce(BufferGateMut) -> U,
519 ) -> Result<U, BufferError>;
520}
521
522impl BufferWorldAccess for World {
523 fn buffer_view<T>(&self, key: &BufferKey<T>) -> Result<BufferView<'_, T>, BufferError>
524 where
525 T: 'static + Send + Sync,
526 {
527 let buffer_ref = self
528 .get_entity(key.tag.buffer)
529 .map_err(|_| BufferError::BufferMissing)?;
530 let storage = buffer_ref
531 .get::<BufferStorage<T>>()
532 .ok_or(BufferError::BufferMissing)?;
533 Ok(BufferView {
534 storage,
535 session: key.tag.session,
536 })
537 }
538
539 fn buffer_gate_view(
540 &self,
541 key: impl Into<AnyBufferKey>,
542 ) -> Result<BufferGateView<'_>, BufferError> {
543 let key: AnyBufferKey = key.into();
544 let buffer_ref = self
545 .get_entity(key.tag.buffer)
546 .or(Err(BufferError::BufferMissing))?;
547 let gate = buffer_ref
548 .get::<GateState>()
549 .ok_or(BufferError::BufferMissing)?;
550 Ok(BufferGateView {
551 gate,
552 session: key.tag.session,
553 })
554 }
555
556 fn buffer_mut<T, U>(
557 &mut self,
558 key: &BufferKey<T>,
559 f: impl FnOnce(BufferMut<T>) -> U,
560 ) -> Result<U, BufferError>
561 where
562 T: 'static + Send + Sync,
563 {
564 let mut state = SystemState::<BufferAccessMut<T>>::new(self);
565 let mut buffer_access_mut = state.get_mut(self);
566 let buffer_mut = buffer_access_mut
567 .get_mut(key)
568 .map_err(|_| BufferError::BufferMissing)?;
569 Ok(f(buffer_mut))
570 }
571
572 fn buffer_gate_mut<U>(
573 &mut self,
574 key: impl Into<AnyBufferKey>,
575 f: impl FnOnce(BufferGateMut) -> U,
576 ) -> Result<U, BufferError> {
577 let mut state = SystemState::<BufferGateAccessMut>::new(self);
578 let mut buffer_gate_access_mut = state.get_mut(self);
579 let buffer_mut = buffer_gate_access_mut
580 .get_mut(key)
581 .map_err(|_| BufferError::BufferMissing)?;
582 Ok(f(buffer_mut))
583 }
584}
585
586pub struct BufferView<'a, T>
588where
589 T: 'static + Send + Sync,
590{
591 storage: &'a BufferStorage<T>,
592 session: Entity,
593}
594
595impl<'a, T> BufferView<'a, T>
596where
597 T: 'static + Send + Sync,
598{
599 pub fn iter(&self) -> IterBufferView<'a, T> {
601 self.storage.iter(self.session)
602 }
603
604 pub fn oldest(&self) -> Option<&'a T> {
606 self.storage.oldest(self.session)
607 }
608
609 pub fn newest(&self) -> Option<&'a T> {
611 self.storage.newest(self.session)
612 }
613
614 pub fn get(&self, index: usize) -> Option<&'a T> {
617 self.storage.get(self.session, index)
618 }
619
620 pub fn len(&self) -> usize {
622 self.storage.count(self.session)
623 }
624
625 pub fn is_empty(&self) -> bool {
627 self.len() == 0
628 }
629}
630
631pub struct BufferMut<'w, 's, 'a, T>
633where
634 T: 'static + Send + Sync,
635{
636 storage: Mut<'a, BufferStorage<T>>,
637 buffer: Entity,
638 session: Entity,
639 accessor: Option<Entity>,
640 commands: &'a mut Commands<'w, 's>,
641 modified: bool,
642}
643
644impl<'w, 's, 'a, T> BufferMut<'w, 's, 'a, T>
645where
646 T: 'static + Send + Sync,
647{
648 pub fn allow_closed_loops(mut self) -> Self {
663 self.accessor = None;
664 self
665 }
666
667 pub fn iter(&self) -> IterBufferView<'_, T> {
669 self.storage.iter(self.session)
670 }
671
672 pub fn oldest(&self) -> Option<&T> {
674 self.storage.oldest(self.session)
675 }
676
677 pub fn newest(&self) -> Option<&T> {
679 self.storage.newest(self.session)
680 }
681
682 pub fn get(&self, index: usize) -> Option<&T> {
685 self.storage.get(self.session, index)
686 }
687
688 pub fn len(&self) -> usize {
690 self.storage.count(self.session)
691 }
692
693 pub fn is_empty(&self) -> bool {
695 self.len() == 0
696 }
697
698 pub fn iter_mut(&mut self) -> IterBufferMut<'_, T> {
700 self.modified = true;
701 self.storage.iter_mut(self.session)
702 }
703
704 pub fn oldest_mut(&mut self) -> Option<&mut T> {
706 self.modified = true;
707 self.storage.oldest_mut(self.session)
708 }
709
710 pub fn newest_mut(&mut self) -> Option<&mut T> {
712 self.modified = true;
713 self.storage.newest_mut(self.session)
714 }
715
716 pub fn newest_mut_or_default(&mut self) -> Option<&mut T>
722 where
723 T: Default,
724 {
725 self.newest_mut_or_else(|| T::default())
726 }
727
728 pub fn newest_mut_or_else(&mut self, f: impl FnOnce() -> T) -> Option<&mut T> {
734 self.modified = true;
735 self.storage.newest_mut_or_else(self.session, f)
736 }
737
738 pub fn get_mut(&mut self, index: usize) -> Option<&mut T> {
741 self.modified = true;
742 self.storage.get_mut(self.session, index)
743 }
744
745 pub fn drain<R>(&mut self, range: R) -> DrainBuffer<'_, T>
747 where
748 R: RangeBounds<usize>,
749 {
750 self.modified = true;
751 self.storage.drain(self.session, range)
752 }
753
754 pub fn pull(&mut self) -> Option<T> {
756 self.modified = true;
757 self.storage.pull(self.session)
758 }
759
760 pub fn pull_newest(&mut self) -> Option<T> {
763 self.modified = true;
764 self.storage.pull_newest(self.session)
765 }
766
767 pub fn push(&mut self, value: T) -> Option<T> {
770 self.modified = true;
771 self.storage.push(self.session, value)
772 }
773
774 pub fn push_as_oldest(&mut self, value: T) -> Option<T> {
778 self.modified = true;
779 self.storage.push_as_oldest(self.session, value)
780 }
781
782 pub fn pulse(&mut self) {
786 self.modified = true;
787 }
788
789 fn new(
790 storage: Mut<'a, BufferStorage<T>>,
791 buffer: Entity,
792 session: Entity,
793 accessor: Entity,
794 commands: &'a mut Commands<'w, 's>,
795 ) -> Self {
796 Self {
797 storage,
798 buffer,
799 session,
800 accessor: Some(accessor),
801 commands,
802 modified: false,
803 }
804 }
805}
806
807impl<'w, 's, 'a, T> Drop for BufferMut<'w, 's, 'a, T>
808where
809 T: 'static + Send + Sync,
810{
811 fn drop(&mut self) {
812 if self.modified {
813 self.commands.queue(NotifyBufferUpdate::new(
814 self.buffer,
815 self.session,
816 self.accessor,
817 ));
818 }
819 }
820}
821
822#[derive(ThisError, Debug, Clone)]
823pub enum BufferError {
824 #[error("The key was unable to identify a buffer")]
825 BufferMissing,
826}
827
828#[cfg(test)]
829mod tests {
830 use crate::{prelude::*, testing::*, AddBufferToMap, Gate};
831 use std::future::Future;
832
833 #[test]
834 fn test_buffer_key_access() {
835 let mut context = TestingContext::minimal_plugins();
836
837 let add_buffers_by_pull_cb = add_buffers_by_pull.into_blocking_callback();
838 let add_from_buffer_cb = add_from_buffer.into_blocking_callback();
839 let multiply_buffers_by_copy_cb = multiply_buffers_by_copy.into_blocking_callback();
840
841 let workflow = context.spawn_io_workflow(|scope: Scope<(f64, f64), f64>, builder| {
842 scope
843 .input
844 .chain(builder)
845 .unzip()
846 .listen(builder)
847 .then(multiply_buffers_by_copy_cb)
848 .connect(scope.terminate);
849 });
850
851 let mut promise =
852 context.command(|commands| commands.request((2.0, 3.0), workflow).take_response());
853
854 context.run_with_conditions(&mut promise, Duration::from_secs(2));
855 assert!(promise.take().available().is_some_and(|value| value == 6.0));
856 assert!(context.no_unhandled_errors());
857
858 let workflow = context.spawn_io_workflow(|scope: Scope<(f64, f64), f64>, builder| {
859 scope
860 .input
861 .chain(builder)
862 .unzip()
863 .listen(builder)
864 .then(add_buffers_by_pull_cb)
865 .dispose_on_none()
866 .connect(scope.terminate);
867 });
868
869 let mut promise =
870 context.command(|commands| commands.request((4.0, 5.0), workflow).take_response());
871
872 context.run_with_conditions(&mut promise, Duration::from_secs(2));
873 assert!(promise.take().available().is_some_and(|value| value == 9.0));
874 assert!(context.no_unhandled_errors());
875
876 let workflow =
877 context.spawn_io_workflow(|scope: Scope<(f64, f64), Result<f64, f64>>, builder| {
878 let (branch_to_adder, branch_to_buffer) = scope.input.chain(builder).unzip();
879 let buffer = builder.create_buffer::<f64>(BufferSettings::keep_first(10));
880 builder.connect(branch_to_buffer, buffer.input_slot());
881
882 let adder_node = branch_to_adder
883 .chain(builder)
884 .with_access(buffer)
885 .then_node(add_from_buffer_cb.clone());
886
887 adder_node.output.chain(builder).fork_result(
888 |chain| {
893 chain
894 .with_access(buffer)
895 .then(add_from_buffer_cb.clone())
896 .connect(scope.terminate)
897 },
898 |chain| chain.with_access(buffer).connect(adder_node.input),
901 );
902 });
903
904 let mut promise =
905 context.command(|commands| commands.request((2.0, 3.0), workflow).take_response());
906
907 context.run_with_conditions(&mut promise, Duration::from_secs(2));
908 assert!(promise
909 .take()
910 .available()
911 .is_some_and(|value| value.is_err_and(|n| n == 5.0)));
912 assert!(context.no_unhandled_errors());
913
914 let workflow = context.spawn_io_workflow(|scope, builder| {
916 let (branch_to_adder, branch_to_buffer) = scope.input.chain(builder).unzip();
917 let buffer = builder.create_buffer::<f64>(BufferSettings::keep_first(10));
918 builder.connect(branch_to_buffer, buffer.input_slot());
919
920 let access = builder.create_buffer_access(buffer);
921 builder.connect(branch_to_adder, access.input);
922 access
923 .output
924 .chain(builder)
925 .then(add_from_buffer_cb.clone())
926 .fork_result(
927 |ok| {
928 let (output, builder) = ok.unpack();
929 let second_access = builder.create_buffer_access(buffer);
930 builder.connect(output, second_access.input);
931 second_access
932 .output
933 .chain(builder)
934 .then(add_from_buffer_cb.clone())
935 .connect(scope.terminate);
936 },
937 |err| err.connect(access.input),
938 );
939 });
940
941 let mut promise =
942 context.command(|commands| commands.request((2.0, 3.0), workflow).take_response());
943
944 context.run_with_conditions(&mut promise, Duration::from_secs(2));
945 assert!(promise
946 .take()
947 .available()
948 .is_some_and(|value| value.is_err_and(|n| n == 5.0)));
949 assert!(context.no_unhandled_errors());
950 }
951
952 fn add_from_buffer(
953 In((lhs, key)): In<(f64, BufferKey<f64>)>,
954 mut access: BufferAccessMut<f64>,
955 ) -> Result<f64, f64> {
956 let rhs = access.get_mut(&key).map_err(|_| lhs)?.pull().ok_or(lhs)?;
957 Ok(lhs + rhs)
958 }
959
960 fn multiply_buffers_by_copy(
961 In((key_a, key_b)): In<(BufferKey<f64>, BufferKey<f64>)>,
962 access: BufferAccess<f64>,
963 ) -> f64 {
964 *access.get(&key_a).unwrap().oldest().unwrap()
965 * *access.get(&key_b).unwrap().oldest().unwrap()
966 }
967
968 fn add_buffers_by_pull(
969 In((key_a, key_b)): In<(BufferKey<f64>, BufferKey<f64>)>,
970 mut access: BufferAccessMut<f64>,
971 ) -> Option<f64> {
972 if access.get(&key_a).unwrap().is_empty() {
973 return None;
974 }
975
976 if access.get(&key_b).unwrap().is_empty() {
977 return None;
978 }
979
980 let rhs = access.get_mut(&key_a).unwrap().pull().unwrap();
981 let lhs = access.get_mut(&key_b).unwrap().pull().unwrap();
982 Some(rhs + lhs)
983 }
984
985 #[test]
986 fn test_buffer_key_lifecycle() {
987 let mut context = TestingContext::minimal_plugins();
988
989 let workflow = context.spawn_io_workflow(|scope, builder| {
992 let buffer = builder.create_buffer::<Register>(BufferSettings::keep_all());
993
994 builder
996 .listen(buffer)
997 .then(pull_register_from_buffer.into_blocking_callback())
998 .dispose_on_none()
999 .connect(scope.terminate);
1000
1001 let decrement_register_cb = decrement_register.into_blocking_callback();
1002 let async_decrement_register_cb = async_decrement_register.as_callback();
1003 scope
1004 .input
1005 .chain(builder)
1006 .with_access(buffer)
1007 .then(decrement_register_cb.clone())
1008 .with_access(buffer)
1009 .then(async_decrement_register_cb.clone())
1010 .dispose_on_none()
1011 .with_access(buffer)
1012 .then(decrement_register_cb.clone())
1013 .with_access(buffer)
1014 .then(async_decrement_register_cb)
1015 .unused();
1016 });
1017
1018 run_register_test(workflow, 0, true, &mut context);
1019 run_register_test(workflow, 1, true, &mut context);
1020 run_register_test(workflow, 2, true, &mut context);
1021 run_register_test(workflow, 3, true, &mut context);
1022 run_register_test(workflow, 4, false, &mut context);
1023 run_register_test(workflow, 5, false, &mut context);
1024 run_register_test(workflow, 6, false, &mut context);
1025
1026 let workflow = context.spawn_io_workflow(|scope, builder| {
1030 let buffer = builder.create_buffer::<Register>(BufferSettings::keep_all());
1031
1032 builder
1034 .listen(buffer)
1035 .then(pull_register_from_buffer.into_blocking_callback())
1036 .dispose_on_none()
1037 .connect(scope.terminate);
1038
1039 let decrement_register_and_pass_keys_cb =
1040 decrement_register_and_pass_keys.into_blocking_callback();
1041 let async_decrement_register_and_pass_keys_cb =
1042 async_decrement_register_and_pass_keys.as_callback();
1043 let (loose_end, dead_end): (_, Output<Option<Register>>) = scope
1044 .input
1045 .chain(builder)
1046 .with_access(buffer)
1047 .then(decrement_register_and_pass_keys_cb.clone())
1048 .then(async_decrement_register_and_pass_keys_cb.clone())
1049 .dispose_on_none()
1050 .map_block(|v| (v, None))
1051 .unzip();
1052
1053 dead_end.chain(builder).dispose_on_none().unused();
1055
1056 loose_end
1057 .chain(builder)
1058 .then(async_decrement_register_and_pass_keys_cb)
1059 .dispose_on_none()
1060 .then(decrement_register_and_pass_keys_cb)
1061 .unused();
1062 });
1063
1064 run_register_test(workflow, 0, true, &mut context);
1065 run_register_test(workflow, 1, true, &mut context);
1066 run_register_test(workflow, 2, true, &mut context);
1067 run_register_test(workflow, 3, true, &mut context);
1068 run_register_test(workflow, 4, false, &mut context);
1069 run_register_test(workflow, 5, false, &mut context);
1070 run_register_test(workflow, 6, false, &mut context);
1071 }
1072
1073 fn run_register_test(
1074 workflow: Service<Register, Register>,
1075 initial_value: u64,
1076 expect_success: bool,
1077 context: &mut TestingContext,
1078 ) {
1079 let mut promise = context.command(|commands| {
1080 commands
1081 .request(Register::new(initial_value), workflow)
1082 .take_response()
1083 });
1084
1085 context.run_while_pending(&mut promise);
1086 if expect_success {
1087 assert!(promise
1088 .take()
1089 .available()
1090 .is_some_and(|r| r.finished_with(initial_value)));
1091 } else {
1092 assert!(promise.take().is_cancelled());
1093 }
1094 assert!(context.no_unhandled_errors());
1095 }
1096
1097 #[derive(Clone, Copy, Debug)]
1102 struct Register {
1103 in_slot: u64,
1104 out_slot: u64,
1105 }
1106
1107 impl Register {
1108 fn new(start_from: u64) -> Self {
1109 Self {
1110 in_slot: start_from,
1111 out_slot: 0,
1112 }
1113 }
1114
1115 fn finished_with(&self, out_slot: u64) -> bool {
1116 self.in_slot == 0 && self.out_slot == out_slot
1117 }
1118 }
1119
1120 fn pull_register_from_buffer(
1121 In(key): In<BufferKey<Register>>,
1122 mut access: BufferAccessMut<Register>,
1123 ) -> Option<Register> {
1124 access.get_mut(&key).ok()?.pull()
1125 }
1126
1127 fn decrement_register(
1128 In((mut register, key)): In<(Register, BufferKey<Register>)>,
1129 mut access: BufferAccessMut<Register>,
1130 ) -> Register {
1131 if register.in_slot == 0 {
1132 access.get_mut(&key).unwrap().push(register);
1133 return register;
1134 }
1135
1136 register.in_slot -= 1;
1137 register.out_slot += 1;
1138 register
1139 }
1140
1141 fn decrement_register_and_pass_keys(
1142 In((mut register, key)): In<(Register, BufferKey<Register>)>,
1143 mut access: BufferAccessMut<Register>,
1144 ) -> (Register, BufferKey<Register>) {
1145 if register.in_slot == 0 {
1146 access.get_mut(&key).unwrap().push(register);
1147 return (register, key);
1148 }
1149
1150 register.in_slot -= 1;
1151 register.out_slot += 1;
1152 (register, key)
1153 }
1154
1155 fn async_decrement_register(
1156 In(input): In<AsyncCallback<(Register, BufferKey<Register>)>>,
1157 ) -> impl Future<Output = Option<Register>> {
1158 async move {
1159 input
1160 .channel
1161 .query(input.request, decrement_register.into_blocking_callback())
1162 .await
1163 .available()
1164 }
1165 }
1166
1167 fn async_decrement_register_and_pass_keys(
1168 In(input): In<AsyncCallback<(Register, BufferKey<Register>)>>,
1169 ) -> impl Future<Output = Option<(Register, BufferKey<Register>)>> {
1170 async move {
1171 input
1172 .channel
1173 .query(
1174 input.request,
1175 decrement_register_and_pass_keys.into_blocking_callback(),
1176 )
1177 .await
1178 .available()
1179 }
1180 }
1181
1182 #[test]
1183 fn test_buffer_key_gate_control() {
1184 let mut context = TestingContext::minimal_plugins();
1185
1186 let workflow = context.spawn_io_workflow(|scope, builder| {
1187 let service = builder.commands().spawn_service(gate_access_test_open_loop);
1188
1189 let buffer = builder.create_buffer(BufferSettings::keep_all());
1190 builder.connect(scope.input, buffer.input_slot());
1191 builder
1192 .listen(buffer)
1193 .then_gate_close(buffer)
1194 .then(service)
1195 .fork_unzip((
1196 |chain: Chain<_>| chain.dispose_on_none().connect(buffer.input_slot()),
1197 |chain: Chain<_>| chain.dispose_on_none().connect(scope.terminate),
1198 ));
1199 });
1200
1201 let mut promise = context.command(|commands| commands.request(0, workflow).take_response());
1202
1203 context.run_with_conditions(&mut promise, Duration::from_secs(2));
1204 assert!(promise.take().available().is_some_and(|v| v == 5));
1205 assert!(context.no_unhandled_errors());
1206 }
1207
1208 fn gate_access_test_open_loop(
1211 In(BlockingService { request: key, .. }): BlockingServiceInput<BufferKey<u64>>,
1212 mut access: BufferAccessMut<u64>,
1213 mut gate_access: BufferGateAccessMut,
1214 ) -> (Option<u64>, Option<u64>) {
1215 let mut buffer = access.get_mut(&key).unwrap();
1218 let value = buffer.pull().unwrap();
1219
1220 let mut gate = gate_access.get_mut(key).unwrap();
1223 assert_eq!(gate.get(), Gate::Closed);
1224 gate.open_gate();
1228
1229 if value >= 5 {
1230 (None, Some(value))
1231 } else {
1232 (Some(value + 1), None)
1233 }
1234 }
1235
1236 #[test]
1237 fn test_closed_loop_key_access() {
1238 let mut context = TestingContext::minimal_plugins();
1239
1240 let delay = context.spawn_delay(Duration::from_secs_f32(0.1));
1241
1242 let workflow = context.spawn_io_workflow(|scope, builder| {
1243 let service = builder
1244 .commands()
1245 .spawn_service(gate_access_test_closed_loop);
1246
1247 let buffer = builder.create_buffer(BufferSettings::keep_all());
1248 builder.connect(scope.input, buffer.input_slot());
1249 builder.listen(buffer).then(service).fork_unzip((
1250 |chain: Chain<_>| {
1251 chain
1252 .dispose_on_none()
1253 .then(delay)
1254 .connect(buffer.input_slot())
1255 },
1256 |chain: Chain<_>| chain.dispose_on_none().connect(scope.terminate),
1257 ));
1258 });
1259
1260 let mut promise = context.command(|commands| commands.request(3, workflow).take_response());
1261
1262 context.run_with_conditions(&mut promise, Duration::from_secs(2));
1263 assert!(promise.take().available().is_some_and(|v| v == 0));
1264 assert!(context.no_unhandled_errors());
1265 }
1266
1267 fn gate_access_test_closed_loop(
1269 In(BlockingService { request: key, .. }): BlockingServiceInput<BufferKey<u64>>,
1270 mut access: BufferAccessMut<u64>,
1271 ) -> (Option<u64>, Option<u64>) {
1272 let mut buffer = access.get_mut(&key).unwrap().allow_closed_loops();
1273 if let Some(value) = buffer.pull() {
1274 (Some(value + 1), None)
1275 } else {
1276 (None, Some(0))
1277 }
1278 }
1279
1280 #[test]
1281 fn test_any_buffer_join_by_clone() {
1282 let mut context = TestingContext::minimal_plugins();
1283
1284 let workflow = context.spawn_io_workflow(|scope, builder| {
1285 let message_buffer = builder.create_buffer(Default::default()).join_by_cloning();
1286 let count_buffer = builder.create_buffer(Default::default());
1287 let (message, count) = builder.chain(scope.input).unzip();
1288 builder.connect(message, message_buffer.input_slot());
1289 builder.connect(count, count_buffer.input_slot());
1290
1291 let any_message_buffer = message_buffer.as_any_buffer();
1294 let any_count_buffer = count_buffer.as_any_buffer();
1295
1296 let mut buffer_map = BufferMap::default();
1297 buffer_map.insert_buffer("message", any_message_buffer);
1298 buffer_map.insert_buffer("count", any_count_buffer);
1299
1300 builder
1301 .try_join::<JoinByCloneTest>(&buffer_map)
1302 .unwrap()
1303 .map_block(|joined| {
1304 if joined.count < 10 {
1305 Err(joined.count + 1)
1307 } else {
1308 Ok(joined)
1309 }
1310 })
1311 .fork_result(
1312 |ok| ok.connect(scope.terminate),
1313 |err| err.connect(count_buffer.input_slot()),
1314 );
1315 });
1316
1317 let mut promise = context.command(|commands| {
1318 commands
1319 .request((String::from("hello"), 0), workflow)
1320 .take_response()
1321 });
1322
1323 context.run_with_conditions(&mut promise, Duration::from_secs(2));
1324 let r = promise.take().available().unwrap();
1325 assert_eq!(r.count, 10);
1326 assert_eq!(r.message, "hello");
1327 }
1328
1329 #[derive(Joined)]
1330 struct JoinByCloneTest {
1331 count: i64,
1332 message: String,
1333 }
1334}