use bevy_ecs::{
change_detection::Mut,
prelude::{Commands, Entity, Query},
query::QueryEntityError,
system::SystemParam,
};
use std::{ops::RangeBounds, sync::Arc};
use crate::{
Builder, Chain, Gate, GateState, InputSlot, NotifyBufferUpdate, OnNewBufferValue, UnusedTarget,
};
mod buffer_access_lifecycle;
pub(crate) use buffer_access_lifecycle::*;
mod buffer_key_builder;
pub(crate) use buffer_key_builder::*;
mod buffer_storage;
pub(crate) use buffer_storage::*;
mod buffered;
pub use buffered::*;
mod bufferable;
pub use bufferable::*;
mod manage_buffer;
pub use manage_buffer::*;
pub struct Buffer<T> {
pub(crate) scope: Entity,
pub(crate) source: Entity,
pub(crate) _ignore: std::marker::PhantomData<fn(T)>,
}
impl<T> Buffer<T> {
pub fn on_new_value<'w, 's, 'a, 'b>(
&self,
builder: &'b mut Builder<'w, 's, 'a>,
) -> Chain<'w, 's, 'a, 'b, ()> {
assert_eq!(self.scope, builder.scope);
let target = builder.commands.spawn(UnusedTarget).id();
builder
.commands
.add(OnNewBufferValue::new(self.source, target));
Chain::new(target, builder)
}
pub fn by_cloning(self) -> CloneFromBuffer<T>
where
T: Clone,
{
CloneFromBuffer {
scope: self.scope,
source: self.source,
_ignore: Default::default(),
}
}
pub fn input_slot(self) -> InputSlot<T> {
InputSlot::new(self.scope, self.source)
}
}
pub struct CloneFromBuffer<T: Clone> {
pub(crate) scope: Entity,
pub(crate) source: Entity,
pub(crate) _ignore: std::marker::PhantomData<fn(T)>,
}
#[derive(Default, Clone, Copy)]
pub struct BufferSettings {
retention: RetentionPolicy,
}
impl BufferSettings {
pub fn new(retention: RetentionPolicy) -> Self {
Self { retention }
}
pub fn keep_last(n: usize) -> Self {
Self::new(RetentionPolicy::KeepLast(n))
}
pub fn keep_first(n: usize) -> Self {
Self::new(RetentionPolicy::KeepFirst(n))
}
pub fn keep_all() -> Self {
Self::new(RetentionPolicy::KeepAll)
}
pub fn retention(&self) -> RetentionPolicy {
self.retention
}
pub fn retention_mut(&mut self) -> &mut RetentionPolicy {
&mut self.retention
}
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum RetentionPolicy {
KeepLast(usize),
KeepFirst(usize),
KeepAll,
}
impl Default for RetentionPolicy {
fn default() -> Self {
Self::KeepLast(1)
}
}
impl<T> Clone for Buffer<T> {
fn clone(&self) -> Self {
*self
}
}
impl<T> Copy for Buffer<T> {}
impl<T: Clone> Clone for CloneFromBuffer<T> {
fn clone(&self) -> Self {
*self
}
}
impl<T: Clone> Copy for CloneFromBuffer<T> {}
pub struct BufferKey<T> {
buffer: Entity,
session: Entity,
accessor: Entity,
lifecycle: Option<Arc<BufferAccessLifecycle>>,
_ignore: std::marker::PhantomData<fn(T)>,
}
impl<T> Clone for BufferKey<T> {
fn clone(&self) -> Self {
Self {
buffer: self.buffer,
session: self.session,
accessor: self.accessor,
lifecycle: self.lifecycle.as_ref().map(Arc::clone),
_ignore: Default::default(),
}
}
}
impl<T> BufferKey<T> {
pub fn id(&self) -> Entity {
self.buffer
}
pub fn session(&self) -> Entity {
self.session
}
pub(crate) fn is_in_use(&self) -> bool {
self.lifecycle.as_ref().is_some_and(|l| l.is_in_use())
}
pub(crate) fn deep_clone(&self) -> Self {
let mut deep = self.clone();
deep.lifecycle = self
.lifecycle
.as_ref()
.map(|l| Arc::new(l.as_ref().clone()));
deep
}
}
#[derive(SystemParam)]
pub struct BufferAccess<'w, 's, T>
where
T: 'static + Send + Sync,
{
query: Query<'w, 's, (&'static BufferStorage<T>, &'static GateState)>,
}
impl<'w, 's, T: 'static + Send + Sync> BufferAccess<'w, 's, T> {
pub fn get<'a>(&'a self, key: &BufferKey<T>) -> Result<BufferView<'a, T>, QueryEntityError> {
let session = key.session;
self.query
.get(key.buffer)
.map(|(storage, gate)| BufferView {
storage,
gate,
session,
})
}
}
#[derive(SystemParam)]
pub struct BufferAccessMut<'w, 's, T>
where
T: 'static + Send + Sync,
{
query: Query<'w, 's, (&'static mut BufferStorage<T>, &'static mut GateState)>,
commands: Commands<'w, 's>,
}
impl<'w, 's, T> BufferAccessMut<'w, 's, T>
where
T: 'static + Send + Sync,
{
pub fn get<'a>(&'a self, key: &BufferKey<T>) -> Result<BufferView<'a, T>, QueryEntityError> {
let session = key.session;
self.query
.get(key.buffer)
.map(|(storage, gate)| BufferView {
storage,
gate,
session,
})
}
pub fn get_mut<'a>(
&'a mut self,
key: &BufferKey<T>,
) -> Result<BufferMut<'w, 's, 'a, T>, QueryEntityError> {
let buffer = key.buffer;
let session = key.session;
let accessor = key.accessor;
self.query.get_mut(key.buffer).map(|(storage, gate)| {
BufferMut::new(storage, gate, buffer, session, accessor, &mut self.commands)
})
}
}
pub struct BufferView<'a, T>
where
T: 'static + Send + Sync,
{
storage: &'a BufferStorage<T>,
gate: &'a GateState,
session: Entity,
}
impl<'a, T> BufferView<'a, T>
where
T: 'static + Send + Sync,
{
pub fn iter(&self) -> IterBufferView<'_, T> {
self.storage.iter(self.session)
}
pub fn oldest(&self) -> Option<&T> {
self.storage.oldest(self.session)
}
pub fn newest(&self) -> Option<&T> {
self.storage.newest(self.session)
}
pub fn get(&self, index: usize) -> Option<&T> {
self.storage.get(self.session, index)
}
pub fn len(&self) -> usize {
self.storage.count(self.session)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn gate(&self) -> Gate {
self.gate
.map
.get(&self.session)
.copied()
.unwrap_or(Gate::Open)
}
}
pub struct BufferMut<'w, 's, 'a, T>
where
T: 'static + Send + Sync,
{
storage: Mut<'a, BufferStorage<T>>,
gate: Mut<'a, GateState>,
buffer: Entity,
session: Entity,
accessor: Option<Entity>,
commands: &'a mut Commands<'w, 's>,
modified: bool,
}
impl<'w, 's, 'a, T> BufferMut<'w, 's, 'a, T>
where
T: 'static + Send + Sync,
{
pub fn allow_closed_loops(mut self) -> Self {
self.accessor = None;
self
}
pub fn iter(&self) -> IterBufferView<'_, T> {
self.storage.iter(self.session)
}
pub fn oldest(&self) -> Option<&T> {
self.storage.oldest(self.session)
}
pub fn newest(&self) -> Option<&T> {
self.storage.newest(self.session)
}
pub fn get(&self, index: usize) -> Option<&T> {
self.storage.get(self.session, index)
}
pub fn len(&self) -> usize {
self.storage.count(self.session)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn gate(&self) -> Gate {
self.gate
.map
.get(&self.session)
.copied()
.unwrap_or(Gate::Open)
}
pub fn iter_mut(&mut self) -> IterBufferMut<'_, T> {
self.modified = true;
self.storage.iter_mut(self.session)
}
pub fn oldest_mut(&mut self) -> Option<&mut T> {
self.modified = true;
self.storage.oldest_mut(self.session)
}
pub fn newest_mut(&mut self) -> Option<&mut T> {
self.modified = true;
self.storage.newest_mut(self.session)
}
pub fn get_mut(&mut self, index: usize) -> Option<&mut T> {
self.modified = true;
self.storage.get_mut(self.session, index)
}
pub fn drain<R>(&mut self, range: R) -> DrainBuffer<'_, T>
where
R: RangeBounds<usize>,
{
self.modified = true;
self.storage.drain(self.session, range)
}
pub fn pull(&mut self) -> Option<T> {
self.modified = true;
self.storage.pull(self.session)
}
pub fn pull_newest(&mut self) -> Option<T> {
self.modified = true;
self.storage.pull_newest(self.session)
}
pub fn push(&mut self, value: T) -> Option<T> {
self.modified = true;
self.storage.push(self.session, value)
}
pub fn push_as_oldest(&mut self, value: T) -> Option<T> {
self.modified = true;
self.storage.push_as_oldest(self.session, value)
}
pub fn open_gate(&mut self) {
if let Some(gate) = self.gate.map.get_mut(&self.session) {
if *gate != Gate::Open {
*gate = Gate::Open;
self.modified = true;
}
}
}
pub fn close_gate(&mut self) {
if let Some(gate) = self.gate.map.get_mut(&self.session) {
*gate = Gate::Closed;
}
}
pub fn gate_action(&mut self, action: Gate) {
match action {
Gate::Open => self.open_gate(),
Gate::Closed => self.close_gate(),
}
}
pub fn pulse(&mut self) {
self.modified = true;
}
fn new(
storage: Mut<'a, BufferStorage<T>>,
gate: Mut<'a, GateState>,
buffer: Entity,
session: Entity,
accessor: Entity,
commands: &'a mut Commands<'w, 's>,
) -> Self {
Self {
storage,
gate,
buffer,
session,
accessor: Some(accessor),
commands,
modified: false,
}
}
}
impl<'w, 's, 'a, T> Drop for BufferMut<'w, 's, 'a, T>
where
T: 'static + Send + Sync,
{
fn drop(&mut self) {
if self.modified {
self.commands.add(NotifyBufferUpdate::new(
self.buffer,
self.session,
self.accessor,
));
}
}
}
#[cfg(test)]
mod tests {
use crate::{prelude::*, testing::*, Gate};
use std::future::Future;
#[test]
fn test_buffer_key_access() {
let mut context = TestingContext::minimal_plugins();
let add_buffers_by_pull_cb = add_buffers_by_pull.into_blocking_callback();
let add_from_buffer_cb = add_from_buffer.into_blocking_callback();
let multiply_buffers_by_copy_cb = multiply_buffers_by_copy.into_blocking_callback();
let workflow = context.spawn_io_workflow(|scope: Scope<(f64, f64), f64>, builder| {
scope
.input
.chain(builder)
.unzip()
.listen(builder)
.then(multiply_buffers_by_copy_cb)
.connect(scope.terminate);
});
let mut promise =
context.command(|commands| commands.request((2.0, 3.0), workflow).take_response());
context.run_with_conditions(&mut promise, Duration::from_secs(2));
assert!(promise.take().available().is_some_and(|value| value == 6.0));
assert!(context.no_unhandled_errors());
let workflow = context.spawn_io_workflow(|scope: Scope<(f64, f64), f64>, builder| {
scope
.input
.chain(builder)
.unzip()
.listen(builder)
.then(add_buffers_by_pull_cb)
.dispose_on_none()
.connect(scope.terminate);
});
let mut promise =
context.command(|commands| commands.request((4.0, 5.0), workflow).take_response());
context.run_with_conditions(&mut promise, Duration::from_secs(2));
assert!(promise.take().available().is_some_and(|value| value == 9.0));
assert!(context.no_unhandled_errors());
let workflow =
context.spawn_io_workflow(|scope: Scope<(f64, f64), Result<f64, f64>>, builder| {
let (branch_to_adder, branch_to_buffer) = scope.input.chain(builder).unzip();
let buffer = builder.create_buffer::<f64>(BufferSettings::keep_first(10));
builder.connect(branch_to_buffer, buffer.input_slot());
let adder_node = branch_to_adder
.chain(builder)
.with_access(buffer)
.then_node(add_from_buffer_cb.clone());
adder_node.output.chain(builder).fork_result(
|chain| {
chain
.with_access(buffer)
.then(add_from_buffer_cb.clone())
.connect(scope.terminate)
},
|chain| chain.with_access(buffer).connect(adder_node.input),
);
});
let mut promise =
context.command(|commands| commands.request((2.0, 3.0), workflow).take_response());
context.run_with_conditions(&mut promise, Duration::from_secs(2));
assert!(promise
.take()
.available()
.is_some_and(|value| value.is_err_and(|n| n == 5.0)));
assert!(context.no_unhandled_errors());
let workflow = context.spawn_io_workflow(|scope, builder| {
let (branch_to_adder, branch_to_buffer) = scope.input.chain(builder).unzip();
let buffer = builder.create_buffer::<f64>(BufferSettings::keep_first(10));
builder.connect(branch_to_buffer, buffer.input_slot());
let access = builder.create_buffer_access(buffer);
builder.connect(branch_to_adder, access.input);
access
.output
.chain(builder)
.then(add_from_buffer_cb.clone())
.fork_result(
|ok| {
let (output, builder) = ok.unpack();
let second_access = builder.create_buffer_access(buffer);
builder.connect(output, second_access.input);
second_access
.output
.chain(builder)
.then(add_from_buffer_cb.clone())
.connect(scope.terminate);
},
|err| err.connect(access.input),
);
});
let mut promise =
context.command(|commands| commands.request((2.0, 3.0), workflow).take_response());
context.run_with_conditions(&mut promise, Duration::from_secs(2));
assert!(promise
.take()
.available()
.is_some_and(|value| value.is_err_and(|n| n == 5.0)));
assert!(context.no_unhandled_errors());
}
fn add_from_buffer(
In((lhs, key)): In<(f64, BufferKey<f64>)>,
mut access: BufferAccessMut<f64>,
) -> Result<f64, f64> {
let rhs = access.get_mut(&key).map_err(|_| lhs)?.pull().ok_or(lhs)?;
Ok(lhs + rhs)
}
fn multiply_buffers_by_copy(
In((key_a, key_b)): In<(BufferKey<f64>, BufferKey<f64>)>,
access: BufferAccess<f64>,
) -> f64 {
*access.get(&key_a).unwrap().oldest().unwrap()
* *access.get(&key_b).unwrap().oldest().unwrap()
}
fn add_buffers_by_pull(
In((key_a, key_b)): In<(BufferKey<f64>, BufferKey<f64>)>,
mut access: BufferAccessMut<f64>,
) -> Option<f64> {
if access.get(&key_a).unwrap().is_empty() {
return None;
}
if access.get(&key_b).unwrap().is_empty() {
return None;
}
let rhs = access.get_mut(&key_a).unwrap().pull().unwrap();
let lhs = access.get_mut(&key_b).unwrap().pull().unwrap();
Some(rhs + lhs)
}
#[test]
fn test_buffer_key_lifecycle() {
let mut context = TestingContext::minimal_plugins();
let workflow = context.spawn_io_workflow(|scope, builder| {
let buffer = builder.create_buffer::<Register>(BufferSettings::keep_all());
builder
.listen(buffer)
.then(pull_register_from_buffer.into_blocking_callback())
.dispose_on_none()
.connect(scope.terminate);
let decrement_register_cb = decrement_register.into_blocking_callback();
let async_decrement_register_cb = async_decrement_register.as_callback();
scope
.input
.chain(builder)
.with_access(buffer)
.then(decrement_register_cb.clone())
.with_access(buffer)
.then(async_decrement_register_cb.clone())
.dispose_on_none()
.with_access(buffer)
.then(decrement_register_cb.clone())
.with_access(buffer)
.then(async_decrement_register_cb)
.unused();
});
run_register_test(workflow, 0, true, &mut context);
run_register_test(workflow, 1, true, &mut context);
run_register_test(workflow, 2, true, &mut context);
run_register_test(workflow, 3, true, &mut context);
run_register_test(workflow, 4, false, &mut context);
run_register_test(workflow, 5, false, &mut context);
run_register_test(workflow, 6, false, &mut context);
let workflow = context.spawn_io_workflow(|scope, builder| {
let buffer = builder.create_buffer::<Register>(BufferSettings::keep_all());
builder
.listen(buffer)
.then(pull_register_from_buffer.into_blocking_callback())
.dispose_on_none()
.connect(scope.terminate);
let decrement_register_and_pass_keys_cb =
decrement_register_and_pass_keys.into_blocking_callback();
let async_decrement_register_and_pass_keys_cb =
async_decrement_register_and_pass_keys.as_callback();
let (loose_end, dead_end): (_, Output<Option<Register>>) = scope
.input
.chain(builder)
.with_access(buffer)
.then(decrement_register_and_pass_keys_cb.clone())
.then(async_decrement_register_and_pass_keys_cb.clone())
.dispose_on_none()
.map_block(|v| (v, None))
.unzip();
dead_end.chain(builder).dispose_on_none().unused();
loose_end
.chain(builder)
.then(async_decrement_register_and_pass_keys_cb)
.dispose_on_none()
.then(decrement_register_and_pass_keys_cb)
.unused();
});
run_register_test(workflow, 0, true, &mut context);
run_register_test(workflow, 1, true, &mut context);
run_register_test(workflow, 2, true, &mut context);
run_register_test(workflow, 3, true, &mut context);
run_register_test(workflow, 4, false, &mut context);
run_register_test(workflow, 5, false, &mut context);
run_register_test(workflow, 6, false, &mut context);
}
fn run_register_test(
workflow: Service<Register, Register>,
initial_value: u64,
expect_success: bool,
context: &mut TestingContext,
) {
let mut promise = context.command(|commands| {
commands
.request(Register::new(initial_value), workflow)
.take_response()
});
context.run_while_pending(&mut promise);
if expect_success {
assert!(promise
.take()
.available()
.is_some_and(|r| r.finished_with(initial_value)));
} else {
assert!(promise.take().is_cancelled());
}
assert!(context.no_unhandled_errors());
}
#[derive(Clone, Copy, Debug)]
struct Register {
in_slot: u64,
out_slot: u64,
}
impl Register {
fn new(start_from: u64) -> Self {
Self {
in_slot: start_from,
out_slot: 0,
}
}
fn finished_with(&self, out_slot: u64) -> bool {
self.in_slot == 0 && self.out_slot == out_slot
}
}
fn pull_register_from_buffer(
In(key): In<BufferKey<Register>>,
mut access: BufferAccessMut<Register>,
) -> Option<Register> {
access.get_mut(&key).ok()?.pull()
}
fn decrement_register(
In((mut register, key)): In<(Register, BufferKey<Register>)>,
mut access: BufferAccessMut<Register>,
) -> Register {
if register.in_slot == 0 {
access.get_mut(&key).unwrap().push(register);
return register;
}
register.in_slot -= 1;
register.out_slot += 1;
register
}
fn decrement_register_and_pass_keys(
In((mut register, key)): In<(Register, BufferKey<Register>)>,
mut access: BufferAccessMut<Register>,
) -> (Register, BufferKey<Register>) {
if register.in_slot == 0 {
access.get_mut(&key).unwrap().push(register);
return (register, key);
}
register.in_slot -= 1;
register.out_slot += 1;
(register, key)
}
fn async_decrement_register(
In(input): In<AsyncCallback<(Register, BufferKey<Register>)>>,
) -> impl Future<Output = Option<Register>> {
async move {
input
.channel
.query(input.request, decrement_register.into_blocking_callback())
.await
.available()
}
}
fn async_decrement_register_and_pass_keys(
In(input): In<AsyncCallback<(Register, BufferKey<Register>)>>,
) -> impl Future<Output = Option<(Register, BufferKey<Register>)>> {
async move {
input
.channel
.query(
input.request,
decrement_register_and_pass_keys.into_blocking_callback(),
)
.await
.available()
}
}
#[test]
fn test_buffer_key_gate_control() {
let mut context = TestingContext::minimal_plugins();
let workflow = context.spawn_io_workflow(|scope, builder| {
let service = builder.commands().spawn_service(gate_access_test_open_loop);
let buffer = builder.create_buffer(BufferSettings::keep_all());
builder.connect(scope.input, buffer.input_slot());
builder
.listen(buffer)
.then_gate_close(buffer)
.then(service)
.fork_unzip((
|chain: Chain<_>| chain.dispose_on_none().connect(buffer.input_slot()),
|chain: Chain<_>| chain.dispose_on_none().connect(scope.terminate),
));
});
let mut promise = context.command(|commands| commands.request(0, workflow).take_response());
context.run_with_conditions(&mut promise, Duration::from_secs(2));
assert!(promise.take().available().is_some_and(|v| v == 5));
assert!(context.no_unhandled_errors());
}
fn gate_access_test_open_loop(
In(BlockingService { request: key, .. }): BlockingServiceInput<BufferKey<u64>>,
mut access: BufferAccessMut<u64>,
) -> (Option<u64>, Option<u64>) {
let mut buffer = access.get_mut(&key).unwrap();
let value = buffer.pull().unwrap();
assert_eq!(buffer.gate(), Gate::Closed);
buffer.open_gate();
if value >= 5 {
(None, Some(value))
} else {
(Some(value + 1), None)
}
}
#[test]
fn test_closed_loop_key_access() {
let mut context = TestingContext::minimal_plugins();
let delay = context.spawn_delay(Duration::from_secs_f32(0.1));
let workflow = context.spawn_io_workflow(|scope, builder| {
let service = builder
.commands()
.spawn_service(gate_access_test_closed_loop);
let buffer = builder.create_buffer(BufferSettings::keep_all());
builder.connect(scope.input, buffer.input_slot());
builder.listen(buffer).then(service).fork_unzip((
|chain: Chain<_>| {
chain
.dispose_on_none()
.then(delay)
.connect(buffer.input_slot())
},
|chain: Chain<_>| chain.dispose_on_none().connect(scope.terminate),
));
});
let mut promise = context.command(|commands| commands.request(3, workflow).take_response());
context.run_with_conditions(&mut promise, Duration::from_secs(2));
assert!(promise.take().available().is_some_and(|v| v == 0));
assert!(context.no_unhandled_errors());
}
fn gate_access_test_closed_loop(
In(BlockingService { request: key, .. }): BlockingServiceInput<BufferKey<u64>>,
mut access: BufferAccessMut<u64>,
) -> (Option<u64>, Option<u64>) {
let mut buffer = access.get_mut(&key).unwrap().allow_closed_loops();
if let Some(value) = buffer.pull() {
(Some(value + 1), None)
} else {
(None, Some(0))
}
}
}