use std::any::{Any, TypeId};
use std::collections::HashSet;
use std::time::Duration;
pub use crate::budget::TickBudget;
pub use crate::components::{EntityId, Phase};
pub trait SystemContext: crate::world::handle::WorldHandle {
fn spawn(&mut self) -> EntityId;
fn despawn(&mut self, entity: EntityId);
fn set_component(
&mut self,
entity: EntityId,
type_id: TypeId,
component: Box<dyn Any + Send + Sync>,
);
fn entities_with(&self, type_id: TypeId) -> Vec<EntityId>;
fn get_component(&self, entity: EntityId, type_id: TypeId) -> Option<&dyn Any>;
fn get_component_mut(&mut self, entity: EntityId, type_id: TypeId) -> Option<&mut dyn Any>;
fn budget(&self) -> &TickBudget;
}
pub trait SystemContextExt {
fn get<T: crate::components::Component>(&self, entity: EntityId) -> Option<&T>;
fn get_mut<T: crate::components::Component>(&mut self, entity: EntityId) -> Option<&mut T>;
fn set<T: crate::components::Component>(&mut self, entity: EntityId, component: T);
fn query<T: crate::components::Component>(&self) -> Vec<EntityId>;
}
impl<S: SystemContext + ?Sized> SystemContextExt for S {
fn get<T: crate::components::Component>(&self, entity: EntityId) -> Option<&T> {
self.get_component(entity, TypeId::of::<T>())
.and_then(|any| any.downcast_ref::<T>())
}
fn get_mut<T: crate::components::Component>(&mut self, entity: EntityId) -> Option<&mut T> {
self.get_component_mut(entity, TypeId::of::<T>())
.and_then(|any| any.downcast_mut::<T>())
}
fn set<T: crate::components::Component>(&mut self, entity: EntityId, component: T) {
self.set_component(entity, TypeId::of::<T>(), Box::new(component));
}
fn query<T: crate::components::Component>(&self) -> Vec<EntityId> {
self.entities_with(TypeId::of::<T>())
}
}
pub struct SystemDescriptor {
pub name: String,
pub phase: Phase,
pub every: u64,
pub access: SystemAccess,
pub budget: Option<Duration>,
pub runner: Box<dyn SystemRunner>,
}
pub trait SystemRunner: Send {
fn run(&mut self, ctx: &mut dyn SystemContext);
}
impl<F: FnMut(&mut dyn SystemContext) + Send> SystemRunner for F {
fn run(&mut self, ctx: &mut dyn SystemContext) {
self(ctx);
}
}
#[derive(Debug, Clone)]
pub struct SystemAccess {
pub reads: HashSet<TypeId>,
pub writes: HashSet<TypeId>,
}
impl SystemAccess {
pub fn new() -> Self {
Self {
reads: HashSet::new(),
writes: HashSet::new(),
}
}
pub fn conflicts_with(&self, other: &SystemAccess) -> bool {
for w in &self.writes {
if other.reads.contains(w) || other.writes.contains(w) {
return true;
}
}
for w in &other.writes {
if self.reads.contains(w) {
return true;
}
}
false
}
}
impl Default for SystemAccess {
fn default() -> Self {
Self::new()
}
}
pub struct SystemBuilder {
name: String,
phase: Phase,
every: u64,
access: SystemAccess,
budget: Option<Duration>,
}
impl SystemBuilder {
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
phase: Phase::Simulate,
every: 1,
access: SystemAccess::new(),
budget: None,
}
}
pub fn phase(mut self, phase: Phase) -> Self {
self.phase = phase;
self
}
pub fn every(mut self, every: u64) -> Self {
self.every = every;
self
}
pub fn reads<T: crate::components::Component>(mut self) -> Self {
self.access.reads.insert(TypeId::of::<T>());
self
}
pub fn writes<T: crate::components::Component>(mut self) -> Self {
self.access.writes.insert(TypeId::of::<T>());
self
}
pub fn budget_ms(mut self, ms: u64) -> Self {
self.budget = Some(Duration::from_millis(ms));
self
}
pub fn run<F: FnMut(&mut dyn SystemContext) + Send + 'static>(
self,
runner: F,
) -> SystemDescriptor {
SystemDescriptor {
name: self.name,
phase: self.phase,
every: self.every,
access: self.access,
budget: self.budget,
runner: Box::new(runner),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::components::{Position, Velocity};
#[test]
fn system_builder_defaults() {
let desc = SystemBuilder::new("test").run(|_ctx| {});
assert_eq!(desc.name, "test");
assert_eq!(desc.phase, Phase::Simulate);
assert_eq!(desc.every, 1);
}
#[test]
fn system_builder_with_access() {
let desc = SystemBuilder::new("physics")
.phase(Phase::Simulate)
.every(1)
.reads::<Position>()
.writes::<Position>()
.writes::<Velocity>()
.run(|_ctx| {});
assert!(desc.access.reads.contains(&TypeId::of::<Position>()));
assert!(desc.access.writes.contains(&TypeId::of::<Position>()));
assert!(desc.access.writes.contains(&TypeId::of::<Velocity>()));
}
#[test]
fn access_conflict_detection() {
let mut a = SystemAccess::new();
a.writes.insert(TypeId::of::<Position>());
let mut b = SystemAccess::new();
b.reads.insert(TypeId::of::<Position>());
assert!(a.conflicts_with(&b));
assert!(b.conflicts_with(&a));
}
#[test]
fn no_conflict_for_disjoint_access() {
let mut a = SystemAccess::new();
a.reads.insert(TypeId::of::<Position>());
let mut b = SystemAccess::new();
b.reads.insert(TypeId::of::<Velocity>());
assert!(!a.conflicts_with(&b));
}
#[test]
fn read_read_no_conflict() {
let mut a = SystemAccess::new();
a.reads.insert(TypeId::of::<Position>());
let mut b = SystemAccess::new();
b.reads.insert(TypeId::of::<Position>());
assert!(!a.conflicts_with(&b));
}
}