use crate::agent::Agent;
use crate::messaging::{BruteForceMessages, SpatialMessages2D, SpatialMessages3D};
use crate::store::AgentStore;
use crate::types::AgentId;
#[derive(Debug, Clone)]
pub struct TwoPhaseResult {
pub output_us: u128,
pub finalize_us: u128,
pub input_us: u128,
pub message_count: usize,
pub agent_count: usize,
}
pub fn two_phase_brute_force<A, S, M, FOut, FIn>(
store: &S,
ids: &[AgentId],
messages: &mut BruteForceMessages<M>,
mut output_fn: FOut,
mut input_fn: FIn,
) -> TwoPhaseResult
where
A: Agent,
M: Clone,
S: AgentStore<A>,
FOut: FnMut(&A, &mut BruteForceMessages<M>),
FIn: FnMut(&mut A, &[M]),
{
let t_output = std::time::Instant::now();
for &id in ids {
if let Some(agent) = store.get(id) {
output_fn(&*agent, messages);
}
}
let output_us = t_output.elapsed().as_micros();
let t_finalize = std::time::Instant::now();
messages.finalize();
let finalize_us = t_finalize.elapsed().as_micros();
let message_count = messages.len();
let t_input = std::time::Instant::now();
let all_msgs = messages.read_all();
for &id in ids {
if let Some(mut agent) = store.get_mut(id) {
input_fn(&mut *agent, all_msgs);
}
}
let input_us = t_input.elapsed().as_micros();
messages.clear();
TwoPhaseResult {
output_us,
finalize_us,
input_us,
message_count,
agent_count: ids.len(),
}
}
pub fn two_phase_spatial_2d<A, S, M, FOut, FIn>(
store: &S,
ids: &[AgentId],
messages: &mut SpatialMessages2D<M>,
mut output_fn: FOut,
mut input_fn: FIn,
) -> TwoPhaseResult
where
A: Agent,
M: Clone,
S: AgentStore<A>,
FOut: FnMut(&A, &mut SpatialMessages2D<M>),
FIn: FnMut(&mut A, &SpatialMessages2D<M>),
{
let t_output = std::time::Instant::now();
for &id in ids {
if let Some(agent) = store.get(id) {
output_fn(&*agent, messages);
}
}
let output_us = t_output.elapsed().as_micros();
let t_finalize = std::time::Instant::now();
messages.finalize();
let finalize_us = t_finalize.elapsed().as_micros();
let message_count = messages.len();
let t_input = std::time::Instant::now();
for &id in ids {
if let Some(mut agent) = store.get_mut(id) {
input_fn(&mut *agent, messages);
}
}
let input_us = t_input.elapsed().as_micros();
messages.clear();
TwoPhaseResult {
output_us,
finalize_us,
input_us,
message_count,
agent_count: ids.len(),
}
}
pub fn two_phase_spatial_3d<A, S, M, FOut, FIn>(
store: &S,
ids: &[AgentId],
messages: &mut SpatialMessages3D<M>,
mut output_fn: FOut,
mut input_fn: FIn,
) -> TwoPhaseResult
where
A: Agent,
M: Clone,
S: AgentStore<A>,
FOut: FnMut(&A, &mut SpatialMessages3D<M>),
FIn: FnMut(&mut A, &SpatialMessages3D<M>),
{
let t_output = std::time::Instant::now();
for &id in ids {
if let Some(agent) = store.get(id) {
output_fn(&*agent, messages);
}
}
let output_us = t_output.elapsed().as_micros();
let t_finalize = std::time::Instant::now();
messages.finalize();
let finalize_us = t_finalize.elapsed().as_micros();
let message_count = messages.len();
let t_input = std::time::Instant::now();
for &id in ids {
if let Some(mut agent) = store.get_mut(id) {
input_fn(&mut *agent, messages);
}
}
let input_us = t_input.elapsed().as_micros();
messages.clear();
TwoPhaseResult {
output_us,
finalize_us,
input_us,
message_count,
agent_count: ids.len(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::prelude::*;
#[derive(Debug, Clone)]
struct Boid {
id: AgentId,
x: f32,
y: f32,
vx: f32,
vy: f32,
}
impl Agent for Boid {
fn id(&self) -> AgentId {
self.id
}
}
#[derive(Debug, Clone)]
struct BoidMessage {
id: AgentId,
x: f32,
y: f32,
#[allow(dead_code)]
vx: f32,
#[allow(dead_code)]
vy: f32,
}
#[test]
fn two_phase_brute_force_boids() {
let mut store = HashMapStore::new();
for i in 1..=100 {
store.insert(Boid {
id: i,
x: (i as f32) * 0.01,
y: (i as f32) * 0.01,
vx: 0.001,
vy: 0.001,
});
}
let ids: Vec<AgentId> = store.iter_ids();
let mut messages = BruteForceMessages::with_capacity(100);
let result = two_phase_brute_force(
&store,
&ids,
&mut messages,
|agent: &Boid, msgs| {
msgs.output(BoidMessage {
id: agent.id,
x: agent.x,
y: agent.y,
vx: agent.vx,
vy: agent.vy,
});
},
|agent: &mut Boid, all_msgs| {
let mut sum_x = 0.0;
let mut sum_y = 0.0;
let mut count = 0;
for msg in all_msgs {
if msg.id != agent.id {
sum_x += msg.x;
sum_y += msg.y;
count += 1;
}
}
if count > 0 {
let avg_x = sum_x / count as f32;
let avg_y = sum_y / count as f32;
agent.vx += (avg_x - agent.x) * 0.01;
agent.vy += (avg_y - agent.y) * 0.01;
}
agent.x += agent.vx;
agent.y += agent.vy;
},
);
assert_eq!(result.message_count, 100);
assert_eq!(result.agent_count, 100);
}
#[test]
fn two_phase_spatial_2d_boids() {
let mut store = HashMapStore::new();
for i in 1..=50 {
store.insert(Boid {
id: i,
x: (i as f32) * 0.1,
y: (i as f32) * 0.1,
vx: 0.001,
vy: 0.001,
});
}
let ids: Vec<AgentId> = store.iter_ids();
let mut messages = SpatialMessages2D::new(1.0).unwrap();
let result = two_phase_spatial_2d(
&store,
&ids,
&mut messages,
|agent: &Boid, msgs| {
msgs.output(
BoidMessage {
id: agent.id,
x: agent.x,
y: agent.y,
vx: agent.vx,
vy: agent.vy,
},
agent.x,
agent.y,
);
},
|agent: &mut Boid, msgs| {
let nearby: Vec<_> = msgs.read_nearby(agent.x, agent.y, 1.0).collect();
let mut sum_x = 0.0;
let mut count = 0;
for (msg, _dist_sq) in &nearby {
if msg.id != agent.id {
sum_x += msg.x;
count += 1;
}
}
if count > 0 {
agent.vx += (sum_x / count as f32 - agent.x) * 0.01;
}
agent.x += agent.vx;
agent.y += agent.vy;
},
);
assert_eq!(result.message_count, 50);
assert_eq!(result.agent_count, 50);
}
}