use rustsim_core::types::AgentId;
pub struct LayerFunction<F> {
pub name: &'static str,
pub function: F,
}
#[derive(Debug, Clone)]
pub struct LayerTiming {
pub layer_index: usize,
pub function_names: Vec<&'static str>,
pub elapsed_us: u128,
}
#[derive(Debug, Clone)]
pub struct StepTiming {
pub layers: Vec<LayerTiming>,
pub total_us: u128,
}
pub struct LayerExecutor<S, A, Props> {
layers: Vec<Vec<LayerEntry<S, A, Props>>>,
}
#[allow(clippy::type_complexity)]
enum LayerEntry<S, A, Props> {
BatchStep {
name: &'static str,
function: Box<dyn FnMut(&mut [Vec<f32>], usize)>,
},
AgentStep {
name: &'static str,
function: Box<dyn FnMut(AgentId, &mut S, &mut Props)>,
},
ModelStep {
name: &'static str,
function: Box<dyn FnMut(&mut S, &mut Props)>,
},
_Phantom(std::marker::PhantomData<A>),
}
impl<S, A, Props> LayerExecutor<S, A, Props> {
pub fn new() -> Self {
Self { layers: Vec::new() }
}
pub fn add_batch_layer(
&mut self,
name: &'static str,
function: impl FnMut(&mut [Vec<f32>], usize) + 'static,
) {
self.layers.push(vec![LayerEntry::BatchStep {
name,
function: Box::new(function),
}]);
}
pub fn add_agent_layer(
&mut self,
name: &'static str,
function: impl FnMut(AgentId, &mut S, &mut Props) + 'static,
) {
self.layers.push(vec![LayerEntry::AgentStep {
name,
function: Box::new(function),
}]);
}
pub fn add_model_layer(
&mut self,
name: &'static str,
function: impl FnMut(&mut S, &mut Props) + 'static,
) {
self.layers.push(vec![LayerEntry::ModelStep {
name,
function: Box::new(function),
}]);
}
pub fn num_layers(&self) -> usize {
self.layers.len()
}
pub fn execute(
&mut self,
columns: &mut [Vec<f32>],
agent_count: usize,
agent_ids: &[AgentId],
space: &mut S,
properties: &mut Props,
) -> StepTiming {
let t_total = std::time::Instant::now();
let mut layer_timings = Vec::with_capacity(self.layers.len());
for (layer_idx, layer) in self.layers.iter_mut().enumerate() {
let t_layer = std::time::Instant::now();
let mut names = Vec::new();
for entry in layer.iter_mut() {
match entry {
LayerEntry::BatchStep { name, function } => {
names.push(*name);
function(columns, agent_count);
}
LayerEntry::AgentStep { name, function } => {
names.push(*name);
for &id in agent_ids {
function(id, space, properties);
}
}
LayerEntry::ModelStep { name, function } => {
names.push(*name);
function(space, properties);
}
LayerEntry::_Phantom(_) => unreachable!(),
}
}
layer_timings.push(LayerTiming {
layer_index: layer_idx,
function_names: names,
elapsed_us: t_layer.elapsed().as_micros(),
});
}
StepTiming {
layers: layer_timings,
total_us: t_total.elapsed().as_micros(),
}
}
}
impl<S, A, Props> Default for LayerExecutor<S, A, Props> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_layer_execution() {
let mut executor: LayerExecutor<(), (), ()> = LayerExecutor::new();
executor.add_batch_layer("increment_x", |columns, n| {
for v in columns[0].iter_mut().take(n) {
*v += 1.0;
}
});
executor.add_batch_layer("double_x", |columns, n| {
for v in columns[0].iter_mut().take(n) {
*v *= 2.0;
}
});
let mut columns = vec![vec![0.0f32; 10]];
let ids: Vec<AgentId> = (0..10).collect();
let mut space = ();
let mut props = ();
let timing = executor.execute(&mut columns, 10, &ids, &mut space, &mut props);
assert_eq!(timing.layers.len(), 2);
for &v in &columns[0] {
assert!((v - 2.0).abs() < 1e-5);
}
}
}