use bevy_ecs::prelude::Entity;
use std::{
collections::{BTreeMap, HashMap, HashSet},
fmt::Debug,
hash::Hash,
};
use thiserror::Error as ThisError;
use crate::{Builder, Chain, ConnectToSplit, OperationResult, Output, UnusedTarget};
pub trait Splittable: Sized {
type Key: 'static + Send + Sync + Eq + Hash + Clone + Debug;
type Identifier: 'static + Send + Sync;
type Item: 'static + Send + Sync;
fn validate(key: &Self::Key) -> bool;
fn next(key: &Option<Self::Key>) -> Option<Self::Key>;
fn split(
self,
dispatcher: SplitDispatcher<'_, Self::Key, Self::Identifier, Self::Item>,
) -> OperationResult;
}
#[must_use]
pub struct SplitBuilder<'w, 's, 'a, 'b, T: Splittable> {
outputs: SplitOutputs<T>,
builder: &'b mut Builder<'w, 's, 'a>,
}
impl<'w, 's, 'a, 'b, T: Splittable> std::fmt::Debug for SplitBuilder<'w, 's, 'a, 'b, T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SplitBuilder")
.field("outputs", &self.outputs)
.finish()
}
}
impl<'w, 's, 'a, 'b, T: 'static + Splittable> SplitBuilder<'w, 's, 'a, 'b, T> {
pub fn outputs(self) -> SplitOutputs<T> {
self.outputs
}
pub fn unpack(self) -> (SplitOutputs<T>, &'b mut Builder<'w, 's, 'a>) {
(self.outputs, self.builder)
}
pub fn chain_for<U>(
&mut self,
key: T::Key,
f: impl FnOnce(Chain<(T::Identifier, T::Item)>) -> U,
) -> SplitChainResult<U> {
let output = match self.output_for(key) {
Ok(output) => output,
Err(err) => return Err(err),
};
let u = f(self.builder.chain(output));
Ok(u)
}
pub fn specific_chain<U>(
&mut self,
specific_key: <T::Key as FromSpecific>::SpecificKey,
f: impl FnOnce(Chain<(T::Identifier, T::Item)>) -> U,
) -> SplitChainResult<U>
where
T::Key: FromSpecific,
{
self.chain_for(T::Key::from_specific(specific_key), f)
}
pub fn sequential_chain<U>(
&mut self,
sequence_number: usize,
f: impl FnOnce(Chain<(T::Identifier, T::Item)>) -> U,
) -> SplitChainResult<U>
where
T::Key: FromSequential,
{
self.chain_for(T::Key::from_sequential(sequence_number), f)
}
pub fn remaining_chain<U>(
&mut self,
f: impl FnOnce(Chain<(T::Identifier, T::Item)>) -> U,
) -> SplitChainResult<U>
where
T::Key: ForRemaining,
{
self.chain_for(T::Key::for_remaining(), f)
}
pub fn next_chain<U>(
mut self,
f: impl FnOnce(T::Key, Chain<(T::Identifier, T::Item)>) -> U,
) -> SplitChainResult<U> {
let Some((key, output)) = self.next() else {
return Err(SplitConnectionError::KeyOutOfBounds);
};
Ok(f(key, self.builder.chain(output)))
}
pub fn branch_for(
mut self,
key: T::Key,
f: impl FnOnce(Chain<(T::Identifier, T::Item)>),
) -> SplitBranchResult<'w, 's, 'a, 'b, T> {
let output = match self.output_for(key) {
Ok(output) => output,
Err(err) => return Err((self, err)),
};
f(output.chain(self.builder));
Ok(self)
}
pub fn specific_branch(
self,
specific_key: <T::Key as FromSpecific>::SpecificKey,
f: impl FnOnce(Chain<(T::Identifier, T::Item)>),
) -> SplitBranchResult<'w, 's, 'a, 'b, T>
where
T::Key: FromSpecific,
{
self.branch_for(T::Key::from_specific(specific_key), f)
}
pub fn sequential_branch(
self,
sequence_number: usize,
f: impl FnOnce(Chain<(T::Identifier, T::Item)>),
) -> SplitBranchResult<'w, 's, 'a, 'b, T>
where
T::Key: FromSequential,
{
self.branch_for(T::Key::from_sequential(sequence_number), f)
}
pub fn remaining_branch(
self,
f: impl FnOnce(Chain<(T::Identifier, T::Item)>),
) -> SplitBranchResult<'w, 's, 'a, 'b, T>
where
T::Key: ForRemaining,
{
self.branch_for(T::Key::for_remaining(), f)
}
pub fn next_branch(
mut self,
f: impl FnOnce(T::Key, Chain<(T::Identifier, T::Item)>),
) -> SplitBranchResult<'w, 's, 'a, 'b, T> {
let Some((key, output)) = self.next() else {
return Err((self, SplitConnectionError::KeyOutOfBounds));
};
f(key, output.chain(self.builder));
Ok(self)
}
pub fn output_for(
&mut self,
key: T::Key,
) -> Result<Output<(T::Identifier, T::Item)>, SplitConnectionError> {
if !T::validate(&key) {
return Err(SplitConnectionError::KeyOutOfBounds);
}
if !self.outputs.used.insert(key.clone()) {
return Err(SplitConnectionError::KeyAlreadyUsed);
}
let target = self.builder.commands.spawn(UnusedTarget).id();
self.builder.commands.queue(ConnectToSplit::<T> {
source: self.outputs.source,
target,
key,
});
Ok(Output::new(self.outputs.scope, target))
}
pub fn specific_output(
&mut self,
specific_key: <T::Key as FromSpecific>::SpecificKey,
) -> Result<Output<(T::Identifier, T::Item)>, SplitConnectionError>
where
T::Key: FromSpecific,
{
self.output_for(T::Key::from_specific(specific_key))
}
pub fn sequential_output(
&mut self,
sequence_number: usize,
) -> Result<Output<(T::Identifier, T::Item)>, SplitConnectionError>
where
T::Key: FromSequential,
{
self.output_for(T::Key::from_sequential(sequence_number))
}
pub fn remaining_output(
&mut self,
) -> Result<Output<(T::Identifier, T::Item)>, SplitConnectionError>
where
T::Key: ForRemaining,
{
self.output_for(T::Key::for_remaining())
}
pub fn unused(self) {
}
pub(crate) fn new(source: Entity, builder: &'b mut Builder<'w, 's, 'a>) -> Self {
Self {
outputs: SplitOutputs::new(builder.scope(), source),
builder,
}
}
}
impl<'w, 's, 'a, 'b, T: 'static + Splittable> Iterator for SplitBuilder<'w, 's, 'a, 'b, T> {
type Item = (T::Key, Output<(T::Identifier, T::Item)>);
fn next(&mut self) -> Option<Self::Item> {
loop {
let next_key = T::next(&self.outputs.last_key)?;
self.outputs.last_key = Some(next_key.clone());
match self.output_for(next_key.clone()) {
Ok(output) => {
return Some((next_key, output));
}
Err(SplitConnectionError::KeyAlreadyUsed) => {
continue;
}
Err(SplitConnectionError::KeyOutOfBounds) => {
return None;
}
}
}
}
}
#[must_use]
pub struct SplitOutputs<T: Splittable> {
scope: Entity,
source: Entity,
last_key: Option<T::Key>,
used: HashSet<T::Key>,
}
impl<T: Splittable> std::fmt::Debug for SplitOutputs<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct(&format!("SplitOutputs<{}>", std::any::type_name::<T>()))
.field("scope", &self.scope)
.field("source", &self.source)
.field("last_key", &self.last_key)
.field("used", &self.used)
.finish()
}
}
impl<T: Splittable> SplitOutputs<T> {
pub fn build<'w, 's, 'a, 'b>(
self,
builder: &'b mut Builder<'w, 's, 'a>,
) -> SplitBuilder<'w, 's, 'a, 'b, T> {
assert_eq!(self.scope, builder.scope());
SplitBuilder {
outputs: self,
builder,
}
}
pub(crate) fn new(scope: Entity, source: Entity) -> Self {
Self {
scope,
source,
last_key: None,
used: Default::default(),
}
}
}
pub type SplitBranchResult<'w, 's, 'a, 'b, T> = Result<
SplitBuilder<'w, 's, 'a, 'b, T>,
(SplitBuilder<'w, 's, 'a, 'b, T>, SplitConnectionError),
>;
pub type SplitChainResult<U> = Result<U, SplitConnectionError>;
pub trait IgnoreSplitChainResult<'w, 's, 'a, 'b, T: Splittable> {
fn ignore_result(self) -> SplitBuilder<'w, 's, 'a, 'b, T>;
}
impl<'w, 's, 'a, 'b, T: Splittable> IgnoreSplitChainResult<'w, 's, 'a, 'b, T>
for SplitBranchResult<'w, 's, 'a, 'b, T>
{
fn ignore_result(self) -> SplitBuilder<'w, 's, 'a, 'b, T> {
match self {
Ok(split) => split,
Err((split, _)) => split,
}
}
}
#[derive(ThisError, Debug, Clone)]
#[error("An error occurred while trying to connect to a split")]
pub enum SplitConnectionError {
KeyAlreadyUsed,
KeyOutOfBounds,
}
pub struct SplitDispatcher<'a, Key, Identifier, Item> {
pub(crate) connections: &'a HashMap<Key, usize>,
pub(crate) outputs: &'a mut Vec<Vec<(Identifier, Item)>>,
}
impl<'a, Key, Identifier, Item> SplitDispatcher<'a, Key, Identifier, Item>
where
Key: 'static + Send + Sync + Eq + Hash + Clone + Debug,
Identifier: 'static + Send + Sync,
Item: 'static + Send + Sync,
{
pub fn outputs_for<'o>(&'o mut self, key: &Key) -> Option<&'o mut Vec<(Identifier, Item)>> {
let index = *self.connections.get(key)?;
if self.outputs.len() <= index {
self.outputs.resize_with(index + 1, Vec::new);
}
self.outputs.get_mut(index)
}
}
pub trait FromSequential {
fn from_sequential(seq: usize) -> Self;
}
pub trait ForRemaining {
fn for_remaining() -> Self;
}
pub trait FromSpecific {
type SpecificKey;
fn from_specific(specific: Self::SpecificKey) -> Self;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum ListSplitKey {
Sequential(usize),
Remaining,
}
impl FromSequential for ListSplitKey {
fn from_sequential(seq: usize) -> Self {
ListSplitKey::Sequential(seq)
}
}
impl ForRemaining for ListSplitKey {
fn for_remaining() -> Self {
ListSplitKey::Remaining
}
}
pub struct SplitAsList<T: 'static + Send + Sync + IntoIterator> {
pub contents: T,
}
impl<T: 'static + Send + Sync + IntoIterator> SplitAsList<T> {
pub fn new(contents: T) -> Self {
Self { contents }
}
}
impl<T> Splittable for SplitAsList<T>
where
T: 'static + Send + Sync + IntoIterator,
T::Item: 'static + Send + Sync,
{
type Key = ListSplitKey;
type Identifier = usize;
type Item = T::Item;
fn validate(_: &Self::Key) -> bool {
true
}
fn next(key: &Option<Self::Key>) -> Option<Self::Key> {
if let Some(key) = key {
match key {
ListSplitKey::Sequential(k) => Some(ListSplitKey::Sequential(*k + 1)),
ListSplitKey::Remaining => None,
}
} else {
Some(ListSplitKey::Sequential(0))
}
}
fn split(
self,
mut dispatcher: SplitDispatcher<'_, Self::Key, Self::Identifier, Self::Item>,
) -> OperationResult {
for (index, value) in self.contents.into_iter().enumerate() {
match dispatcher.outputs_for(&ListSplitKey::Sequential(index)) {
Some(outputs) => {
outputs.push((index, value));
}
None => {
if let Some(outputs) = dispatcher.outputs_for(&ListSplitKey::Remaining) {
outputs.push((index, value));
}
}
}
}
Ok(())
}
}
impl<T: 'static + Send + Sync> Splittable for Vec<T> {
type Key = ListSplitKey;
type Identifier = usize;
type Item = T;
fn validate(_: &Self::Key) -> bool {
true
}
fn next(key: &Option<Self::Key>) -> Option<Self::Key> {
SplitAsList::<Self>::next(key)
}
fn split(
self,
dispatcher: SplitDispatcher<'_, Self::Key, Self::Identifier, Self::Item>,
) -> OperationResult {
SplitAsList::new(self).split(dispatcher)
}
}
impl<T: 'static + Send + Sync, const N: usize> Splittable for smallvec::SmallVec<[T; N]> {
type Key = ListSplitKey;
type Identifier = usize;
type Item = T;
fn validate(_: &Self::Key) -> bool {
true
}
fn next(key: &Option<Self::Key>) -> Option<Self::Key> {
SplitAsList::<Self>::next(key)
}
fn split(
self,
dispatcher: SplitDispatcher<'_, Self::Key, Self::Identifier, Self::Item>,
) -> OperationResult {
SplitAsList::new(self).split(dispatcher)
}
}
impl<T: 'static + Send + Sync, const N: usize> Splittable for [T; N] {
type Key = ListSplitKey;
type Identifier = usize;
type Item = T;
fn validate(key: &Self::Key) -> bool {
match key {
ListSplitKey::Sequential(s) => *s < N,
ListSplitKey::Remaining => true,
}
}
fn next(key: &Option<Self::Key>) -> Option<Self::Key> {
let mut key = SplitAsList::<Self>::next(key);
if key.map_or(false, |key| Self::validate(&key)) {
key.take()
} else {
None
}
}
fn split(
self,
dispatcher: SplitDispatcher<'_, Self::Key, Self::Identifier, Self::Item>,
) -> OperationResult {
SplitAsList::new(self).split(dispatcher)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum MapSplitKey<K> {
Specific(K),
Sequential(usize),
Remaining,
}
impl<K> MapSplitKey<K> {
pub fn specific(self) -> Option<K> {
match self {
MapSplitKey::Specific(key) => Some(key),
_ => None,
}
}
pub fn next(this: &Option<Self>) -> Option<Self> {
match this {
Some(key) => {
match key {
MapSplitKey::Sequential(index) => Some(MapSplitKey::Sequential(index + 1)),
MapSplitKey::Specific(_) => None,
MapSplitKey::Remaining => None,
}
}
None => Some(MapSplitKey::Sequential(0)),
}
}
}
impl<K> From<K> for MapSplitKey<K> {
fn from(value: K) -> Self {
MapSplitKey::Specific(value)
}
}
impl<K> FromSpecific for MapSplitKey<K> {
type SpecificKey = K;
fn from_specific(specific: Self::SpecificKey) -> Self {
Self::Specific(specific)
}
}
impl<K> FromSequential for MapSplitKey<K> {
fn from_sequential(seq: usize) -> Self {
Self::Sequential(seq)
}
}
impl<K> ForRemaining for MapSplitKey<K> {
fn for_remaining() -> Self {
Self::Remaining
}
}
pub struct SplitAsMap<K, V, M>
where
K: 'static + Send + Sync + Eq + Hash + Clone + Debug,
V: 'static + Send + Sync,
M: 'static + Send + Sync + IntoIterator<Item = (K, V)>,
{
pub contents: M,
_ignore: std::marker::PhantomData<(K, V)>,
}
impl<K, V, M> SplitAsMap<K, V, M>
where
K: 'static + Send + Sync + Eq + Hash + Clone + Debug,
V: 'static + Send + Sync,
M: 'static + Send + Sync + IntoIterator<Item = (K, V)>,
{
pub fn new(contents: M) -> Self {
Self {
contents,
_ignore: Default::default(),
}
}
}
impl<K, V, M> Splittable for SplitAsMap<K, V, M>
where
K: 'static + Send + Sync + Eq + Hash + Clone + Debug,
V: 'static + Send + Sync,
M: 'static + Send + Sync + IntoIterator<Item = (K, V)>,
{
type Key = MapSplitKey<K>;
type Identifier = K;
type Item = V;
fn validate(_: &Self::Key) -> bool {
true
}
fn next(key: &Option<Self::Key>) -> Option<Self::Key> {
MapSplitKey::next(key)
}
fn split(
self,
mut dispatcher: SplitDispatcher<'_, Self::Key, Self::Identifier, Self::Item>,
) -> OperationResult {
let mut next_seq = 0;
for (specific_key, value) in self.contents.into_iter() {
let key = MapSplitKey::Specific(specific_key);
match dispatcher.outputs_for(&key) {
Some(outputs) => {
outputs.push((key.specific().unwrap(), value));
}
None => {
let seq = MapSplitKey::Sequential(next_seq);
next_seq += 1;
match dispatcher.outputs_for(&seq) {
Some(outputs) => {
outputs.push((key.specific().unwrap(), value));
}
None => {
let remaining = MapSplitKey::Remaining;
if let Some(outputs) = dispatcher.outputs_for(&remaining) {
outputs.push((key.specific().unwrap(), value));
}
}
}
}
}
}
Ok(())
}
}
impl<K, V> Splittable for HashMap<K, V>
where
K: 'static + Send + Sync + Eq + Hash + Clone + Debug,
V: 'static + Send + Sync,
{
type Key = MapSplitKey<K>;
type Identifier = K;
type Item = V;
fn validate(_: &Self::Key) -> bool {
true
}
fn next(key: &Option<Self::Key>) -> Option<Self::Key> {
SplitAsMap::<K, V, Self>::next(key)
}
fn split(
self,
dispatcher: SplitDispatcher<'_, Self::Key, Self::Identifier, Self::Item>,
) -> OperationResult {
SplitAsMap::new(self).split(dispatcher)
}
}
impl<K, V> Splittable for BTreeMap<K, V>
where
K: 'static + Send + Sync + Eq + Hash + Clone + Debug,
V: 'static + Send + Sync,
{
type Key = MapSplitKey<K>;
type Identifier = K;
type Item = V;
fn validate(_: &Self::Key) -> bool {
true
}
fn next(key: &Option<Self::Key>) -> Option<Self::Key> {
SplitAsMap::<K, V, Self>::next(key)
}
fn split(
self,
dispatcher: SplitDispatcher<'_, Self::Key, Self::Identifier, Self::Item>,
) -> OperationResult {
SplitAsMap::new(self).split(dispatcher)
}
}
#[cfg(test)]
mod tests {
use crate::{testing::*, *};
use std::collections::{BTreeMap, HashMap};
#[test]
fn test_split_array() {
let mut context = TestingContext::minimal_plugins();
let workflow = context.spawn_io_workflow(|scope, builder| {
builder
.chain(scope.start)
.split(|split| {
let mut outputs = Vec::new();
split
.sequential_branch(0, |chain| {
outputs.push(chain.value().map_block(|v| v + 0.0).output());
})
.unwrap()
.sequential_branch(2, |chain| {
outputs
.push(chain.value().map_async(|v| async move { v + 2.0 }).output());
})
.unwrap()
.sequential_branch(4, |chain| {
outputs.push(chain.value().map_block(|v| v + 4.0).output());
})
.unwrap()
.unused();
outputs
})
.join_vec::<5>(builder)
.connect(scope.terminate);
});
let r = context.resolve_request([5.0, 4.0, 3.0, 2.0, 1.0], workflow);
assert_eq!(r, [5.0, 5.0, 5.0].into());
let workflow = context.spawn_io_workflow(|scope: Scope<[f64; 3], f64>, builder| {
builder.chain(scope.start).split(|split| {
let split = split
.sequential_branch(0, |chain| {
chain
.fork_clone((
|chain: Chain<_>| chain.unused(),
|chain: Chain<_>| chain.unused(),
|chain: Chain<_>| chain.unused(),
));
})
.unwrap();
let err = split.sequential_branch(3, |chain| {
chain.value().connect(scope.terminate);
});
assert!(matches!(
&err,
Err((_, SplitConnectionError::KeyOutOfBounds))
));
let split = err
.ignore_result()
.sequential_branch(1, |chain| {
chain.unused();
})
.unwrap();
let err = split.sequential_branch(0, |chain| {
chain.value().connect(scope.terminate);
});
assert!(matches!(
&err,
Err((_, SplitConnectionError::KeyAlreadyUsed))
));
err.ignore_result()
.sequential_branch(2, |chain| {
chain.value().connect(scope.terminate);
})
.unwrap()
.unused();
});
});
let r = context
.try_resolve_request([1.0, 2.0, 3.0], workflow, 1)
.unwrap();
assert_eq!(r, 3.0);
}
#[test]
fn test_split_map() {
let mut context = TestingContext::minimal_plugins();
let km_to_miles = 0.621371;
let per_second_to_per_hour = 3600.0;
let convert_speed = move |v: f64| v * km_to_miles * per_second_to_per_hour;
let convert_distance = move |d: f64| d * km_to_miles;
let workflow =
context.spawn_io_workflow(|scope: Scope<BTreeMap<String, f64>, _>, builder| {
let collector = builder.create_collect_all::<_, 16>();
builder.chain(scope.start).split(|split| {
split
.specific_branch("speed".to_owned(), |chain| {
chain
.map_block(move |(k, v)| (k, convert_speed(v)))
.connect(collector.input);
})
.ignore_result()
.specific_branch("velocity".to_owned(), |chain| {
chain
.map_async(move |(k, v)| async move { (k, convert_speed(v)) })
.connect(collector.input);
})
.unwrap()
.specific_branch("distance".to_owned(), |chain| {
chain
.map_block(move |(k, v)| (k, convert_distance(v)))
.connect(collector.input);
})
.unwrap()
.sequential_branch(0, |chain| {
chain
.map_block(move |(k, v)| (k, 0.0 * v))
.connect(collector.input);
})
.unwrap()
.sequential_branch(1, |chain| {
chain
.map_async(move |(k, v)| async move { (k, 1.0 * v) })
.connect(collector.input);
})
.unwrap()
.sequential_branch(2, |chain| {
chain
.map_block(move |(k, v)| (k, 2.0 * v))
.connect(collector.input);
})
.unwrap()
.remaining_branch(|chain| {
chain.connect(collector.input);
})
.unwrap()
.unused();
});
builder
.chain(collector.output)
.map_block(|v| HashMap::<String, f64>::from_iter(v))
.connect(scope.terminate);
});
let input_map: BTreeMap<String, f64> = [
("a", 3.14159),
("b", 2.71828),
("c", 4.0),
("speed", 16.1),
("velocity", -32.4),
("distance", 4325.78),
("foo", 42.0),
("fib", 78.3),
("dib", -22.1),
]
.into_iter()
.map(|(k, v)| (k.to_owned(), v))
.collect();
let result = context.resolve_request(input_map.clone(), workflow);
assert_eq!(result.len(), input_map.len());
assert_eq!(result["a"], input_map["a"] * 0.0);
assert_eq!(result["b"], input_map["b"] * 1.0);
assert_eq!(result["c"], input_map["c"] * 2.0);
assert_eq!(result["speed"], convert_speed(input_map["speed"]));
assert_eq!(result["velocity"], convert_speed(input_map["velocity"]));
assert_eq!(result["distance"], convert_distance(input_map["distance"]));
assert_eq!(result["foo"], input_map["foo"]);
assert_eq!(result["fib"], input_map["fib"]);
assert_eq!(result["dib"], input_map["dib"]);
}
#[test]
fn test_array_split_limit() {
let mut context = TestingContext::minimal_plugins();
let workflow = context.spawn_io_workflow(|scope, builder| {
builder.chain(scope.start).split(|split| {
let err = split
.next_branch(|_, chain| {
chain.value().connect(scope.terminate);
})
.unwrap()
.next_branch(|_, chain| {
chain.value().connect(scope.terminate);
})
.unwrap()
.next_branch(|_, chain| {
chain.value().connect(scope.terminate);
})
.unwrap()
.next_branch(|_, chain| {
chain.value().connect(scope.terminate);
})
.unwrap()
.next_branch(|_, chain| {
chain.value().connect(scope.terminate);
});
assert!(matches!(err, Err(_)));
})
});
let result = context
.try_resolve_request([1, 2, 3, 4], workflow, 1)
.unwrap();
assert_eq!(result, 1);
}
}