use bevy_ecs::prelude::{Component, Entity};
use crate::{
FunnelInputStorage, Input, InputBundle, Joining, ManageInput, Operation, OperationCleanup,
OperationReachability, OperationRequest, OperationResult, OperationSetup, OrBroken,
ReachabilityResult, SingleInputStorage, SingleTargetStorage,
};
pub(crate) struct Join<Buffers> {
buffers: Buffers,
target: Entity,
}
impl<Buffers> Join<Buffers> {
pub(crate) fn new(buffers: Buffers, target: Entity) -> Self {
Self { buffers, target }
}
}
#[derive(Component)]
struct BufferStorage<Buffers>(Buffers);
impl<Buffers: Joining + 'static + Send + Sync> Operation for Join<Buffers>
where
Buffers::Item: 'static + Send + Sync,
{
fn setup(self, OperationSetup { source, world }: OperationSetup) -> OperationResult {
world
.get_entity_mut(self.target)
.or_broken()?
.insert(SingleInputStorage::new(source));
self.buffers.add_listener(source, world)?;
world.entity_mut(source).insert((
FunnelInputStorage::from(self.buffers.as_input()),
BufferStorage(self.buffers),
InputBundle::<()>::new(),
SingleTargetStorage::new(self.target),
));
Ok(())
}
fn execute(
OperationRequest {
source,
world,
roster,
}: OperationRequest,
) -> OperationResult {
let mut source_mut = world.get_entity_mut(source).or_broken()?;
let Input { session, .. } = source_mut.take_input::<()>()?;
let target = source_mut.get::<SingleTargetStorage>().or_broken()?.get();
let buffers = source_mut
.get::<BufferStorage<Buffers>>()
.or_broken()?
.0
.clone();
loop {
if buffers.buffered_count(session, world)? < 1 {
return Ok(());
}
let output = buffers.fetch_for_join(session, world)?;
world
.get_entity_mut(target)
.or_broken()?
.give_input(session, output, roster)?;
}
}
fn cleanup(mut clean: OperationCleanup) -> OperationResult {
clean.cleanup_inputs::<()>()?;
clean.notify_cleaned()
}
fn is_reachable(mut r: OperationReachability) -> ReachabilityResult {
let buffers = r
.world
.get::<BufferStorage<Buffers>>(r.source)
.or_broken()?;
let inputs = r
.world
.get_entity(r.source)
.or_broken()?
.get::<FunnelInputStorage>()
.or_broken()?;
for input in &inputs.0 {
if !r.check_upstream(*input)? {
if buffers.0.buffered_count_for(*input, r.session, r.world)? == 0 {
return Ok(false);
}
}
}
Ok(true)
}
}
#[cfg(test)]
mod tests {
use crate::{prelude::*, testing::*};
#[test]
fn test_join_reachability() {
let mut context = TestingContext::minimal_plugins();
let workflow = context.spawn_io_workflow(|scope, builder| {
let node = builder.create_map(|input: AsyncMap<f64, StreamOf<f64>>| async move {
input.streams.send(input.request);
let never = async_std::future::pending::<()>();
let _ = async_std::future::timeout(Duration::from_millis(1), never).await;
input.request
});
builder.connect(scope.start, node.input);
let lhs_buffer = builder.create_buffer(BufferSettings::default());
builder.connect(node.streams, lhs_buffer.input_slot());
let rhs_buffer = builder.create_buffer(BufferSettings::default());
builder.connect(node.output, rhs_buffer.input_slot());
builder
.join((lhs_buffer, rhs_buffer))
.map_block(|(a, b)| a + b)
.connect(scope.terminate);
});
let r = context.resolve_request(2.0, workflow);
assert_eq!(r, 4.0);
}
#[test]
fn test_multi_join() {
let mut context = TestingContext::minimal_plugins();
let workflow = context.spawn_io_workflow(|scope, builder| {
let left_buffer = builder.create_buffer(BufferSettings::keep_all());
let right_buffer = builder.create_buffer(BufferSettings::keep_all());
let set_buffer = |In(input): In<(u64, BufferKey<u64>)>,
mut access: BufferAccessMut<u64>| {
let mut buffer = access.get_mut(&input.1).unwrap();
for i in 1..=input.0 {
buffer.push(i);
}
};
let set_buffer = set_buffer.into_blocking_callback();
builder.chain(scope.start).fork_clone((
|chain: Chain<_>| {
chain
.with_access(left_buffer)
.then(set_buffer.clone())
.unused();
},
|chain: Chain<_>| {
chain
.with_access(right_buffer)
.then(set_buffer.clone())
.unused();
},
));
builder
.join((left_buffer, right_buffer))
.collect_all::<8>()
.connect(scope.terminate);
});
let mut test_for_count = |count: u64| {
let r: Vec<(u64, u64)> = context
.resolve_request(count, workflow)
.into_iter()
.collect();
r
};
let r = test_for_count(1);
assert_eq!(r, vec![(1, 1)]);
let r = test_for_count(2);
assert_eq!(r, vec![(1, 1), (2, 2)]);
let r = test_for_count(3);
assert_eq!(r, vec![(1, 1), (2, 2), (3, 3)]);
let r = test_for_count(4);
assert_eq!(r, vec![(1, 1), (2, 2), (3, 3), (4, 4)]);
let r = test_for_count(5);
assert_eq!(r, vec![(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)]);
}
}