crossflow/
buffer.rs

1/*
2 * Copyright (C) 2024 Open Source Robotics Foundation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *
16*/
17
18use bevy_ecs::{
19    change_detection::Mut,
20    prelude::{Commands, Entity, EntityRef, Query, World},
21    query::QueryEntityError,
22    system::{SystemParam, SystemState},
23};
24
25use std::{
26    any::TypeId,
27    collections::HashMap,
28    ops::RangeBounds,
29    sync::{Arc, Mutex, OnceLock},
30};
31
32use thiserror::Error as ThisError;
33
34use crate::{GateState, InputSlot, NotifyBufferUpdate, OperationError};
35
36mod any_buffer;
37pub use any_buffer::*;
38
39mod buffer_access_lifecycle;
40pub use buffer_access_lifecycle::BufferKeyLifecycle;
41pub(crate) use buffer_access_lifecycle::*;
42
43mod buffer_key_builder;
44pub use buffer_key_builder::*;
45
46mod buffer_gate;
47pub use buffer_gate::*;
48
49mod buffer_map;
50pub use buffer_map::*;
51
52mod buffer_storage;
53pub(crate) use buffer_storage::*;
54
55mod buffering;
56pub use buffering::*;
57
58mod bufferable;
59pub use bufferable::*;
60
61mod manage_buffer;
62pub use manage_buffer::*;
63
64#[cfg(feature = "diagram")]
65mod json_buffer;
66#[cfg(feature = "diagram")]
67pub use json_buffer::*;
68
69mod fetch_from_buffer;
70pub use fetch_from_buffer::*;
71
72/// A buffer is a special type of node within a workflow that is able to store
73/// and release data. When a session is finished, the buffered data from the
74/// session will be automatically cleared.
75pub struct Buffer<T> {
76    pub(crate) location: BufferLocation,
77    pub(crate) _ignore: std::marker::PhantomData<fn(T)>,
78}
79
80impl<T: 'static + Send + Sync> Buffer<T> {
81    /// Specify that you want this Buffer to join by cloning an element. This
82    /// can be used by operations like join to tell them that they should clone
83    /// from the buffer instead of consuming from it.
84    ///
85    /// This can only be used for message types that support [`Clone`].
86    pub fn join_by_cloning(self) -> CloneFromBuffer<T>
87    where
88        T: Clone,
89    {
90        CloneFromBuffer::new(self.location)
91    }
92
93    /// Get an input slot for this buffer.
94    pub fn input_slot(self) -> InputSlot<T> {
95        InputSlot::new(self.scope(), self.id())
96    }
97
98    /// Get the entity ID of the buffer.
99    pub fn id(&self) -> Entity {
100        self.location.source
101    }
102
103    /// Get the ID of the workflow that the buffer is associated with.
104    pub fn scope(&self) -> Entity {
105        self.location.scope
106    }
107
108    /// Get general information about the buffer.
109    pub fn location(&self) -> BufferLocation {
110        self.location
111    }
112}
113
114impl<T> Clone for Buffer<T> {
115    fn clone(&self) -> Self {
116        *self
117    }
118}
119
120impl<T> Copy for Buffer<T> {}
121
122/// The general identifying information for a buffer to locate it within the
123/// world. This does not indicate anything about the type of messages that the
124/// buffer can contain.
125#[derive(Clone, Copy, Debug)]
126pub struct BufferLocation {
127    /// The entity ID of the buffer.
128    pub scope: Entity,
129    /// The ID of the workflow that the buffer is associated with.
130    pub source: Entity,
131}
132
133#[derive(Clone)]
134pub struct CloneFromBuffer<T: Clone + Send + Sync + 'static> {
135    location: BufferLocation,
136    _ignore: std::marker::PhantomData<fn(T)>,
137}
138
139impl<T: Clone + Send + Sync + 'static> Copy for CloneFromBuffer<T> {}
140
141impl<T: Clone + Send + Sync + 'static> CloneFromBuffer<T> {
142    /// Get an input slot for this buffer.
143    pub fn input_slot(self) -> InputSlot<T> {
144        InputSlot::new(self.scope(), self.id())
145    }
146
147    /// Get the entity ID of the buffer.
148    pub fn id(&self) -> Entity {
149        self.location.source
150    }
151
152    /// Get the ID of the workflow that the buffer is associated with.
153    pub fn scope(&self) -> Entity {
154        self.location.scope
155    }
156
157    /// Get general information about the buffer.
158    pub fn location(&self) -> BufferLocation {
159        self.location
160    }
161
162    /// Specify that you want this Buffer to join by pulling an element. This
163    /// is the default behavior of a Buffer, so you don't generally need to call
164    /// this method, but you can use it to change from the join-by-cloning
165    /// setting back to join-by-pulling.
166    #[must_use]
167    pub fn join_by_pulling(self) -> Buffer<T> {
168        Buffer {
169            location: self.location,
170            _ignore: Default::default(),
171        }
172    }
173
174    fn new(location: BufferLocation) -> Self {
175        Self::register_clone_for_join();
176        Self {
177            location,
178            _ignore: Default::default(),
179        }
180    }
181
182    /// This function ensures that [`AnyBuffer`] can be downcast into a
183    /// [`CloneFromBuffer`] and that it can correctly transfer any Cloning join
184    /// behavior if it gets downcast to a [`FetchFromBuffer`].
185    pub fn register_clone_for_join() {
186        static REGISTER_CLONE: OnceLock<Mutex<HashMap<TypeId, ()>>> = OnceLock::new();
187        let register_clone = REGISTER_CLONE.get_or_init(|| Mutex::default());
188
189        // TODO(@mxgrey): Consider whether there is a way to avoid needing all
190        // these mutex locks and hashmap lookups every time we create a CloneFromBuffer.
191        let mut register_mut = register_clone.lock().unwrap();
192        register_mut.entry(TypeId::of::<T>()).or_insert_with(|| {
193            let interface = AnyBuffer::interface_for::<T>();
194            interface.register_cloning(
195                clone_for_any_join::<T>,
196                &(clone_for_join::<T> as FetchFromBufferFn<T>),
197            );
198            interface.register_buffer_downcast(
199                TypeId::of::<CloneFromBuffer<T>>(),
200                Box::new(|buffer: AnyBuffer| {
201                    Ok(Box::new(CloneFromBuffer::<T>::new(buffer.location)))
202                }),
203            );
204        });
205    }
206}
207
208fn clone_for_any_join<T: 'static + Send + Sync + Clone>(
209    entity_ref: &EntityRef,
210    session: Entity,
211) -> Result<AnyMessageBox, OperationError> {
212    // In general we expect pulling to imply pulling the oldest since the most
213    // typical pattern when information is being pulled from a source would be
214    // FIFO. It's very unlikely that someone pulling information would want the
215    // order that they're pulling data to be backwards. However for cloning it's
216    // much less obvious: Would someone want to clone the oldest piece of data
217    // or the newest?
218    //
219    // For now we will assume that If the user is cloning information without
220    // removing it, then they are leaving the old data in there to act as a
221    // history, and therefore when cloning they will always want the newest.
222    //
223    // TODO(@mxgrey): Allow users to set whether pull and/or clone should
224    // be operating on the oldest data or the newest, by default. Also allow
225    // the join operations themselves override this, e.g. putting a setting
226    // into BufferLocation to change which side is being pulled from.
227    entity_ref
228        .clone_from_buffer::<T>(session)
229        .map(to_any_message)
230}
231
232impl<T: Clone + Send + Sync> From<CloneFromBuffer<T>> for Buffer<T> {
233    fn from(value: CloneFromBuffer<T>) -> Self {
234        Buffer {
235            location: value.location,
236            _ignore: Default::default(),
237        }
238    }
239}
240
241/// Settings to describe the behavior of a buffer.
242#[cfg_attr(
243    feature = "diagram",
244    derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema),
245    serde(rename_all = "snake_case")
246)]
247#[derive(Default, Clone, Copy, Debug)]
248pub struct BufferSettings {
249    retention: RetentionPolicy,
250}
251
252impl BufferSettings {
253    /// Define new buffer settings
254    pub fn new(retention: RetentionPolicy) -> Self {
255        Self { retention }
256    }
257
258    /// Create `BufferSettings` with a retention policy of [`RetentionPolicy::KeepLast`]`(n)`.
259    pub fn keep_last(n: usize) -> Self {
260        Self::new(RetentionPolicy::KeepLast(n))
261    }
262
263    /// Create `BufferSettings` with a retention policy of [`RetentionPolicy::KeepFirst`]`(n)`.
264    pub fn keep_first(n: usize) -> Self {
265        Self::new(RetentionPolicy::KeepFirst(n))
266    }
267
268    /// Create `BufferSettings` with a retention policy of [`RetentionPolicy::KeepAll`].
269    pub fn keep_all() -> Self {
270        Self::new(RetentionPolicy::KeepAll)
271    }
272
273    /// Get the retention policy for the buffer.
274    pub fn retention(&self) -> RetentionPolicy {
275        self.retention
276    }
277
278    /// Modify the retention policy for the buffer.
279    pub fn retention_mut(&mut self) -> &mut RetentionPolicy {
280        &mut self.retention
281    }
282}
283
284/// Describe how data within a buffer gets retained. Most mechanisms that pull
285/// data from a buffer will remove the oldest item in the buffer, so this policy
286/// is for dealing with situations where items are being stored faster than they
287/// are being pulled.
288///
289/// The default value is KeepLast(1).
290#[cfg_attr(
291    feature = "diagram",
292    derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema),
293    serde(rename_all = "snake_case")
294)]
295#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)]
296pub enum RetentionPolicy {
297    /// Keep the last N items that were stored into the buffer. Once the limit
298    /// is reached, the oldest item will be removed any time a new item arrives.
299    KeepLast(usize),
300    /// Keep the first N items that are stored into the buffer. Once the limit
301    /// is reached, any new item that arrives will be discarded.
302    KeepFirst(usize),
303    /// Do not limit how many items can be stored in the buffer.
304    KeepAll,
305}
306
307impl Default for RetentionPolicy {
308    fn default() -> Self {
309        Self::KeepLast(1)
310    }
311}
312
313/// This key can unlock access to the contents of a buffer by passing it into
314/// [`BufferAccess`] or [`BufferAccessMut`].
315///
316/// To obtain a `BufferKey`, use [`Chain::with_access`][1], or [`listen`][2].
317///
318/// [1]: crate::Chain::with_access
319/// [2]: crate::Accessible::listen
320pub struct BufferKey<T> {
321    tag: BufferKeyTag,
322    _ignore: std::marker::PhantomData<fn(T)>,
323}
324
325impl<T> Clone for BufferKey<T> {
326    fn clone(&self) -> Self {
327        Self {
328            tag: self.tag.clone(),
329            _ignore: Default::default(),
330        }
331    }
332}
333
334impl<T> BufferKey<T> {
335    /// The buffer ID of this key.
336    pub fn buffer(&self) -> Entity {
337        self.tag.buffer
338    }
339
340    /// The session that this key belongs to.
341    pub fn session(&self) -> Entity {
342        self.tag.session
343    }
344
345    pub fn tag(&self) -> &BufferKeyTag {
346        &self.tag
347    }
348}
349
350impl<T: 'static + Send + Sync> BufferKeyLifecycle for BufferKey<T> {
351    type TargetBuffer = Buffer<T>;
352
353    fn create_key(buffer: &Self::TargetBuffer, builder: &BufferKeyBuilder) -> Self {
354        BufferKey {
355            tag: builder.make_tag(buffer.id()),
356            _ignore: Default::default(),
357        }
358    }
359
360    fn is_in_use(&self) -> bool {
361        self.tag.is_in_use()
362    }
363
364    fn deep_clone(&self) -> Self {
365        Self {
366            tag: self.tag.deep_clone(),
367            _ignore: Default::default(),
368        }
369    }
370}
371
372impl<T> std::fmt::Debug for BufferKey<T> {
373    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374        f.debug_struct("BufferKey")
375            .field("message_type_name", &std::any::type_name::<T>())
376            .field("tag", &self.tag)
377            .finish()
378    }
379}
380
381/// The identifying information for a buffer key. This does not indicate
382/// anything about the type of messages that the buffer can contain.
383#[derive(Clone)]
384pub struct BufferKeyTag {
385    pub buffer: Entity,
386    pub session: Entity,
387    pub accessor: Entity,
388    pub lifecycle: Option<Arc<BufferAccessLifecycle>>,
389}
390
391impl BufferKeyTag {
392    pub fn is_in_use(&self) -> bool {
393        self.lifecycle.as_ref().is_some_and(|l| l.is_in_use())
394    }
395
396    pub fn deep_clone(&self) -> Self {
397        let mut deep = self.clone();
398        deep.lifecycle = self
399            .lifecycle
400            .as_ref()
401            .map(|l| Arc::new(l.as_ref().clone()));
402        deep
403    }
404}
405
406impl std::fmt::Debug for BufferKeyTag {
407    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
408        f.debug_struct("BufferKeyTag")
409            .field("buffer", &self.buffer)
410            .field("session", &self.session)
411            .field("accessor", &self.accessor)
412            .field("in_use", &self.is_in_use())
413            .finish()
414    }
415}
416
417/// This system parameter lets you get read-only access to a buffer that exists
418/// within a workflow. Use a [`BufferKey`] to unlock the access.
419///
420/// See [`BufferAccessMut`] for mutable access.
421#[derive(SystemParam)]
422pub struct BufferAccess<'w, 's, T>
423where
424    T: 'static + Send + Sync,
425{
426    query: Query<'w, 's, &'static BufferStorage<T>>,
427}
428
429impl<'w, 's, T: 'static + Send + Sync> BufferAccess<'w, 's, T> {
430    pub fn get<'a>(&'a self, key: &BufferKey<T>) -> Result<BufferView<'a, T>, QueryEntityError> {
431        let session = key.session();
432        self.query
433            .get(key.buffer())
434            .map(|storage| BufferView { storage, session })
435    }
436
437    pub fn get_newest<'a>(&'a self, key: &BufferKey<T>) -> Option<&'a T> {
438        self.get(key).ok().map(|view| view.newest()).flatten()
439    }
440}
441
442/// This system parameter lets you get mutable access to a buffer that exists
443/// within a workflow. Use a [`BufferKey`] to unlock the access.
444///
445/// See [`BufferAccess`] for read-only access.
446#[derive(SystemParam)]
447pub struct BufferAccessMut<'w, 's, T>
448where
449    T: 'static + Send + Sync,
450{
451    query: Query<'w, 's, &'static mut BufferStorage<T>>,
452    commands: Commands<'w, 's>,
453}
454
455impl<'w, 's, T> BufferAccessMut<'w, 's, T>
456where
457    T: 'static + Send + Sync,
458{
459    pub fn get<'a>(&'a self, key: &BufferKey<T>) -> Result<BufferView<'a, T>, QueryEntityError> {
460        let session = key.session();
461        self.query
462            .get(key.buffer())
463            .map(|storage| BufferView { storage, session })
464    }
465
466    pub fn get_newest<'a>(&'a self, key: &BufferKey<T>) -> Option<&'a T> {
467        self.get(key).ok().map(|view| view.newest()).flatten()
468    }
469
470    pub fn get_mut<'a>(
471        &'a mut self,
472        key: &BufferKey<T>,
473    ) -> Result<BufferMut<'w, 's, 'a, T>, QueryEntityError> {
474        let buffer = key.buffer();
475        let session = key.session();
476        let accessor = key.tag.accessor;
477        self.query
478            .get_mut(key.buffer())
479            .map(|storage| BufferMut::new(storage, buffer, session, accessor, &mut self.commands))
480    }
481}
482
483/// This trait allows [`World`] to give you access to any buffer using a [`BufferKey`]
484pub trait BufferWorldAccess {
485    /// Call this to get read-only access to a buffer from a [`World`].
486    ///
487    /// Alternatively you can use [`BufferAccess`] as a regular bevy system parameter,
488    /// which does not need direct world access.
489    fn buffer_view<T>(&self, key: &BufferKey<T>) -> Result<BufferView<'_, T>, BufferError>
490    where
491        T: 'static + Send + Sync;
492
493    /// Call this to get read-only access to the gate of a buffer from a [`World`].
494    fn buffer_gate_view(
495        &self,
496        key: impl Into<AnyBufferKey>,
497    ) -> Result<BufferGateView<'_>, BufferError>;
498
499    /// Call this to get mutable access to a buffer.
500    ///
501    /// Pass in a callback that will receive [`BufferMut`], allowing it to view
502    /// and modify the contents of the buffer.
503    fn buffer_mut<T, U>(
504        &mut self,
505        key: &BufferKey<T>,
506        f: impl FnOnce(BufferMut<T>) -> U,
507    ) -> Result<U, BufferError>
508    where
509        T: 'static + Send + Sync;
510
511    /// Call this to get mutable access to the gate of a buffer.
512    ///
513    /// Pass in a callback that will receive [`BufferGateMut`], allowing it to
514    /// view and modify the gate of the buffer.
515    fn buffer_gate_mut<U>(
516        &mut self,
517        key: impl Into<AnyBufferKey>,
518        f: impl FnOnce(BufferGateMut) -> U,
519    ) -> Result<U, BufferError>;
520}
521
522impl BufferWorldAccess for World {
523    fn buffer_view<T>(&self, key: &BufferKey<T>) -> Result<BufferView<'_, T>, BufferError>
524    where
525        T: 'static + Send + Sync,
526    {
527        let buffer_ref = self
528            .get_entity(key.tag.buffer)
529            .map_err(|_| BufferError::BufferMissing)?;
530        let storage = buffer_ref
531            .get::<BufferStorage<T>>()
532            .ok_or(BufferError::BufferMissing)?;
533        Ok(BufferView {
534            storage,
535            session: key.tag.session,
536        })
537    }
538
539    fn buffer_gate_view(
540        &self,
541        key: impl Into<AnyBufferKey>,
542    ) -> Result<BufferGateView<'_>, BufferError> {
543        let key: AnyBufferKey = key.into();
544        let buffer_ref = self
545            .get_entity(key.tag.buffer)
546            .or(Err(BufferError::BufferMissing))?;
547        let gate = buffer_ref
548            .get::<GateState>()
549            .ok_or(BufferError::BufferMissing)?;
550        Ok(BufferGateView {
551            gate,
552            session: key.tag.session,
553        })
554    }
555
556    fn buffer_mut<T, U>(
557        &mut self,
558        key: &BufferKey<T>,
559        f: impl FnOnce(BufferMut<T>) -> U,
560    ) -> Result<U, BufferError>
561    where
562        T: 'static + Send + Sync,
563    {
564        let mut state = SystemState::<BufferAccessMut<T>>::new(self);
565        let mut buffer_access_mut = state.get_mut(self);
566        let buffer_mut = buffer_access_mut
567            .get_mut(key)
568            .map_err(|_| BufferError::BufferMissing)?;
569        Ok(f(buffer_mut))
570    }
571
572    fn buffer_gate_mut<U>(
573        &mut self,
574        key: impl Into<AnyBufferKey>,
575        f: impl FnOnce(BufferGateMut) -> U,
576    ) -> Result<U, BufferError> {
577        let mut state = SystemState::<BufferGateAccessMut>::new(self);
578        let mut buffer_gate_access_mut = state.get_mut(self);
579        let buffer_mut = buffer_gate_access_mut
580            .get_mut(key)
581            .map_err(|_| BufferError::BufferMissing)?;
582        Ok(f(buffer_mut))
583    }
584}
585
586/// Access to view a buffer that exists inside a workflow.
587pub struct BufferView<'a, T>
588where
589    T: 'static + Send + Sync,
590{
591    storage: &'a BufferStorage<T>,
592    session: Entity,
593}
594
595impl<'a, T> BufferView<'a, T>
596where
597    T: 'static + Send + Sync,
598{
599    /// Iterate over the contents in the buffer
600    pub fn iter(&self) -> IterBufferView<'a, T> {
601        self.storage.iter(self.session)
602    }
603
604    /// Borrow the oldest item in the buffer.
605    pub fn oldest(&self) -> Option<&'a T> {
606        self.storage.oldest(self.session)
607    }
608
609    /// Borrow the newest item in the buffer.
610    pub fn newest(&self) -> Option<&'a T> {
611        self.storage.newest(self.session)
612    }
613
614    /// Borrow an item from the buffer. Index 0 is the oldest item in the buffer
615    /// with the highest index being the newest item in the buffer.
616    pub fn get(&self, index: usize) -> Option<&'a T> {
617        self.storage.get(self.session, index)
618    }
619
620    /// How many items are in the buffer?
621    pub fn len(&self) -> usize {
622        self.storage.count(self.session)
623    }
624
625    /// Check if the buffer is empty.
626    pub fn is_empty(&self) -> bool {
627        self.len() == 0
628    }
629}
630
631/// Access to mutate a buffer that exists inside a workflow.
632pub struct BufferMut<'w, 's, 'a, T>
633where
634    T: 'static + Send + Sync,
635{
636    storage: Mut<'a, BufferStorage<T>>,
637    buffer: Entity,
638    session: Entity,
639    accessor: Option<Entity>,
640    commands: &'a mut Commands<'w, 's>,
641    modified: bool,
642}
643
644impl<'w, 's, 'a, T> BufferMut<'w, 's, 'a, T>
645where
646    T: 'static + Send + Sync,
647{
648    /// When you make a modification using this `BufferMut`, anything listening
649    /// to the buffer will be notified about the update. This can create
650    /// unintentional infinite loops where a node in the workflow wakes itself
651    /// up every time it runs because of a modification it makes to a buffer.
652    ///
653    /// By default this closed loop is disabled by keeping track of which
654    /// listener created the key that's being used to modify the buffer, and
655    /// then skipping that listener when notifying about the modification.
656    ///
657    /// In some cases a key can be used far downstream of the listener. In that
658    /// case, there may be nodes downstream of the listener that do want to be
659    /// woken up by the modification. Use this function to allow that closed
660    /// loop to happen. It will be up to you to prevent the closed loop from
661    /// being a problem.
662    pub fn allow_closed_loops(mut self) -> Self {
663        self.accessor = None;
664        self
665    }
666
667    /// Iterate over the contents in the buffer.
668    pub fn iter(&self) -> IterBufferView<'_, T> {
669        self.storage.iter(self.session)
670    }
671
672    /// Look at the oldest item in the buffer.
673    pub fn oldest(&self) -> Option<&T> {
674        self.storage.oldest(self.session)
675    }
676
677    /// Look at the newest item in the buffer.
678    pub fn newest(&self) -> Option<&T> {
679        self.storage.newest(self.session)
680    }
681
682    /// Borrow an item from the buffer. Index 0 is the oldest item in the buffer
683    /// with the highest index being the newest item in the buffer.
684    pub fn get(&self, index: usize) -> Option<&T> {
685        self.storage.get(self.session, index)
686    }
687
688    /// How many items are in the buffer?
689    pub fn len(&self) -> usize {
690        self.storage.count(self.session)
691    }
692
693    /// Check if the buffer is empty.
694    pub fn is_empty(&self) -> bool {
695        self.len() == 0
696    }
697
698    /// Iterate over mutable borrows of the contents in the buffer.
699    pub fn iter_mut(&mut self) -> IterBufferMut<'_, T> {
700        self.modified = true;
701        self.storage.iter_mut(self.session)
702    }
703
704    /// Modify the oldest item in the buffer.
705    pub fn oldest_mut(&mut self) -> Option<&mut T> {
706        self.modified = true;
707        self.storage.oldest_mut(self.session)
708    }
709
710    /// Modify the newest item in the buffer.
711    pub fn newest_mut(&mut self) -> Option<&mut T> {
712        self.modified = true;
713        self.storage.newest_mut(self.session)
714    }
715
716    /// Modify the newest item in the buffer or create a default-initialized
717    /// item to modify if the buffer was empty.
718    ///
719    /// This may fail to provide a mutable borrow if the buffer was already
720    /// expired or if the buffer capacity was zero.
721    pub fn newest_mut_or_default(&mut self) -> Option<&mut T>
722    where
723        T: Default,
724    {
725        self.newest_mut_or_else(|| T::default())
726    }
727
728    /// Modify the newest item in the buffer or initialize an item if the
729    /// buffer was empty.
730    ///
731    /// This may fail to provide a mutable borrow if the buffer was already
732    /// expired or if the buffer capacity was zero.
733    pub fn newest_mut_or_else(&mut self, f: impl FnOnce() -> T) -> Option<&mut T> {
734        self.modified = true;
735        self.storage.newest_mut_or_else(self.session, f)
736    }
737
738    /// Modify an item in the buffer. Index 0 is the oldest item in the buffer
739    /// with the highest index being the newest item in the buffer.
740    pub fn get_mut(&mut self, index: usize) -> Option<&mut T> {
741        self.modified = true;
742        self.storage.get_mut(self.session, index)
743    }
744
745    /// Drain items out of the buffer
746    pub fn drain<R>(&mut self, range: R) -> DrainBuffer<'_, T>
747    where
748        R: RangeBounds<usize>,
749    {
750        self.modified = true;
751        self.storage.drain(self.session, range)
752    }
753
754    /// Pull the oldest item from the buffer.
755    pub fn pull(&mut self) -> Option<T> {
756        self.modified = true;
757        self.storage.pull(self.session)
758    }
759
760    /// Pull the item that was most recently put into the buffer (instead of
761    /// the oldest, which is what [`Self::pull`] gives).
762    pub fn pull_newest(&mut self) -> Option<T> {
763        self.modified = true;
764        self.storage.pull_newest(self.session)
765    }
766
767    /// Push a new value into the buffer. If the buffer is at its limit, this
768    /// will return the value that needed to be removed.
769    pub fn push(&mut self, value: T) -> Option<T> {
770        self.modified = true;
771        self.storage.push(self.session, value)
772    }
773
774    /// Push a value into the buffer as if it is the oldest value of the buffer.
775    /// If the buffer is at its limit, this will return the value that needed to
776    /// be removed.
777    pub fn push_as_oldest(&mut self, value: T) -> Option<T> {
778        self.modified = true;
779        self.storage.push_as_oldest(self.session, value)
780    }
781
782    /// Trigger the listeners for this buffer to wake up even if nothing in the
783    /// buffer has changed. This could be used for timers or timeout elements
784    /// in a workflow.
785    pub fn pulse(&mut self) {
786        self.modified = true;
787    }
788
789    fn new(
790        storage: Mut<'a, BufferStorage<T>>,
791        buffer: Entity,
792        session: Entity,
793        accessor: Entity,
794        commands: &'a mut Commands<'w, 's>,
795    ) -> Self {
796        Self {
797            storage,
798            buffer,
799            session,
800            accessor: Some(accessor),
801            commands,
802            modified: false,
803        }
804    }
805}
806
807impl<'w, 's, 'a, T> Drop for BufferMut<'w, 's, 'a, T>
808where
809    T: 'static + Send + Sync,
810{
811    fn drop(&mut self) {
812        if self.modified {
813            self.commands.queue(NotifyBufferUpdate::new(
814                self.buffer,
815                self.session,
816                self.accessor,
817            ));
818        }
819    }
820}
821
822#[derive(ThisError, Debug, Clone)]
823pub enum BufferError {
824    #[error("The key was unable to identify a buffer")]
825    BufferMissing,
826}
827
828#[cfg(test)]
829mod tests {
830    use crate::{prelude::*, testing::*, AddBufferToMap, Gate};
831    use std::future::Future;
832
833    #[test]
834    fn test_buffer_key_access() {
835        let mut context = TestingContext::minimal_plugins();
836
837        let add_buffers_by_pull_cb = add_buffers_by_pull.into_blocking_callback();
838        let add_from_buffer_cb = add_from_buffer.into_blocking_callback();
839        let multiply_buffers_by_copy_cb = multiply_buffers_by_copy.into_blocking_callback();
840
841        let workflow = context.spawn_io_workflow(|scope: Scope<(f64, f64), f64>, builder| {
842            scope
843                .input
844                .chain(builder)
845                .unzip()
846                .listen(builder)
847                .then(multiply_buffers_by_copy_cb)
848                .connect(scope.terminate);
849        });
850
851        let mut promise =
852            context.command(|commands| commands.request((2.0, 3.0), workflow).take_response());
853
854        context.run_with_conditions(&mut promise, Duration::from_secs(2));
855        assert!(promise.take().available().is_some_and(|value| value == 6.0));
856        assert!(context.no_unhandled_errors());
857
858        let workflow = context.spawn_io_workflow(|scope: Scope<(f64, f64), f64>, builder| {
859            scope
860                .input
861                .chain(builder)
862                .unzip()
863                .listen(builder)
864                .then(add_buffers_by_pull_cb)
865                .dispose_on_none()
866                .connect(scope.terminate);
867        });
868
869        let mut promise =
870            context.command(|commands| commands.request((4.0, 5.0), workflow).take_response());
871
872        context.run_with_conditions(&mut promise, Duration::from_secs(2));
873        assert!(promise.take().available().is_some_and(|value| value == 9.0));
874        assert!(context.no_unhandled_errors());
875
876        let workflow =
877            context.spawn_io_workflow(|scope: Scope<(f64, f64), Result<f64, f64>>, builder| {
878                let (branch_to_adder, branch_to_buffer) = scope.input.chain(builder).unzip();
879                let buffer = builder.create_buffer::<f64>(BufferSettings::keep_first(10));
880                builder.connect(branch_to_buffer, buffer.input_slot());
881
882                let adder_node = branch_to_adder
883                    .chain(builder)
884                    .with_access(buffer)
885                    .then_node(add_from_buffer_cb.clone());
886
887                adder_node.output.chain(builder).fork_result(
888                    // If the buffer had an item in it, we send it to another
889                    // node that tries to pull a second time (we expect the
890                    // buffer to be empty this second time) and then
891                    // terminates.
892                    |chain| {
893                        chain
894                            .with_access(buffer)
895                            .then(add_from_buffer_cb.clone())
896                            .connect(scope.terminate)
897                    },
898                    // If the buffer was empty, keep looping back until there
899                    // is a value available.
900                    |chain| chain.with_access(buffer).connect(adder_node.input),
901                );
902            });
903
904        let mut promise =
905            context.command(|commands| commands.request((2.0, 3.0), workflow).take_response());
906
907        context.run_with_conditions(&mut promise, Duration::from_secs(2));
908        assert!(promise
909            .take()
910            .available()
911            .is_some_and(|value| value.is_err_and(|n| n == 5.0)));
912        assert!(context.no_unhandled_errors());
913
914        // Same as previous test, but using Builder::create_buffer_access instead
915        let workflow = context.spawn_io_workflow(|scope, builder| {
916            let (branch_to_adder, branch_to_buffer) = scope.input.chain(builder).unzip();
917            let buffer = builder.create_buffer::<f64>(BufferSettings::keep_first(10));
918            builder.connect(branch_to_buffer, buffer.input_slot());
919
920            let access = builder.create_buffer_access(buffer);
921            builder.connect(branch_to_adder, access.input);
922            access
923                .output
924                .chain(builder)
925                .then(add_from_buffer_cb.clone())
926                .fork_result(
927                    |ok| {
928                        let (output, builder) = ok.unpack();
929                        let second_access = builder.create_buffer_access(buffer);
930                        builder.connect(output, second_access.input);
931                        second_access
932                            .output
933                            .chain(builder)
934                            .then(add_from_buffer_cb.clone())
935                            .connect(scope.terminate);
936                    },
937                    |err| err.connect(access.input),
938                );
939        });
940
941        let mut promise =
942            context.command(|commands| commands.request((2.0, 3.0), workflow).take_response());
943
944        context.run_with_conditions(&mut promise, Duration::from_secs(2));
945        assert!(promise
946            .take()
947            .available()
948            .is_some_and(|value| value.is_err_and(|n| n == 5.0)));
949        assert!(context.no_unhandled_errors());
950    }
951
952    fn add_from_buffer(
953        In((lhs, key)): In<(f64, BufferKey<f64>)>,
954        mut access: BufferAccessMut<f64>,
955    ) -> Result<f64, f64> {
956        let rhs = access.get_mut(&key).map_err(|_| lhs)?.pull().ok_or(lhs)?;
957        Ok(lhs + rhs)
958    }
959
960    fn multiply_buffers_by_copy(
961        In((key_a, key_b)): In<(BufferKey<f64>, BufferKey<f64>)>,
962        access: BufferAccess<f64>,
963    ) -> f64 {
964        *access.get(&key_a).unwrap().oldest().unwrap()
965            * *access.get(&key_b).unwrap().oldest().unwrap()
966    }
967
968    fn add_buffers_by_pull(
969        In((key_a, key_b)): In<(BufferKey<f64>, BufferKey<f64>)>,
970        mut access: BufferAccessMut<f64>,
971    ) -> Option<f64> {
972        if access.get(&key_a).unwrap().is_empty() {
973            return None;
974        }
975
976        if access.get(&key_b).unwrap().is_empty() {
977            return None;
978        }
979
980        let rhs = access.get_mut(&key_a).unwrap().pull().unwrap();
981        let lhs = access.get_mut(&key_b).unwrap().pull().unwrap();
982        Some(rhs + lhs)
983    }
984
985    #[test]
986    fn test_buffer_key_lifecycle() {
987        let mut context = TestingContext::minimal_plugins();
988
989        // Test a workflow where each node in a long chain repeatedly accesses
990        // a buffer and might be the one to push a value into it.
991        let workflow = context.spawn_io_workflow(|scope, builder| {
992            let buffer = builder.create_buffer::<Register>(BufferSettings::keep_all());
993
994            // The only path to termination is from listening to the buffer.
995            builder
996                .listen(buffer)
997                .then(pull_register_from_buffer.into_blocking_callback())
998                .dispose_on_none()
999                .connect(scope.terminate);
1000
1001            let decrement_register_cb = decrement_register.into_blocking_callback();
1002            let async_decrement_register_cb = async_decrement_register.as_callback();
1003            scope
1004                .input
1005                .chain(builder)
1006                .with_access(buffer)
1007                .then(decrement_register_cb.clone())
1008                .with_access(buffer)
1009                .then(async_decrement_register_cb.clone())
1010                .dispose_on_none()
1011                .with_access(buffer)
1012                .then(decrement_register_cb.clone())
1013                .with_access(buffer)
1014                .then(async_decrement_register_cb)
1015                .unused();
1016        });
1017
1018        run_register_test(workflow, 0, true, &mut context);
1019        run_register_test(workflow, 1, true, &mut context);
1020        run_register_test(workflow, 2, true, &mut context);
1021        run_register_test(workflow, 3, true, &mut context);
1022        run_register_test(workflow, 4, false, &mut context);
1023        run_register_test(workflow, 5, false, &mut context);
1024        run_register_test(workflow, 6, false, &mut context);
1025
1026        // Test a workflow where only one buffer accessor node is used, but the
1027        // key is passed through a long chain in the workflow, with a disposal
1028        // being forced as well.
1029        let workflow = context.spawn_io_workflow(|scope, builder| {
1030            let buffer = builder.create_buffer::<Register>(BufferSettings::keep_all());
1031
1032            // The only path to termination is from listening to the buffer.
1033            builder
1034                .listen(buffer)
1035                .then(pull_register_from_buffer.into_blocking_callback())
1036                .dispose_on_none()
1037                .connect(scope.terminate);
1038
1039            let decrement_register_and_pass_keys_cb =
1040                decrement_register_and_pass_keys.into_blocking_callback();
1041            let async_decrement_register_and_pass_keys_cb =
1042                async_decrement_register_and_pass_keys.as_callback();
1043            let (loose_end, dead_end): (_, Output<Option<Register>>) = scope
1044                .input
1045                .chain(builder)
1046                .with_access(buffer)
1047                .then(decrement_register_and_pass_keys_cb.clone())
1048                .then(async_decrement_register_and_pass_keys_cb.clone())
1049                .dispose_on_none()
1050                .map_block(|v| (v, None))
1051                .unzip();
1052
1053            // Force the workflow to trigger a disposal while the key is still in flight
1054            dead_end.chain(builder).dispose_on_none().unused();
1055
1056            loose_end
1057                .chain(builder)
1058                .then(async_decrement_register_and_pass_keys_cb)
1059                .dispose_on_none()
1060                .then(decrement_register_and_pass_keys_cb)
1061                .unused();
1062        });
1063
1064        run_register_test(workflow, 0, true, &mut context);
1065        run_register_test(workflow, 1, true, &mut context);
1066        run_register_test(workflow, 2, true, &mut context);
1067        run_register_test(workflow, 3, true, &mut context);
1068        run_register_test(workflow, 4, false, &mut context);
1069        run_register_test(workflow, 5, false, &mut context);
1070        run_register_test(workflow, 6, false, &mut context);
1071    }
1072
1073    fn run_register_test(
1074        workflow: Service<Register, Register>,
1075        initial_value: u64,
1076        expect_success: bool,
1077        context: &mut TestingContext,
1078    ) {
1079        let mut promise = context.command(|commands| {
1080            commands
1081                .request(Register::new(initial_value), workflow)
1082                .take_response()
1083        });
1084
1085        context.run_while_pending(&mut promise);
1086        if expect_success {
1087            assert!(promise
1088                .take()
1089                .available()
1090                .is_some_and(|r| r.finished_with(initial_value)));
1091        } else {
1092            assert!(promise.take().is_cancelled());
1093        }
1094        assert!(context.no_unhandled_errors());
1095    }
1096
1097    // We use this struct to keep track of operations that have occurred in the
1098    // test workflow. Values from in_slot get moved to out_slot until in_slot
1099    // reaches 0, then the whole struct gets put into a buffer where a listener
1100    // will then send it to the terminal node.
1101    #[derive(Clone, Copy, Debug)]
1102    struct Register {
1103        in_slot: u64,
1104        out_slot: u64,
1105    }
1106
1107    impl Register {
1108        fn new(start_from: u64) -> Self {
1109            Self {
1110                in_slot: start_from,
1111                out_slot: 0,
1112            }
1113        }
1114
1115        fn finished_with(&self, out_slot: u64) -> bool {
1116            self.in_slot == 0 && self.out_slot == out_slot
1117        }
1118    }
1119
1120    fn pull_register_from_buffer(
1121        In(key): In<BufferKey<Register>>,
1122        mut access: BufferAccessMut<Register>,
1123    ) -> Option<Register> {
1124        access.get_mut(&key).ok()?.pull()
1125    }
1126
1127    fn decrement_register(
1128        In((mut register, key)): In<(Register, BufferKey<Register>)>,
1129        mut access: BufferAccessMut<Register>,
1130    ) -> Register {
1131        if register.in_slot == 0 {
1132            access.get_mut(&key).unwrap().push(register);
1133            return register;
1134        }
1135
1136        register.in_slot -= 1;
1137        register.out_slot += 1;
1138        register
1139    }
1140
1141    fn decrement_register_and_pass_keys(
1142        In((mut register, key)): In<(Register, BufferKey<Register>)>,
1143        mut access: BufferAccessMut<Register>,
1144    ) -> (Register, BufferKey<Register>) {
1145        if register.in_slot == 0 {
1146            access.get_mut(&key).unwrap().push(register);
1147            return (register, key);
1148        }
1149
1150        register.in_slot -= 1;
1151        register.out_slot += 1;
1152        (register, key)
1153    }
1154
1155    fn async_decrement_register(
1156        In(input): In<AsyncCallback<(Register, BufferKey<Register>)>>,
1157    ) -> impl Future<Output = Option<Register>> {
1158        async move {
1159            input
1160                .channel
1161                .query(input.request, decrement_register.into_blocking_callback())
1162                .await
1163                .available()
1164        }
1165    }
1166
1167    fn async_decrement_register_and_pass_keys(
1168        In(input): In<AsyncCallback<(Register, BufferKey<Register>)>>,
1169    ) -> impl Future<Output = Option<(Register, BufferKey<Register>)>> {
1170        async move {
1171            input
1172                .channel
1173                .query(
1174                    input.request,
1175                    decrement_register_and_pass_keys.into_blocking_callback(),
1176                )
1177                .await
1178                .available()
1179        }
1180    }
1181
1182    #[test]
1183    fn test_buffer_key_gate_control() {
1184        let mut context = TestingContext::minimal_plugins();
1185
1186        let workflow = context.spawn_io_workflow(|scope, builder| {
1187            let service = builder.commands().spawn_service(gate_access_test_open_loop);
1188
1189            let buffer = builder.create_buffer(BufferSettings::keep_all());
1190            builder.connect(scope.input, buffer.input_slot());
1191            builder
1192                .listen(buffer)
1193                .then_gate_close(buffer)
1194                .then(service)
1195                .fork_unzip((
1196                    |chain: Chain<_>| chain.dispose_on_none().connect(buffer.input_slot()),
1197                    |chain: Chain<_>| chain.dispose_on_none().connect(scope.terminate),
1198                ));
1199        });
1200
1201        let mut promise = context.command(|commands| commands.request(0, workflow).take_response());
1202
1203        context.run_with_conditions(&mut promise, Duration::from_secs(2));
1204        assert!(promise.take().available().is_some_and(|v| v == 5));
1205        assert!(context.no_unhandled_errors());
1206    }
1207
1208    /// Used to verify that when a key is used to open a buffer gate, it will not
1209    /// trigger the key's listener to wake up again.
1210    fn gate_access_test_open_loop(
1211        In(BlockingService { request: key, .. }): BlockingServiceInput<BufferKey<u64>>,
1212        mut access: BufferAccessMut<u64>,
1213        mut gate_access: BufferGateAccessMut,
1214    ) -> (Option<u64>, Option<u64>) {
1215        // We should never see a spurious wake-up in this service because the
1216        // gate opening is done by the key of this service.
1217        let mut buffer = access.get_mut(&key).unwrap();
1218        let value = buffer.pull().unwrap();
1219
1220        // The gate should have previously been closed before reaching this
1221        // service
1222        let mut gate = gate_access.get_mut(key).unwrap();
1223        assert_eq!(gate.get(), Gate::Closed);
1224        // Open the gate, which would normally trigger a notice, but the notice
1225        // should not come to this service because we're using the key without
1226        // closed loops allowed.
1227        gate.open_gate();
1228
1229        if value >= 5 {
1230            (None, Some(value))
1231        } else {
1232            (Some(value + 1), None)
1233        }
1234    }
1235
1236    #[test]
1237    fn test_closed_loop_key_access() {
1238        let mut context = TestingContext::minimal_plugins();
1239
1240        let delay = context.spawn_delay(Duration::from_secs_f32(0.1));
1241
1242        let workflow = context.spawn_io_workflow(|scope, builder| {
1243            let service = builder
1244                .commands()
1245                .spawn_service(gate_access_test_closed_loop);
1246
1247            let buffer = builder.create_buffer(BufferSettings::keep_all());
1248            builder.connect(scope.input, buffer.input_slot());
1249            builder.listen(buffer).then(service).fork_unzip((
1250                |chain: Chain<_>| {
1251                    chain
1252                        .dispose_on_none()
1253                        .then(delay)
1254                        .connect(buffer.input_slot())
1255                },
1256                |chain: Chain<_>| chain.dispose_on_none().connect(scope.terminate),
1257            ));
1258        });
1259
1260        let mut promise = context.command(|commands| commands.request(3, workflow).take_response());
1261
1262        context.run_with_conditions(&mut promise, Duration::from_secs(2));
1263        assert!(promise.take().available().is_some_and(|v| v == 0));
1264        assert!(context.no_unhandled_errors());
1265    }
1266
1267    /// Used to verify that we get spurious wakeups when closed loops are allowed
1268    fn gate_access_test_closed_loop(
1269        In(BlockingService { request: key, .. }): BlockingServiceInput<BufferKey<u64>>,
1270        mut access: BufferAccessMut<u64>,
1271    ) -> (Option<u64>, Option<u64>) {
1272        let mut buffer = access.get_mut(&key).unwrap().allow_closed_loops();
1273        if let Some(value) = buffer.pull() {
1274            (Some(value + 1), None)
1275        } else {
1276            (None, Some(0))
1277        }
1278    }
1279
1280    #[test]
1281    fn test_any_buffer_join_by_clone() {
1282        let mut context = TestingContext::minimal_plugins();
1283
1284        let workflow = context.spawn_io_workflow(|scope, builder| {
1285            let message_buffer = builder.create_buffer(Default::default()).join_by_cloning();
1286            let count_buffer = builder.create_buffer(Default::default());
1287            let (message, count) = builder.chain(scope.input).unzip();
1288            builder.connect(message, message_buffer.input_slot());
1289            builder.connect(count, count_buffer.input_slot());
1290
1291            // Make absolutely sure that the type information has been erased
1292            // before we assemble the buffer map.
1293            let any_message_buffer = message_buffer.as_any_buffer();
1294            let any_count_buffer = count_buffer.as_any_buffer();
1295
1296            let mut buffer_map = BufferMap::default();
1297            buffer_map.insert_buffer("message", any_message_buffer);
1298            buffer_map.insert_buffer("count", any_count_buffer);
1299
1300            builder
1301                .try_join::<JoinByCloneTest>(&buffer_map)
1302                .unwrap()
1303                .map_block(|joined| {
1304                    if joined.count < 10 {
1305                        // Increment the count buffer
1306                        Err(joined.count + 1)
1307                    } else {
1308                        Ok(joined)
1309                    }
1310                })
1311                .fork_result(
1312                    |ok| ok.connect(scope.terminate),
1313                    |err| err.connect(count_buffer.input_slot()),
1314                );
1315        });
1316
1317        let mut promise = context.command(|commands| {
1318            commands
1319                .request((String::from("hello"), 0), workflow)
1320                .take_response()
1321        });
1322
1323        context.run_with_conditions(&mut promise, Duration::from_secs(2));
1324        let r = promise.take().available().unwrap();
1325        assert_eq!(r.count, 10);
1326        assert_eq!(r.message, "hello");
1327    }
1328
1329    #[derive(Joined)]
1330    struct JoinByCloneTest {
1331        count: i64,
1332        message: String,
1333    }
1334}