use bevy_ecs::prelude::{Component, Entity, World};
use smallvec::SmallVec;
use std::collections::{hash_map::Entry, HashMap, HashSet};
use crate::{
emit_disposal, immediately_downstream_of, Cancellation, CleanupContents, Disposal,
FinalizeCleanup, FinalizeCleanupRequest, Input, InputBundle, ManageCancellation, ManageInput,
Operation, OperationCleanup, OperationError, OperationReachability, OperationRequest,
OperationResult, OperationSetup, OrBroken, ReachabilityResult, ScopeEntryStorage, ScopeStorage,
SingleInputStorage, SingleTargetStorage, TrimBranch, TrimPoint, TrimPolicy,
};
pub(crate) struct Trim<T> {
branches: SmallVec<[TrimBranch; 16]>,
target: Entity,
_ignore: std::marker::PhantomData<fn(T)>,
}
impl<T> Trim<T> {
pub(crate) fn new(branches: SmallVec<[TrimBranch; 16]>, target: Entity) -> Self {
Self {
branches,
target,
_ignore: Default::default(),
}
}
}
#[derive(Component)]
struct TrimStorage {
branches: SmallVec<[TrimBranch; 16]>,
nodes: Option<Result<SmallVec<[Entity; 16]>, Cancellation>>,
}
#[derive(Component)]
struct HoldingStorage<T> {
map: HashMap<Entity, Input<T>>,
}
impl<T> Default for HoldingStorage<T> {
fn default() -> Self {
Self {
map: Default::default(),
}
}
}
impl TrimStorage {
fn new(branches: SmallVec<[TrimBranch; 16]>) -> Self {
Self {
branches,
nodes: None,
}
}
}
impl<T: 'static + Send + Sync> Operation for Trim<T> {
fn setup(self, OperationSetup { source, world }: OperationSetup) -> OperationResult {
world
.get_entity_mut(self.target)
.or_broken()?
.insert(SingleInputStorage::new(source));
world.entity_mut(source).insert((
TrimStorage::new(self.branches),
SingleTargetStorage::new(self.target),
InputBundle::<T>::new(),
CleanupContents::new(),
FinalizeCleanup::new(Self::finalize_trim),
HoldingStorage::<T>::default(),
));
Ok(())
}
fn execute(
OperationRequest {
source,
world,
roster,
}: OperationRequest,
) -> OperationResult {
let Input { session, data } = world
.get_entity_mut(source)
.or_broken()?
.take_input::<T>()?;
let source_ref = world.get_entity(source).or_broken()?;
let trim = source_ref.get::<TrimStorage>().or_broken()?;
let nodes = match &trim.nodes {
Some(Ok(nodes)) => nodes.clone(),
Some(Err(cancellation)) => {
let cancellation = cancellation.clone();
world.get_entity_mut(source).or_broken()?.emit_cancel(
session,
cancellation,
roster,
);
return Ok(());
}
None => {
let scope = world.get::<ScopeStorage>(source).or_broken()?.get();
let scope_entry = world.get::<ScopeEntryStorage>(scope).or_broken()?.0;
match calculate_nodes(scope_entry, &trim.branches, world) {
Ok(Ok(nodes)) => {
world.get_mut::<TrimStorage>(source).or_broken()?.nodes =
Some(Ok(nodes.clone()));
nodes
}
Ok(Err(cancellation)) => {
let mut source_mut = world.get_entity_mut(source).or_broken()?;
source_mut.get_mut::<TrimStorage>().or_broken()?.nodes =
Some(Err(cancellation.clone()));
source_mut.emit_cancel(session, cancellation, roster);
return Ok(());
}
Err(broken) => {
return Err(broken);
}
}
}
};
let cleanup_id = world.spawn(()).id();
world
.get_mut::<HoldingStorage<T>>(source)
.or_broken()?
.map
.insert(cleanup_id, Input { data, session });
world
.get_mut::<CleanupContents>(source)
.or_broken()?
.add_cleanup(cleanup_id, nodes.clone());
for node in nodes {
OperationCleanup::new(source, node, session, cleanup_id, world, roster).clean();
}
Ok(())
}
fn cleanup(mut clean: OperationCleanup) -> OperationResult {
clean.cleanup_inputs::<T>()?;
clean.cleanup_disposals()?;
let session = clean.cleanup.session;
clean
.world
.get_mut::<HoldingStorage<T>>(clean.source)
.or_broken()?
.map
.retain(|_, input| input.session != session);
clean.notify_cleaned()
}
fn is_reachable(mut reachability: OperationReachability) -> ReachabilityResult {
if reachability.has_input::<T>()? {
return Ok(true);
}
SingleInputStorage::is_reachable(&mut reachability)
}
}
impl<T: 'static + Send + Sync> Trim<T> {
fn finalize_trim(
FinalizeCleanupRequest {
cleanup,
world,
roster,
}: FinalizeCleanupRequest,
) -> OperationResult {
let mut source_mut = world.get_entity_mut(cleanup.cleaner).or_broken()?;
let Input { session, data } = source_mut
.get_mut::<HoldingStorage<T>>()
.or_broken()?
.map
.remove(&cleanup.cleanup_id)
.or_not_ready()?;
let nodes = source_mut
.get::<TrimStorage>()
.or_broken()?
.nodes
.clone()
.and_then(|n| n.ok())
.unwrap_or(SmallVec::new());
let target = source_mut.get::<SingleTargetStorage>().or_broken()?.get();
let disposal = Disposal::trimming(cleanup.cleaner, nodes);
emit_disposal(cleanup.cleaner, cleanup.session, disposal, world, roster);
world
.get_entity_mut(target)
.or_broken()?
.give_input(session, data, roster)
}
}
fn calculate_nodes(
scope_entry: Entity,
branches: &SmallVec<[TrimBranch; 16]>,
world: &World,
) -> Result<Result<SmallVec<[Entity; 16]>, Cancellation>, OperationError> {
let mut all_nodes: SmallVec<[Entity; 16]> = SmallVec::new();
for branch in branches {
let result = match branch.policy() {
TrimPolicy::Downstream => calculate_downstream(scope_entry, branch.from_point(), world),
TrimPolicy::Span(span) => {
calculate_all_spans(scope_entry, branch.from_point(), span, world)
}
};
match result? {
Ok(nodes) => {
all_nodes.extend(nodes);
}
Err(cancellation) => {
return Ok(Err(cancellation));
}
}
}
all_nodes.sort();
all_nodes.dedup();
Ok(Ok(all_nodes))
}
fn calculate_downstream(
scope_entry: Entity,
initial_point: TrimPoint,
world: &World,
) -> Result<Result<SmallVec<[Entity; 16]>, Cancellation>, OperationError> {
let filter = {
let mut filter = calculate_span(scope_entry, initial_point.id(), &HashSet::new(), world);
filter.remove(&initial_point.id());
filter
};
let mut visited = HashSet::new();
let mut queue: Vec<Entity> = Vec::new();
queue.push(initial_point.id());
while let Some(top) = queue.pop() {
if filter.contains(&top) {
continue;
}
if visited.insert(top) {
for next in immediately_downstream_of(top, world) {
queue.push(next);
}
}
}
if visited.is_empty() {
return Ok(Err(Cancellation::invalid_span(initial_point.id(), None)));
}
Ok(Ok(visited
.into_iter()
.filter(|n| initial_point.accept(*n))
.collect()))
}
fn calculate_all_spans(
scope_entry: Entity,
initial_point: TrimPoint,
span: &SmallVec<[TrimPoint; 16]>,
world: &World,
) -> Result<Result<SmallVec<[Entity; 16]>, Cancellation>, OperationError> {
let mut all_nodes: SmallVec<[Entity; 16]> = SmallVec::new();
for to_point in span {
match calculate_span_nodes(scope_entry, initial_point, *to_point, world)? {
Ok(nodes) => {
all_nodes.extend(nodes);
}
Err(cancellation) => {
return Ok(Err(cancellation));
}
}
}
all_nodes.sort();
all_nodes.dedup();
Ok(Ok(all_nodes))
}
fn calculate_span_nodes(
scope_entry: Entity,
initial_point: TrimPoint,
to_point: TrimPoint,
world: &World,
) -> Result<Result<SmallVec<[Entity; 16]>, Cancellation>, OperationError> {
let mut filter = calculate_span(scope_entry, initial_point.id(), &HashSet::new(), world);
filter.remove(&initial_point.id());
let span = calculate_span(initial_point.id(), to_point.id(), &filter, world);
if span.is_empty() {
return Ok(Err(Cancellation::invalid_span(
initial_point.id(),
Some(to_point.id()),
)));
}
Ok(Ok(span
.into_iter()
.filter(|n| initial_point.accept(*n) && to_point.accept(*n))
.collect()))
}
fn calculate_span(
initial_point: Entity,
to_point: Entity,
filter: &HashSet<Entity>,
world: &World,
) -> HashSet<Entity> {
let mut span_map: HashMap<Entity, HashSet<Entity>> = Default::default();
span_map.insert(initial_point, Default::default());
if filter.contains(&to_point) {
return HashSet::new();
}
let mut queue: Vec<Entity> = Vec::new();
queue.push(initial_point);
while let Some(top) = queue.pop() {
if top == to_point {
continue;
}
for next in immediately_downstream_of(top, world) {
if filter.contains(&next) {
continue;
}
let entry = span_map.entry(next);
let keep_expanding = matches!(&entry, Entry::Vacant(_));
let children = entry.or_default();
children.insert(top);
if keep_expanding {
queue.push(next);
}
}
}
let mut nodes = HashSet::new();
queue.push(to_point);
while let Some(top) = queue.pop() {
if nodes.insert(top) {
if let Some(parents) = span_map.get(&top) {
for parent in parents {
queue.push(*parent);
}
}
}
}
nodes
}