use std::sync::Arc;
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use rustsim_core::prelude::AgentId;
use rustsim_crowd::Pedestrian;
use rustsim_io::arrow::ArrowValue;
use rustsim_io::bridge::{BridgeError, CollectArrowBridge};
pub const CROWD_BATCH_SIZE: usize = 8_192;
pub const CROWD_NUM_COLUMNS: usize = 8;
pub fn crowd_arrow_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("tick", DataType::Int64, false),
Field::new("agent_id", DataType::Int64, false),
Field::new("pos_x", DataType::Float64, false),
Field::new("pos_y", DataType::Float64, false),
Field::new("vel_x", DataType::Float64, false),
Field::new("vel_y", DataType::Float64, false),
Field::new("radius", DataType::Float64, false),
Field::new("desired_speed", DataType::Float64, false),
]))
}
pub fn push_crowd_row(
bridge: &mut CollectArrowBridge,
tick: i64,
id: AgentId,
ped: &Pedestrian,
) -> Result<(), BridgeError> {
bridge.push_row(&[
ArrowValue::Int64(tick),
ArrowValue::Int64(id as i64),
ArrowValue::Float64(ped.pos[0]),
ArrowValue::Float64(ped.pos[1]),
ArrowValue::Float64(ped.vel[0]),
ArrowValue::Float64(ped.vel[1]),
ArrowValue::Float64(ped.radius),
ArrowValue::Float64(ped.desired_speed),
])
}
pub fn crowd_observer<'a>(
bridge: &'a mut CollectArrowBridge,
tick: i64,
) -> impl FnMut(AgentId, &Pedestrian) + 'a {
move |id: AgentId, ped: &Pedestrian| {
push_crowd_row(bridge, tick, id, ped).expect(
"crowd_observer: bridge schema does not match crowd_arrow_schema; \
this is a programmer error, fix the bridge construction",
);
}
}
#[cfg(test)]
mod tests {
use super::*;
use rustsim_core::prelude::VecStore;
use rustsim_core::store::AgentStore;
use rustsim_crowd::common::Pedestrian;
use rustsim_crowd::prelude::{
recommended_cell_size, social_force, step_scratch_store_observed, CrowdAgent, Scratch,
SocialForceModel,
};
fn fixture(n: usize) -> (VecStore<CrowdAgent>, Vec<rustsim_crowd::WallSegment>) {
let mut store = VecStore::<CrowdAgent>::new();
let cols = (n as f64).sqrt().ceil() as usize;
for i in 0..n {
let r = i / cols;
let c = i % cols;
let id = (i as u64) + 1;
let pos = [c as f64 * 1.5, r as f64 * 1.5];
let dest = [pos[0] + 30.0, pos[1]];
let ped = Pedestrian::new(pos, [0.0, 0.0], 0.25, 1.34, dest);
store.insert(CrowdAgent { id, ped });
}
let walls = vec![rustsim_crowd::WallSegment {
a: [-10.0, -1.0],
b: [200.0, -1.0],
}];
(store, walls)
}
#[test]
fn schema_is_stable_eight_columns() {
let s = crowd_arrow_schema();
assert_eq!(s.fields().len(), CROWD_NUM_COLUMNS);
let names: Vec<&str> = s.fields().iter().map(|f| f.name().as_str()).collect();
assert_eq!(
names,
vec![
"tick",
"agent_id",
"pos_x",
"pos_y",
"vel_x",
"vel_y",
"radius",
"desired_speed"
]
);
for f in s.fields() {
assert!(!f.is_nullable(), "{} must be non-nullable", f.name());
}
}
#[test]
fn end_to_end_observed_drive_emits_one_row_per_agent_per_tick() {
const N: usize = 200;
const TICKS: i64 = 60;
const DT: f64 = 0.05;
let (mut store, walls) = fixture(N);
let params = social_force::Params::default();
let cell = recommended_cell_size(social_force::neighbor_cutoff(¶ms));
let mut scratch = Scratch::new(cell);
let mut peds_buf: Vec<Pedestrian> = Vec::new();
let mut bridge = CollectArrowBridge::new(crowd_arrow_schema(), CROWD_BATCH_SIZE).unwrap();
for tick in 0..TICKS {
let mut obs = crowd_observer(&mut bridge, tick);
step_scratch_store_observed(
&SocialForceModel,
&mut store,
&walls,
¶ms,
DT,
&mut scratch,
&mut peds_buf,
&mut obs,
);
}
let batches = bridge.take_batches().unwrap();
let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum();
assert_eq!(
total_rows,
N * TICKS as usize,
"expected one row per agent per tick"
);
let expected = crowd_arrow_schema();
for b in &batches {
assert_eq!(b.schema(), expected);
}
let mut last_tick: i64 = -1;
let mut id_counts: std::collections::HashMap<i64, usize> =
std::collections::HashMap::with_capacity(N);
for b in &batches {
let tick_col = b
.column_by_name("tick")
.unwrap()
.as_any()
.downcast_ref::<arrow_array::Int64Array>()
.unwrap();
let id_col = b
.column_by_name("agent_id")
.unwrap()
.as_any()
.downcast_ref::<arrow_array::Int64Array>()
.unwrap();
for i in 0..b.num_rows() {
let t = tick_col.value(i);
assert!(t >= last_tick, "tick column must be monotone");
last_tick = t;
*id_counts.entry(id_col.value(i)).or_insert(0) += 1;
}
}
assert_eq!(id_counts.len(), N, "every agent must be observed");
for (id, count) in &id_counts {
assert_eq!(*count, TICKS as usize, "agent {id} appeared {count} times");
}
}
}