use alloc::{format, vec::Vec};
use bevy_utils::prelude::DebugName;
use core::marker::PhantomData;
use crate::{
change_detection::{CheckChangeTicks, Tick},
error::ErrorContext,
prelude::World,
query::FilteredAccessSet,
schedule::InternedSystemSet,
system::{input::SystemInput, SystemIn, SystemParamValidationError},
world::unsafe_world_cell::UnsafeWorldCell,
};
use super::{IntoSystem, ReadOnlySystem, RunSystemError, System};
#[diagnostic::on_unimplemented(
message = "`{Self}` can not combine systems `{A}` and `{B}`",
label = "invalid system combination",
note = "the inputs and outputs of `{A}` and `{B}` are not compatible with this combiner"
)]
pub trait Combine<A: System, B: System> {
type In: SystemInput;
type Out;
fn combine<T>(
input: <Self::In as SystemInput>::Inner<'_>,
data: &mut T,
a: impl FnOnce(SystemIn<'_, A>, &mut T) -> Result<A::Out, RunSystemError>,
b: impl FnOnce(SystemIn<'_, B>, &mut T) -> Result<B::Out, RunSystemError>,
) -> Result<Self::Out, RunSystemError>;
}
pub struct CombinatorSystem<Func, A, B> {
_marker: PhantomData<fn() -> Func>,
a: A,
b: B,
name: DebugName,
}
impl<Func, A, B> CombinatorSystem<Func, A, B> {
pub fn new(a: A, b: B, name: DebugName) -> Self {
Self {
_marker: PhantomData,
a,
b,
name,
}
}
}
impl<A, B, Func> System for CombinatorSystem<Func, A, B>
where
Func: Combine<A, B> + 'static,
A: System,
B: System,
{
type In = Func::In;
type Out = Func::Out;
fn name(&self) -> DebugName {
self.name.clone()
}
#[inline]
fn flags(&self) -> super::SystemStateFlags {
self.a.flags() | self.b.flags()
}
unsafe fn run_unsafe(
&mut self,
input: SystemIn<'_, Self>,
world: UnsafeWorldCell,
) -> Result<Self::Out, RunSystemError> {
struct PrivateUnsafeWorldCell<'w>(UnsafeWorldCell<'w>);
unsafe fn run_system<S: System>(
system: &mut S,
input: SystemIn<S>,
world: &mut PrivateUnsafeWorldCell,
) -> Result<S::Out, RunSystemError> {
#![deny(unsafe_op_in_unsafe_fn)]
match (|| unsafe {
system.validate_param_unsafe(world.0)?;
system.run_unsafe(input, world.0)
})() {
Err(RunSystemError::Failed(err)) => {
(unsafe { world.0.default_error_handler() })(
err,
ErrorContext::System {
name: system.name(),
last_run: system.get_last_run(),
},
);
Err(format!("System `{}` failed", system.name()).into())
}
result @ (Ok(_) | Err(RunSystemError::Skipped(_))) => result,
}
}
Func::combine(
input,
&mut PrivateUnsafeWorldCell(world),
|input, world| unsafe { run_system(&mut self.a, input, world) },
|input, world| unsafe { run_system(&mut self.b, input, world) },
)
}
#[cfg(feature = "hotpatching")]
#[inline]
fn refresh_hotpatch(&mut self) {
self.a.refresh_hotpatch();
self.b.refresh_hotpatch();
}
#[inline]
fn apply_deferred(&mut self, world: &mut World) {
self.a.apply_deferred(world);
self.b.apply_deferred(world);
}
#[inline]
fn queue_deferred(&mut self, mut world: crate::world::DeferredWorld) {
self.a.queue_deferred(world.reborrow());
self.b.queue_deferred(world);
}
#[inline]
unsafe fn validate_param_unsafe(
&mut self,
_world: UnsafeWorldCell,
) -> Result<(), SystemParamValidationError> {
Ok(())
}
fn initialize(&mut self, world: &mut World) -> FilteredAccessSet {
let mut a_access = self.a.initialize(world);
let b_access = self.b.initialize(world);
a_access.extend(b_access);
let error_resource = world.register_resource::<crate::error::DefaultErrorHandler>();
a_access.add_unfiltered_resource_read(error_resource);
a_access
}
fn check_change_tick(&mut self, check: CheckChangeTicks) {
self.a.check_change_tick(check);
self.b.check_change_tick(check);
}
fn default_system_sets(&self) -> Vec<InternedSystemSet> {
let mut default_sets = self.a.default_system_sets();
default_sets.append(&mut self.b.default_system_sets());
default_sets
}
fn get_last_run(&self) -> Tick {
self.a.get_last_run()
}
fn set_last_run(&mut self, last_run: Tick) {
self.a.set_last_run(last_run);
self.b.set_last_run(last_run);
}
}
unsafe impl<Func, A, B> ReadOnlySystem for CombinatorSystem<Func, A, B>
where
Func: Combine<A, B> + 'static,
A: ReadOnlySystem,
B: ReadOnlySystem,
{
}
impl<Func, A, B> Clone for CombinatorSystem<Func, A, B>
where
A: Clone,
B: Clone,
{
fn clone(&self) -> Self {
CombinatorSystem::new(self.a.clone(), self.b.clone(), self.name.clone())
}
}
#[derive(Clone)]
pub struct IntoPipeSystem<A, B> {
a: A,
b: B,
}
impl<A, B> IntoPipeSystem<A, B> {
pub const fn new(a: A, b: B) -> Self {
Self { a, b }
}
}
#[doc(hidden)]
pub struct IsPipeSystemMarker;
impl<A, B, IA, OA, IB, OB, MA, MB> IntoSystem<IA, OB, (IsPipeSystemMarker, OA, IB, MA, MB)>
for IntoPipeSystem<A, B>
where
IA: SystemInput,
A: IntoSystem<IA, OA, MA>,
B: IntoSystem<IB, OB, MB>,
for<'a> IB: SystemInput<Inner<'a> = OA>,
{
type System = PipeSystem<A::System, B::System>;
fn into_system(this: Self) -> Self::System {
let system_a = IntoSystem::into_system(this.a);
let system_b = IntoSystem::into_system(this.b);
let name = format!("Pipe({}, {})", system_a.name(), system_b.name());
PipeSystem::new(system_a, system_b, DebugName::owned(name))
}
}
pub struct PipeSystem<A, B> {
a: A,
b: B,
name: DebugName,
}
impl<A, B> PipeSystem<A, B>
where
A: System,
B: System,
for<'a> B::In: SystemInput<Inner<'a> = A::Out>,
{
pub fn new(a: A, b: B, name: DebugName) -> Self {
Self { a, b, name }
}
}
impl<A, B> System for PipeSystem<A, B>
where
A: System,
B: System,
for<'a> B::In: SystemInput<Inner<'a> = A::Out>,
{
type In = A::In;
type Out = B::Out;
fn name(&self) -> DebugName {
self.name.clone()
}
#[inline]
fn flags(&self) -> super::SystemStateFlags {
self.a.flags() | self.b.flags()
}
unsafe fn run_unsafe(
&mut self,
input: SystemIn<'_, Self>,
world: UnsafeWorldCell,
) -> Result<Self::Out, RunSystemError> {
unsafe {
let value = self.a.run_unsafe(input, world)?;
self.b.validate_param_unsafe(world)?;
self.b.run_unsafe(value, world)
}
}
#[cfg(feature = "hotpatching")]
#[inline]
fn refresh_hotpatch(&mut self) {
self.a.refresh_hotpatch();
self.b.refresh_hotpatch();
}
fn apply_deferred(&mut self, world: &mut World) {
self.a.apply_deferred(world);
self.b.apply_deferred(world);
}
fn queue_deferred(&mut self, mut world: crate::world::DeferredWorld) {
self.a.queue_deferred(world.reborrow());
self.b.queue_deferred(world);
}
unsafe fn validate_param_unsafe(
&mut self,
world: UnsafeWorldCell,
) -> Result<(), SystemParamValidationError> {
unsafe { self.a.validate_param_unsafe(world) }
}
fn initialize(&mut self, world: &mut World) -> FilteredAccessSet {
let mut a_access = self.a.initialize(world);
let b_access = self.b.initialize(world);
a_access.extend(b_access);
a_access
}
fn check_change_tick(&mut self, check: CheckChangeTicks) {
self.a.check_change_tick(check);
self.b.check_change_tick(check);
}
fn default_system_sets(&self) -> Vec<InternedSystemSet> {
let mut default_sets = self.a.default_system_sets();
default_sets.append(&mut self.b.default_system_sets());
default_sets
}
fn get_last_run(&self) -> Tick {
self.a.get_last_run()
}
fn set_last_run(&mut self, last_run: Tick) {
self.a.set_last_run(last_run);
self.b.set_last_run(last_run);
}
}
unsafe impl<A, B> ReadOnlySystem for PipeSystem<A, B>
where
A: ReadOnlySystem,
B: ReadOnlySystem,
for<'a> B::In: SystemInput<Inner<'a> = A::Out>,
{
}
#[cfg(test)]
mod tests {
use crate::error::DefaultErrorHandler;
use crate::prelude::*;
use bevy_utils::prelude::DebugName;
use crate::{
schedule::OrMarker,
system::{assert_system_does_not_conflict, CombinatorSystem},
};
#[test]
fn combinator_with_error_handler_access() {
fn my_system(_: ResMut<DefaultErrorHandler>) {}
fn a() -> bool {
true
}
fn b(_: ResMut<DefaultErrorHandler>) -> bool {
true
}
fn asdf(_: In<bool>) {}
let mut world = World::new();
world.insert_resource(DefaultErrorHandler::default());
let system = CombinatorSystem::<OrMarker, _, _>::new(
IntoSystem::into_system(a),
IntoSystem::into_system(b),
DebugName::borrowed("a OR b"),
);
assert_system_does_not_conflict(system.clone());
let mut schedule = Schedule::default();
schedule.add_systems((my_system, system.pipe(asdf)));
schedule.initialize(&mut world).unwrap();
assert!(!schedule.graph().conflicting_systems().is_empty());
schedule.run(&mut world);
}
#[test]
fn exclusive_system_piping_is_possible() {
fn my_exclusive_system(_world: &mut World) -> u32 {
1
}
fn out_pipe(input: In<u32>) {
assert!(input.0 == 1);
}
let mut world = World::new();
let mut schedule = Schedule::default();
schedule.add_systems(my_exclusive_system.pipe(out_pipe));
schedule.run(&mut world);
}
}