use bevy_ecs::{
change_detection::Mut,
prelude::{Commands, Entity, EntityRef, Query, World},
query::QueryEntityError,
system::{SystemParam, SystemState},
};
use std::{
any::TypeId,
collections::HashMap,
ops::RangeBounds,
sync::{Arc, Mutex, OnceLock},
};
use thiserror::Error as ThisError;
use crate::{GateState, InputSlot, NotifyBufferUpdate, OperationError};
mod any_buffer;
pub use any_buffer::*;
mod buffer_access_lifecycle;
pub use buffer_access_lifecycle::BufferKeyLifecycle;
pub(crate) use buffer_access_lifecycle::*;
mod buffer_key_builder;
pub use buffer_key_builder::*;
mod buffer_gate;
pub use buffer_gate::*;
mod buffer_map;
pub use buffer_map::*;
mod buffer_storage;
pub(crate) use buffer_storage::*;
mod buffering;
pub use buffering::*;
mod bufferable;
pub use bufferable::*;
mod manage_buffer;
pub use manage_buffer::*;
#[cfg(feature = "diagram")]
mod json_buffer;
#[cfg(feature = "diagram")]
pub use json_buffer::*;
mod fetch_from_buffer;
pub use fetch_from_buffer::*;
pub struct Buffer<T> {
pub(crate) location: BufferLocation,
pub(crate) _ignore: std::marker::PhantomData<fn(T)>,
}
impl<T: 'static + Send + Sync> Buffer<T> {
pub fn join_by_cloning(self) -> CloneFromBuffer<T>
where
T: Clone,
{
CloneFromBuffer::new(self.location)
}
pub fn input_slot(self) -> InputSlot<T> {
InputSlot::new(self.scope(), self.id())
}
pub fn id(&self) -> Entity {
self.location.source
}
pub fn scope(&self) -> Entity {
self.location.scope
}
pub fn location(&self) -> BufferLocation {
self.location
}
}
impl<T> Clone for Buffer<T> {
fn clone(&self) -> Self {
*self
}
}
impl<T> Copy for Buffer<T> {}
#[derive(Clone, Copy, Debug)]
pub struct BufferLocation {
pub scope: Entity,
pub source: Entity,
}
#[derive(Clone)]
pub struct CloneFromBuffer<T: Clone + Send + Sync + 'static> {
location: BufferLocation,
_ignore: std::marker::PhantomData<fn(T)>,
}
impl<T: Clone + Send + Sync + 'static> Copy for CloneFromBuffer<T> {}
impl<T: Clone + Send + Sync + 'static> CloneFromBuffer<T> {
pub fn input_slot(self) -> InputSlot<T> {
InputSlot::new(self.scope(), self.id())
}
pub fn id(&self) -> Entity {
self.location.source
}
pub fn scope(&self) -> Entity {
self.location.scope
}
pub fn location(&self) -> BufferLocation {
self.location
}
#[must_use]
pub fn join_by_pulling(self) -> Buffer<T> {
Buffer {
location: self.location,
_ignore: Default::default(),
}
}
fn new(location: BufferLocation) -> Self {
Self::register_clone_for_join();
Self {
location,
_ignore: Default::default(),
}
}
pub fn register_clone_for_join() {
static REGISTER_CLONE: OnceLock<Mutex<HashMap<TypeId, ()>>> = OnceLock::new();
let register_clone = REGISTER_CLONE.get_or_init(|| Mutex::default());
let mut register_mut = register_clone.lock().unwrap();
register_mut.entry(TypeId::of::<T>()).or_insert_with(|| {
let interface = AnyBuffer::interface_for::<T>();
interface.register_cloning(
clone_for_any_join::<T>,
&(clone_for_join::<T> as FetchFromBufferFn<T>),
);
interface.register_buffer_downcast(
TypeId::of::<CloneFromBuffer<T>>(),
Box::new(|buffer: AnyBuffer| {
Ok(Box::new(CloneFromBuffer::<T>::new(buffer.location)))
}),
);
});
}
}
fn clone_for_any_join<T: 'static + Send + Sync + Clone>(
entity_ref: &EntityRef,
session: Entity,
) -> Result<AnyMessageBox, OperationError> {
entity_ref
.clone_from_buffer::<T>(session)
.map(to_any_message)
}
impl<T: Clone + Send + Sync> From<CloneFromBuffer<T>> for Buffer<T> {
fn from(value: CloneFromBuffer<T>) -> Self {
Buffer {
location: value.location,
_ignore: Default::default(),
}
}
}
#[cfg_attr(
feature = "diagram",
derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema),
serde(rename_all = "snake_case")
)]
#[derive(Default, Clone, Copy, Debug)]
pub struct BufferSettings {
retention: RetentionPolicy,
}
impl BufferSettings {
pub fn new(retention: RetentionPolicy) -> Self {
Self { retention }
}
pub fn keep_last(n: usize) -> Self {
Self::new(RetentionPolicy::KeepLast(n))
}
pub fn keep_first(n: usize) -> Self {
Self::new(RetentionPolicy::KeepFirst(n))
}
pub fn keep_all() -> Self {
Self::new(RetentionPolicy::KeepAll)
}
pub fn retention(&self) -> RetentionPolicy {
self.retention
}
pub fn retention_mut(&mut self) -> &mut RetentionPolicy {
&mut self.retention
}
}
#[cfg_attr(
feature = "diagram",
derive(serde::Serialize, serde::Deserialize, schemars::JsonSchema),
serde(rename_all = "snake_case")
)]
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub enum RetentionPolicy {
KeepLast(usize),
KeepFirst(usize),
KeepAll,
}
impl Default for RetentionPolicy {
fn default() -> Self {
Self::KeepLast(1)
}
}
pub struct BufferKey<T> {
tag: BufferKeyTag,
_ignore: std::marker::PhantomData<fn(T)>,
}
impl<T> Clone for BufferKey<T> {
fn clone(&self) -> Self {
Self {
tag: self.tag.clone(),
_ignore: Default::default(),
}
}
}
impl<T> BufferKey<T> {
pub fn buffer(&self) -> Entity {
self.tag.buffer
}
pub fn session(&self) -> Entity {
self.tag.session
}
pub fn tag(&self) -> &BufferKeyTag {
&self.tag
}
}
impl<T: 'static + Send + Sync> BufferKeyLifecycle for BufferKey<T> {
type TargetBuffer = Buffer<T>;
fn create_key(buffer: &Self::TargetBuffer, builder: &BufferKeyBuilder) -> Self {
BufferKey {
tag: builder.make_tag(buffer.id()),
_ignore: Default::default(),
}
}
fn is_in_use(&self) -> bool {
self.tag.is_in_use()
}
fn deep_clone(&self) -> Self {
Self {
tag: self.tag.deep_clone(),
_ignore: Default::default(),
}
}
}
impl<T> std::fmt::Debug for BufferKey<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BufferKey")
.field("message_type_name", &std::any::type_name::<T>())
.field("tag", &self.tag)
.finish()
}
}
#[derive(Clone)]
pub struct BufferKeyTag {
pub buffer: Entity,
pub session: Entity,
pub accessor: Entity,
pub lifecycle: Option<Arc<BufferAccessLifecycle>>,
}
impl BufferKeyTag {
pub fn is_in_use(&self) -> bool {
self.lifecycle.as_ref().is_some_and(|l| l.is_in_use())
}
pub fn deep_clone(&self) -> Self {
let mut deep = self.clone();
deep.lifecycle = self
.lifecycle
.as_ref()
.map(|l| Arc::new(l.as_ref().clone()));
deep
}
}
impl std::fmt::Debug for BufferKeyTag {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BufferKeyTag")
.field("buffer", &self.buffer)
.field("session", &self.session)
.field("accessor", &self.accessor)
.field("in_use", &self.is_in_use())
.finish()
}
}
#[derive(SystemParam)]
pub struct BufferAccess<'w, 's, T>
where
T: 'static + Send + Sync,
{
query: Query<'w, 's, &'static BufferStorage<T>>,
}
impl<'w, 's, T: 'static + Send + Sync> BufferAccess<'w, 's, T> {
pub fn get<'a>(&'a self, key: &BufferKey<T>) -> Result<BufferView<'a, T>, QueryEntityError> {
let session = key.session();
self.query
.get(key.buffer())
.map(|storage| BufferView { storage, session })
}
pub fn get_newest<'a>(&'a self, key: &BufferKey<T>) -> Option<&'a T> {
self.get(key).ok().map(|view| view.newest()).flatten()
}
}
#[derive(SystemParam)]
pub struct BufferAccessMut<'w, 's, T>
where
T: 'static + Send + Sync,
{
query: Query<'w, 's, &'static mut BufferStorage<T>>,
commands: Commands<'w, 's>,
}
impl<'w, 's, T> BufferAccessMut<'w, 's, T>
where
T: 'static + Send + Sync,
{
pub fn get<'a>(&'a self, key: &BufferKey<T>) -> Result<BufferView<'a, T>, QueryEntityError> {
let session = key.session();
self.query
.get(key.buffer())
.map(|storage| BufferView { storage, session })
}
pub fn get_newest<'a>(&'a self, key: &BufferKey<T>) -> Option<&'a T> {
self.get(key).ok().map(|view| view.newest()).flatten()
}
pub fn get_mut<'a>(
&'a mut self,
key: &BufferKey<T>,
) -> Result<BufferMut<'w, 's, 'a, T>, QueryEntityError> {
let buffer = key.buffer();
let session = key.session();
let accessor = key.tag.accessor;
self.query
.get_mut(key.buffer())
.map(|storage| BufferMut::new(storage, buffer, session, accessor, &mut self.commands))
}
}
pub trait BufferWorldAccess {
fn buffer_view<T>(&self, key: &BufferKey<T>) -> Result<BufferView<'_, T>, BufferError>
where
T: 'static + Send + Sync;
fn buffer_gate_view(
&self,
key: impl Into<AnyBufferKey>,
) -> Result<BufferGateView<'_>, BufferError>;
fn buffer_mut<T, U>(
&mut self,
key: &BufferKey<T>,
f: impl FnOnce(BufferMut<T>) -> U,
) -> Result<U, BufferError>
where
T: 'static + Send + Sync;
fn buffer_gate_mut<U>(
&mut self,
key: impl Into<AnyBufferKey>,
f: impl FnOnce(BufferGateMut) -> U,
) -> Result<U, BufferError>;
}
impl BufferWorldAccess for World {
fn buffer_view<T>(&self, key: &BufferKey<T>) -> Result<BufferView<'_, T>, BufferError>
where
T: 'static + Send + Sync,
{
let buffer_ref = self
.get_entity(key.tag.buffer)
.map_err(|_| BufferError::BufferMissing)?;
let storage = buffer_ref
.get::<BufferStorage<T>>()
.ok_or(BufferError::BufferMissing)?;
Ok(BufferView {
storage,
session: key.tag.session,
})
}
fn buffer_gate_view(
&self,
key: impl Into<AnyBufferKey>,
) -> Result<BufferGateView<'_>, BufferError> {
let key: AnyBufferKey = key.into();
let buffer_ref = self
.get_entity(key.tag.buffer)
.or(Err(BufferError::BufferMissing))?;
let gate = buffer_ref
.get::<GateState>()
.ok_or(BufferError::BufferMissing)?;
Ok(BufferGateView {
gate,
session: key.tag.session,
})
}
fn buffer_mut<T, U>(
&mut self,
key: &BufferKey<T>,
f: impl FnOnce(BufferMut<T>) -> U,
) -> Result<U, BufferError>
where
T: 'static + Send + Sync,
{
let mut state = SystemState::<BufferAccessMut<T>>::new(self);
let mut buffer_access_mut = state.get_mut(self);
let buffer_mut = buffer_access_mut
.get_mut(key)
.map_err(|_| BufferError::BufferMissing)?;
Ok(f(buffer_mut))
}
fn buffer_gate_mut<U>(
&mut self,
key: impl Into<AnyBufferKey>,
f: impl FnOnce(BufferGateMut) -> U,
) -> Result<U, BufferError> {
let mut state = SystemState::<BufferGateAccessMut>::new(self);
let mut buffer_gate_access_mut = state.get_mut(self);
let buffer_mut = buffer_gate_access_mut
.get_mut(key)
.map_err(|_| BufferError::BufferMissing)?;
Ok(f(buffer_mut))
}
}
pub struct BufferView<'a, T>
where
T: 'static + Send + Sync,
{
storage: &'a BufferStorage<T>,
session: Entity,
}
impl<'a, T> BufferView<'a, T>
where
T: 'static + Send + Sync,
{
pub fn iter(&self) -> IterBufferView<'a, T> {
self.storage.iter(self.session)
}
pub fn oldest(&self) -> Option<&'a T> {
self.storage.oldest(self.session)
}
pub fn newest(&self) -> Option<&'a T> {
self.storage.newest(self.session)
}
pub fn get(&self, index: usize) -> Option<&'a T> {
self.storage.get(self.session, index)
}
pub fn len(&self) -> usize {
self.storage.count(self.session)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub struct BufferMut<'w, 's, 'a, T>
where
T: 'static + Send + Sync,
{
storage: Mut<'a, BufferStorage<T>>,
buffer: Entity,
session: Entity,
accessor: Option<Entity>,
commands: &'a mut Commands<'w, 's>,
modified: bool,
}
impl<'w, 's, 'a, T> BufferMut<'w, 's, 'a, T>
where
T: 'static + Send + Sync,
{
pub fn allow_closed_loops(mut self) -> Self {
self.accessor = None;
self
}
pub fn iter(&self) -> IterBufferView<'_, T> {
self.storage.iter(self.session)
}
pub fn oldest(&self) -> Option<&T> {
self.storage.oldest(self.session)
}
pub fn newest(&self) -> Option<&T> {
self.storage.newest(self.session)
}
pub fn get(&self, index: usize) -> Option<&T> {
self.storage.get(self.session, index)
}
pub fn len(&self) -> usize {
self.storage.count(self.session)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn iter_mut(&mut self) -> IterBufferMut<'_, T> {
self.modified = true;
self.storage.iter_mut(self.session)
}
pub fn oldest_mut(&mut self) -> Option<&mut T> {
self.modified = true;
self.storage.oldest_mut(self.session)
}
pub fn newest_mut(&mut self) -> Option<&mut T> {
self.modified = true;
self.storage.newest_mut(self.session)
}
pub fn newest_mut_or_default(&mut self) -> Option<&mut T>
where
T: Default,
{
self.newest_mut_or_else(|| T::default())
}
pub fn newest_mut_or_else(&mut self, f: impl FnOnce() -> T) -> Option<&mut T> {
self.modified = true;
self.storage.newest_mut_or_else(self.session, f)
}
pub fn get_mut(&mut self, index: usize) -> Option<&mut T> {
self.modified = true;
self.storage.get_mut(self.session, index)
}
pub fn drain<R>(&mut self, range: R) -> DrainBuffer<'_, T>
where
R: RangeBounds<usize>,
{
self.modified = true;
self.storage.drain(self.session, range)
}
pub fn pull(&mut self) -> Option<T> {
self.modified = true;
self.storage.pull(self.session)
}
pub fn pull_newest(&mut self) -> Option<T> {
self.modified = true;
self.storage.pull_newest(self.session)
}
pub fn push(&mut self, value: T) -> Option<T> {
self.modified = true;
self.storage.push(self.session, value)
}
pub fn push_as_oldest(&mut self, value: T) -> Option<T> {
self.modified = true;
self.storage.push_as_oldest(self.session, value)
}
pub fn pulse(&mut self) {
self.modified = true;
}
fn new(
storage: Mut<'a, BufferStorage<T>>,
buffer: Entity,
session: Entity,
accessor: Entity,
commands: &'a mut Commands<'w, 's>,
) -> Self {
Self {
storage,
buffer,
session,
accessor: Some(accessor),
commands,
modified: false,
}
}
}
impl<'w, 's, 'a, T> Drop for BufferMut<'w, 's, 'a, T>
where
T: 'static + Send + Sync,
{
fn drop(&mut self) {
if self.modified {
self.commands.queue(NotifyBufferUpdate::new(
self.buffer,
self.session,
self.accessor,
));
}
}
}
#[derive(ThisError, Debug, Clone)]
pub enum BufferError {
#[error("The key was unable to identify a buffer")]
BufferMissing,
}
#[cfg(test)]
mod tests {
use crate::{AddBufferToMap, Gate, prelude::*, testing::*};
use std::future::Future;
#[test]
fn test_buffer_key_access() {
let mut context = TestingContext::minimal_plugins();
let add_buffers_by_pull_cb = add_buffers_by_pull.into_blocking_callback();
let add_from_buffer_cb = add_from_buffer.into_blocking_callback();
let multiply_buffers_by_copy_cb = multiply_buffers_by_copy.into_blocking_callback();
let workflow = context.spawn_io_workflow(|scope: Scope<(f64, f64), f64>, builder| {
builder
.chain(scope.start)
.unzip()
.listen(builder)
.then(multiply_buffers_by_copy_cb)
.connect(scope.terminate);
});
let r = context.resolve_request((2.0, 3.0), workflow);
assert_eq!(r, 6.0);
let workflow = context.spawn_io_workflow(|scope: Scope<(f64, f64), f64>, builder| {
builder
.chain(scope.start)
.unzip()
.listen(builder)
.then(add_buffers_by_pull_cb)
.dispose_on_none()
.connect(scope.terminate);
});
let r = context.resolve_request((4.0, 5.0), workflow);
assert_eq!(r, 9.0);
let workflow =
context.spawn_io_workflow(|scope: Scope<(f64, f64), Result<f64, f64>>, builder| {
let (branch_to_adder, branch_to_buffer) = builder.chain(scope.start).unzip();
let buffer = builder.create_buffer::<f64>(BufferSettings::keep_first(10));
builder.connect(branch_to_buffer, buffer.input_slot());
let adder_node = builder
.chain(branch_to_adder)
.with_access(buffer)
.then_node(add_from_buffer_cb.clone());
builder.chain(adder_node.output).fork_result(
|chain| {
chain
.with_access(buffer)
.then(add_from_buffer_cb.clone())
.connect(scope.terminate)
},
|chain| chain.with_access(buffer).connect(adder_node.input),
);
});
let r = context.resolve_request((2.0, 3.0), workflow);
assert!(r.is_err_and(|n| n == 5.0));
let workflow = context.spawn_io_workflow(|scope, builder| {
let (branch_to_adder, branch_to_buffer) = builder.chain(scope.start).unzip();
let buffer = builder.create_buffer::<f64>(BufferSettings::keep_first(10));
builder.connect(branch_to_buffer, buffer.input_slot());
let access = builder.create_buffer_access(buffer);
builder.connect(branch_to_adder, access.input);
builder
.chain(access.output)
.then(add_from_buffer_cb.clone())
.fork_result(
|ok| {
let (output, builder) = ok.unpack();
let second_access = builder.create_buffer_access(buffer);
builder.connect(output, second_access.input);
builder
.chain(second_access.output)
.then(add_from_buffer_cb.clone())
.connect(scope.terminate);
},
|err| err.connect(access.input),
);
});
let r = context.resolve_request((2.0, 3.0), workflow);
assert!(r.is_err_and(|n| n == 5.0));
}
fn add_from_buffer(
In((lhs, key)): In<(f64, BufferKey<f64>)>,
mut access: BufferAccessMut<f64>,
) -> Result<f64, f64> {
let rhs = access.get_mut(&key).map_err(|_| lhs)?.pull().ok_or(lhs)?;
Ok(lhs + rhs)
}
fn multiply_buffers_by_copy(
In((key_a, key_b)): In<(BufferKey<f64>, BufferKey<f64>)>,
access: BufferAccess<f64>,
) -> f64 {
*access.get(&key_a).unwrap().oldest().unwrap()
* *access.get(&key_b).unwrap().oldest().unwrap()
}
fn add_buffers_by_pull(
In((key_a, key_b)): In<(BufferKey<f64>, BufferKey<f64>)>,
mut access: BufferAccessMut<f64>,
) -> Option<f64> {
if access.get(&key_a).unwrap().is_empty() {
return None;
}
if access.get(&key_b).unwrap().is_empty() {
return None;
}
let rhs = access.get_mut(&key_a).unwrap().pull().unwrap();
let lhs = access.get_mut(&key_b).unwrap().pull().unwrap();
Some(rhs + lhs)
}
#[test]
fn test_buffer_key_lifecycle() {
let mut context = TestingContext::minimal_plugins();
let workflow = context.spawn_io_workflow(|scope, builder| {
let buffer = builder.create_buffer::<Register>(BufferSettings::keep_all());
builder
.listen(buffer)
.then(pull_register_from_buffer.into_blocking_callback())
.dispose_on_none()
.connect(scope.terminate);
let decrement_register_cb = decrement_register.into_blocking_callback();
let async_decrement_register_cb = async_decrement_register.as_callback();
builder
.chain(scope.start)
.with_access(buffer)
.then(decrement_register_cb.clone())
.with_access(buffer)
.then(async_decrement_register_cb.clone())
.dispose_on_none()
.with_access(buffer)
.then(decrement_register_cb.clone())
.with_access(buffer)
.then(async_decrement_register_cb)
.unused();
});
run_register_test(workflow, 0, true, &mut context);
run_register_test(workflow, 1, true, &mut context);
run_register_test(workflow, 2, true, &mut context);
run_register_test(workflow, 3, true, &mut context);
run_register_test(workflow, 4, false, &mut context);
run_register_test(workflow, 5, false, &mut context);
run_register_test(workflow, 6, false, &mut context);
let workflow = context.spawn_io_workflow(|scope, builder| {
let buffer = builder.create_buffer::<Register>(BufferSettings::keep_all());
builder
.listen(buffer)
.then(pull_register_from_buffer.into_blocking_callback())
.dispose_on_none()
.connect(scope.terminate);
let decrement_register_and_pass_keys_cb =
decrement_register_and_pass_keys.into_blocking_callback();
let async_decrement_register_and_pass_keys_cb =
async_decrement_register_and_pass_keys.as_callback();
let (loose_end, dead_end): (_, Output<Option<Register>>) = builder
.chain(scope.start)
.with_access(buffer)
.then(decrement_register_and_pass_keys_cb.clone())
.then(async_decrement_register_and_pass_keys_cb.clone())
.dispose_on_none()
.map_block(|v| (v, None))
.unzip();
builder.chain(dead_end).dispose_on_none().unused();
builder
.chain(loose_end)
.then(async_decrement_register_and_pass_keys_cb)
.dispose_on_none()
.then(decrement_register_and_pass_keys_cb)
.unused();
});
run_register_test(workflow, 0, true, &mut context);
run_register_test(workflow, 1, true, &mut context);
run_register_test(workflow, 2, true, &mut context);
run_register_test(workflow, 3, true, &mut context);
run_register_test(workflow, 4, false, &mut context);
run_register_test(workflow, 5, false, &mut context);
run_register_test(workflow, 6, false, &mut context);
}
fn run_register_test(
workflow: Service<Register, Register>,
initial_value: u64,
expect_success: bool,
context: &mut TestingContext,
) {
let r = context.try_resolve_request(Register::new(initial_value), workflow, ());
if expect_success {
assert!(r.unwrap().finished_with(initial_value));
} else {
assert!(r.is_err());
}
}
#[derive(Clone, Copy, Debug)]
struct Register {
in_slot: u64,
out_slot: u64,
}
impl Register {
fn new(start_from: u64) -> Self {
Self {
in_slot: start_from,
out_slot: 0,
}
}
fn finished_with(&self, out_slot: u64) -> bool {
self.in_slot == 0 && self.out_slot == out_slot
}
}
fn pull_register_from_buffer(
In(key): In<BufferKey<Register>>,
mut access: BufferAccessMut<Register>,
) -> Option<Register> {
access.get_mut(&key).ok()?.pull()
}
fn decrement_register(
In((mut register, key)): In<(Register, BufferKey<Register>)>,
mut access: BufferAccessMut<Register>,
) -> Register {
if register.in_slot == 0 {
access.get_mut(&key).unwrap().push(register);
return register;
}
register.in_slot -= 1;
register.out_slot += 1;
register
}
fn decrement_register_and_pass_keys(
In((mut register, key)): In<(Register, BufferKey<Register>)>,
mut access: BufferAccessMut<Register>,
) -> (Register, BufferKey<Register>) {
if register.in_slot == 0 {
access.get_mut(&key).unwrap().push(register);
return (register, key);
}
register.in_slot -= 1;
register.out_slot += 1;
(register, key)
}
fn async_decrement_register(
In(input): In<AsyncCallback<(Register, BufferKey<Register>)>>,
) -> impl Future<Output = Option<Register>> + use<> {
async move {
input
.channel
.request_outcome(input.request, decrement_register.into_blocking_callback())
.await
.ok()
}
}
fn async_decrement_register_and_pass_keys(
In(input): In<AsyncCallback<(Register, BufferKey<Register>)>>,
) -> impl Future<Output = Option<(Register, BufferKey<Register>)>> + use<> {
async move {
input
.channel
.request_outcome(
input.request,
decrement_register_and_pass_keys.into_blocking_callback(),
)
.await
.ok()
}
}
#[test]
fn test_buffer_key_gate_control() {
let mut context = TestingContext::minimal_plugins();
let workflow = context.spawn_io_workflow(|scope, builder| {
let service = builder.commands().spawn_service(gate_access_test_open_loop);
let buffer = builder.create_buffer(BufferSettings::keep_all());
builder.connect(scope.start, buffer.input_slot());
builder
.listen(buffer)
.then_gate_close(buffer)
.then(service)
.fork_unzip((
|chain: Chain<_>| chain.dispose_on_none().connect(buffer.input_slot()),
|chain: Chain<_>| chain.dispose_on_none().connect(scope.terminate),
));
});
let r = context.resolve_request(0, workflow);
assert_eq!(r, 5);
}
fn gate_access_test_open_loop(
In(BlockingService { request: key, .. }): BlockingServiceInput<BufferKey<u64>>,
mut access: BufferAccessMut<u64>,
mut gate_access: BufferGateAccessMut,
) -> (Option<u64>, Option<u64>) {
let mut buffer = access.get_mut(&key).unwrap();
let value = buffer.pull().unwrap();
let mut gate = gate_access.get_mut(key).unwrap();
assert_eq!(gate.get(), Gate::Closed);
gate.open_gate();
if value >= 5 {
(None, Some(value))
} else {
(Some(value + 1), None)
}
}
#[test]
fn test_closed_loop_key_access() {
let mut context = TestingContext::minimal_plugins();
let delay = context.spawn_delay(Duration::from_secs_f32(0.1));
let workflow = context.spawn_io_workflow(|scope, builder| {
let service = builder
.commands()
.spawn_service(gate_access_test_closed_loop);
let buffer = builder.create_buffer(BufferSettings::keep_all());
builder.connect(scope.start, buffer.input_slot());
builder.listen(buffer).then(service).fork_unzip((
|chain: Chain<_>| {
chain
.dispose_on_none()
.then(delay)
.connect(buffer.input_slot())
},
|chain: Chain<_>| chain.dispose_on_none().connect(scope.terminate),
));
});
let r = context.resolve_request(3, workflow);
assert_eq!(r, 0);
}
fn gate_access_test_closed_loop(
In(BlockingService { request: key, .. }): BlockingServiceInput<BufferKey<u64>>,
mut access: BufferAccessMut<u64>,
) -> (Option<u64>, Option<u64>) {
let mut buffer = access.get_mut(&key).unwrap().allow_closed_loops();
if let Some(value) = buffer.pull() {
(Some(value + 1), None)
} else {
(None, Some(0))
}
}
#[test]
fn test_any_buffer_join_by_clone() {
let mut context = TestingContext::minimal_plugins();
let workflow = context.spawn_io_workflow(|scope, builder| {
let message_buffer = builder.create_buffer(Default::default()).join_by_cloning();
let count_buffer = builder.create_buffer(Default::default());
let (message, count) = builder.chain(scope.start).unzip();
builder.connect(message, message_buffer.input_slot());
builder.connect(count, count_buffer.input_slot());
let any_message_buffer = message_buffer.as_any_buffer();
let any_count_buffer = count_buffer.as_any_buffer();
let mut buffer_map = BufferMap::default();
buffer_map.insert_buffer("message", any_message_buffer);
buffer_map.insert_buffer("count", any_count_buffer);
builder
.try_join::<JoinByCloneTest>(&buffer_map)
.unwrap()
.map_block(|joined| {
if joined.count < 10 {
Err(joined.count + 1)
} else {
Ok(joined)
}
})
.fork_result(
|ok| ok.connect(scope.terminate),
|err| err.connect(count_buffer.input_slot()),
);
});
let r = context.resolve_request((String::from("hello"), 0), workflow);
assert_eq!(r.count, 10);
assert_eq!(r.message, "hello");
}
#[derive(Joined)]
struct JoinByCloneTest {
count: i64,
message: String,
}
fn get_largest_value(
In(input): In<((), BufferKey<i32>)>,
access: BufferAccess<i32>,
) -> Option<i32> {
let access = access.get(&input.1).ok()?;
access.iter().max().cloned()
}
fn push_values(In(input): In<(Vec<i32>, BufferKey<i32>)>, mut access: BufferAccessMut<i32>) {
let Ok(mut access) = access.get_mut(&input.1) else {
return;
};
for value in input.0 {
access.push(value);
}
}
#[test]
fn test_buffer_access_example() {
let mut context = TestingContext::minimal_plugins();
let workflow = context.spawn_io_workflow(|scope, builder| {
let buffer = builder.create_buffer(BufferSettings::keep_all());
builder
.chain(scope.start)
.with_access(buffer)
.then(push_values.into_blocking_callback())
.with_access(buffer)
.then(get_largest_value.into_blocking_callback())
.connect(scope.terminate);
});
let r = context.resolve_request(vec![-3, 2, 10], workflow);
assert_eq!(r.unwrap(), 10);
}
}