use rustsim_core::soa::{self, SoaExtractable, SoaExtractableF64};
use rustsim_core::store::AgentStore;
use rustsim_core::types::AgentId;
use thiserror::Error;
#[derive(Debug, Clone, PartialEq)]
pub struct DeviceSoaCheckpoint {
pub ids: Vec<AgentId>,
pub columns: Vec<Vec<f32>>,
pub column_names: Vec<&'static str>,
pub schema: Vec<SoaColumnSchema>,
}
impl DeviceSoaCheckpoint {
pub fn agent_count(&self) -> usize {
self.ids.len()
}
pub fn num_columns(&self) -> usize {
self.columns.len()
}
pub fn resident_bytes(&self) -> usize {
let ids_bytes = self.ids.len() * std::mem::size_of::<AgentId>();
let cols_bytes: usize = self
.columns
.iter()
.map(|c| c.len() * std::mem::size_of::<f32>())
.sum();
ids_bytes + cols_bytes
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SoaColumnType {
F32,
F64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct SoaColumnSchema {
pub name: &'static str,
pub column_type: SoaColumnType,
}
impl SoaColumnSchema {
pub const fn f32(name: &'static str) -> Self {
Self {
name,
column_type: SoaColumnType::F32,
}
}
pub const fn f64(name: &'static str) -> Self {
Self {
name,
column_type: SoaColumnType::F64,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct DeviceSoaCheckpointF64 {
pub ids: Vec<AgentId>,
pub columns: Vec<Vec<f64>>,
pub column_names: Vec<&'static str>,
pub schema: Vec<SoaColumnSchema>,
}
impl DeviceSoaCheckpointF64 {
pub fn agent_count(&self) -> usize {
self.ids.len()
}
pub fn num_columns(&self) -> usize {
self.columns.len()
}
pub fn resident_bytes(&self) -> usize {
let ids_bytes = self.ids.len() * std::mem::size_of::<AgentId>();
let cols_bytes: usize = self
.columns
.iter()
.map(|c| c.len() * std::mem::size_of::<f64>())
.sum();
ids_bytes + cols_bytes
}
}
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum DeviceSoaRestoreError {
#[error("checkpoint has {columns} columns but {names} column names")]
ColumnNameCountMismatch { columns: usize, names: usize },
#[error("checkpoint column {column} has {len} rows but ids contain {agent_count} agents")]
ColumnLengthMismatch {
column: usize,
len: usize,
agent_count: usize,
},
#[error("checkpoint contains duplicate agent id {0}")]
DuplicateAgentId(AgentId),
#[error("checkpoint schema has {schema} columns but data contains {columns} columns")]
SchemaColumnCountMismatch { columns: usize, schema: usize },
#[error("checkpoint column {column} has precision {actual:?}, expected {expected:?}")]
SchemaTypeMismatch {
column: usize,
expected: SoaColumnType,
actual: SoaColumnType,
},
}
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum DeviceSoaMutationError {
#[error("insert supplied {actual} columns, expected {expected}")]
ColumnCountMismatch { expected: usize, actual: usize },
#[error("insert column {column} has {actual} rows, expected {expected}")]
ColumnLengthMismatch {
column: usize,
expected: usize,
actual: usize,
},
#[error("insert batch contains duplicate agent id {0}")]
DuplicateInsertedId(AgentId),
#[error("agent id {0} already exists in authoritative SoA store")]
ExistingAgentId(AgentId),
#[error("remove command contains duplicate agent id {0}")]
DuplicateRemovedId(AgentId),
#[error("agent id {0} does not exist in authoritative SoA store")]
MissingAgentId(AgentId),
}
#[derive(Debug)]
pub struct DeviceSoaStore {
ids: Vec<AgentId>,
columns: Vec<Vec<f32>>,
column_names: Vec<&'static str>,
schema: Vec<SoaColumnSchema>,
agent_count: usize,
dirty: bool,
#[cfg(feature = "cuda")]
cuda_resident: Option<CudaResident>,
}
#[cfg(feature = "cuda")]
struct CudaResident {
#[allow(dead_code)]
ctx: std::sync::Arc<cudarc::driver::CudaContext>,
copy_stream: std::sync::Arc<cudarc::driver::CudaStream>,
compute_stream: std::sync::Arc<cudarc::driver::CudaStream>,
#[allow(dead_code)]
module: std::sync::Arc<cudarc::driver::CudaModule>,
func: cudarc::driver::CudaFunction,
kernel_name: String,
d_cols: Vec<cudarc::driver::CudaSlice<f32>>,
pinned: Vec<cudarc::driver::PinnedHostSlice<f32>>,
device_dirty: bool,
host_dirty: bool,
}
#[cfg(feature = "cuda")]
impl std::fmt::Debug for CudaResident {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CudaResident")
.field("kernel_name", &self.kernel_name)
.field("num_columns", &self.d_cols.len())
.field("device_dirty", &self.device_dirty)
.field("host_dirty", &self.host_dirty)
.finish()
}
}
impl DeviceSoaStore {
pub fn upload<A, S>(store: &S) -> Self
where
A: SoaExtractable,
S: AgentStore<A>,
{
let (ids, columns) = soa::extract_soa::<A, S>(store);
let agent_count = ids.len();
let column_names = A::column_names();
let schema = schema_from_names(&column_names, SoaColumnType::F32);
Self {
ids,
columns,
column_names,
agent_count,
dirty: false,
schema,
#[cfg(feature = "cuda")]
cuda_resident: None,
}
}
pub fn from_checkpoint(checkpoint: DeviceSoaCheckpoint) -> Result<Self, DeviceSoaRestoreError> {
let agent_count = checkpoint.ids.len();
if checkpoint.column_names.len() != checkpoint.columns.len() {
return Err(DeviceSoaRestoreError::ColumnNameCountMismatch {
columns: checkpoint.columns.len(),
names: checkpoint.column_names.len(),
});
}
validate_schema(
&checkpoint.schema,
checkpoint.columns.len(),
SoaColumnType::F32,
)?;
validate_column_lengths(&checkpoint.columns, agent_count)?;
validate_unique_ids(&checkpoint.ids)?;
Ok(Self {
ids: checkpoint.ids,
columns: checkpoint.columns,
column_names: checkpoint.column_names,
agent_count,
dirty: false,
schema: checkpoint.schema,
#[cfg(feature = "cuda")]
cuda_resident: None,
})
}
pub fn checkpoint(&self) -> DeviceSoaCheckpoint {
DeviceSoaCheckpoint {
ids: self.ids.clone(),
columns: self.columns.clone(),
column_names: self.column_names.clone(),
schema: self.schema.clone(),
}
}
pub fn agent_count(&self) -> usize {
self.agent_count
}
pub fn num_columns(&self) -> usize {
self.columns.len()
}
pub fn resident_bytes(&self) -> usize {
let ids_bytes = self.ids.len() * std::mem::size_of::<AgentId>();
let cols_bytes: usize = self
.columns
.iter()
.map(|c| c.len() * std::mem::size_of::<f32>())
.sum();
ids_bytes + cols_bytes
}
pub fn column_names(&self) -> &[&'static str] {
&self.column_names
}
pub fn schema(&self) -> &[SoaColumnSchema] {
&self.schema
}
pub fn column(&self, index: usize) -> &[f32] {
&self.columns[index]
}
pub fn columns_mut(&mut self) -> &mut [Vec<f32>] {
self.dirty = true;
&mut self.columns
}
pub fn ids(&self) -> &[AgentId] {
&self.ids
}
pub fn row_of(&self, id: AgentId) -> Option<usize> {
self.ids.iter().position(|existing| *existing == id)
}
pub fn contains_id(&self, id: AgentId) -> bool {
self.row_of(id).is_some()
}
pub fn next_available_id(&self) -> AgentId {
next_available_id(&self.ids)
}
pub fn is_dirty(&self) -> bool {
self.dirty
}
pub fn step_cpu<F>(&mut self, mut kernel: F) -> u128
where
F: FnMut(&mut [Vec<f32>], usize),
{
let t0 = std::time::Instant::now();
kernel(&mut self.columns, self.agent_count);
self.dirty = true;
t0.elapsed().as_micros()
}
pub fn run_sequence_cpu<'a, I, K>(&mut self, kernels: I) -> Vec<(&'static str, u128)>
where
I: IntoIterator<Item = (&'static str, K)>,
K: FnOnce(&mut [Vec<f32>], usize) + 'a,
{
let mut timings = Vec::new();
for (name, kernel) in kernels {
let t0 = std::time::Instant::now();
kernel(&mut self.columns, self.agent_count);
timings.push((name, t0.elapsed().as_micros()));
}
if !timings.is_empty() {
self.dirty = true;
}
timings
}
pub fn download<A, S>(&mut self, store: &S)
where
A: SoaExtractable,
S: AgentStore<A>,
{
soa::write_back_soa::<A, S>(store, &self.ids, &self.columns);
self.dirty = false;
}
#[cfg(feature = "cuda")]
pub fn step_cuda(
&mut self,
ptx_source: &str,
_module_name: &str,
kernel_name: &str,
block_size: u32,
) -> Result<u128, String> {
use cudarc::driver::{LaunchConfig, PushKernelArg};
if block_size == 0 {
return Err("block_size must be positive".to_string());
}
let n = self.agent_count;
if n == 0 {
return Ok(0);
}
let ctx = crate::cuda_context::new_context(0)?;
let stream = ctx.default_stream();
let ptx = cudarc::nvrtc::Ptx::from_src(ptx_source);
let module = ctx.load_module(ptx).map_err(|e| format!("PTX load: {e}"))?;
let func = module
.load_function(kernel_name)
.map_err(|e| format!("kernel '{kernel_name}' not found: {e}"))?;
let mut d_cols = Vec::with_capacity(self.columns.len());
for col in &self.columns {
let d = stream
.clone_htod(col.as_slice())
.map_err(|e| format!("htod: {e}"))?;
d_cols.push(d);
}
let n_u32 = n as u32;
let grid_size = n.div_ceil(block_size as usize) as u32;
let cfg = LaunchConfig {
grid_dim: (grid_size, 1, 1),
block_dim: (block_size, 1, 1),
shared_mem_bytes: 0,
};
let t0 = std::time::Instant::now();
unsafe {
let mut builder = stream.launch_builder(&func);
for d in d_cols.iter_mut() {
builder.arg(d);
}
builder.arg(&n_u32);
builder.launch(cfg).map_err(|e| format!("launch: {e}"))?;
}
stream.synchronize().map_err(|e| format!("sync: {e}"))?;
let kernel_us = t0.elapsed().as_micros();
for (i, d_col) in d_cols.iter().enumerate() {
stream
.memcpy_dtoh(d_col, &mut self.columns[i])
.map_err(|e| format!("dtoh: {e}"))?;
}
self.dirty = true;
Ok(kernel_us)
}
#[cfg(feature = "cuda")]
pub fn step_cuda_pinned(
&mut self,
ptx_source: &str,
_module_name: &str,
kernel_name: &str,
block_size: u32,
) -> Result<u128, String> {
use cudarc::driver::{LaunchConfig, PushKernelArg};
if block_size == 0 {
return Err("block_size must be positive".to_string());
}
let n = self.agent_count;
if n == 0 {
return Ok(0);
}
let ctx = crate::cuda_context::new_context(0)?;
let copy_stream = ctx
.new_stream()
.map_err(|e| format!("copy stream init: {e}"))?;
let compute_stream = ctx
.new_stream()
.map_err(|e| format!("compute stream init: {e}"))?;
let ptx = cudarc::nvrtc::Ptx::from_src(ptx_source);
let module = ctx.load_module(ptx).map_err(|e| format!("PTX load: {e}"))?;
let func = module
.load_function(kernel_name)
.map_err(|e| format!("kernel '{kernel_name}' not found: {e}"))?;
let mut pinned: Vec<cudarc::driver::PinnedHostSlice<f32>> =
Vec::with_capacity(self.columns.len());
for col in &self.columns {
let mut p = unsafe { ctx.alloc_pinned::<f32>(col.len()) }
.map_err(|e| format!("pinned alloc: {e}"))?;
p.as_mut_slice()
.map_err(|e| format!("pinned access: {e}"))?
.copy_from_slice(col);
pinned.push(p);
}
let mut d_cols: Vec<cudarc::driver::CudaSlice<f32>> = Vec::with_capacity(pinned.len());
for p in &pinned {
let d = copy_stream
.clone_htod(p)
.map_err(|e| format!("htod: {e}"))?;
d_cols.push(d);
}
compute_stream
.join(©_stream)
.map_err(|e| format!("compute.join(copy): {e}"))?;
let n_u32 = n as u32;
let grid_size = n.div_ceil(block_size as usize) as u32;
let cfg = LaunchConfig {
grid_dim: (grid_size, 1, 1),
block_dim: (block_size, 1, 1),
shared_mem_bytes: 0,
};
let t0 = std::time::Instant::now();
unsafe {
let mut builder = compute_stream.launch_builder(&func);
for d in d_cols.iter_mut() {
builder.arg(d);
}
builder.arg(&n_u32);
builder.launch(cfg).map_err(|e| format!("launch: {e}"))?;
}
copy_stream
.join(&compute_stream)
.map_err(|e| format!("copy.join(compute): {e}"))?;
for (i, d_col) in d_cols.iter().enumerate() {
copy_stream
.memcpy_dtoh(d_col, &mut pinned[i])
.map_err(|e| format!("dtoh: {e}"))?;
}
copy_stream
.synchronize()
.map_err(|e| format!("sync: {e}"))?;
let kernel_us = t0.elapsed().as_micros();
for (i, p) in pinned.iter().enumerate() {
self.columns[i]
.copy_from_slice(p.as_slice().map_err(|e| format!("pinned readback: {e}"))?);
}
self.dirty = true;
Ok(kernel_us)
}
#[cfg(feature = "cuda")]
pub fn init_cuda(&mut self, ptx_source: &str, kernel_name: &str) -> Result<(), String> {
let ctx = crate::cuda_context::new_context(0)?;
let copy_stream = ctx
.new_stream()
.map_err(|e| format!("copy stream init: {e}"))?;
let compute_stream = ctx
.new_stream()
.map_err(|e| format!("compute stream init: {e}"))?;
let ptx = cudarc::nvrtc::Ptx::from_src(ptx_source);
let module = ctx.load_module(ptx).map_err(|e| format!("PTX load: {e}"))?;
let func = module
.load_function(kernel_name)
.map_err(|e| format!("kernel '{kernel_name}' not found: {e}"))?;
let mut pinned: Vec<cudarc::driver::PinnedHostSlice<f32>> =
Vec::with_capacity(self.columns.len());
let mut d_cols: Vec<cudarc::driver::CudaSlice<f32>> =
Vec::with_capacity(self.columns.len());
for col in &self.columns {
let mut p = unsafe { ctx.alloc_pinned::<f32>(col.len()) }
.map_err(|e| format!("pinned alloc: {e}"))?;
p.as_mut_slice()
.map_err(|e| format!("pinned access: {e}"))?
.copy_from_slice(col);
let d = copy_stream
.clone_htod(&p)
.map_err(|e| format!("htod: {e}"))?;
pinned.push(p);
d_cols.push(d);
}
copy_stream
.synchronize()
.map_err(|e| format!("sync: {e}"))?;
self.cuda_resident = Some(CudaResident {
ctx,
copy_stream,
compute_stream,
module,
func,
kernel_name: kernel_name.to_string(),
d_cols,
pinned,
device_dirty: false,
host_dirty: false,
});
Ok(())
}
#[cfg(feature = "cuda")]
pub fn step_cuda_resident(&mut self, block_size: u32) -> Result<u128, String> {
use cudarc::driver::{LaunchConfig, PushKernelArg};
if block_size == 0 {
return Err("block_size must be positive".to_string());
}
let n = self.agent_count;
if n == 0 {
return Ok(0);
}
let resident = self
.cuda_resident
.as_mut()
.ok_or_else(|| "init_cuda must be called before step_cuda_resident".to_string())?;
if resident.host_dirty {
for (i, col) in self.columns.iter().enumerate() {
resident.pinned[i]
.as_mut_slice()
.map_err(|e| format!("pinned access: {e}"))?
.copy_from_slice(col);
resident
.copy_stream
.memcpy_htod(&resident.pinned[i], &mut resident.d_cols[i])
.map_err(|e| format!("htod: {e}"))?;
}
resident
.compute_stream
.join(&resident.copy_stream)
.map_err(|e| format!("compute.join(copy): {e}"))?;
resident.host_dirty = false;
}
let n_u32 = n as u32;
let grid_size = n.div_ceil(block_size as usize) as u32;
let cfg = LaunchConfig {
grid_dim: (grid_size, 1, 1),
block_dim: (block_size, 1, 1),
shared_mem_bytes: 0,
};
let t0 = std::time::Instant::now();
unsafe {
let mut builder = resident.compute_stream.launch_builder(&resident.func);
for d in resident.d_cols.iter_mut() {
builder.arg(d);
}
builder.arg(&n_u32);
builder.launch(cfg).map_err(|e| format!("launch: {e}"))?;
}
let kernel_us = t0.elapsed().as_micros();
resident.device_dirty = true;
self.dirty = true;
Ok(kernel_us)
}
#[cfg(feature = "cuda")]
pub fn run_sequence_cuda_resident(
&mut self,
kernels: &[&str],
block_size: u32,
) -> Result<Vec<(String, u128)>, String> {
use cudarc::driver::{LaunchConfig, PushKernelArg};
if block_size == 0 {
return Err("block_size must be positive".to_string());
}
let n = self.agent_count;
if n == 0 {
return Ok(Vec::new());
}
let resident = self.cuda_resident.as_mut().ok_or_else(|| {
"init_cuda must be called before run_sequence_cuda_resident".to_string()
})?;
if resident.host_dirty {
for (i, col) in self.columns.iter().enumerate() {
resident.pinned[i]
.as_mut_slice()
.map_err(|e| format!("pinned access: {e}"))?
.copy_from_slice(col);
resident
.copy_stream
.memcpy_htod(&resident.pinned[i], &mut resident.d_cols[i])
.map_err(|e| format!("htod: {e}"))?;
}
resident
.compute_stream
.join(&resident.copy_stream)
.map_err(|e| format!("compute.join(copy): {e}"))?;
resident.host_dirty = false;
}
let n_u32 = n as u32;
let grid_size = n.div_ceil(block_size as usize) as u32;
let cfg = LaunchConfig {
grid_dim: (grid_size, 1, 1),
block_dim: (block_size, 1, 1),
shared_mem_bytes: 0,
};
let mut timings = Vec::with_capacity(kernels.len());
for name in kernels {
let func = if *name == resident.kernel_name {
resident.func.clone()
} else {
resident
.module
.load_function(name)
.map_err(|e| format!("kernel '{name}' not found: {e}"))?
};
let t0 = std::time::Instant::now();
unsafe {
let mut builder = resident.compute_stream.launch_builder(&func);
for d in resident.d_cols.iter_mut() {
builder.arg(d);
}
builder.arg(&n_u32);
builder.launch(cfg).map_err(|e| format!("launch: {e}"))?;
}
timings.push((name.to_string(), t0.elapsed().as_micros()));
}
resident.device_dirty = true;
self.dirty = true;
Ok(timings)
}
#[cfg(feature = "cuda")]
pub fn sync_to_host(&mut self) -> Result<(), String> {
let resident = match self.cuda_resident.as_mut() {
Some(r) => r,
None => return Ok(()),
};
if !resident.device_dirty {
return Ok(());
}
resident
.copy_stream
.join(&resident.compute_stream)
.map_err(|e| format!("copy.join(compute): {e}"))?;
for (i, d_col) in resident.d_cols.iter().enumerate() {
resident
.copy_stream
.memcpy_dtoh(d_col, &mut resident.pinned[i])
.map_err(|e| format!("dtoh: {e}"))?;
}
resident
.copy_stream
.synchronize()
.map_err(|e| format!("sync: {e}"))?;
for (i, p) in resident.pinned.iter().enumerate() {
self.columns[i]
.copy_from_slice(p.as_slice().map_err(|e| format!("pinned readback: {e}"))?);
}
resident.device_dirty = false;
Ok(())
}
#[cfg(feature = "cuda")]
pub fn mark_host_dirty(&mut self) {
if let Some(r) = self.cuda_resident.as_mut() {
r.host_dirty = true;
}
}
#[cfg(feature = "cuda")]
pub fn release_cuda(&mut self) {
self.cuda_resident = None;
}
#[cfg(feature = "cuda")]
pub fn has_cuda_resident(&self) -> bool {
self.cuda_resident.is_some()
}
pub fn scatter_remove(&mut self, dead_ids: &[AgentId]) {
if dead_ids.is_empty() {
return;
}
let dead_set: std::collections::HashSet<AgentId> = dead_ids.iter().copied().collect();
let mut write = 0;
for read in 0..self.agent_count {
if !dead_set.contains(&self.ids[read]) {
if write != read {
self.ids[write] = self.ids[read];
for col in &mut self.columns {
col[write] = col[read];
}
}
write += 1;
}
}
self.agent_count = write;
self.ids.truncate(write);
for col in &mut self.columns {
col.truncate(write);
}
self.dirty = true;
}
pub fn try_scatter_remove(
&mut self,
dead_ids: &[AgentId],
) -> Result<(), DeviceSoaMutationError> {
validate_remove_ids(dead_ids, &self.ids)?;
self.scatter_remove(dead_ids);
Ok(())
}
pub fn scatter_insert(&mut self, new_ids: &[AgentId], new_columns: &[&[f32]]) {
self.try_scatter_insert(new_ids, new_columns)
.expect("valid SoA scatter insert");
}
pub fn try_scatter_insert(
&mut self,
new_ids: &[AgentId],
new_columns: &[&[f32]],
) -> Result<(), DeviceSoaMutationError> {
validate_insert_shape(new_ids, new_columns, self.columns.len())?;
validate_insert_ids(new_ids, &self.ids)?;
let new_count = new_ids.len();
if new_count == 0 {
return Ok(());
}
self.ids.extend_from_slice(new_ids);
for (i, col) in self.columns.iter_mut().enumerate() {
col.extend_from_slice(new_columns[i]);
}
self.agent_count += new_count;
self.dirty = true;
Ok(())
}
pub fn try_scatter_replace(
&mut self,
dead_ids: &[AgentId],
new_ids: &[AgentId],
new_columns: &[&[f32]],
) -> Result<(), DeviceSoaMutationError> {
validate_remove_ids(dead_ids, &self.ids)?;
validate_insert_shape(new_ids, new_columns, self.columns.len())?;
validate_insert_ids_after_removal(new_ids, &self.ids, dead_ids)?;
self.scatter_remove(dead_ids);
append_rows(&mut self.ids, &mut self.columns, new_ids, new_columns);
self.agent_count = self.ids.len();
if !(dead_ids.is_empty() && new_ids.is_empty()) {
self.dirty = true;
}
Ok(())
}
}
#[derive(Debug)]
pub struct DeviceSoaStoreF64 {
ids: Vec<AgentId>,
columns: Vec<Vec<f64>>,
column_names: Vec<&'static str>,
schema: Vec<SoaColumnSchema>,
agent_count: usize,
dirty: bool,
}
impl DeviceSoaStoreF64 {
pub fn upload<A, S>(store: &S) -> Self
where
A: SoaExtractableF64,
S: AgentStore<A>,
{
let (ids, columns) = soa::extract_soa_f64::<A, S>(store);
let agent_count = ids.len();
let column_names = <A as SoaExtractableF64>::column_names();
let schema = schema_from_names(&column_names, SoaColumnType::F64);
Self {
ids,
columns,
column_names,
schema,
agent_count,
dirty: false,
}
}
pub fn from_checkpoint(
checkpoint: DeviceSoaCheckpointF64,
) -> Result<Self, DeviceSoaRestoreError> {
let agent_count = checkpoint.ids.len();
if checkpoint.column_names.len() != checkpoint.columns.len() {
return Err(DeviceSoaRestoreError::ColumnNameCountMismatch {
columns: checkpoint.columns.len(),
names: checkpoint.column_names.len(),
});
}
validate_schema(
&checkpoint.schema,
checkpoint.columns.len(),
SoaColumnType::F64,
)?;
validate_column_lengths(&checkpoint.columns, agent_count)?;
validate_unique_ids(&checkpoint.ids)?;
Ok(Self {
ids: checkpoint.ids,
columns: checkpoint.columns,
column_names: checkpoint.column_names,
schema: checkpoint.schema,
agent_count,
dirty: false,
})
}
pub fn checkpoint(&self) -> DeviceSoaCheckpointF64 {
DeviceSoaCheckpointF64 {
ids: self.ids.clone(),
columns: self.columns.clone(),
column_names: self.column_names.clone(),
schema: self.schema.clone(),
}
}
pub fn agent_count(&self) -> usize {
self.agent_count
}
pub fn num_columns(&self) -> usize {
self.columns.len()
}
pub fn resident_bytes(&self) -> usize {
let ids_bytes = self.ids.len() * std::mem::size_of::<AgentId>();
let cols_bytes: usize = self
.columns
.iter()
.map(|c| c.len() * std::mem::size_of::<f64>())
.sum();
ids_bytes + cols_bytes
}
pub fn column_names(&self) -> &[&'static str] {
&self.column_names
}
pub fn schema(&self) -> &[SoaColumnSchema] {
&self.schema
}
pub fn column(&self, index: usize) -> &[f64] {
&self.columns[index]
}
pub fn columns_mut(&mut self) -> &mut [Vec<f64>] {
self.dirty = true;
&mut self.columns
}
pub fn ids(&self) -> &[AgentId] {
&self.ids
}
pub fn row_of(&self, id: AgentId) -> Option<usize> {
self.ids.iter().position(|existing| *existing == id)
}
pub fn contains_id(&self, id: AgentId) -> bool {
self.row_of(id).is_some()
}
pub fn next_available_id(&self) -> AgentId {
next_available_id(&self.ids)
}
pub fn is_dirty(&self) -> bool {
self.dirty
}
pub fn step_cpu<F>(&mut self, mut kernel: F) -> u128
where
F: FnMut(&mut [Vec<f64>], usize),
{
let t0 = std::time::Instant::now();
kernel(&mut self.columns, self.agent_count);
self.dirty = true;
t0.elapsed().as_micros()
}
pub fn run_sequence_cpu<'a, I, K>(&mut self, kernels: I) -> Vec<(&'static str, u128)>
where
I: IntoIterator<Item = (&'static str, K)>,
K: FnOnce(&mut [Vec<f64>], usize) + 'a,
{
let mut timings = Vec::new();
for (name, kernel) in kernels {
let t0 = std::time::Instant::now();
kernel(&mut self.columns, self.agent_count);
timings.push((name, t0.elapsed().as_micros()));
}
if !timings.is_empty() {
self.dirty = true;
}
timings
}
pub fn download<A, S>(&mut self, store: &S)
where
A: SoaExtractableF64,
S: AgentStore<A>,
{
soa::write_back_soa_f64::<A, S>(store, &self.ids, &self.columns);
self.dirty = false;
}
pub fn scatter_remove(&mut self, dead_ids: &[AgentId]) {
if dead_ids.is_empty() {
return;
}
let dead_set: std::collections::HashSet<AgentId> = dead_ids.iter().copied().collect();
let mut write = 0;
for read in 0..self.agent_count {
if !dead_set.contains(&self.ids[read]) {
if write != read {
self.ids[write] = self.ids[read];
for col in &mut self.columns {
col[write] = col[read];
}
}
write += 1;
}
}
self.agent_count = write;
self.ids.truncate(write);
for col in &mut self.columns {
col.truncate(write);
}
self.dirty = true;
}
pub fn try_scatter_remove(
&mut self,
dead_ids: &[AgentId],
) -> Result<(), DeviceSoaMutationError> {
validate_remove_ids(dead_ids, &self.ids)?;
self.scatter_remove(dead_ids);
Ok(())
}
pub fn try_scatter_insert(
&mut self,
new_ids: &[AgentId],
new_columns: &[&[f64]],
) -> Result<(), DeviceSoaMutationError> {
validate_insert_shape(new_ids, new_columns, self.columns.len())?;
validate_insert_ids(new_ids, &self.ids)?;
let new_count = new_ids.len();
if new_count == 0 {
return Ok(());
}
self.ids.extend_from_slice(new_ids);
for (i, col) in self.columns.iter_mut().enumerate() {
col.extend_from_slice(new_columns[i]);
}
self.agent_count += new_count;
self.dirty = true;
Ok(())
}
pub fn try_scatter_replace(
&mut self,
dead_ids: &[AgentId],
new_ids: &[AgentId],
new_columns: &[&[f64]],
) -> Result<(), DeviceSoaMutationError> {
validate_remove_ids(dead_ids, &self.ids)?;
validate_insert_shape(new_ids, new_columns, self.columns.len())?;
validate_insert_ids_after_removal(new_ids, &self.ids, dead_ids)?;
self.scatter_remove(dead_ids);
append_rows(&mut self.ids, &mut self.columns, new_ids, new_columns);
self.agent_count = self.ids.len();
if !(dead_ids.is_empty() && new_ids.is_empty()) {
self.dirty = true;
}
Ok(())
}
}
fn schema_from_names(names: &[&'static str], column_type: SoaColumnType) -> Vec<SoaColumnSchema> {
names
.iter()
.map(|name| SoaColumnSchema { name, column_type })
.collect()
}
fn validate_schema(
schema: &[SoaColumnSchema],
columns: usize,
expected: SoaColumnType,
) -> Result<(), DeviceSoaRestoreError> {
if schema.len() != columns {
return Err(DeviceSoaRestoreError::SchemaColumnCountMismatch {
columns,
schema: schema.len(),
});
}
for (column, entry) in schema.iter().enumerate() {
if entry.column_type != expected {
return Err(DeviceSoaRestoreError::SchemaTypeMismatch {
column,
expected,
actual: entry.column_type,
});
}
}
Ok(())
}
fn validate_column_lengths<T>(
columns: &[Vec<T>],
agent_count: usize,
) -> Result<(), DeviceSoaRestoreError> {
for (column, values) in columns.iter().enumerate() {
if values.len() != agent_count {
return Err(DeviceSoaRestoreError::ColumnLengthMismatch {
column,
len: values.len(),
agent_count,
});
}
}
Ok(())
}
fn validate_unique_ids(ids: &[AgentId]) -> Result<(), DeviceSoaRestoreError> {
let mut seen = std::collections::HashSet::with_capacity(ids.len());
for id in ids {
if !seen.insert(*id) {
return Err(DeviceSoaRestoreError::DuplicateAgentId(*id));
}
}
Ok(())
}
fn validate_insert_shape<T>(
new_ids: &[AgentId],
new_columns: &[&[T]],
expected_columns: usize,
) -> Result<(), DeviceSoaMutationError> {
if new_columns.len() != expected_columns {
return Err(DeviceSoaMutationError::ColumnCountMismatch {
expected: expected_columns,
actual: new_columns.len(),
});
}
for (column, values) in new_columns.iter().enumerate() {
if values.len() != new_ids.len() {
return Err(DeviceSoaMutationError::ColumnLengthMismatch {
column,
expected: new_ids.len(),
actual: values.len(),
});
}
}
Ok(())
}
fn validate_insert_ids(
new_ids: &[AgentId],
existing_ids: &[AgentId],
) -> Result<(), DeviceSoaMutationError> {
validate_unique_insert_ids(new_ids)?;
let existing: std::collections::HashSet<AgentId> = existing_ids.iter().copied().collect();
for id in new_ids {
if existing.contains(id) {
return Err(DeviceSoaMutationError::ExistingAgentId(*id));
}
}
Ok(())
}
fn validate_unique_insert_ids(new_ids: &[AgentId]) -> Result<(), DeviceSoaMutationError> {
let mut inserted = std::collections::HashSet::with_capacity(new_ids.len());
for id in new_ids {
if !inserted.insert(*id) {
return Err(DeviceSoaMutationError::DuplicateInsertedId(*id));
}
}
Ok(())
}
fn validate_remove_ids(
dead_ids: &[AgentId],
existing_ids: &[AgentId],
) -> Result<(), DeviceSoaMutationError> {
let mut removed = std::collections::HashSet::with_capacity(dead_ids.len());
for id in dead_ids {
if !removed.insert(*id) {
return Err(DeviceSoaMutationError::DuplicateRemovedId(*id));
}
}
let existing: std::collections::HashSet<AgentId> = existing_ids.iter().copied().collect();
for id in dead_ids {
if !existing.contains(id) {
return Err(DeviceSoaMutationError::MissingAgentId(*id));
}
}
Ok(())
}
fn validate_insert_ids_after_removal(
new_ids: &[AgentId],
existing_ids: &[AgentId],
dead_ids: &[AgentId],
) -> Result<(), DeviceSoaMutationError> {
validate_unique_insert_ids(new_ids)?;
let removed: std::collections::HashSet<AgentId> = dead_ids.iter().copied().collect();
let existing: std::collections::HashSet<AgentId> = existing_ids.iter().copied().collect();
for id in new_ids {
if existing.contains(id) && !removed.contains(id) {
return Err(DeviceSoaMutationError::ExistingAgentId(*id));
}
}
Ok(())
}
fn append_rows<T: Copy>(
ids: &mut Vec<AgentId>,
columns: &mut [Vec<T>],
new_ids: &[AgentId],
new_columns: &[&[T]],
) {
ids.extend_from_slice(new_ids);
for (i, col) in columns.iter_mut().enumerate() {
col.extend_from_slice(new_columns[i]);
}
}
fn next_available_id(ids: &[AgentId]) -> AgentId {
let existing: std::collections::HashSet<AgentId> = ids.iter().copied().collect();
let mut candidate = 0;
while existing.contains(&candidate) {
candidate = candidate.saturating_add(1);
}
candidate
}
#[cfg(test)]
mod tests {
use super::*;
use rustsim_core::prelude::*;
#[derive(Debug, Clone)]
struct TestAgent {
id: AgentId,
x: f32,
vx: f32,
}
impl Agent for TestAgent {
fn id(&self) -> AgentId {
self.id
}
}
impl SoaExtractable for TestAgent {
fn num_columns() -> usize {
2
}
fn column_names() -> Vec<&'static str> {
vec!["x", "vx"]
}
fn extract_row(&self, columns: &mut [Vec<f32>]) {
columns[0].push(self.x);
columns[1].push(self.vx);
}
fn write_back_row(&mut self, columns: &[&[f32]], row: usize) {
self.x = columns[0][row];
self.vx = columns[1][row];
}
}
#[test]
fn persistent_soa_step() {
let mut store = HashMapStore::new();
for i in 1..=100 {
store.insert(TestAgent {
id: i,
x: 0.0,
vx: 1.0,
});
}
let mut device = DeviceSoaStore::upload::<TestAgent, _>(&store);
assert_eq!(device.agent_count(), 100);
assert_eq!(device.num_columns(), 2);
for _ in 0..10 {
device.step_cpu(|columns, n| {
let (x_col, rest) = columns.split_at_mut(1);
let x = &mut x_col[0];
let vx = &rest[0];
for i in 0..n {
x[i] += vx[i];
}
});
}
assert!(device.is_dirty());
for i in 0..device.agent_count() {
assert!((device.column(0)[i] - 10.0).abs() < 1e-5);
}
device.download::<TestAgent, _>(&store);
assert!(!device.is_dirty());
}
#[test]
fn run_sequence_cpu_applies_kernels_in_order() {
let mut store = HashMapStore::new();
for i in 1..=4 {
store.insert(TestAgent {
id: i,
x: 0.0,
vx: 2.0,
});
}
let mut device = DeviceSoaStore::upload::<TestAgent, _>(&store);
let timings = device.run_sequence_cpu(vec![
(
"advance",
Box::new(|cols: &mut [Vec<f32>], n: usize| {
let (x_col, rest) = cols.split_at_mut(1);
let x = &mut x_col[0];
let vx = &rest[0];
for i in 0..n {
x[i] += vx[i];
}
}) as Box<dyn FnOnce(&mut [Vec<f32>], usize)>,
),
(
"scale_vx",
Box::new(|cols: &mut [Vec<f32>], n: usize| {
for v in cols[1].iter_mut().take(n) {
*v *= 3.0;
}
}),
),
(
"advance_again",
Box::new(|cols: &mut [Vec<f32>], n: usize| {
let (x_col, rest) = cols.split_at_mut(1);
let x = &mut x_col[0];
let vx = &rest[0];
for i in 0..n {
x[i] += vx[i];
}
}),
),
]);
assert_eq!(timings.len(), 3);
assert_eq!(timings[0].0, "advance");
assert_eq!(timings[1].0, "scale_vx");
assert_eq!(timings[2].0, "advance_again");
assert!(device.is_dirty());
for i in 0..device.agent_count() {
assert!((device.column(0)[i] - 8.0).abs() < 1e-5);
assert!((device.column(1)[i] - 6.0).abs() < 1e-5);
}
}
#[test]
fn scatter_remove_and_insert() {
let mut store = HashMapStore::new();
for i in 1..=10 {
store.insert(TestAgent {
id: i,
x: i as f32,
vx: 0.0,
});
}
let mut device = DeviceSoaStore::upload::<TestAgent, _>(&store);
assert_eq!(device.agent_count(), 10);
device.scatter_remove(&[3, 7]);
assert_eq!(device.agent_count(), 8);
assert!(!device.ids().contains(&3));
assert!(!device.ids().contains(&7));
device.scatter_insert(&[11, 12], &[&[11.0, 12.0], &[0.0, 0.0]]);
assert_eq!(device.agent_count(), 10);
assert!(device.ids().contains(&11));
assert!(device.ids().contains(&12));
}
}