1use bevy_ecs::{
19 change_detection::Mut,
20 prelude::{Commands, Entity, EntityRef, Query, World},
21 query::QueryEntityError,
22 system::{SystemParam, SystemState},
23};
24
25use std::{
26 any::TypeId,
27 collections::HashMap,
28 ops::RangeBounds,
29 sync::{Arc, Mutex, OnceLock},
30};
31
32use thiserror::Error as ThisError;
33
34use crate::{GateState, InputSlot, NotifyBufferUpdate, OperationError};
35
36mod any_buffer;
37pub use any_buffer::*;
38
39mod buffer_access_lifecycle;
40pub use buffer_access_lifecycle::BufferKeyLifecycle;
41pub(crate) use buffer_access_lifecycle::*;
42
43mod buffer_key_builder;
44pub use buffer_key_builder::*;
45
46mod buffer_gate;
47pub use buffer_gate::*;
48
49mod buffer_map;
50pub use buffer_map::*;
51
52mod buffer_storage;
53pub(crate) use buffer_storage::*;
54
55mod buffering;
56pub use buffering::*;
57
58mod bufferable;
59pub use bufferable::*;
60
61mod manage_buffer;
62pub use manage_buffer::*;
63
64#[cfg(feature = "diagram")]
65mod json_buffer;
66#[cfg(feature = "diagram")]
67pub use json_buffer::*;
68
69mod fetch_from_buffer;
70pub use fetch_from_buffer::*;
71
72pub struct Buffer<T> {
76 pub(crate) location: BufferLocation,
77 pub(crate) _ignore: std::marker::PhantomData<fn(T)>,
78}
79
80impl<T: 'static + Send + Sync> Buffer<T> {
81 pub fn join_by_cloning(self) -> CloneFromBuffer<T>
87 where
88 T: Clone,
89 {
90 CloneFromBuffer::new(self.location)
91 }
92
93 pub fn input_slot(self) -> InputSlot<T> {
95 InputSlot::new(self.scope(), self.id())
96 }
97
98 pub fn id(&self) -> Entity {
100 self.location.source
101 }
102
103 pub fn scope(&self) -> Entity {
105 self.location.scope
106 }
107
108 pub fn location(&self) -> BufferLocation {
110 self.location
111 }
112}
113
114impl<T> Clone for Buffer<T> {
115 fn clone(&self) -> Self {
116 *self
117 }
118}
119
120impl<T> Copy for Buffer<T> {}
121
122#[derive(Clone, Copy, Debug)]
126pub struct BufferLocation {
127 pub scope: Entity,
129 pub source: Entity,
131}
132
133#[derive(Clone)]
134pub struct CloneFromBuffer<T: Clone + Send + Sync + 'static> {
135 location: BufferLocation,
136 _ignore: std::marker::PhantomData<fn(T)>,
137}
138
139impl<T: Clone + Send + Sync + 'static> Copy for CloneFromBuffer<T> {}
140
141impl<T: Clone + Send + Sync + 'static> CloneFromBuffer<T> {
142 pub fn input_slot(self) -> InputSlot<T> {
144 InputSlot::new(self.scope(), self.id())
145 }
146
147 pub fn id(&self) -> Entity {
149 self.location.source
150 }
151
152 pub fn scope(&self) -> Entity {
154 self.location.scope
155 }
156
157 pub fn location(&self) -> BufferLocation {
159 self.location
160 }
161
162 #[must_use]
167 pub fn join_by_pulling(self) -> Buffer<T> {
168 Buffer {
169 location: self.location,
170 _ignore: Default::default(),
171 }
172 }
173
174 fn new(location: BufferLocation) -> Self {
175 Self::register_clone_for_join();
176 Self {
177 location,
178 _ignore: Default::default(),
179 }
180 }
181
182 pub fn register_clone_for_join() {
186 static REGISTER_CLONE: OnceLock<Mutex<HashMap<TypeId, ()>>> = OnceLock::new();
187 let register_clone = REGISTER_CLONE.get_or_init(|| Mutex::default());
188
189 let mut register_mut = register_clone.lock().unwrap();
192 register_mut.entry(TypeId::of::<T>()).or_insert_with(|| {
193 let interface = AnyBuffer::interface_for::<T>();
194 interface.register_cloning(
195 clone_for_any_join::<T>,
196 &(clone_for_join::<T> as FetchFromBufferFn<T>),
197 );
198 interface.register_buffer_downcast(
199 TypeId::of::<CloneFromBuffer<T>>(),
200 Box::new(|buffer: AnyBuffer| {
201 Ok(Box::new(CloneFromBuffer::<T>::new(buffer.location)))
202 }),
203 );
204 });
205 }
206}
207
208fn clone_for_any_join<T: 'static + Send + Sync + Clone>(
209 entity_ref: &EntityRef,
210 session: Entity,
211) -> Result<AnyMessageBox, OperationError> {
212 entity_ref
228 .clone_from_buffer::<T>(session)
229 .map(to_any_message)
230}
231
232impl<T: Clone + Send + Sync> From<CloneFromBuffer<T>> for Buffer<T> {
233 fn from(value: CloneFromBuffer<T>) -> Self {
234 Buffer {
235 location: value.location,
236 _ignore: Default::default(),
237 }
238 }
239}
240
241#[cfg_attr(
243 feature = "diagram",
244 derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema),
245 serde(rename_all = "snake_case")
246)]
247#[derive(Default, Clone, Copy, Debug)]
248pub struct BufferSettings {
249 retention: RetentionPolicy,
250}
251
252impl BufferSettings {
253 pub fn new(retention: RetentionPolicy) -> Self {
255 Self { retention }
256 }
257
258 pub fn keep_last(n: usize) -> Self {
260 Self::new(RetentionPolicy::KeepLast(n))
261 }
262
263 pub fn keep_first(n: usize) -> Self {
265 Self::new(RetentionPolicy::KeepFirst(n))
266 }
267
268 pub fn keep_all() -> Self {
270 Self::new(RetentionPolicy::KeepAll)
271 }
272
273 pub fn retention(&self) -> RetentionPolicy {
275 self.retention
276 }
277
278 pub fn retention_mut(&mut self) -> &mut RetentionPolicy {
280 &mut self.retention
281 }
282}
283
284#[cfg_attr(
291 feature = "diagram",
292 derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema),
293 serde(rename_all = "snake_case")
294)]
295#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)]
296pub enum RetentionPolicy {
297 KeepLast(usize),
300 KeepFirst(usize),
303 KeepAll,
305}
306
307impl Default for RetentionPolicy {
308 fn default() -> Self {
309 Self::KeepLast(1)
310 }
311}
312
313pub struct BufferKey<T> {
321 tag: BufferKeyTag,
322 _ignore: std::marker::PhantomData<fn(T)>,
323}
324
325impl<T> Clone for BufferKey<T> {
326 fn clone(&self) -> Self {
327 Self {
328 tag: self.tag.clone(),
329 _ignore: Default::default(),
330 }
331 }
332}
333
334impl<T> BufferKey<T> {
335 pub fn buffer(&self) -> Entity {
337 self.tag.buffer
338 }
339
340 pub fn session(&self) -> Entity {
342 self.tag.session
343 }
344
345 pub fn tag(&self) -> &BufferKeyTag {
346 &self.tag
347 }
348}
349
350impl<T: 'static + Send + Sync> BufferKeyLifecycle for BufferKey<T> {
351 type TargetBuffer = Buffer<T>;
352
353 fn create_key(buffer: &Self::TargetBuffer, builder: &BufferKeyBuilder) -> Self {
354 BufferKey {
355 tag: builder.make_tag(buffer.id()),
356 _ignore: Default::default(),
357 }
358 }
359
360 fn is_in_use(&self) -> bool {
361 self.tag.is_in_use()
362 }
363
364 fn deep_clone(&self) -> Self {
365 Self {
366 tag: self.tag.deep_clone(),
367 _ignore: Default::default(),
368 }
369 }
370}
371
372impl<T> std::fmt::Debug for BufferKey<T> {
373 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374 f.debug_struct("BufferKey")
375 .field("message_type_name", &std::any::type_name::<T>())
376 .field("tag", &self.tag)
377 .finish()
378 }
379}
380
381#[derive(Clone)]
384pub struct BufferKeyTag {
385 pub buffer: Entity,
386 pub session: Entity,
387 pub accessor: Entity,
388 pub lifecycle: Option<Arc<BufferAccessLifecycle>>,
389}
390
391impl BufferKeyTag {
392 pub fn is_in_use(&self) -> bool {
393 self.lifecycle.as_ref().is_some_and(|l| l.is_in_use())
394 }
395
396 pub fn deep_clone(&self) -> Self {
397 let mut deep = self.clone();
398 deep.lifecycle = self
399 .lifecycle
400 .as_ref()
401 .map(|l| Arc::new(l.as_ref().clone()));
402 deep
403 }
404}
405
406impl std::fmt::Debug for BufferKeyTag {
407 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
408 f.debug_struct("BufferKeyTag")
409 .field("buffer", &self.buffer)
410 .field("session", &self.session)
411 .field("accessor", &self.accessor)
412 .field("in_use", &self.is_in_use())
413 .finish()
414 }
415}
416
417#[derive(SystemParam)]
465pub struct BufferAccess<'w, 's, T>
466where
467 T: 'static + Send + Sync,
468{
469 query: Query<'w, 's, &'static BufferStorage<T>>,
470}
471
472impl<'w, 's, T: 'static + Send + Sync> BufferAccess<'w, 's, T> {
473 pub fn get<'a>(&'a self, key: &BufferKey<T>) -> Result<BufferView<'a, T>, QueryEntityError> {
474 let session = key.session();
475 self.query
476 .get(key.buffer())
477 .map(|storage| BufferView { storage, session })
478 }
479
480 pub fn get_newest<'a>(&'a self, key: &BufferKey<T>) -> Option<&'a T> {
481 self.get(key).ok().map(|view| view.newest()).flatten()
482 }
483}
484
485#[derive(SystemParam)]
533pub struct BufferAccessMut<'w, 's, T>
534where
535 T: 'static + Send + Sync,
536{
537 query: Query<'w, 's, &'static mut BufferStorage<T>>,
538 commands: Commands<'w, 's>,
539}
540
541impl<'w, 's, T> BufferAccessMut<'w, 's, T>
542where
543 T: 'static + Send + Sync,
544{
545 pub fn get<'a>(&'a self, key: &BufferKey<T>) -> Result<BufferView<'a, T>, QueryEntityError> {
546 let session = key.session();
547 self.query
548 .get(key.buffer())
549 .map(|storage| BufferView { storage, session })
550 }
551
552 pub fn get_newest<'a>(&'a self, key: &BufferKey<T>) -> Option<&'a T> {
553 self.get(key).ok().map(|view| view.newest()).flatten()
554 }
555
556 pub fn get_mut<'a>(
557 &'a mut self,
558 key: &BufferKey<T>,
559 ) -> Result<BufferMut<'w, 's, 'a, T>, QueryEntityError> {
560 let buffer = key.buffer();
561 let session = key.session();
562 let accessor = key.tag.accessor;
563 self.query
564 .get_mut(key.buffer())
565 .map(|storage| BufferMut::new(storage, buffer, session, accessor, &mut self.commands))
566 }
567}
568
569pub trait BufferWorldAccess {
571 fn buffer_view<T>(&self, key: &BufferKey<T>) -> Result<BufferView<'_, T>, BufferError>
576 where
577 T: 'static + Send + Sync;
578
579 fn buffer_gate_view(
581 &self,
582 key: impl Into<AnyBufferKey>,
583 ) -> Result<BufferGateView<'_>, BufferError>;
584
585 fn buffer_mut<T, U>(
590 &mut self,
591 key: &BufferKey<T>,
592 f: impl FnOnce(BufferMut<T>) -> U,
593 ) -> Result<U, BufferError>
594 where
595 T: 'static + Send + Sync;
596
597 fn buffer_gate_mut<U>(
602 &mut self,
603 key: impl Into<AnyBufferKey>,
604 f: impl FnOnce(BufferGateMut) -> U,
605 ) -> Result<U, BufferError>;
606}
607
608impl BufferWorldAccess for World {
609 fn buffer_view<T>(&self, key: &BufferKey<T>) -> Result<BufferView<'_, T>, BufferError>
610 where
611 T: 'static + Send + Sync,
612 {
613 let buffer_ref = self
614 .get_entity(key.tag.buffer)
615 .map_err(|_| BufferError::BufferMissing)?;
616 let storage = buffer_ref
617 .get::<BufferStorage<T>>()
618 .ok_or(BufferError::BufferMissing)?;
619 Ok(BufferView {
620 storage,
621 session: key.tag.session,
622 })
623 }
624
625 fn buffer_gate_view(
626 &self,
627 key: impl Into<AnyBufferKey>,
628 ) -> Result<BufferGateView<'_>, BufferError> {
629 let key: AnyBufferKey = key.into();
630 let buffer_ref = self
631 .get_entity(key.tag.buffer)
632 .or(Err(BufferError::BufferMissing))?;
633 let gate = buffer_ref
634 .get::<GateState>()
635 .ok_or(BufferError::BufferMissing)?;
636 Ok(BufferGateView {
637 gate,
638 session: key.tag.session,
639 })
640 }
641
642 fn buffer_mut<T, U>(
643 &mut self,
644 key: &BufferKey<T>,
645 f: impl FnOnce(BufferMut<T>) -> U,
646 ) -> Result<U, BufferError>
647 where
648 T: 'static + Send + Sync,
649 {
650 let mut state = SystemState::<BufferAccessMut<T>>::new(self);
651 let mut buffer_access_mut = state.get_mut(self);
652 let buffer_mut = buffer_access_mut
653 .get_mut(key)
654 .map_err(|_| BufferError::BufferMissing)?;
655 Ok(f(buffer_mut))
656 }
657
658 fn buffer_gate_mut<U>(
659 &mut self,
660 key: impl Into<AnyBufferKey>,
661 f: impl FnOnce(BufferGateMut) -> U,
662 ) -> Result<U, BufferError> {
663 let mut state = SystemState::<BufferGateAccessMut>::new(self);
664 let mut buffer_gate_access_mut = state.get_mut(self);
665 let buffer_mut = buffer_gate_access_mut
666 .get_mut(key)
667 .map_err(|_| BufferError::BufferMissing)?;
668 Ok(f(buffer_mut))
669 }
670}
671
672pub struct BufferView<'a, T>
674where
675 T: 'static + Send + Sync,
676{
677 storage: &'a BufferStorage<T>,
678 session: Entity,
679}
680
681impl<'a, T> BufferView<'a, T>
682where
683 T: 'static + Send + Sync,
684{
685 pub fn iter(&self) -> IterBufferView<'a, T> {
687 self.storage.iter(self.session)
688 }
689
690 pub fn oldest(&self) -> Option<&'a T> {
692 self.storage.oldest(self.session)
693 }
694
695 pub fn newest(&self) -> Option<&'a T> {
697 self.storage.newest(self.session)
698 }
699
700 pub fn get(&self, index: usize) -> Option<&'a T> {
703 self.storage.get(self.session, index)
704 }
705
706 pub fn len(&self) -> usize {
708 self.storage.count(self.session)
709 }
710
711 pub fn is_empty(&self) -> bool {
713 self.len() == 0
714 }
715}
716
717pub struct BufferMut<'w, 's, 'a, T>
719where
720 T: 'static + Send + Sync,
721{
722 storage: Mut<'a, BufferStorage<T>>,
723 buffer: Entity,
724 session: Entity,
725 accessor: Option<Entity>,
726 commands: &'a mut Commands<'w, 's>,
727 modified: bool,
728}
729
730impl<'w, 's, 'a, T> BufferMut<'w, 's, 'a, T>
731where
732 T: 'static + Send + Sync,
733{
734 pub fn allow_closed_loops(mut self) -> Self {
749 self.accessor = None;
750 self
751 }
752
753 pub fn iter(&self) -> IterBufferView<'_, T> {
755 self.storage.iter(self.session)
756 }
757
758 pub fn oldest(&self) -> Option<&T> {
760 self.storage.oldest(self.session)
761 }
762
763 pub fn newest(&self) -> Option<&T> {
765 self.storage.newest(self.session)
766 }
767
768 pub fn get(&self, index: usize) -> Option<&T> {
771 self.storage.get(self.session, index)
772 }
773
774 pub fn len(&self) -> usize {
776 self.storage.count(self.session)
777 }
778
779 pub fn is_empty(&self) -> bool {
781 self.len() == 0
782 }
783
784 pub fn iter_mut(&mut self) -> IterBufferMut<'_, T> {
786 self.modified = true;
787 self.storage.iter_mut(self.session)
788 }
789
790 pub fn oldest_mut(&mut self) -> Option<&mut T> {
792 self.modified = true;
793 self.storage.oldest_mut(self.session)
794 }
795
796 pub fn newest_mut(&mut self) -> Option<&mut T> {
798 self.modified = true;
799 self.storage.newest_mut(self.session)
800 }
801
802 pub fn newest_mut_or_default(&mut self) -> Option<&mut T>
808 where
809 T: Default,
810 {
811 self.newest_mut_or_else(|| T::default())
812 }
813
814 pub fn newest_mut_or_else(&mut self, f: impl FnOnce() -> T) -> Option<&mut T> {
820 self.modified = true;
821 self.storage.newest_mut_or_else(self.session, f)
822 }
823
824 pub fn get_mut(&mut self, index: usize) -> Option<&mut T> {
827 self.modified = true;
828 self.storage.get_mut(self.session, index)
829 }
830
831 pub fn drain<R>(&mut self, range: R) -> DrainBuffer<'_, T>
833 where
834 R: RangeBounds<usize>,
835 {
836 self.modified = true;
837 self.storage.drain(self.session, range)
838 }
839
840 pub fn pull(&mut self) -> Option<T> {
842 self.modified = true;
843 self.storage.pull(self.session)
844 }
845
846 pub fn pull_newest(&mut self) -> Option<T> {
849 self.modified = true;
850 self.storage.pull_newest(self.session)
851 }
852
853 pub fn push(&mut self, value: T) -> Option<T> {
856 self.modified = true;
857 self.storage.push(self.session, value)
858 }
859
860 pub fn push_as_oldest(&mut self, value: T) -> Option<T> {
864 self.modified = true;
865 self.storage.push_as_oldest(self.session, value)
866 }
867
868 pub fn pulse(&mut self) {
872 self.modified = true;
873 }
874
875 fn new(
876 storage: Mut<'a, BufferStorage<T>>,
877 buffer: Entity,
878 session: Entity,
879 accessor: Entity,
880 commands: &'a mut Commands<'w, 's>,
881 ) -> Self {
882 Self {
883 storage,
884 buffer,
885 session,
886 accessor: Some(accessor),
887 commands,
888 modified: false,
889 }
890 }
891}
892
893impl<'w, 's, 'a, T> Drop for BufferMut<'w, 's, 'a, T>
894where
895 T: 'static + Send + Sync,
896{
897 fn drop(&mut self) {
898 if self.modified {
899 self.commands.queue(NotifyBufferUpdate::new(
900 self.buffer,
901 self.session,
902 self.accessor,
903 ));
904 }
905 }
906}
907
908#[derive(ThisError, Debug, Clone)]
909pub enum BufferError {
910 #[error("The key was unable to identify a buffer")]
911 BufferMissing,
912}
913
914#[cfg(test)]
915mod tests {
916 use crate::{AddBufferToMap, Gate, prelude::*, testing::*};
917 use std::future::Future;
918
919 #[test]
920 fn test_buffer_key_access() {
921 let mut context = TestingContext::minimal_plugins();
922
923 let add_buffers_by_pull_cb = add_buffers_by_pull.into_blocking_callback();
924 let add_from_buffer_cb = add_from_buffer.into_blocking_callback();
925 let multiply_buffers_by_copy_cb = multiply_buffers_by_copy.into_blocking_callback();
926
927 let workflow = context.spawn_io_workflow(|scope: Scope<(f64, f64), f64>, builder| {
928 builder
929 .chain(scope.start)
930 .unzip()
931 .listen(builder)
932 .then(multiply_buffers_by_copy_cb)
933 .connect(scope.terminate);
934 });
935
936 let r = context.resolve_request((2.0, 3.0), workflow);
937 assert_eq!(r, 6.0);
938
939 let workflow = context.spawn_io_workflow(|scope: Scope<(f64, f64), f64>, builder| {
940 builder
941 .chain(scope.start)
942 .unzip()
943 .listen(builder)
944 .then(add_buffers_by_pull_cb)
945 .dispose_on_none()
946 .connect(scope.terminate);
947 });
948
949 let r = context.resolve_request((4.0, 5.0), workflow);
950 assert_eq!(r, 9.0);
951
952 let workflow =
953 context.spawn_io_workflow(|scope: Scope<(f64, f64), Result<f64, f64>>, builder| {
954 let (branch_to_adder, branch_to_buffer) = builder.chain(scope.start).unzip();
955 let buffer = builder.create_buffer::<f64>(BufferSettings::keep_first(10));
956 builder.connect(branch_to_buffer, buffer.input_slot());
957
958 let adder_node = builder
959 .chain(branch_to_adder)
960 .with_access(buffer)
961 .then_node(add_from_buffer_cb.clone());
962
963 builder.chain(adder_node.output).fork_result(
964 |chain| {
969 chain
970 .with_access(buffer)
971 .then(add_from_buffer_cb.clone())
972 .connect(scope.terminate)
973 },
974 |chain| chain.with_access(buffer).connect(adder_node.input),
977 );
978 });
979
980 let r = context.resolve_request((2.0, 3.0), workflow);
981 assert!(r.is_err_and(|n| n == 5.0));
982
983 let workflow = context.spawn_io_workflow(|scope, builder| {
985 let (branch_to_adder, branch_to_buffer) = builder.chain(scope.start).unzip();
986 let buffer = builder.create_buffer::<f64>(BufferSettings::keep_first(10));
987 builder.connect(branch_to_buffer, buffer.input_slot());
988
989 let access = builder.create_buffer_access(buffer);
990 builder.connect(branch_to_adder, access.input);
991 builder
992 .chain(access.output)
993 .then(add_from_buffer_cb.clone())
994 .fork_result(
995 |ok| {
996 let (output, builder) = ok.unpack();
997 let second_access = builder.create_buffer_access(buffer);
998 builder.connect(output, second_access.input);
999 builder
1000 .chain(second_access.output)
1001 .then(add_from_buffer_cb.clone())
1002 .connect(scope.terminate);
1003 },
1004 |err| err.connect(access.input),
1005 );
1006 });
1007
1008 let r = context.resolve_request((2.0, 3.0), workflow);
1009 assert!(r.is_err_and(|n| n == 5.0));
1010 }
1011
1012 fn add_from_buffer(
1013 In((lhs, key)): In<(f64, BufferKey<f64>)>,
1014 mut access: BufferAccessMut<f64>,
1015 ) -> Result<f64, f64> {
1016 let rhs = access.get_mut(&key).map_err(|_| lhs)?.pull().ok_or(lhs)?;
1017 Ok(lhs + rhs)
1018 }
1019
1020 fn multiply_buffers_by_copy(
1021 In((key_a, key_b)): In<(BufferKey<f64>, BufferKey<f64>)>,
1022 access: BufferAccess<f64>,
1023 ) -> f64 {
1024 *access.get(&key_a).unwrap().oldest().unwrap()
1025 * *access.get(&key_b).unwrap().oldest().unwrap()
1026 }
1027
1028 fn add_buffers_by_pull(
1029 In((key_a, key_b)): In<(BufferKey<f64>, BufferKey<f64>)>,
1030 mut access: BufferAccessMut<f64>,
1031 ) -> Option<f64> {
1032 if access.get(&key_a).unwrap().is_empty() {
1033 return None;
1034 }
1035
1036 if access.get(&key_b).unwrap().is_empty() {
1037 return None;
1038 }
1039
1040 let rhs = access.get_mut(&key_a).unwrap().pull().unwrap();
1041 let lhs = access.get_mut(&key_b).unwrap().pull().unwrap();
1042 Some(rhs + lhs)
1043 }
1044
1045 #[test]
1046 fn test_buffer_key_lifecycle() {
1047 let mut context = TestingContext::minimal_plugins();
1048
1049 let workflow = context.spawn_io_workflow(|scope, builder| {
1052 let buffer = builder.create_buffer::<Register>(BufferSettings::keep_all());
1053
1054 builder
1056 .listen(buffer)
1057 .then(pull_register_from_buffer.into_blocking_callback())
1058 .dispose_on_none()
1059 .connect(scope.terminate);
1060
1061 let decrement_register_cb = decrement_register.into_blocking_callback();
1062 let async_decrement_register_cb = async_decrement_register.as_callback();
1063 builder
1064 .chain(scope.start)
1065 .with_access(buffer)
1066 .then(decrement_register_cb.clone())
1067 .with_access(buffer)
1068 .then(async_decrement_register_cb.clone())
1069 .dispose_on_none()
1070 .with_access(buffer)
1071 .then(decrement_register_cb.clone())
1072 .with_access(buffer)
1073 .then(async_decrement_register_cb)
1074 .unused();
1075 });
1076
1077 run_register_test(workflow, 0, true, &mut context);
1078 run_register_test(workflow, 1, true, &mut context);
1079 run_register_test(workflow, 2, true, &mut context);
1080 run_register_test(workflow, 3, true, &mut context);
1081 run_register_test(workflow, 4, false, &mut context);
1082 run_register_test(workflow, 5, false, &mut context);
1083 run_register_test(workflow, 6, false, &mut context);
1084
1085 let workflow = context.spawn_io_workflow(|scope, builder| {
1089 let buffer = builder.create_buffer::<Register>(BufferSettings::keep_all());
1090
1091 builder
1093 .listen(buffer)
1094 .then(pull_register_from_buffer.into_blocking_callback())
1095 .dispose_on_none()
1096 .connect(scope.terminate);
1097
1098 let decrement_register_and_pass_keys_cb =
1099 decrement_register_and_pass_keys.into_blocking_callback();
1100 let async_decrement_register_and_pass_keys_cb =
1101 async_decrement_register_and_pass_keys.as_callback();
1102 let (loose_end, dead_end): (_, Output<Option<Register>>) = builder
1103 .chain(scope.start)
1104 .with_access(buffer)
1105 .then(decrement_register_and_pass_keys_cb.clone())
1106 .then(async_decrement_register_and_pass_keys_cb.clone())
1107 .dispose_on_none()
1108 .map_block(|v| (v, None))
1109 .unzip();
1110
1111 builder.chain(dead_end).dispose_on_none().unused();
1113
1114 builder
1115 .chain(loose_end)
1116 .then(async_decrement_register_and_pass_keys_cb)
1117 .dispose_on_none()
1118 .then(decrement_register_and_pass_keys_cb)
1119 .unused();
1120 });
1121
1122 run_register_test(workflow, 0, true, &mut context);
1123 run_register_test(workflow, 1, true, &mut context);
1124 run_register_test(workflow, 2, true, &mut context);
1125 run_register_test(workflow, 3, true, &mut context);
1126 run_register_test(workflow, 4, false, &mut context);
1127 run_register_test(workflow, 5, false, &mut context);
1128 run_register_test(workflow, 6, false, &mut context);
1129 }
1130
1131 fn run_register_test(
1132 workflow: Service<Register, Register>,
1133 initial_value: u64,
1134 expect_success: bool,
1135 context: &mut TestingContext,
1136 ) {
1137 let r = context.try_resolve_request(Register::new(initial_value), workflow, ());
1138 if expect_success {
1139 assert!(r.unwrap().finished_with(initial_value));
1140 } else {
1141 assert!(r.is_err());
1142 }
1143 }
1144
1145 #[derive(Clone, Copy, Debug)]
1150 struct Register {
1151 in_slot: u64,
1152 out_slot: u64,
1153 }
1154
1155 impl Register {
1156 fn new(start_from: u64) -> Self {
1157 Self {
1158 in_slot: start_from,
1159 out_slot: 0,
1160 }
1161 }
1162
1163 fn finished_with(&self, out_slot: u64) -> bool {
1164 self.in_slot == 0 && self.out_slot == out_slot
1165 }
1166 }
1167
1168 fn pull_register_from_buffer(
1169 In(key): In<BufferKey<Register>>,
1170 mut access: BufferAccessMut<Register>,
1171 ) -> Option<Register> {
1172 access.get_mut(&key).ok()?.pull()
1173 }
1174
1175 fn decrement_register(
1176 In((mut register, key)): In<(Register, BufferKey<Register>)>,
1177 mut access: BufferAccessMut<Register>,
1178 ) -> Register {
1179 if register.in_slot == 0 {
1180 access.get_mut(&key).unwrap().push(register);
1181 return register;
1182 }
1183
1184 register.in_slot -= 1;
1185 register.out_slot += 1;
1186 register
1187 }
1188
1189 fn decrement_register_and_pass_keys(
1190 In((mut register, key)): In<(Register, BufferKey<Register>)>,
1191 mut access: BufferAccessMut<Register>,
1192 ) -> (Register, BufferKey<Register>) {
1193 if register.in_slot == 0 {
1194 access.get_mut(&key).unwrap().push(register);
1195 return (register, key);
1196 }
1197
1198 register.in_slot -= 1;
1199 register.out_slot += 1;
1200 (register, key)
1201 }
1202
1203 fn async_decrement_register(
1204 In(input): In<AsyncCallback<(Register, BufferKey<Register>)>>,
1205 ) -> impl Future<Output = Option<Register>> + use<> {
1206 async move {
1207 input
1208 .channel
1209 .request_outcome(input.request, decrement_register.into_blocking_callback())
1210 .await
1211 .ok()
1212 }
1213 }
1214
1215 fn async_decrement_register_and_pass_keys(
1216 In(input): In<AsyncCallback<(Register, BufferKey<Register>)>>,
1217 ) -> impl Future<Output = Option<(Register, BufferKey<Register>)>> + use<> {
1218 async move {
1219 input
1220 .channel
1221 .request_outcome(
1222 input.request,
1223 decrement_register_and_pass_keys.into_blocking_callback(),
1224 )
1225 .await
1226 .ok()
1227 }
1228 }
1229
1230 #[test]
1231 fn test_buffer_key_gate_control() {
1232 let mut context = TestingContext::minimal_plugins();
1233
1234 let workflow = context.spawn_io_workflow(|scope, builder| {
1235 let service = builder.commands().spawn_service(gate_access_test_open_loop);
1236
1237 let buffer = builder.create_buffer(BufferSettings::keep_all());
1238 builder.connect(scope.start, buffer.input_slot());
1239 builder
1240 .listen(buffer)
1241 .then_gate_close(buffer)
1242 .then(service)
1243 .fork_unzip((
1244 |chain: Chain<_>| chain.dispose_on_none().connect(buffer.input_slot()),
1245 |chain: Chain<_>| chain.dispose_on_none().connect(scope.terminate),
1246 ));
1247 });
1248
1249 let r = context.resolve_request(0, workflow);
1250 assert_eq!(r, 5);
1251 }
1252
1253 fn gate_access_test_open_loop(
1256 In(BlockingService { request: key, .. }): BlockingServiceInput<BufferKey<u64>>,
1257 mut access: BufferAccessMut<u64>,
1258 mut gate_access: BufferGateAccessMut,
1259 ) -> (Option<u64>, Option<u64>) {
1260 let mut buffer = access.get_mut(&key).unwrap();
1263 let value = buffer.pull().unwrap();
1264
1265 let mut gate = gate_access.get_mut(key).unwrap();
1268 assert_eq!(gate.get(), Gate::Closed);
1269 gate.open_gate();
1273
1274 if value >= 5 {
1275 (None, Some(value))
1276 } else {
1277 (Some(value + 1), None)
1278 }
1279 }
1280
1281 #[test]
1282 fn test_closed_loop_key_access() {
1283 let mut context = TestingContext::minimal_plugins();
1284
1285 let delay = context.spawn_delay(Duration::from_secs_f32(0.1));
1286
1287 let workflow = context.spawn_io_workflow(|scope, builder| {
1288 let service = builder
1289 .commands()
1290 .spawn_service(gate_access_test_closed_loop);
1291
1292 let buffer = builder.create_buffer(BufferSettings::keep_all());
1293 builder.connect(scope.start, buffer.input_slot());
1294 builder.listen(buffer).then(service).fork_unzip((
1295 |chain: Chain<_>| {
1296 chain
1297 .dispose_on_none()
1298 .then(delay)
1299 .connect(buffer.input_slot())
1300 },
1301 |chain: Chain<_>| chain.dispose_on_none().connect(scope.terminate),
1302 ));
1303 });
1304
1305 let r = context.resolve_request(3, workflow);
1306 assert_eq!(r, 0);
1307 }
1308
1309 fn gate_access_test_closed_loop(
1311 In(BlockingService { request: key, .. }): BlockingServiceInput<BufferKey<u64>>,
1312 mut access: BufferAccessMut<u64>,
1313 ) -> (Option<u64>, Option<u64>) {
1314 let mut buffer = access.get_mut(&key).unwrap().allow_closed_loops();
1315 if let Some(value) = buffer.pull() {
1316 (Some(value + 1), None)
1317 } else {
1318 (None, Some(0))
1319 }
1320 }
1321
1322 #[test]
1323 fn test_any_buffer_join_by_clone() {
1324 let mut context = TestingContext::minimal_plugins();
1325
1326 let workflow = context.spawn_io_workflow(|scope, builder| {
1327 let message_buffer = builder.create_buffer(Default::default()).join_by_cloning();
1328 let count_buffer = builder.create_buffer(Default::default());
1329 let (message, count) = builder.chain(scope.start).unzip();
1330 builder.connect(message, message_buffer.input_slot());
1331 builder.connect(count, count_buffer.input_slot());
1332
1333 let any_message_buffer = message_buffer.as_any_buffer();
1336 let any_count_buffer = count_buffer.as_any_buffer();
1337
1338 let mut buffer_map = BufferMap::default();
1339 buffer_map.insert_buffer("message", any_message_buffer);
1340 buffer_map.insert_buffer("count", any_count_buffer);
1341
1342 builder
1343 .try_join::<JoinByCloneTest>(&buffer_map)
1344 .unwrap()
1345 .map_block(|joined| {
1346 if joined.count < 10 {
1347 Err(joined.count + 1)
1349 } else {
1350 Ok(joined)
1351 }
1352 })
1353 .fork_result(
1354 |ok| ok.connect(scope.terminate),
1355 |err| err.connect(count_buffer.input_slot()),
1356 );
1357 });
1358
1359 let r = context.resolve_request((String::from("hello"), 0), workflow);
1360 assert_eq!(r.count, 10);
1361 assert_eq!(r.message, "hello");
1362 }
1363
1364 #[derive(Joined)]
1365 struct JoinByCloneTest {
1366 count: i64,
1367 message: String,
1368 }
1369
1370 fn get_largest_value(
1371 In(input): In<((), BufferKey<i32>)>,
1372 access: BufferAccess<i32>,
1373 ) -> Option<i32> {
1374 let access = access.get(&input.1).ok()?;
1375 access.iter().max().cloned()
1376 }
1377
1378 fn push_values(In(input): In<(Vec<i32>, BufferKey<i32>)>, mut access: BufferAccessMut<i32>) {
1379 let Ok(mut access) = access.get_mut(&input.1) else {
1380 return;
1381 };
1382
1383 for value in input.0 {
1384 access.push(value);
1385 }
1386 }
1387
1388 #[test]
1389 fn test_buffer_access_example() {
1390 let mut context = TestingContext::minimal_plugins();
1391
1392 let workflow = context.spawn_io_workflow(|scope, builder| {
1393 let buffer = builder.create_buffer(BufferSettings::keep_all());
1394 builder
1395 .chain(scope.start)
1396 .with_access(buffer)
1397 .then(push_values.into_blocking_callback())
1398 .with_access(buffer)
1399 .then(get_largest_value.into_blocking_callback())
1400 .connect(scope.terminate);
1401 });
1402
1403 let r = context.resolve_request(vec![-3, 2, 10], workflow);
1404 assert_eq!(r.unwrap(), 10);
1405 }
1406}