use crate::columnar_runtime::{ColumnarPhaseBackend, ColumnarRuntime, ColumnarRuntimeError};
use crate::device_store::{DeviceSoaCheckpoint, DeviceSoaRestoreError, DeviceSoaStore};
use std::collections::VecDeque;
use std::sync::Arc;
use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ReplayFingerprint {
pub step_index: u64,
pub value: u64,
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub enum ReplayMode {
#[default]
Off,
Record,
Verify { expected: Vec<ReplayFingerprint> },
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum CheckpointPolicy {
#[default]
Disabled,
EverySteps { interval_steps: u64 },
}
impl CheckpointPolicy {
pub const fn disabled() -> Self {
Self::Disabled
}
pub const fn every_steps(interval_steps: u64) -> Self {
Self::EverySteps { interval_steps }
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ControlPlaneConfig {
pub replay_mode: ReplayMode,
pub checkpoint_policy: CheckpointPolicy,
pub max_retained_checkpoints: usize,
pub partition_plan: PartitionPlan,
pub partition_workers: Vec<PartitionWorker>,
}
impl Default for ControlPlaneConfig {
fn default() -> Self {
Self {
replay_mode: ReplayMode::Off,
checkpoint_policy: CheckpointPolicy::Disabled,
max_retained_checkpoints: 4,
partition_plan: PartitionPlan::single_process(),
partition_workers: Vec::new(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PartitionMode {
SingleProcess,
StaticRowRanges,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct PartitionId(pub u32);
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PartitionAssignment {
pub partition_id: PartitionId,
pub row_start: usize,
pub row_end: usize,
pub worker: Option<String>,
}
impl PartitionAssignment {
pub fn len(&self) -> usize {
self.row_end.saturating_sub(self.row_start)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PartitionPlan {
mode: PartitionMode,
assignments: Vec<PartitionAssignment>,
}
impl PartitionPlan {
pub fn single_process() -> Self {
Self {
mode: PartitionMode::SingleProcess,
assignments: Vec::new(),
}
}
pub fn row_ranges(
agent_count: usize,
partition_count: usize,
) -> Result<Self, PartitionPlanError> {
if partition_count == 0 {
return Err(PartitionPlanError::InvalidPartitionCount);
}
if agent_count == 0 {
return Ok(Self {
mode: PartitionMode::StaticRowRanges,
assignments: vec![PartitionAssignment {
partition_id: PartitionId(0),
row_start: 0,
row_end: 0,
worker: None,
}],
});
}
let mut assignments = Vec::with_capacity(partition_count.min(agent_count));
let base = agent_count / partition_count;
let remainder = agent_count % partition_count;
let mut row_start = 0usize;
for partition in 0..partition_count {
let extra = usize::from(partition < remainder);
let len = base + extra;
if len == 0 {
continue;
}
let row_end = row_start + len;
assignments.push(PartitionAssignment {
partition_id: PartitionId(partition as u32),
row_start,
row_end,
worker: None,
});
row_start = row_end;
}
Ok(Self {
mode: PartitionMode::StaticRowRanges,
assignments,
})
}
pub fn with_workers(mut self, workers: &[impl AsRef<str>]) -> Result<Self, PartitionPlanError> {
if self.assignments.len() != workers.len() {
return Err(PartitionPlanError::WorkerCountMismatch {
assignments: self.assignments.len(),
workers: workers.len(),
});
}
for (assignment, worker) in self.assignments.iter_mut().zip(workers) {
assignment.worker = Some(worker.as_ref().to_string());
}
Ok(self)
}
pub fn mode(&self) -> PartitionMode {
self.mode
}
pub fn assignments(&self) -> &[PartitionAssignment] {
&self.assignments
}
pub fn partition_count(&self) -> usize {
match self.mode {
PartitionMode::SingleProcess => 1,
PartitionMode::StaticRowRanges => self.assignments.len(),
}
}
pub fn has_worker_assignments(&self) -> bool {
self.assignments
.iter()
.any(|assignment| assignment.worker.is_some())
}
fn validate(&self, agent_count: usize) -> Result<(), PartitionPlanError> {
match self.mode {
PartitionMode::SingleProcess => {
if self.assignments.is_empty() {
Ok(())
} else {
Err(PartitionPlanError::SingleProcessHasAssignments)
}
}
PartitionMode::StaticRowRanges => {
validate_static_assignments(&self.assignments, agent_count)
}
}
}
}
impl Default for PartitionPlan {
fn default() -> Self {
Self::single_process()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PartitionExecutionBackend {
LocalCpu,
CudaDevice { device_ordinal: u32 },
External { endpoint: String },
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PartitionWorker {
pub label: String,
pub backend: PartitionExecutionBackend,
}
impl PartitionWorker {
pub fn local_cpu(label: impl Into<String>) -> Self {
Self {
label: label.into(),
backend: PartitionExecutionBackend::LocalCpu,
}
}
pub fn cuda_device(label: impl Into<String>, device_ordinal: u32) -> Self {
Self {
label: label.into(),
backend: PartitionExecutionBackend::CudaDevice { device_ordinal },
}
}
pub fn external(label: impl Into<String>, endpoint: impl Into<String>) -> Self {
Self {
label: label.into(),
backend: PartitionExecutionBackend::External {
endpoint: endpoint.into(),
},
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PartitionExecutionContext {
pub partition_id: PartitionId,
pub row_start: usize,
pub row_end: usize,
pub worker: Option<String>,
pub backend: PartitionExecutionBackend,
pub step_index: u64,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PartitionPhaseReport {
pub partition_id: PartitionId,
pub worker: Option<String>,
pub phase_index: usize,
pub name: &'static str,
pub backend: PartitionExecutionBackend,
pub rows: usize,
pub elapsed_us: u128,
}
#[derive(Debug, Clone)]
pub struct PartitionedStep {
pub step_index: u64,
pub agent_count: usize,
pub phases: Vec<PartitionPhaseReport>,
pub fingerprint: Option<ReplayFingerprint>,
pub checkpoint: Option<EngineCheckpoint>,
pub total_us: u128,
}
type PartitionCpuPhaseFn =
dyn FnMut(&PartitionExecutionContext, &mut [Vec<f32>], usize) -> Result<(), String>;
type PartitionExecutablePhaseFn =
dyn Fn(&PartitionExecutionContext, &mut [Vec<f32>], usize) -> Result<(), String> + Send + Sync;
struct PartitionCpuPhase {
name: &'static str,
function: Box<PartitionCpuPhaseFn>,
}
#[derive(Clone)]
pub struct PartitionExecutablePhase {
name: &'static str,
function: Arc<PartitionExecutablePhaseFn>,
}
impl PartitionExecutablePhase {
pub fn name(&self) -> &'static str {
self.name
}
pub fn run(
&self,
context: &PartitionExecutionContext,
columns: &mut [Vec<f32>],
rows: usize,
) -> Result<(), String> {
(self.function)(context, columns, rows)
}
}
#[derive(Clone)]
pub struct PartitionTask {
context: PartitionExecutionContext,
checkpoint: PartitionCheckpoint,
phase_base: usize,
phases: Vec<PartitionExecutablePhase>,
}
impl PartitionTask {
pub fn context(&self) -> &PartitionExecutionContext {
&self.context
}
pub fn checkpoint(&self) -> &PartitionCheckpoint {
&self.checkpoint
}
pub fn phase_base(&self) -> usize {
self.phase_base
}
pub fn phases(&self) -> &[PartitionExecutablePhase] {
&self.phases
}
pub fn into_parts(
self,
) -> (
PartitionExecutionContext,
PartitionCheckpoint,
usize,
Vec<PartitionExecutablePhase>,
) {
(self.context, self.checkpoint, self.phase_base, self.phases)
}
}
#[derive(Debug, Clone)]
pub struct PartitionTaskResult {
pub checkpoint: PartitionCheckpoint,
pub phases: Vec<PartitionPhaseReport>,
}
impl PartitionTaskResult {
pub fn completed(
context: &PartitionExecutionContext,
mut checkpoint: PartitionCheckpoint,
phases: Vec<PartitionPhaseReport>,
) -> Self {
checkpoint.step_index = context.step_index.saturating_add(1);
checkpoint.fingerprint = fingerprint_checkpoint(checkpoint.step_index, &checkpoint.state);
checkpoint.resident_bytes = checkpoint.state.resident_bytes();
Self { checkpoint, phases }
}
}
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum PartitionExecutorError {
#[error("partition {partition:?} worker failed: {message}")]
Worker {
partition: PartitionId,
message: String,
},
#[error("partition {partition:?} phase {phase} failed: {message}")]
Phase {
partition: PartitionId,
phase: &'static str,
message: String,
},
#[error("partition {partition:?} has no external endpoint")]
MissingEndpoint {
partition: PartitionId,
},
#[error("partition executor returned {actual} results, expected {expected}")]
ResultCountMismatch {
expected: usize,
actual: usize,
},
}
pub trait PartitionExecutor {
fn execute(
&mut self,
tasks: Vec<PartitionTask>,
) -> Result<Vec<PartitionTaskResult>, PartitionExecutorError>;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct LocalThreadedPartitionExecutor;
impl LocalThreadedPartitionExecutor {
pub const fn new() -> Self {
Self
}
}
impl PartitionExecutor for LocalThreadedPartitionExecutor {
fn execute(
&mut self,
tasks: Vec<PartitionTask>,
) -> Result<Vec<PartitionTaskResult>, PartitionExecutorError> {
std::thread::scope(|scope| {
let handles = tasks
.into_iter()
.map(|task| scope.spawn(move || execute_partition_task(task)))
.collect::<Vec<_>>();
let mut results = Vec::with_capacity(handles.len());
for handle in handles {
match handle.join() {
Ok(result) => results.push(result?),
Err(_) => {
return Err(PartitionExecutorError::Worker {
partition: PartitionId(u32::MAX),
message: "partition worker thread panicked".to_string(),
});
}
}
}
Ok(results)
})
}
}
pub trait ExternalPartitionClient {
fn execute_partition(
&mut self,
endpoint: &str,
task: PartitionTask,
) -> Result<PartitionTaskResult, String>;
}
pub struct EndpointPartitionExecutor<C> {
client: C,
}
impl<C> EndpointPartitionExecutor<C> {
pub fn new(client: C) -> Self {
Self { client }
}
pub fn client(&self) -> &C {
&self.client
}
pub fn client_mut(&mut self) -> &mut C {
&mut self.client
}
}
impl<C> PartitionExecutor for EndpointPartitionExecutor<C>
where
C: ExternalPartitionClient,
{
fn execute(
&mut self,
tasks: Vec<PartitionTask>,
) -> Result<Vec<PartitionTaskResult>, PartitionExecutorError> {
let mut results = Vec::with_capacity(tasks.len());
for task in tasks {
let partition = task.context.partition_id;
let endpoint = match &task.context.backend {
PartitionExecutionBackend::External { endpoint } => endpoint.clone(),
_ => {
return Err(PartitionExecutorError::Worker {
partition,
message: "endpoint executor can only run External partition backends"
.to_string(),
});
}
};
let result = self
.client
.execute_partition(&endpoint, task)
.map_err(|message| PartitionExecutorError::Worker { partition, message })?;
results.push(result);
}
Ok(results)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum PartitionPlanError {
#[error("partition_count must be greater than zero")]
InvalidPartitionCount,
#[error("single-process partition plan must not contain row assignments")]
SingleProcessHasAssignments,
#[error("worker count {workers} does not match assignment count {assignments}")]
WorkerCountMismatch { assignments: usize, workers: usize },
#[error("partition {partition:?} starts at row {row_start}, expected {expected_start}")]
NonContiguousRange {
partition: PartitionId,
row_start: usize,
expected_start: usize,
},
#[error("partition {partition:?} ends at row {row_end}, beyond agent count {agent_count}")]
RangeOutOfBounds {
partition: PartitionId,
row_end: usize,
agent_count: usize,
},
#[error("partition {partition:?} has an empty row range")]
EmptyRange { partition: PartitionId },
#[error("static partition plan covers {covered_rows} rows, expected {agent_count}")]
CoverageMismatch {
covered_rows: usize,
agent_count: usize,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EnginePhaseMetric {
pub phase_index: usize,
pub name: &'static str,
pub backend: ColumnarPhaseBackend,
pub executions: u64,
pub last_elapsed_us: u128,
pub total_elapsed_us: u128,
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct EngineMetrics {
pub current_step: u64,
pub agent_count: usize,
pub phase_count: usize,
pub resident_bytes: usize,
pub steps_completed: u64,
pub last_step_us: u128,
pub total_step_us: u128,
pub checkpoints_created: u64,
pub replay_fingerprints: u64,
pub partition_count: usize,
pub partition_steps_completed: u64,
pub partition_rows_processed: u64,
pub phases: Vec<EnginePhaseMetric>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct EngineCheckpoint {
pub sequence: u64,
pub label: Option<String>,
pub step_index: u64,
pub fingerprint: ReplayFingerprint,
pub resident_bytes: usize,
pub state: DeviceSoaCheckpoint,
}
#[derive(Debug, Clone, PartialEq)]
pub struct PartitionCheckpoint {
pub assignment: PartitionAssignment,
pub step_index: u64,
pub fingerprint: ReplayFingerprint,
pub resident_bytes: usize,
pub state: DeviceSoaCheckpoint,
}
#[derive(Debug, Clone)]
pub struct ControlledStep {
pub timing: crate::columnar_runtime::ColumnarStepTiming,
pub fingerprint: Option<ReplayFingerprint>,
pub checkpoint: Option<EngineCheckpoint>,
}
#[derive(Debug, Error)]
pub enum EngineControlPlaneError {
#[error("invalid control-plane configuration: {0}")]
InvalidConfig(String),
#[error("runtime error: {0}")]
Runtime(#[from] ColumnarRuntimeError),
#[error("checkpoint restore failed: {0}")]
Restore(#[from] DeviceSoaRestoreError),
#[error("replay diverged at step {step_index}: expected {expected:#018x}, got {actual:#018x}")]
ReplayDivergence {
step_index: u64,
expected: u64,
actual: u64,
},
#[error("replay expected step {expected_step} but runtime produced step {actual_step}")]
ReplayStepMismatch {
expected_step: u64,
actual_step: u64,
},
#[error("replay has no expected fingerprint for step {step_index}")]
ReplayExhausted { step_index: u64 },
#[error("partition plan error: {0}")]
Partition(#[from] PartitionPlanError),
#[error("partition checkpoint error: {0}")]
PartitionCheckpoint(String),
#[error("partition {partition:?} references missing worker {worker}")]
MissingPartitionWorker {
partition: PartitionId,
worker: String,
},
#[error("partition {partition:?} phase {phase} failed: {message}")]
PartitionExecution {
partition: PartitionId,
phase: &'static str,
message: String,
},
#[error("partition executor error: {0}")]
PartitionExecutor(#[from] PartitionExecutorError),
}
pub struct EngineControlPlane {
runtime: ColumnarRuntime,
config: ControlPlaneConfig,
metrics: EngineMetrics,
retained_checkpoints: VecDeque<EngineCheckpoint>,
recorded_replay: Vec<ReplayFingerprint>,
verify_cursor: usize,
partition_phases: Vec<PartitionCpuPhase>,
executable_partition_phases: Vec<PartitionExecutablePhase>,
}
impl EngineControlPlane {
pub fn new(runtime: ColumnarRuntime) -> Self {
let config = ControlPlaneConfig::default();
let metrics = Self::initial_metrics(&runtime, &config);
Self {
runtime,
config,
metrics,
retained_checkpoints: VecDeque::new(),
recorded_replay: Vec::new(),
verify_cursor: 0,
partition_phases: Vec::new(),
executable_partition_phases: Vec::new(),
}
}
pub fn with_config(
runtime: ColumnarRuntime,
config: ControlPlaneConfig,
) -> Result<Self, EngineControlPlaneError> {
validate_config(&config, runtime.agent_count())?;
let metrics = Self::initial_metrics(&runtime, &config);
Ok(Self {
runtime,
config,
metrics,
retained_checkpoints: VecDeque::new(),
recorded_replay: Vec::new(),
verify_cursor: 0,
partition_phases: Vec::new(),
executable_partition_phases: Vec::new(),
})
}
pub fn runtime(&self) -> &ColumnarRuntime {
&self.runtime
}
pub fn runtime_mut(&mut self) -> &mut ColumnarRuntime {
&mut self.runtime
}
pub fn config(&self) -> &ControlPlaneConfig {
&self.config
}
pub fn metrics(&self) -> EngineMetrics {
self.metrics.clone()
}
pub fn checkpoints(&self) -> impl DoubleEndedIterator<Item = &EngineCheckpoint> {
self.retained_checkpoints.iter()
}
pub fn latest_checkpoint(&self) -> Option<&EngineCheckpoint> {
self.retained_checkpoints.back()
}
pub fn recorded_replay(&self) -> &[ReplayFingerprint] {
&self.recorded_replay
}
pub fn partition_plan(&self) -> &PartitionPlan {
&self.config.partition_plan
}
pub fn partition_phase_count(&self) -> usize {
self.partition_phases.len() + self.executable_partition_phases.len()
}
pub fn add_partition_cpu_phase(
&mut self,
name: &'static str,
function: impl FnMut(&PartitionExecutionContext, &mut [Vec<f32>], usize) -> Result<(), String>
+ 'static,
) {
self.partition_phases.push(PartitionCpuPhase {
name,
function: Box::new(function),
});
self.refresh_live_metrics();
}
pub fn add_partition_executor_phase(
&mut self,
name: &'static str,
function: impl Fn(&PartitionExecutionContext, &mut [Vec<f32>], usize) -> Result<(), String>
+ Send
+ Sync
+ 'static,
) {
self.executable_partition_phases
.push(PartitionExecutablePhase {
name,
function: Arc::new(function),
});
self.refresh_live_metrics();
}
pub fn step_partitions(&mut self) -> Result<PartitionedStep, EngineControlPlaneError> {
if self.partition_phases.is_empty() {
return Err(EngineControlPlaneError::InvalidConfig(
"at least one partition phase must be registered".to_string(),
));
}
let total_start = std::time::Instant::now();
let source_step = self.runtime.step_index();
let phase_base = self.runtime.phase_count();
let workers = self.config.partition_workers.clone();
let mut partitions = self.partition_checkpoints()?;
let mut phase_reports = Vec::with_capacity(partitions.len() * self.partition_phases.len());
for partition in &mut partitions {
let backend = resolve_partition_backend(&partition.assignment, &workers)?;
let context = PartitionExecutionContext {
partition_id: partition.assignment.partition_id,
row_start: partition.assignment.row_start,
row_end: partition.assignment.row_end,
worker: partition.assignment.worker.clone(),
backend,
step_index: source_step,
};
let rows = partition.state.agent_count();
for (local_phase_index, phase) in self.partition_phases.iter_mut().enumerate() {
let phase_start = std::time::Instant::now();
(phase.function)(&context, &mut partition.state.columns, rows).map_err(
|message| EngineControlPlaneError::PartitionExecution {
partition: context.partition_id,
phase: phase.name,
message,
},
)?;
phase_reports.push(PartitionPhaseReport {
partition_id: context.partition_id,
worker: context.worker.clone(),
phase_index: phase_base + local_phase_index,
name: phase.name,
backend: context.backend.clone(),
rows,
elapsed_us: phase_start.elapsed().as_micros(),
});
}
partition.step_index = source_step.saturating_add(1);
partition.fingerprint = fingerprint_checkpoint(partition.step_index, &partition.state);
partition.resident_bytes = partition.state.resident_bytes();
}
self.restore_partitions(&partitions)?;
let total_us = total_start.elapsed().as_micros();
self.update_partition_metrics(total_us, &phase_reports);
let fingerprint = self.capture_replay_fingerprint()?;
let checkpoint = self.maybe_checkpoint()?;
self.refresh_live_metrics();
Ok(PartitionedStep {
step_index: source_step,
agent_count: self.runtime.agent_count(),
phases: phase_reports,
fingerprint,
checkpoint,
total_us,
})
}
pub fn run_partitions(
&mut self,
steps: usize,
) -> Result<Vec<PartitionedStep>, EngineControlPlaneError> {
let mut results = Vec::with_capacity(steps);
for _ in 0..steps {
results.push(self.step_partitions()?);
}
Ok(results)
}
pub fn step_partitions_with_executor(
&mut self,
executor: &mut impl PartitionExecutor,
) -> Result<PartitionedStep, EngineControlPlaneError> {
if self.executable_partition_phases.is_empty() {
return Err(EngineControlPlaneError::InvalidConfig(
"at least one executor partition phase must be registered".to_string(),
));
}
let total_start = std::time::Instant::now();
let source_step = self.runtime.step_index();
let phase_base = self.runtime.phase_count() + self.partition_phases.len();
let tasks = self.build_partition_tasks(source_step, phase_base)?;
let expected_results = tasks.len();
let results = executor.execute(tasks)?;
if results.len() != expected_results {
return Err(PartitionExecutorError::ResultCountMismatch {
expected: expected_results,
actual: results.len(),
}
.into());
}
let mut partitions = Vec::with_capacity(results.len());
let mut phase_reports = Vec::new();
for result in results {
phase_reports.extend(result.phases);
partitions.push(result.checkpoint);
}
self.restore_partitions(&partitions)?;
let total_us = total_start.elapsed().as_micros();
self.update_partition_metrics(total_us, &phase_reports);
let fingerprint = self.capture_replay_fingerprint()?;
let checkpoint = self.maybe_checkpoint()?;
self.refresh_live_metrics();
Ok(PartitionedStep {
step_index: source_step,
agent_count: self.runtime.agent_count(),
phases: phase_reports,
fingerprint,
checkpoint,
total_us,
})
}
pub fn run_partitions_with_executor(
&mut self,
steps: usize,
executor: &mut impl PartitionExecutor,
) -> Result<Vec<PartitionedStep>, EngineControlPlaneError> {
let mut results = Vec::with_capacity(steps);
for _ in 0..steps {
results.push(self.step_partitions_with_executor(executor)?);
}
Ok(results)
}
pub fn partition_checkpoints(
&mut self,
) -> Result<Vec<PartitionCheckpoint>, EngineControlPlaneError> {
self.runtime.sync_to_host()?;
self.config
.partition_plan
.validate(self.runtime.agent_count())?;
let source = self.runtime.device_store().checkpoint();
let assignments =
materialized_assignments(&self.config.partition_plan, source.agent_count());
let mut partitions = Vec::with_capacity(assignments.len());
for assignment in assignments {
let state = slice_checkpoint(&source, assignment.row_start, assignment.row_end)?;
let fingerprint = fingerprint_checkpoint(self.runtime.step_index(), &state);
partitions.push(PartitionCheckpoint {
assignment,
step_index: self.runtime.step_index(),
resident_bytes: state.resident_bytes(),
fingerprint,
state,
});
}
Ok(partitions)
}
pub fn restore_partitions(
&mut self,
partitions: &[PartitionCheckpoint],
) -> Result<(), EngineControlPlaneError> {
let checkpoint = assemble_partition_checkpoints(partitions, &self.config.partition_plan)?;
let step_index = partitions[0].step_index;
let store = DeviceSoaStore::from_checkpoint(checkpoint)?;
self.runtime.restore_device_store(store, step_index);
self.metrics.current_step = step_index;
self.metrics.agent_count = self.runtime.agent_count();
self.metrics.resident_bytes = self.runtime.device_store().resident_bytes();
self.verify_cursor = self.next_verify_cursor_after(step_index);
self.config
.partition_plan
.validate(self.runtime.agent_count())?;
self.refresh_live_metrics();
Ok(())
}
pub fn step(&mut self) -> Result<ControlledStep, EngineControlPlaneError> {
let timing = self.runtime.step()?;
self.update_metrics(&timing);
let fingerprint = self.capture_replay_fingerprint()?;
let checkpoint = self.maybe_checkpoint()?;
self.refresh_live_metrics();
Ok(ControlledStep {
timing,
fingerprint,
checkpoint,
})
}
pub fn run(&mut self, steps: usize) -> Result<Vec<ControlledStep>, EngineControlPlaneError> {
let mut results = Vec::with_capacity(steps);
for _ in 0..steps {
results.push(self.step()?);
}
Ok(results)
}
pub fn checkpoint(
&mut self,
label: Option<String>,
) -> Result<EngineCheckpoint, EngineControlPlaneError> {
self.runtime.sync_to_host()?;
let fingerprint = fingerprint_runtime(&self.runtime);
let state = self.runtime.device_store().checkpoint();
let checkpoint = EngineCheckpoint {
sequence: self.metrics.checkpoints_created + 1,
label,
step_index: self.runtime.step_index(),
fingerprint,
resident_bytes: state.resident_bytes(),
state,
};
self.metrics.checkpoints_created += 1;
self.retain_checkpoint(checkpoint.clone());
self.refresh_live_metrics();
Ok(checkpoint)
}
pub fn restore(
&mut self,
checkpoint: &EngineCheckpoint,
) -> Result<(), EngineControlPlaneError> {
let store = DeviceSoaStore::from_checkpoint(checkpoint.state.clone())?;
self.runtime
.restore_device_store(store, checkpoint.step_index);
self.metrics.current_step = checkpoint.step_index;
self.metrics.agent_count = self.runtime.agent_count();
self.metrics.resident_bytes = self.runtime.device_store().resident_bytes();
self.verify_cursor = self.next_verify_cursor_after(checkpoint.step_index);
self.config
.partition_plan
.validate(self.runtime.agent_count())?;
Ok(())
}
fn initial_metrics(runtime: &ColumnarRuntime, config: &ControlPlaneConfig) -> EngineMetrics {
EngineMetrics {
current_step: runtime.step_index(),
agent_count: runtime.agent_count(),
phase_count: runtime.phase_count(),
resident_bytes: runtime.device_store().resident_bytes(),
partition_count: config.partition_plan.partition_count(),
..EngineMetrics::default()
}
}
fn update_metrics(&mut self, timing: &crate::columnar_runtime::ColumnarStepTiming) {
self.metrics.steps_completed += 1;
self.metrics.current_step = self.runtime.step_index();
self.metrics.agent_count = timing.agent_count;
self.metrics.phase_count = self.runtime.phase_count();
self.metrics.last_step_us = timing.total_us;
self.metrics.total_step_us += timing.total_us;
for phase in &timing.phases {
if let Some(metric) = self
.metrics
.phases
.iter_mut()
.find(|metric| metric.phase_index == phase.phase_index)
{
metric.executions += 1;
metric.last_elapsed_us = phase.elapsed_us;
metric.total_elapsed_us += phase.elapsed_us;
} else {
self.metrics.phases.push(EnginePhaseMetric {
phase_index: phase.phase_index,
name: phase.name,
backend: phase.backend,
executions: 1,
last_elapsed_us: phase.elapsed_us,
total_elapsed_us: phase.elapsed_us,
});
}
}
}
fn refresh_live_metrics(&mut self) {
self.metrics.current_step = self.runtime.step_index();
self.metrics.agent_count = self.runtime.agent_count();
self.metrics.phase_count = self.runtime.phase_count() + self.partition_phase_count();
self.metrics.resident_bytes = self.runtime.device_store().resident_bytes();
self.metrics.partition_count = self.config.partition_plan.partition_count();
}
fn update_partition_metrics(&mut self, total_us: u128, reports: &[PartitionPhaseReport]) {
self.metrics.steps_completed += 1;
self.metrics.partition_steps_completed += 1;
self.metrics.current_step = self.runtime.step_index();
self.metrics.agent_count = self.runtime.agent_count();
self.metrics.phase_count = self.runtime.phase_count() + self.partition_phase_count();
self.metrics.last_step_us = total_us;
self.metrics.total_step_us += total_us;
for report in reports {
self.metrics.partition_rows_processed += report.rows as u64;
if let Some(metric) = self
.metrics
.phases
.iter_mut()
.find(|metric| metric.phase_index == report.phase_index)
{
metric.executions += 1;
metric.last_elapsed_us = report.elapsed_us;
metric.total_elapsed_us += report.elapsed_us;
} else {
self.metrics.phases.push(EnginePhaseMetric {
phase_index: report.phase_index,
name: report.name,
backend: columnar_backend_for_partition(&report.backend),
executions: 1,
last_elapsed_us: report.elapsed_us,
total_elapsed_us: report.elapsed_us,
});
}
}
}
fn build_partition_tasks(
&mut self,
source_step: u64,
phase_base: usize,
) -> Result<Vec<PartitionTask>, EngineControlPlaneError> {
let workers = self.config.partition_workers.clone();
let partitions = self.partition_checkpoints()?;
let mut tasks = Vec::with_capacity(partitions.len());
for partition in partitions {
let backend = resolve_partition_backend(&partition.assignment, &workers)?;
let context = PartitionExecutionContext {
partition_id: partition.assignment.partition_id,
row_start: partition.assignment.row_start,
row_end: partition.assignment.row_end,
worker: partition.assignment.worker.clone(),
backend,
step_index: source_step,
};
tasks.push(PartitionTask {
context,
checkpoint: partition,
phase_base,
phases: self.executable_partition_phases.clone(),
});
}
Ok(tasks)
}
fn capture_replay_fingerprint(
&mut self,
) -> Result<Option<ReplayFingerprint>, EngineControlPlaneError> {
match &self.config.replay_mode {
ReplayMode::Off => Ok(None),
ReplayMode::Record => {
self.runtime.sync_to_host()?;
let fingerprint = fingerprint_runtime(&self.runtime);
self.recorded_replay.push(fingerprint);
self.metrics.replay_fingerprints += 1;
Ok(Some(fingerprint))
}
ReplayMode::Verify { expected } => {
self.runtime.sync_to_host()?;
let fingerprint = fingerprint_runtime(&self.runtime);
let expected_fingerprint = expected.get(self.verify_cursor).ok_or(
EngineControlPlaneError::ReplayExhausted {
step_index: fingerprint.step_index,
},
)?;
if expected_fingerprint.step_index != fingerprint.step_index {
return Err(EngineControlPlaneError::ReplayStepMismatch {
expected_step: expected_fingerprint.step_index,
actual_step: fingerprint.step_index,
});
}
if expected_fingerprint.value != fingerprint.value {
return Err(EngineControlPlaneError::ReplayDivergence {
step_index: fingerprint.step_index,
expected: expected_fingerprint.value,
actual: fingerprint.value,
});
}
self.verify_cursor += 1;
self.recorded_replay.push(fingerprint);
self.metrics.replay_fingerprints += 1;
Ok(Some(fingerprint))
}
}
}
fn maybe_checkpoint(&mut self) -> Result<Option<EngineCheckpoint>, EngineControlPlaneError> {
match self.config.checkpoint_policy {
CheckpointPolicy::Disabled => Ok(None),
CheckpointPolicy::EverySteps { interval_steps } => {
let step_index = self.runtime.step_index();
if step_index > 0 && step_index.is_multiple_of(interval_steps) {
self.checkpoint(None).map(Some)
} else {
Ok(None)
}
}
}
}
fn retain_checkpoint(&mut self, checkpoint: EngineCheckpoint) {
self.retained_checkpoints.push_back(checkpoint);
while self.retained_checkpoints.len() > self.config.max_retained_checkpoints {
self.retained_checkpoints.pop_front();
}
}
fn next_verify_cursor_after(&self, step_index: u64) -> usize {
match &self.config.replay_mode {
ReplayMode::Verify { expected } => expected
.iter()
.position(|fingerprint| fingerprint.step_index > step_index)
.unwrap_or(expected.len()),
ReplayMode::Off | ReplayMode::Record => 0,
}
}
}
fn validate_config(
config: &ControlPlaneConfig,
agent_count: usize,
) -> Result<(), EngineControlPlaneError> {
if config.max_retained_checkpoints == 0 {
return Err(EngineControlPlaneError::InvalidConfig(
"max_retained_checkpoints must be greater than zero".to_string(),
));
}
if matches!(
config.checkpoint_policy,
CheckpointPolicy::EverySteps { interval_steps: 0 }
) {
return Err(EngineControlPlaneError::InvalidConfig(
"checkpoint interval_steps must be greater than zero".to_string(),
));
}
config.partition_plan.validate(agent_count)?;
Ok(())
}
fn resolve_partition_backend(
assignment: &PartitionAssignment,
workers: &[PartitionWorker],
) -> Result<PartitionExecutionBackend, EngineControlPlaneError> {
let Some(worker_label) = &assignment.worker else {
return Ok(PartitionExecutionBackend::LocalCpu);
};
workers
.iter()
.find(|worker| worker.label == *worker_label)
.map(|worker| worker.backend.clone())
.ok_or_else(|| EngineControlPlaneError::MissingPartitionWorker {
partition: assignment.partition_id,
worker: worker_label.clone(),
})
}
fn execute_partition_task(
task: PartitionTask,
) -> Result<PartitionTaskResult, PartitionExecutorError> {
let (context, mut checkpoint, phase_base, phases) = task.into_parts();
let rows = checkpoint.state.agent_count();
let mut reports = Vec::with_capacity(phases.len());
for (local_phase_index, phase) in phases.iter().enumerate() {
let phase_start = std::time::Instant::now();
phase
.run(&context, &mut checkpoint.state.columns, rows)
.map_err(|message| PartitionExecutorError::Phase {
partition: context.partition_id,
phase: phase.name(),
message,
})?;
reports.push(PartitionPhaseReport {
partition_id: context.partition_id,
worker: context.worker.clone(),
phase_index: phase_base + local_phase_index,
name: phase.name(),
backend: context.backend.clone(),
rows,
elapsed_us: phase_start.elapsed().as_micros(),
});
}
Ok(PartitionTaskResult::completed(
&context, checkpoint, reports,
))
}
fn columnar_backend_for_partition(backend: &PartitionExecutionBackend) -> ColumnarPhaseBackend {
match backend {
#[cfg(feature = "cuda")]
PartitionExecutionBackend::CudaDevice { .. } => ColumnarPhaseBackend::CudaResident,
#[cfg(not(feature = "cuda"))]
PartitionExecutionBackend::CudaDevice { .. } => ColumnarPhaseBackend::Cpu,
PartitionExecutionBackend::LocalCpu | PartitionExecutionBackend::External { .. } => {
ColumnarPhaseBackend::Cpu
}
}
}
fn validate_static_assignments(
assignments: &[PartitionAssignment],
agent_count: usize,
) -> Result<(), PartitionPlanError> {
if assignments.is_empty() {
return Err(PartitionPlanError::CoverageMismatch {
covered_rows: 0,
agent_count,
});
}
let mut expected_start = 0usize;
for assignment in assignments {
if assignment.row_start != expected_start {
return Err(PartitionPlanError::NonContiguousRange {
partition: assignment.partition_id,
row_start: assignment.row_start,
expected_start,
});
}
if assignment.row_end > agent_count {
return Err(PartitionPlanError::RangeOutOfBounds {
partition: assignment.partition_id,
row_end: assignment.row_end,
agent_count,
});
}
if assignment.row_end == assignment.row_start && agent_count > 0 {
return Err(PartitionPlanError::EmptyRange {
partition: assignment.partition_id,
});
}
expected_start = assignment.row_end;
}
if expected_start != agent_count {
return Err(PartitionPlanError::CoverageMismatch {
covered_rows: expected_start,
agent_count,
});
}
Ok(())
}
fn materialized_assignments(plan: &PartitionPlan, agent_count: usize) -> Vec<PartitionAssignment> {
match plan.mode() {
PartitionMode::SingleProcess => vec![PartitionAssignment {
partition_id: PartitionId(0),
row_start: 0,
row_end: agent_count,
worker: None,
}],
PartitionMode::StaticRowRanges => plan.assignments().to_vec(),
}
}
fn slice_checkpoint(
checkpoint: &DeviceSoaCheckpoint,
row_start: usize,
row_end: usize,
) -> Result<DeviceSoaCheckpoint, EngineControlPlaneError> {
if row_start > row_end || row_end > checkpoint.agent_count() {
return Err(EngineControlPlaneError::PartitionCheckpoint(format!(
"invalid checkpoint slice {row_start}..{row_end} for {} rows",
checkpoint.agent_count()
)));
}
Ok(DeviceSoaCheckpoint {
ids: checkpoint.ids[row_start..row_end].to_vec(),
columns: checkpoint
.columns
.iter()
.map(|column| column[row_start..row_end].to_vec())
.collect(),
column_names: checkpoint.column_names.clone(),
schema: checkpoint.schema.clone(),
})
}
fn assemble_partition_checkpoints(
partitions: &[PartitionCheckpoint],
plan: &PartitionPlan,
) -> Result<DeviceSoaCheckpoint, EngineControlPlaneError> {
let Some(first) = partitions.first() else {
return Err(EngineControlPlaneError::PartitionCheckpoint(
"at least one partition checkpoint is required".to_string(),
));
};
let step_index = first.step_index;
let column_names = first.state.column_names.clone();
let schema = first.state.schema.clone();
let column_count = first.state.num_columns();
let mut ids = Vec::new();
let mut columns: Vec<Vec<f32>> = (0..column_count).map(|_| Vec::new()).collect();
for partition in partitions {
if partition.step_index != step_index {
return Err(EngineControlPlaneError::PartitionCheckpoint(format!(
"partition {:?} captured step {}, expected {}",
partition.assignment.partition_id, partition.step_index, step_index
)));
}
if partition.assignment.row_start != ids.len() {
return Err(EngineControlPlaneError::PartitionCheckpoint(format!(
"partition {:?} starts at row {}, expected {}",
partition.assignment.partition_id,
partition.assignment.row_start,
ids.len()
)));
}
if partition.assignment.len() != partition.state.agent_count() {
return Err(EngineControlPlaneError::PartitionCheckpoint(format!(
"partition {:?} has assignment length {} but checkpoint contains {} rows",
partition.assignment.partition_id,
partition.assignment.len(),
partition.state.agent_count()
)));
}
if partition.state.column_names != column_names || partition.state.schema != schema {
return Err(EngineControlPlaneError::PartitionCheckpoint(format!(
"partition {:?} schema does not match the first partition",
partition.assignment.partition_id
)));
}
if partition.fingerprint != fingerprint_checkpoint(partition.step_index, &partition.state) {
return Err(EngineControlPlaneError::PartitionCheckpoint(format!(
"partition {:?} fingerprint mismatch",
partition.assignment.partition_id
)));
}
ids.extend_from_slice(&partition.state.ids);
for (target, source) in columns.iter_mut().zip(&partition.state.columns) {
target.extend_from_slice(source);
}
}
let expected = materialized_assignments(plan, ids.len());
let actual: Vec<_> = partitions
.iter()
.map(|partition| partition.assignment.clone())
.collect();
if expected != actual {
return Err(EngineControlPlaneError::PartitionCheckpoint(
"partition checkpoints do not match the active partition plan".to_string(),
));
}
Ok(DeviceSoaCheckpoint {
ids,
columns,
column_names,
schema,
})
}
fn fingerprint_runtime(runtime: &ColumnarRuntime) -> ReplayFingerprint {
let store = runtime.device_store();
let mut hash = FNV_OFFSET;
hash = fnv_u64(hash, runtime.step_index());
hash = fnv_usize(hash, store.agent_count());
hash = fnv_usize(hash, store.num_columns());
for id in store.ids() {
hash = fnv_u64(hash, *id);
}
for name in store.column_names() {
hash = fnv_bytes(hash, name.as_bytes());
hash = fnv_byte(hash, 0xff);
}
for column_index in 0..store.num_columns() {
for value in store.column(column_index) {
hash = fnv_u32(hash, value.to_bits());
}
}
ReplayFingerprint {
step_index: runtime.step_index(),
value: hash,
}
}
fn fingerprint_checkpoint(step_index: u64, checkpoint: &DeviceSoaCheckpoint) -> ReplayFingerprint {
let mut hash = FNV_OFFSET;
hash = fnv_u64(hash, step_index);
hash = fnv_usize(hash, checkpoint.agent_count());
hash = fnv_usize(hash, checkpoint.num_columns());
for id in &checkpoint.ids {
hash = fnv_u64(hash, *id);
}
for name in &checkpoint.column_names {
hash = fnv_bytes(hash, name.as_bytes());
hash = fnv_byte(hash, 0xff);
}
for column in &checkpoint.columns {
for value in column {
hash = fnv_u32(hash, value.to_bits());
}
}
ReplayFingerprint {
step_index,
value: hash,
}
}
const FNV_OFFSET: u64 = 0xcbf29ce484222325;
const FNV_PRIME: u64 = 0x100000001b3;
fn fnv_byte(hash: u64, byte: u8) -> u64 {
(hash ^ u64::from(byte)).wrapping_mul(FNV_PRIME)
}
fn fnv_bytes(mut hash: u64, bytes: &[u8]) -> u64 {
for byte in bytes {
hash = fnv_byte(hash, *byte);
}
hash
}
fn fnv_u32(hash: u64, value: u32) -> u64 {
fnv_bytes(hash, &value.to_le_bytes())
}
fn fnv_u64(hash: u64, value: u64) -> u64 {
fnv_bytes(hash, &value.to_le_bytes())
}
fn fnv_usize(hash: u64, value: usize) -> u64 {
fnv_u64(hash, value as u64)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn row_range_partitioning_covers_rows_once() {
let plan = PartitionPlan::row_ranges(10, 3).unwrap();
assert_eq!(plan.mode(), PartitionMode::StaticRowRanges);
assert_eq!(plan.assignments()[0].row_start, 0);
assert_eq!(plan.assignments()[0].row_end, 4);
assert_eq!(plan.assignments()[1].row_start, 4);
assert_eq!(plan.assignments()[1].row_end, 7);
assert_eq!(plan.assignments()[2].row_start, 7);
assert_eq!(plan.assignments()[2].row_end, 10);
plan.validate(10).unwrap();
}
}