crossflow/
buffer.rs

1/*
2 * Copyright (C) 2024 Open Source Robotics Foundation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *
16*/
17
18use bevy_ecs::{
19    change_detection::Mut,
20    prelude::{Commands, Entity, EntityRef, Query, World},
21    query::QueryEntityError,
22    system::{SystemParam, SystemState},
23};
24
25use std::{
26    any::TypeId,
27    collections::HashMap,
28    ops::RangeBounds,
29    sync::{Arc, Mutex, OnceLock},
30};
31
32use thiserror::Error as ThisError;
33
34use crate::{GateState, InputSlot, NotifyBufferUpdate, OperationError};
35
36mod any_buffer;
37pub use any_buffer::*;
38
39mod buffer_access_lifecycle;
40pub use buffer_access_lifecycle::BufferKeyLifecycle;
41pub(crate) use buffer_access_lifecycle::*;
42
43mod buffer_key_builder;
44pub use buffer_key_builder::*;
45
46mod buffer_gate;
47pub use buffer_gate::*;
48
49mod buffer_map;
50pub use buffer_map::*;
51
52mod buffer_storage;
53pub(crate) use buffer_storage::*;
54
55mod buffering;
56pub use buffering::*;
57
58mod bufferable;
59pub use bufferable::*;
60
61mod manage_buffer;
62pub use manage_buffer::*;
63
64#[cfg(feature = "diagram")]
65mod json_buffer;
66#[cfg(feature = "diagram")]
67pub use json_buffer::*;
68
69mod fetch_from_buffer;
70pub use fetch_from_buffer::*;
71
72/// A buffer is a special type of node within a workflow that is able to store
73/// and release data. When a session is finished, the buffered data from the
74/// session will be automatically cleared.
75pub struct Buffer<T> {
76    pub(crate) location: BufferLocation,
77    pub(crate) _ignore: std::marker::PhantomData<fn(T)>,
78}
79
80impl<T: 'static + Send + Sync> Buffer<T> {
81    /// Specify that you want this Buffer to join by cloning an element. This
82    /// can be used by operations like join to tell them that they should clone
83    /// from the buffer instead of consuming from it.
84    ///
85    /// This can only be used for message types that support [`Clone`].
86    pub fn join_by_cloning(self) -> CloneFromBuffer<T>
87    where
88        T: Clone,
89    {
90        CloneFromBuffer::new(self.location)
91    }
92
93    /// Get an input slot for this buffer.
94    pub fn input_slot(self) -> InputSlot<T> {
95        InputSlot::new(self.scope(), self.id())
96    }
97
98    /// Get the entity ID of the buffer.
99    pub fn id(&self) -> Entity {
100        self.location.source
101    }
102
103    /// Get the ID of the workflow that the buffer is associated with.
104    pub fn scope(&self) -> Entity {
105        self.location.scope
106    }
107
108    /// Get general information about the buffer.
109    pub fn location(&self) -> BufferLocation {
110        self.location
111    }
112}
113
114impl<T> Clone for Buffer<T> {
115    fn clone(&self) -> Self {
116        *self
117    }
118}
119
120impl<T> Copy for Buffer<T> {}
121
122/// The general identifying information for a buffer to locate it within the
123/// world. This does not indicate anything about the type of messages that the
124/// buffer can contain.
125#[derive(Clone, Copy, Debug)]
126pub struct BufferLocation {
127    /// The entity ID of the buffer.
128    pub scope: Entity,
129    /// The ID of the workflow that the buffer is associated with.
130    pub source: Entity,
131}
132
133#[derive(Clone)]
134pub struct CloneFromBuffer<T: Clone + Send + Sync + 'static> {
135    location: BufferLocation,
136    _ignore: std::marker::PhantomData<fn(T)>,
137}
138
139impl<T: Clone + Send + Sync + 'static> Copy for CloneFromBuffer<T> {}
140
141impl<T: Clone + Send + Sync + 'static> CloneFromBuffer<T> {
142    /// Get an input slot for this buffer.
143    pub fn input_slot(self) -> InputSlot<T> {
144        InputSlot::new(self.scope(), self.id())
145    }
146
147    /// Get the entity ID of the buffer.
148    pub fn id(&self) -> Entity {
149        self.location.source
150    }
151
152    /// Get the ID of the workflow that the buffer is associated with.
153    pub fn scope(&self) -> Entity {
154        self.location.scope
155    }
156
157    /// Get general information about the buffer.
158    pub fn location(&self) -> BufferLocation {
159        self.location
160    }
161
162    /// Specify that you want this Buffer to join by pulling an element. This
163    /// is the default behavior of a Buffer, so you don't generally need to call
164    /// this method, but you can use it to change from the join-by-cloning
165    /// setting back to join-by-pulling.
166    #[must_use]
167    pub fn join_by_pulling(self) -> Buffer<T> {
168        Buffer {
169            location: self.location,
170            _ignore: Default::default(),
171        }
172    }
173
174    fn new(location: BufferLocation) -> Self {
175        Self::register_clone_for_join();
176        Self {
177            location,
178            _ignore: Default::default(),
179        }
180    }
181
182    /// This function ensures that [`AnyBuffer`] can be downcast into a
183    /// [`CloneFromBuffer`] and that it can correctly transfer any Cloning join
184    /// behavior if it gets downcast to a [`FetchFromBuffer`].
185    pub fn register_clone_for_join() {
186        static REGISTER_CLONE: OnceLock<Mutex<HashMap<TypeId, ()>>> = OnceLock::new();
187        let register_clone = REGISTER_CLONE.get_or_init(|| Mutex::default());
188
189        // TODO(@mxgrey): Consider whether there is a way to avoid needing all
190        // these mutex locks and hashmap lookups every time we create a CloneFromBuffer.
191        let mut register_mut = register_clone.lock().unwrap();
192        register_mut.entry(TypeId::of::<T>()).or_insert_with(|| {
193            let interface = AnyBuffer::interface_for::<T>();
194            interface.register_cloning(
195                clone_for_any_join::<T>,
196                &(clone_for_join::<T> as FetchFromBufferFn<T>),
197            );
198            interface.register_buffer_downcast(
199                TypeId::of::<CloneFromBuffer<T>>(),
200                Box::new(|buffer: AnyBuffer| {
201                    Ok(Box::new(CloneFromBuffer::<T>::new(buffer.location)))
202                }),
203            );
204        });
205    }
206}
207
208fn clone_for_any_join<T: 'static + Send + Sync + Clone>(
209    entity_ref: &EntityRef,
210    session: Entity,
211) -> Result<AnyMessageBox, OperationError> {
212    // In general we expect pulling to imply pulling the oldest since the most
213    // typical pattern when information is being pulled from a source would be
214    // FIFO. It's very unlikely that someone pulling information would want the
215    // order that they're pulling data to be backwards. However for cloning it's
216    // much less obvious: Would someone want to clone the oldest piece of data
217    // or the newest?
218    //
219    // For now we will assume that If the user is cloning information without
220    // removing it, then they are leaving the old data in there to act as a
221    // history, and therefore when cloning they will always want the newest.
222    //
223    // TODO(@mxgrey): Allow users to set whether pull and/or clone should
224    // be operating on the oldest data or the newest, by default. Also allow
225    // the join operations themselves override this, e.g. putting a setting
226    // into BufferLocation to change which side is being pulled from.
227    entity_ref
228        .clone_from_buffer::<T>(session)
229        .map(to_any_message)
230}
231
232impl<T: Clone + Send + Sync> From<CloneFromBuffer<T>> for Buffer<T> {
233    fn from(value: CloneFromBuffer<T>) -> Self {
234        Buffer {
235            location: value.location,
236            _ignore: Default::default(),
237        }
238    }
239}
240
241/// Settings to describe the behavior of a buffer.
242#[cfg_attr(
243    feature = "diagram",
244    derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema),
245    serde(rename_all = "snake_case")
246)]
247#[derive(Default, Clone, Copy, Debug)]
248pub struct BufferSettings {
249    retention: RetentionPolicy,
250}
251
252impl BufferSettings {
253    /// Define new buffer settings
254    pub fn new(retention: RetentionPolicy) -> Self {
255        Self { retention }
256    }
257
258    /// Create `BufferSettings` with a retention policy of [`RetentionPolicy::KeepLast`]`(n)`.
259    pub fn keep_last(n: usize) -> Self {
260        Self::new(RetentionPolicy::KeepLast(n))
261    }
262
263    /// Create `BufferSettings` with a retention policy of [`RetentionPolicy::KeepFirst`]`(n)`.
264    pub fn keep_first(n: usize) -> Self {
265        Self::new(RetentionPolicy::KeepFirst(n))
266    }
267
268    /// Create `BufferSettings` with a retention policy of [`RetentionPolicy::KeepAll`].
269    pub fn keep_all() -> Self {
270        Self::new(RetentionPolicy::KeepAll)
271    }
272
273    /// Get the retention policy for the buffer.
274    pub fn retention(&self) -> RetentionPolicy {
275        self.retention
276    }
277
278    /// Modify the retention policy for the buffer.
279    pub fn retention_mut(&mut self) -> &mut RetentionPolicy {
280        &mut self.retention
281    }
282}
283
284/// Describe how data within a buffer gets retained. Most mechanisms that pull
285/// data from a buffer will remove the oldest item in the buffer, so this policy
286/// is for dealing with situations where items are being stored faster than they
287/// are being pulled.
288///
289/// The default value is KeepLast(1).
290#[cfg_attr(
291    feature = "diagram",
292    derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema),
293    serde(rename_all = "snake_case")
294)]
295#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)]
296pub enum RetentionPolicy {
297    /// Keep the last N items that were stored into the buffer. Once the limit
298    /// is reached, the oldest item will be removed any time a new item arrives.
299    KeepLast(usize),
300    /// Keep the first N items that are stored into the buffer. Once the limit
301    /// is reached, any new item that arrives will be discarded.
302    KeepFirst(usize),
303    /// Do not limit how many items can be stored in the buffer.
304    KeepAll,
305}
306
307impl Default for RetentionPolicy {
308    fn default() -> Self {
309        Self::KeepLast(1)
310    }
311}
312
313/// This key can unlock access to the contents of a buffer by passing it into
314/// [`BufferAccess`] or [`BufferAccessMut`].
315///
316/// To obtain a `BufferKey`, use [`Chain::with_access`][1], or [`listen`][2].
317///
318/// [1]: crate::Chain::with_access
319/// [2]: crate::Accessible::listen
320pub struct BufferKey<T> {
321    tag: BufferKeyTag,
322    _ignore: std::marker::PhantomData<fn(T)>,
323}
324
325impl<T> Clone for BufferKey<T> {
326    fn clone(&self) -> Self {
327        Self {
328            tag: self.tag.clone(),
329            _ignore: Default::default(),
330        }
331    }
332}
333
334impl<T> BufferKey<T> {
335    /// The buffer ID of this key.
336    pub fn buffer(&self) -> Entity {
337        self.tag.buffer
338    }
339
340    /// The session that this key belongs to.
341    pub fn session(&self) -> Entity {
342        self.tag.session
343    }
344
345    pub fn tag(&self) -> &BufferKeyTag {
346        &self.tag
347    }
348}
349
350impl<T: 'static + Send + Sync> BufferKeyLifecycle for BufferKey<T> {
351    type TargetBuffer = Buffer<T>;
352
353    fn create_key(buffer: &Self::TargetBuffer, builder: &BufferKeyBuilder) -> Self {
354        BufferKey {
355            tag: builder.make_tag(buffer.id()),
356            _ignore: Default::default(),
357        }
358    }
359
360    fn is_in_use(&self) -> bool {
361        self.tag.is_in_use()
362    }
363
364    fn deep_clone(&self) -> Self {
365        Self {
366            tag: self.tag.deep_clone(),
367            _ignore: Default::default(),
368        }
369    }
370}
371
372impl<T> std::fmt::Debug for BufferKey<T> {
373    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374        f.debug_struct("BufferKey")
375            .field("message_type_name", &std::any::type_name::<T>())
376            .field("tag", &self.tag)
377            .finish()
378    }
379}
380
381/// The identifying information for a buffer key. This does not indicate
382/// anything about the type of messages that the buffer can contain.
383#[derive(Clone)]
384pub struct BufferKeyTag {
385    pub buffer: Entity,
386    pub session: Entity,
387    pub accessor: Entity,
388    pub lifecycle: Option<Arc<BufferAccessLifecycle>>,
389}
390
391impl BufferKeyTag {
392    pub fn is_in_use(&self) -> bool {
393        self.lifecycle.as_ref().is_some_and(|l| l.is_in_use())
394    }
395
396    pub fn deep_clone(&self) -> Self {
397        let mut deep = self.clone();
398        deep.lifecycle = self
399            .lifecycle
400            .as_ref()
401            .map(|l| Arc::new(l.as_ref().clone()));
402        deep
403    }
404}
405
406impl std::fmt::Debug for BufferKeyTag {
407    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
408        f.debug_struct("BufferKeyTag")
409            .field("buffer", &self.buffer)
410            .field("session", &self.session)
411            .field("accessor", &self.accessor)
412            .field("in_use", &self.is_in_use())
413            .finish()
414    }
415}
416
417/// This system parameter lets you get read-only access to a buffer that exists
418/// within a workflow. Use a [`BufferKey`] to unlock the access.
419///
420/// See [`BufferAccessMut`] for mutable access.
421///
422/// # Examples
423///
424/// ```
425/// use crossflow::{prelude::*, testing::*};
426///
427/// fn get_largest_value(
428///     In(input): In<((), BufferKey<i32>)>,
429///     access: BufferAccess<i32>,
430/// ) -> Option<i32> {
431///     let access = access.get(&input.1).ok()?;
432///     access.iter().max().cloned()
433/// }
434///
435/// fn push_values(
436///     In(input): In<(Vec<i32>, BufferKey<i32>)>,
437///     mut access: BufferAccessMut<i32>,
438/// ) {
439///     let Ok(mut access) = access.get_mut(&input.1) else {
440///         return;
441///     };
442///
443///     for value in input.0 {
444///         access.push(value);
445///     }
446/// }
447///
448/// let mut context = TestingContext::minimal_plugins();
449///
450/// let workflow = context.spawn_io_workflow(|scope, builder| {
451///     let buffer = builder.create_buffer(BufferSettings::keep_all());
452///     builder
453///         .chain(scope.start)
454///         .with_access(buffer)
455///         .then(push_values.into_blocking_callback())
456///         .with_access(buffer)
457///         .then(get_largest_value.into_blocking_callback())
458///         .connect(scope.terminate);
459/// });
460///
461/// let r = context.resolve_request(vec![-3, 2, 10], workflow);
462/// assert_eq!(r, Some(10));
463/// ```
464#[derive(SystemParam)]
465pub struct BufferAccess<'w, 's, T>
466where
467    T: 'static + Send + Sync,
468{
469    query: Query<'w, 's, &'static BufferStorage<T>>,
470}
471
472impl<'w, 's, T: 'static + Send + Sync> BufferAccess<'w, 's, T> {
473    pub fn get<'a>(&'a self, key: &BufferKey<T>) -> Result<BufferView<'a, T>, QueryEntityError> {
474        let session = key.session();
475        self.query
476            .get(key.buffer())
477            .map(|storage| BufferView { storage, session })
478    }
479
480    pub fn get_newest<'a>(&'a self, key: &BufferKey<T>) -> Option<&'a T> {
481        self.get(key).ok().map(|view| view.newest()).flatten()
482    }
483}
484
485/// This system parameter lets you get mutable access to a buffer that exists
486/// within a workflow. Use a [`BufferKey`] to unlock the access.
487///
488/// See [`BufferAccess`] for read-only access.
489///
490/// # Examples
491///
492/// ```
493/// use crossflow::{prelude::*, testing::*};
494///
495/// fn get_largest_value(
496///     In(input): In<((), BufferKey<i32>)>,
497///     access: BufferAccess<i32>,
498/// ) -> Option<i32> {
499///     let access = access.get(&input.1).ok()?;
500///     access.iter().max().cloned()
501/// }
502///
503/// fn push_values(
504///     In(input): In<(Vec<i32>, BufferKey<i32>)>,
505///     mut access: BufferAccessMut<i32>,
506/// ) {
507///     let Ok(mut access) = access.get_mut(&input.1) else {
508///         return;
509///     };
510///
511///     for value in input.0 {
512///         access.push(value);
513///     }
514/// }
515///
516/// let mut context = TestingContext::minimal_plugins();
517///
518/// let workflow = context.spawn_io_workflow(|scope, builder| {
519///     let buffer = builder.create_buffer(BufferSettings::keep_all());
520///     builder
521///         .chain(scope.start)
522///         .with_access(buffer)
523///         .then(push_values.into_blocking_callback())
524///         .with_access(buffer)
525///         .then(get_largest_value.into_blocking_callback())
526///         .connect(scope.terminate);
527/// });
528///
529/// let r = context.resolve_request(vec![-3, 2, 10], workflow);
530/// assert_eq!(r, Some(10));
531/// ```
532#[derive(SystemParam)]
533pub struct BufferAccessMut<'w, 's, T>
534where
535    T: 'static + Send + Sync,
536{
537    query: Query<'w, 's, &'static mut BufferStorage<T>>,
538    commands: Commands<'w, 's>,
539}
540
541impl<'w, 's, T> BufferAccessMut<'w, 's, T>
542where
543    T: 'static + Send + Sync,
544{
545    pub fn get<'a>(&'a self, key: &BufferKey<T>) -> Result<BufferView<'a, T>, QueryEntityError> {
546        let session = key.session();
547        self.query
548            .get(key.buffer())
549            .map(|storage| BufferView { storage, session })
550    }
551
552    pub fn get_newest<'a>(&'a self, key: &BufferKey<T>) -> Option<&'a T> {
553        self.get(key).ok().map(|view| view.newest()).flatten()
554    }
555
556    pub fn get_mut<'a>(
557        &'a mut self,
558        key: &BufferKey<T>,
559    ) -> Result<BufferMut<'w, 's, 'a, T>, QueryEntityError> {
560        let buffer = key.buffer();
561        let session = key.session();
562        let accessor = key.tag.accessor;
563        self.query
564            .get_mut(key.buffer())
565            .map(|storage| BufferMut::new(storage, buffer, session, accessor, &mut self.commands))
566    }
567}
568
569/// This trait allows [`World`] to give you access to any buffer using a [`BufferKey`]
570pub trait BufferWorldAccess {
571    /// Call this to get read-only access to a buffer from a [`World`].
572    ///
573    /// Alternatively you can use [`BufferAccess`] as a regular bevy system parameter,
574    /// which does not need direct world access.
575    fn buffer_view<T>(&self, key: &BufferKey<T>) -> Result<BufferView<'_, T>, BufferError>
576    where
577        T: 'static + Send + Sync;
578
579    /// Call this to get read-only access to the gate of a buffer from a [`World`].
580    fn buffer_gate_view(
581        &self,
582        key: impl Into<AnyBufferKey>,
583    ) -> Result<BufferGateView<'_>, BufferError>;
584
585    /// Call this to get mutable access to a buffer.
586    ///
587    /// Pass in a callback that will receive [`BufferMut`], allowing it to view
588    /// and modify the contents of the buffer.
589    fn buffer_mut<T, U>(
590        &mut self,
591        key: &BufferKey<T>,
592        f: impl FnOnce(BufferMut<T>) -> U,
593    ) -> Result<U, BufferError>
594    where
595        T: 'static + Send + Sync;
596
597    /// Call this to get mutable access to the gate of a buffer.
598    ///
599    /// Pass in a callback that will receive [`BufferGateMut`], allowing it to
600    /// view and modify the gate of the buffer.
601    fn buffer_gate_mut<U>(
602        &mut self,
603        key: impl Into<AnyBufferKey>,
604        f: impl FnOnce(BufferGateMut) -> U,
605    ) -> Result<U, BufferError>;
606}
607
608impl BufferWorldAccess for World {
609    fn buffer_view<T>(&self, key: &BufferKey<T>) -> Result<BufferView<'_, T>, BufferError>
610    where
611        T: 'static + Send + Sync,
612    {
613        let buffer_ref = self
614            .get_entity(key.tag.buffer)
615            .map_err(|_| BufferError::BufferMissing)?;
616        let storage = buffer_ref
617            .get::<BufferStorage<T>>()
618            .ok_or(BufferError::BufferMissing)?;
619        Ok(BufferView {
620            storage,
621            session: key.tag.session,
622        })
623    }
624
625    fn buffer_gate_view(
626        &self,
627        key: impl Into<AnyBufferKey>,
628    ) -> Result<BufferGateView<'_>, BufferError> {
629        let key: AnyBufferKey = key.into();
630        let buffer_ref = self
631            .get_entity(key.tag.buffer)
632            .or(Err(BufferError::BufferMissing))?;
633        let gate = buffer_ref
634            .get::<GateState>()
635            .ok_or(BufferError::BufferMissing)?;
636        Ok(BufferGateView {
637            gate,
638            session: key.tag.session,
639        })
640    }
641
642    fn buffer_mut<T, U>(
643        &mut self,
644        key: &BufferKey<T>,
645        f: impl FnOnce(BufferMut<T>) -> U,
646    ) -> Result<U, BufferError>
647    where
648        T: 'static + Send + Sync,
649    {
650        let mut state = SystemState::<BufferAccessMut<T>>::new(self);
651        let mut buffer_access_mut = state.get_mut(self);
652        let buffer_mut = buffer_access_mut
653            .get_mut(key)
654            .map_err(|_| BufferError::BufferMissing)?;
655        Ok(f(buffer_mut))
656    }
657
658    fn buffer_gate_mut<U>(
659        &mut self,
660        key: impl Into<AnyBufferKey>,
661        f: impl FnOnce(BufferGateMut) -> U,
662    ) -> Result<U, BufferError> {
663        let mut state = SystemState::<BufferGateAccessMut>::new(self);
664        let mut buffer_gate_access_mut = state.get_mut(self);
665        let buffer_mut = buffer_gate_access_mut
666            .get_mut(key)
667            .map_err(|_| BufferError::BufferMissing)?;
668        Ok(f(buffer_mut))
669    }
670}
671
672/// Access to view a buffer that exists inside a workflow.
673pub struct BufferView<'a, T>
674where
675    T: 'static + Send + Sync,
676{
677    storage: &'a BufferStorage<T>,
678    session: Entity,
679}
680
681impl<'a, T> BufferView<'a, T>
682where
683    T: 'static + Send + Sync,
684{
685    /// Iterate over the contents in the buffer
686    pub fn iter(&self) -> IterBufferView<'a, T> {
687        self.storage.iter(self.session)
688    }
689
690    /// Borrow the oldest item in the buffer.
691    pub fn oldest(&self) -> Option<&'a T> {
692        self.storage.oldest(self.session)
693    }
694
695    /// Borrow the newest item in the buffer.
696    pub fn newest(&self) -> Option<&'a T> {
697        self.storage.newest(self.session)
698    }
699
700    /// Borrow an item from the buffer. Index 0 is the oldest item in the buffer
701    /// with the highest index being the newest item in the buffer.
702    pub fn get(&self, index: usize) -> Option<&'a T> {
703        self.storage.get(self.session, index)
704    }
705
706    /// How many items are in the buffer?
707    pub fn len(&self) -> usize {
708        self.storage.count(self.session)
709    }
710
711    /// Check if the buffer is empty.
712    pub fn is_empty(&self) -> bool {
713        self.len() == 0
714    }
715}
716
717/// Access to mutate a buffer that exists inside a workflow.
718pub struct BufferMut<'w, 's, 'a, T>
719where
720    T: 'static + Send + Sync,
721{
722    storage: Mut<'a, BufferStorage<T>>,
723    buffer: Entity,
724    session: Entity,
725    accessor: Option<Entity>,
726    commands: &'a mut Commands<'w, 's>,
727    modified: bool,
728}
729
730impl<'w, 's, 'a, T> BufferMut<'w, 's, 'a, T>
731where
732    T: 'static + Send + Sync,
733{
734    /// When you make a modification using this `BufferMut`, anything listening
735    /// to the buffer will be notified about the update. This can create
736    /// unintentional infinite loops where a node in the workflow wakes itself
737    /// up every time it runs because of a modification it makes to a buffer.
738    ///
739    /// By default this closed loop is disabled by keeping track of which
740    /// listener created the key that's being used to modify the buffer, and
741    /// then skipping that listener when notifying about the modification.
742    ///
743    /// In some cases a key can be used far downstream of the listener. In that
744    /// case, there may be nodes downstream of the listener that do want to be
745    /// woken up by the modification. Use this function to allow that closed
746    /// loop to happen. It will be up to you to prevent the closed loop from
747    /// being a problem.
748    pub fn allow_closed_loops(mut self) -> Self {
749        self.accessor = None;
750        self
751    }
752
753    /// Iterate over the contents in the buffer.
754    pub fn iter(&self) -> IterBufferView<'_, T> {
755        self.storage.iter(self.session)
756    }
757
758    /// Look at the oldest item in the buffer.
759    pub fn oldest(&self) -> Option<&T> {
760        self.storage.oldest(self.session)
761    }
762
763    /// Look at the newest item in the buffer.
764    pub fn newest(&self) -> Option<&T> {
765        self.storage.newest(self.session)
766    }
767
768    /// Borrow an item from the buffer. Index 0 is the oldest item in the buffer
769    /// with the highest index being the newest item in the buffer.
770    pub fn get(&self, index: usize) -> Option<&T> {
771        self.storage.get(self.session, index)
772    }
773
774    /// How many items are in the buffer?
775    pub fn len(&self) -> usize {
776        self.storage.count(self.session)
777    }
778
779    /// Check if the buffer is empty.
780    pub fn is_empty(&self) -> bool {
781        self.len() == 0
782    }
783
784    /// Iterate over mutable borrows of the contents in the buffer.
785    pub fn iter_mut(&mut self) -> IterBufferMut<'_, T> {
786        self.modified = true;
787        self.storage.iter_mut(self.session)
788    }
789
790    /// Modify the oldest item in the buffer.
791    pub fn oldest_mut(&mut self) -> Option<&mut T> {
792        self.modified = true;
793        self.storage.oldest_mut(self.session)
794    }
795
796    /// Modify the newest item in the buffer.
797    pub fn newest_mut(&mut self) -> Option<&mut T> {
798        self.modified = true;
799        self.storage.newest_mut(self.session)
800    }
801
802    /// Modify the newest item in the buffer or create a default-initialized
803    /// item to modify if the buffer was empty.
804    ///
805    /// This may fail to provide a mutable borrow if the buffer was already
806    /// expired or if the buffer capacity was zero.
807    pub fn newest_mut_or_default(&mut self) -> Option<&mut T>
808    where
809        T: Default,
810    {
811        self.newest_mut_or_else(|| T::default())
812    }
813
814    /// Modify the newest item in the buffer or initialize an item if the
815    /// buffer was empty.
816    ///
817    /// This may fail to provide a mutable borrow if the buffer was already
818    /// expired or if the buffer capacity was zero.
819    pub fn newest_mut_or_else(&mut self, f: impl FnOnce() -> T) -> Option<&mut T> {
820        self.modified = true;
821        self.storage.newest_mut_or_else(self.session, f)
822    }
823
824    /// Modify an item in the buffer. Index 0 is the oldest item in the buffer
825    /// with the highest index being the newest item in the buffer.
826    pub fn get_mut(&mut self, index: usize) -> Option<&mut T> {
827        self.modified = true;
828        self.storage.get_mut(self.session, index)
829    }
830
831    /// Drain items out of the buffer
832    pub fn drain<R>(&mut self, range: R) -> DrainBuffer<'_, T>
833    where
834        R: RangeBounds<usize>,
835    {
836        self.modified = true;
837        self.storage.drain(self.session, range)
838    }
839
840    /// Pull the oldest item from the buffer.
841    pub fn pull(&mut self) -> Option<T> {
842        self.modified = true;
843        self.storage.pull(self.session)
844    }
845
846    /// Pull the item that was most recently put into the buffer (instead of
847    /// the oldest, which is what [`Self::pull`] gives).
848    pub fn pull_newest(&mut self) -> Option<T> {
849        self.modified = true;
850        self.storage.pull_newest(self.session)
851    }
852
853    /// Push a new value into the buffer. If the buffer is at its limit, this
854    /// will return the value that needed to be removed.
855    pub fn push(&mut self, value: T) -> Option<T> {
856        self.modified = true;
857        self.storage.push(self.session, value)
858    }
859
860    /// Push a value into the buffer as if it is the oldest value of the buffer.
861    /// If the buffer is at its limit, this will return the value that needed to
862    /// be removed.
863    pub fn push_as_oldest(&mut self, value: T) -> Option<T> {
864        self.modified = true;
865        self.storage.push_as_oldest(self.session, value)
866    }
867
868    /// Trigger the listeners for this buffer to wake up even if nothing in the
869    /// buffer has changed. This could be used for timers or timeout elements
870    /// in a workflow.
871    pub fn pulse(&mut self) {
872        self.modified = true;
873    }
874
875    fn new(
876        storage: Mut<'a, BufferStorage<T>>,
877        buffer: Entity,
878        session: Entity,
879        accessor: Entity,
880        commands: &'a mut Commands<'w, 's>,
881    ) -> Self {
882        Self {
883            storage,
884            buffer,
885            session,
886            accessor: Some(accessor),
887            commands,
888            modified: false,
889        }
890    }
891}
892
893impl<'w, 's, 'a, T> Drop for BufferMut<'w, 's, 'a, T>
894where
895    T: 'static + Send + Sync,
896{
897    fn drop(&mut self) {
898        if self.modified {
899            self.commands.queue(NotifyBufferUpdate::new(
900                self.buffer,
901                self.session,
902                self.accessor,
903            ));
904        }
905    }
906}
907
908#[derive(ThisError, Debug, Clone)]
909pub enum BufferError {
910    #[error("The key was unable to identify a buffer")]
911    BufferMissing,
912}
913
914#[cfg(test)]
915mod tests {
916    use crate::{AddBufferToMap, Gate, prelude::*, testing::*};
917    use std::future::Future;
918
919    #[test]
920    fn test_buffer_key_access() {
921        let mut context = TestingContext::minimal_plugins();
922
923        let add_buffers_by_pull_cb = add_buffers_by_pull.into_blocking_callback();
924        let add_from_buffer_cb = add_from_buffer.into_blocking_callback();
925        let multiply_buffers_by_copy_cb = multiply_buffers_by_copy.into_blocking_callback();
926
927        let workflow = context.spawn_io_workflow(|scope: Scope<(f64, f64), f64>, builder| {
928            builder
929                .chain(scope.start)
930                .unzip()
931                .listen(builder)
932                .then(multiply_buffers_by_copy_cb)
933                .connect(scope.terminate);
934        });
935
936        let r = context.resolve_request((2.0, 3.0), workflow);
937        assert_eq!(r, 6.0);
938
939        let workflow = context.spawn_io_workflow(|scope: Scope<(f64, f64), f64>, builder| {
940            builder
941                .chain(scope.start)
942                .unzip()
943                .listen(builder)
944                .then(add_buffers_by_pull_cb)
945                .dispose_on_none()
946                .connect(scope.terminate);
947        });
948
949        let r = context.resolve_request((4.0, 5.0), workflow);
950        assert_eq!(r, 9.0);
951
952        let workflow =
953            context.spawn_io_workflow(|scope: Scope<(f64, f64), Result<f64, f64>>, builder| {
954                let (branch_to_adder, branch_to_buffer) = builder.chain(scope.start).unzip();
955                let buffer = builder.create_buffer::<f64>(BufferSettings::keep_first(10));
956                builder.connect(branch_to_buffer, buffer.input_slot());
957
958                let adder_node = builder
959                    .chain(branch_to_adder)
960                    .with_access(buffer)
961                    .then_node(add_from_buffer_cb.clone());
962
963                builder.chain(adder_node.output).fork_result(
964                    // If the buffer had an item in it, we send it to another
965                    // node that tries to pull a second time (we expect the
966                    // buffer to be empty this second time) and then
967                    // terminates.
968                    |chain| {
969                        chain
970                            .with_access(buffer)
971                            .then(add_from_buffer_cb.clone())
972                            .connect(scope.terminate)
973                    },
974                    // If the buffer was empty, keep looping back until there
975                    // is a value available.
976                    |chain| chain.with_access(buffer).connect(adder_node.input),
977                );
978            });
979
980        let r = context.resolve_request((2.0, 3.0), workflow);
981        assert!(r.is_err_and(|n| n == 5.0));
982
983        // Same as previous test, but using Builder::create_buffer_access instead
984        let workflow = context.spawn_io_workflow(|scope, builder| {
985            let (branch_to_adder, branch_to_buffer) = builder.chain(scope.start).unzip();
986            let buffer = builder.create_buffer::<f64>(BufferSettings::keep_first(10));
987            builder.connect(branch_to_buffer, buffer.input_slot());
988
989            let access = builder.create_buffer_access(buffer);
990            builder.connect(branch_to_adder, access.input);
991            builder
992                .chain(access.output)
993                .then(add_from_buffer_cb.clone())
994                .fork_result(
995                    |ok| {
996                        let (output, builder) = ok.unpack();
997                        let second_access = builder.create_buffer_access(buffer);
998                        builder.connect(output, second_access.input);
999                        builder
1000                            .chain(second_access.output)
1001                            .then(add_from_buffer_cb.clone())
1002                            .connect(scope.terminate);
1003                    },
1004                    |err| err.connect(access.input),
1005                );
1006        });
1007
1008        let r = context.resolve_request((2.0, 3.0), workflow);
1009        assert!(r.is_err_and(|n| n == 5.0));
1010    }
1011
1012    fn add_from_buffer(
1013        In((lhs, key)): In<(f64, BufferKey<f64>)>,
1014        mut access: BufferAccessMut<f64>,
1015    ) -> Result<f64, f64> {
1016        let rhs = access.get_mut(&key).map_err(|_| lhs)?.pull().ok_or(lhs)?;
1017        Ok(lhs + rhs)
1018    }
1019
1020    fn multiply_buffers_by_copy(
1021        In((key_a, key_b)): In<(BufferKey<f64>, BufferKey<f64>)>,
1022        access: BufferAccess<f64>,
1023    ) -> f64 {
1024        *access.get(&key_a).unwrap().oldest().unwrap()
1025            * *access.get(&key_b).unwrap().oldest().unwrap()
1026    }
1027
1028    fn add_buffers_by_pull(
1029        In((key_a, key_b)): In<(BufferKey<f64>, BufferKey<f64>)>,
1030        mut access: BufferAccessMut<f64>,
1031    ) -> Option<f64> {
1032        if access.get(&key_a).unwrap().is_empty() {
1033            return None;
1034        }
1035
1036        if access.get(&key_b).unwrap().is_empty() {
1037            return None;
1038        }
1039
1040        let rhs = access.get_mut(&key_a).unwrap().pull().unwrap();
1041        let lhs = access.get_mut(&key_b).unwrap().pull().unwrap();
1042        Some(rhs + lhs)
1043    }
1044
1045    #[test]
1046    fn test_buffer_key_lifecycle() {
1047        let mut context = TestingContext::minimal_plugins();
1048
1049        // Test a workflow where each node in a long chain repeatedly accesses
1050        // a buffer and might be the one to push a value into it.
1051        let workflow = context.spawn_io_workflow(|scope, builder| {
1052            let buffer = builder.create_buffer::<Register>(BufferSettings::keep_all());
1053
1054            // The only path to termination is from listening to the buffer.
1055            builder
1056                .listen(buffer)
1057                .then(pull_register_from_buffer.into_blocking_callback())
1058                .dispose_on_none()
1059                .connect(scope.terminate);
1060
1061            let decrement_register_cb = decrement_register.into_blocking_callback();
1062            let async_decrement_register_cb = async_decrement_register.as_callback();
1063            builder
1064                .chain(scope.start)
1065                .with_access(buffer)
1066                .then(decrement_register_cb.clone())
1067                .with_access(buffer)
1068                .then(async_decrement_register_cb.clone())
1069                .dispose_on_none()
1070                .with_access(buffer)
1071                .then(decrement_register_cb.clone())
1072                .with_access(buffer)
1073                .then(async_decrement_register_cb)
1074                .unused();
1075        });
1076
1077        run_register_test(workflow, 0, true, &mut context);
1078        run_register_test(workflow, 1, true, &mut context);
1079        run_register_test(workflow, 2, true, &mut context);
1080        run_register_test(workflow, 3, true, &mut context);
1081        run_register_test(workflow, 4, false, &mut context);
1082        run_register_test(workflow, 5, false, &mut context);
1083        run_register_test(workflow, 6, false, &mut context);
1084
1085        // Test a workflow where only one buffer accessor node is used, but the
1086        // key is passed through a long chain in the workflow, with a disposal
1087        // being forced as well.
1088        let workflow = context.spawn_io_workflow(|scope, builder| {
1089            let buffer = builder.create_buffer::<Register>(BufferSettings::keep_all());
1090
1091            // The only path to termination is from listening to the buffer.
1092            builder
1093                .listen(buffer)
1094                .then(pull_register_from_buffer.into_blocking_callback())
1095                .dispose_on_none()
1096                .connect(scope.terminate);
1097
1098            let decrement_register_and_pass_keys_cb =
1099                decrement_register_and_pass_keys.into_blocking_callback();
1100            let async_decrement_register_and_pass_keys_cb =
1101                async_decrement_register_and_pass_keys.as_callback();
1102            let (loose_end, dead_end): (_, Output<Option<Register>>) = builder
1103                .chain(scope.start)
1104                .with_access(buffer)
1105                .then(decrement_register_and_pass_keys_cb.clone())
1106                .then(async_decrement_register_and_pass_keys_cb.clone())
1107                .dispose_on_none()
1108                .map_block(|v| (v, None))
1109                .unzip();
1110
1111            // Force the workflow to trigger a disposal while the key is still in flight
1112            builder.chain(dead_end).dispose_on_none().unused();
1113
1114            builder
1115                .chain(loose_end)
1116                .then(async_decrement_register_and_pass_keys_cb)
1117                .dispose_on_none()
1118                .then(decrement_register_and_pass_keys_cb)
1119                .unused();
1120        });
1121
1122        run_register_test(workflow, 0, true, &mut context);
1123        run_register_test(workflow, 1, true, &mut context);
1124        run_register_test(workflow, 2, true, &mut context);
1125        run_register_test(workflow, 3, true, &mut context);
1126        run_register_test(workflow, 4, false, &mut context);
1127        run_register_test(workflow, 5, false, &mut context);
1128        run_register_test(workflow, 6, false, &mut context);
1129    }
1130
1131    fn run_register_test(
1132        workflow: Service<Register, Register>,
1133        initial_value: u64,
1134        expect_success: bool,
1135        context: &mut TestingContext,
1136    ) {
1137        let r = context.try_resolve_request(Register::new(initial_value), workflow, ());
1138        if expect_success {
1139            assert!(r.unwrap().finished_with(initial_value));
1140        } else {
1141            assert!(r.is_err());
1142        }
1143    }
1144
1145    // We use this struct to keep track of operations that have occurred in the
1146    // test workflow. Values from in_slot get moved to out_slot until in_slot
1147    // reaches 0, then the whole struct gets put into a buffer where a listener
1148    // will then send it to the terminal node.
1149    #[derive(Clone, Copy, Debug)]
1150    struct Register {
1151        in_slot: u64,
1152        out_slot: u64,
1153    }
1154
1155    impl Register {
1156        fn new(start_from: u64) -> Self {
1157            Self {
1158                in_slot: start_from,
1159                out_slot: 0,
1160            }
1161        }
1162
1163        fn finished_with(&self, out_slot: u64) -> bool {
1164            self.in_slot == 0 && self.out_slot == out_slot
1165        }
1166    }
1167
1168    fn pull_register_from_buffer(
1169        In(key): In<BufferKey<Register>>,
1170        mut access: BufferAccessMut<Register>,
1171    ) -> Option<Register> {
1172        access.get_mut(&key).ok()?.pull()
1173    }
1174
1175    fn decrement_register(
1176        In((mut register, key)): In<(Register, BufferKey<Register>)>,
1177        mut access: BufferAccessMut<Register>,
1178    ) -> Register {
1179        if register.in_slot == 0 {
1180            access.get_mut(&key).unwrap().push(register);
1181            return register;
1182        }
1183
1184        register.in_slot -= 1;
1185        register.out_slot += 1;
1186        register
1187    }
1188
1189    fn decrement_register_and_pass_keys(
1190        In((mut register, key)): In<(Register, BufferKey<Register>)>,
1191        mut access: BufferAccessMut<Register>,
1192    ) -> (Register, BufferKey<Register>) {
1193        if register.in_slot == 0 {
1194            access.get_mut(&key).unwrap().push(register);
1195            return (register, key);
1196        }
1197
1198        register.in_slot -= 1;
1199        register.out_slot += 1;
1200        (register, key)
1201    }
1202
1203    fn async_decrement_register(
1204        In(input): In<AsyncCallback<(Register, BufferKey<Register>)>>,
1205    ) -> impl Future<Output = Option<Register>> + use<> {
1206        async move {
1207            input
1208                .channel
1209                .request_outcome(input.request, decrement_register.into_blocking_callback())
1210                .await
1211                .ok()
1212        }
1213    }
1214
1215    fn async_decrement_register_and_pass_keys(
1216        In(input): In<AsyncCallback<(Register, BufferKey<Register>)>>,
1217    ) -> impl Future<Output = Option<(Register, BufferKey<Register>)>> + use<> {
1218        async move {
1219            input
1220                .channel
1221                .request_outcome(
1222                    input.request,
1223                    decrement_register_and_pass_keys.into_blocking_callback(),
1224                )
1225                .await
1226                .ok()
1227        }
1228    }
1229
1230    #[test]
1231    fn test_buffer_key_gate_control() {
1232        let mut context = TestingContext::minimal_plugins();
1233
1234        let workflow = context.spawn_io_workflow(|scope, builder| {
1235            let service = builder.commands().spawn_service(gate_access_test_open_loop);
1236
1237            let buffer = builder.create_buffer(BufferSettings::keep_all());
1238            builder.connect(scope.start, buffer.input_slot());
1239            builder
1240                .listen(buffer)
1241                .then_gate_close(buffer)
1242                .then(service)
1243                .fork_unzip((
1244                    |chain: Chain<_>| chain.dispose_on_none().connect(buffer.input_slot()),
1245                    |chain: Chain<_>| chain.dispose_on_none().connect(scope.terminate),
1246                ));
1247        });
1248
1249        let r = context.resolve_request(0, workflow);
1250        assert_eq!(r, 5);
1251    }
1252
1253    /// Used to verify that when a key is used to open a buffer gate, it will not
1254    /// trigger the key's listener to wake up again.
1255    fn gate_access_test_open_loop(
1256        In(BlockingService { request: key, .. }): BlockingServiceInput<BufferKey<u64>>,
1257        mut access: BufferAccessMut<u64>,
1258        mut gate_access: BufferGateAccessMut,
1259    ) -> (Option<u64>, Option<u64>) {
1260        // We should never see a spurious wake-up in this service because the
1261        // gate opening is done by the key of this service.
1262        let mut buffer = access.get_mut(&key).unwrap();
1263        let value = buffer.pull().unwrap();
1264
1265        // The gate should have previously been closed before reaching this
1266        // service
1267        let mut gate = gate_access.get_mut(key).unwrap();
1268        assert_eq!(gate.get(), Gate::Closed);
1269        // Open the gate, which would normally trigger a notice, but the notice
1270        // should not come to this service because we're using the key without
1271        // closed loops allowed.
1272        gate.open_gate();
1273
1274        if value >= 5 {
1275            (None, Some(value))
1276        } else {
1277            (Some(value + 1), None)
1278        }
1279    }
1280
1281    #[test]
1282    fn test_closed_loop_key_access() {
1283        let mut context = TestingContext::minimal_plugins();
1284
1285        let delay = context.spawn_delay(Duration::from_secs_f32(0.1));
1286
1287        let workflow = context.spawn_io_workflow(|scope, builder| {
1288            let service = builder
1289                .commands()
1290                .spawn_service(gate_access_test_closed_loop);
1291
1292            let buffer = builder.create_buffer(BufferSettings::keep_all());
1293            builder.connect(scope.start, buffer.input_slot());
1294            builder.listen(buffer).then(service).fork_unzip((
1295                |chain: Chain<_>| {
1296                    chain
1297                        .dispose_on_none()
1298                        .then(delay)
1299                        .connect(buffer.input_slot())
1300                },
1301                |chain: Chain<_>| chain.dispose_on_none().connect(scope.terminate),
1302            ));
1303        });
1304
1305        let r = context.resolve_request(3, workflow);
1306        assert_eq!(r, 0);
1307    }
1308
1309    /// Used to verify that we get spurious wakeups when closed loops are allowed
1310    fn gate_access_test_closed_loop(
1311        In(BlockingService { request: key, .. }): BlockingServiceInput<BufferKey<u64>>,
1312        mut access: BufferAccessMut<u64>,
1313    ) -> (Option<u64>, Option<u64>) {
1314        let mut buffer = access.get_mut(&key).unwrap().allow_closed_loops();
1315        if let Some(value) = buffer.pull() {
1316            (Some(value + 1), None)
1317        } else {
1318            (None, Some(0))
1319        }
1320    }
1321
1322    #[test]
1323    fn test_any_buffer_join_by_clone() {
1324        let mut context = TestingContext::minimal_plugins();
1325
1326        let workflow = context.spawn_io_workflow(|scope, builder| {
1327            let message_buffer = builder.create_buffer(Default::default()).join_by_cloning();
1328            let count_buffer = builder.create_buffer(Default::default());
1329            let (message, count) = builder.chain(scope.start).unzip();
1330            builder.connect(message, message_buffer.input_slot());
1331            builder.connect(count, count_buffer.input_slot());
1332
1333            // Make absolutely sure that the type information has been erased
1334            // before we assemble the buffer map.
1335            let any_message_buffer = message_buffer.as_any_buffer();
1336            let any_count_buffer = count_buffer.as_any_buffer();
1337
1338            let mut buffer_map = BufferMap::default();
1339            buffer_map.insert_buffer("message", any_message_buffer);
1340            buffer_map.insert_buffer("count", any_count_buffer);
1341
1342            builder
1343                .try_join::<JoinByCloneTest>(&buffer_map)
1344                .unwrap()
1345                .map_block(|joined| {
1346                    if joined.count < 10 {
1347                        // Increment the count buffer
1348                        Err(joined.count + 1)
1349                    } else {
1350                        Ok(joined)
1351                    }
1352                })
1353                .fork_result(
1354                    |ok| ok.connect(scope.terminate),
1355                    |err| err.connect(count_buffer.input_slot()),
1356                );
1357        });
1358
1359        let r = context.resolve_request((String::from("hello"), 0), workflow);
1360        assert_eq!(r.count, 10);
1361        assert_eq!(r.message, "hello");
1362    }
1363
1364    #[derive(Joined)]
1365    struct JoinByCloneTest {
1366        count: i64,
1367        message: String,
1368    }
1369
1370    fn get_largest_value(
1371        In(input): In<((), BufferKey<i32>)>,
1372        access: BufferAccess<i32>,
1373    ) -> Option<i32> {
1374        let access = access.get(&input.1).ok()?;
1375        access.iter().max().cloned()
1376    }
1377
1378    fn push_values(In(input): In<(Vec<i32>, BufferKey<i32>)>, mut access: BufferAccessMut<i32>) {
1379        let Ok(mut access) = access.get_mut(&input.1) else {
1380            return;
1381        };
1382
1383        for value in input.0 {
1384            access.push(value);
1385        }
1386    }
1387
1388    #[test]
1389    fn test_buffer_access_example() {
1390        let mut context = TestingContext::minimal_plugins();
1391
1392        let workflow = context.spawn_io_workflow(|scope, builder| {
1393            let buffer = builder.create_buffer(BufferSettings::keep_all());
1394            builder
1395                .chain(scope.start)
1396                .with_access(buffer)
1397                .then(push_values.into_blocking_callback())
1398                .with_access(buffer)
1399                .then(get_largest_value.into_blocking_callback())
1400                .connect(scope.terminate);
1401        });
1402
1403        let r = context.resolve_request(vec![-3, 2, 10], workflow);
1404        assert_eq!(r.unwrap(), 10);
1405    }
1406}