#[derive(Debug, Clone)]
pub struct StageStates {
data: Vec<f64>,
count: usize,
state_dimension: usize,
}
impl StageStates {
#[must_use]
pub fn new(state_dimension: usize, capacity_states: usize) -> Self {
Self {
data: Vec::with_capacity(capacity_states * state_dimension),
count: 0,
state_dimension,
}
}
#[must_use]
pub fn count(&self) -> usize {
self.count
}
#[must_use]
pub fn state_dimension(&self) -> usize {
self.state_dimension
}
pub fn append(&mut self, gathered: &[f64], total_fwd: usize) {
debug_assert_eq!(gathered.len(), total_fwd * self.state_dimension);
self.data.extend_from_slice(gathered);
self.count += total_fwd;
}
#[must_use]
pub fn states(&self) -> &[f64] {
&self.data[..self.count * self.state_dimension]
}
pub fn trim_to_window(&mut self, window_states: usize) {
if self.count <= window_states {
return;
}
let to_remove = self.count - window_states;
let drain_len = to_remove * self.state_dimension;
debug_assert!(drain_len <= self.data.len());
self.data.drain(..drain_len);
self.count = window_states;
debug_assert_eq!(self.data.len(), self.count * self.state_dimension);
}
}
#[derive(Debug, Clone)]
pub struct VisitedStatesArchive {
stages: Vec<StageStates>,
total_forward_passes: usize,
}
impl VisitedStatesArchive {
const MAX_INITIAL_CAPACITY: usize = 4096;
#[must_use]
pub fn new(
num_stages: usize,
state_dimension: usize,
max_iterations: u64,
total_forward_passes: usize,
) -> Self {
let total_states = usize::try_from(max_iterations)
.unwrap_or(usize::MAX)
.saturating_mul(total_forward_passes);
let capacity_per_stage = total_states.min(Self::MAX_INITIAL_CAPACITY);
let stages = (0..num_stages)
.map(|_| StageStates::new(state_dimension, capacity_per_stage))
.collect();
Self {
stages,
total_forward_passes,
}
}
#[must_use]
pub fn num_stages(&self) -> usize {
self.stages.len()
}
#[must_use]
pub fn stage(&self, stage: usize) -> &StageStates {
&self.stages[stage]
}
pub fn stage_mut(&mut self, stage: usize) -> &mut StageStates {
&mut self.stages[stage]
}
pub fn archive_gathered_states(&mut self, stage: usize, gathered: &[f64], total_fwd: usize) {
self.stages[stage].append(gathered, total_fwd);
}
#[must_use]
pub fn states_for_stage(&self, stage: usize) -> &[f64] {
self.stages[stage].states()
}
#[must_use]
pub fn count(&self, stage: usize) -> usize {
self.stages[stage].count()
}
pub fn trim_to_window(&mut self, window_iterations: u64) {
let window_states = usize::try_from(window_iterations)
.unwrap_or(usize::MAX)
.saturating_mul(self.total_forward_passes);
for stage in &mut self.stages {
stage.trim_to_window(window_states);
}
}
}
#[cfg(test)]
mod tests {
use super::{StageStates, VisitedStatesArchive};
#[allow(clippy::cast_precision_loss)]
fn make_gathered(state_dim: usize, total_fwd: usize, base: f64) -> Vec<f64> {
(0..total_fwd * state_dim)
.map(|i| base + i as f64)
.collect()
}
#[test]
fn stage_states_new_preallocates() {
let s = StageStates::new(4, 100);
assert!(s.states().is_empty());
assert_eq!(s.count(), 0);
assert_eq!(s.state_dimension(), 4);
assert!(s.data.capacity() >= 400);
}
#[test]
fn stage_states_append_single_batch() {
let mut s = StageStates::new(2, 10);
let gathered = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
s.append(&gathered, 3);
assert_eq!(s.count(), 3);
assert_eq!(s.states(), &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn stage_states_append_multiple_batches() {
let mut s = StageStates::new(2, 10);
let g1 = make_gathered(2, 3, 0.0);
s.append(&g1, 3);
let g2 = make_gathered(2, 2, 100.0);
s.append(&g2, 2);
assert_eq!(s.count(), 5);
assert_eq!(
s.states(),
&[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 100.0, 101.0, 102.0, 103.0]
);
}
#[test]
fn stage_states_empty_states() {
let s = StageStates::new(3, 50);
assert_eq!(s.states(), &[] as &[f64]);
assert_eq!(s.count(), 0);
}
#[test]
fn archive_new_creates_correct_stages() {
let a = VisitedStatesArchive::new(5, 4, 10, 20);
assert_eq!(a.num_stages(), 5);
for t in 0..5 {
assert_eq!(a.count(t), 0);
assert!(a.states_for_stage(t).is_empty());
}
}
#[test]
fn archive_gathered_states_delegates() {
let mut a = VisitedStatesArchive::new(4, 3, 10, 10);
let gathered = make_gathered(3, 3, 1.0);
a.archive_gathered_states(2, &gathered, 3);
assert_eq!(a.count(2), 3);
assert_eq!(a.count(0), 0);
assert_eq!(a.count(1), 0);
assert_eq!(a.count(3), 0);
}
#[test]
fn archive_accumulates_across_iterations() {
let mut a = VisitedStatesArchive::new(3, 2, 10, 5);
let g1 = make_gathered(2, 5, 0.0);
a.archive_gathered_states(1, &g1, 5);
let g2 = make_gathered(2, 5, 100.0);
a.archive_gathered_states(1, &g2, 5);
assert_eq!(a.count(1), 10);
}
#[test]
fn archive_states_for_stage_returns_flat_slice() {
let mut a = VisitedStatesArchive::new(3, 2, 10, 10);
let gathered = vec![10.0, 20.0, 30.0, 40.0];
a.archive_gathered_states(1, &gathered, 2);
assert_eq!(a.states_for_stage(1), &[10.0, 20.0, 30.0, 40.0]);
assert!(a.states_for_stage(0).is_empty());
assert!(a.states_for_stage(2).is_empty());
}
#[test]
fn stage_states_trim_to_window_drops_oldest_states() {
let state_dim = 2;
let total = 100;
let mut s = StageStates::new(state_dim, total);
let gathered = make_gathered(state_dim, total, 0.0);
s.append(&gathered, total);
assert_eq!(s.count(), 100);
assert_eq!(s.states().len(), 200);
s.trim_to_window(30);
assert_eq!(s.count(), 30);
assert_eq!(s.states().len(), 60);
let expected: Vec<f64> = (140..200).map(f64::from).collect();
assert_eq!(s.states(), expected.as_slice());
}
#[test]
fn stage_states_trim_to_window_noop_when_count_below_window() {
let state_dim = 2;
let total = 20;
let mut s = StageStates::new(state_dim, total);
let gathered = make_gathered(state_dim, total, 0.0);
s.append(&gathered, total);
assert_eq!(s.count(), 20);
let before: Vec<f64> = s.states().to_vec();
s.trim_to_window(50);
assert_eq!(s.count(), 20);
assert_eq!(s.states(), before.as_slice());
}
#[test]
fn stage_states_trim_to_window_count_equals_window_is_noop() {
let state_dim = 3;
let total = 7;
let mut s = StageStates::new(state_dim, total);
let gathered = make_gathered(state_dim, total, 10.0);
s.append(&gathered, total);
let before: Vec<f64> = s.states().to_vec();
s.trim_to_window(7);
assert_eq!(s.count(), 7);
assert_eq!(s.states(), before.as_slice());
}
#[test]
fn stage_states_trim_to_window_to_zero_clears_buffer() {
let state_dim = 2;
let total = 5;
let mut s = StageStates::new(state_dim, total);
let gathered = make_gathered(state_dim, total, 0.0);
s.append(&gathered, total);
s.trim_to_window(0);
assert_eq!(s.count(), 0);
assert!(s.states().is_empty());
}
#[test]
fn stage_states_trim_then_append_preserves_data() {
let state_dim = 2;
let mut s = StageStates::new(state_dim, 100);
s.append(&make_gathered(state_dim, 10, 0.0), 10);
s.trim_to_window(4);
assert_eq!(s.count(), 4);
let retained: Vec<f64> = (12..20).map(f64::from).collect();
assert_eq!(s.states(), retained.as_slice());
s.append(&make_gathered(state_dim, 3, 100.0), 3);
assert_eq!(s.count(), 7);
let mut expected = retained;
expected.extend((0..6).map(|i| 100.0 + f64::from(i)));
assert_eq!(s.states(), expected.as_slice());
}
#[test]
fn archive_trim_to_window_trims_each_stage() {
let num_stages = 5;
let state_dim = 2;
let total_fwd = 10;
let mut a = VisitedStatesArchive::new(num_stages, state_dim, 10, total_fwd);
for t in 0..num_stages {
for it in 0..10_i32 {
let base = f64::from(i32::try_from(t).unwrap()) * 1000.0 + f64::from(it) * 100.0;
let gathered = make_gathered(state_dim, total_fwd, base);
a.archive_gathered_states(t, &gathered, total_fwd);
}
assert_eq!(a.count(t), 100);
}
a.trim_to_window(3);
for t in 0..num_stages {
assert!(a.count(t) <= 30, "stage {t} has count {}", a.count(t));
assert_eq!(a.count(t), 30);
assert_eq!(a.states_for_stage(t).len(), 30 * state_dim);
}
}
#[test]
fn archive_trim_to_window_noop_when_within_window() {
let total_fwd = 10;
let mut a = VisitedStatesArchive::new(2, 2, 10, total_fwd);
for it in 0..2_i32 {
let gathered = make_gathered(2, total_fwd, f64::from(it) * 100.0);
a.archive_gathered_states(0, &gathered, total_fwd);
a.archive_gathered_states(1, &gathered, total_fwd);
}
let before_0: Vec<f64> = a.states_for_stage(0).to_vec();
let before_1: Vec<f64> = a.states_for_stage(1).to_vec();
a.trim_to_window(5);
assert_eq!(a.count(0), 20);
assert_eq!(a.count(1), 20);
assert_eq!(a.states_for_stage(0), before_0.as_slice());
assert_eq!(a.states_for_stage(1), before_1.as_slice());
}
#[test]
fn archive_trim_to_window_retains_most_recent() {
let state_dim = 2;
let total_fwd = 5;
let mut a = VisitedStatesArchive::new(1, state_dim, 10, total_fwd);
for it in 0..4_i32 {
let base = f64::from(it) * 100.0;
a.archive_gathered_states(0, &make_gathered(state_dim, total_fwd, base), total_fwd);
}
assert_eq!(a.count(0), 20);
a.trim_to_window(2);
assert_eq!(a.count(0), 10);
let mut expected: Vec<f64> = (200..210).map(f64::from).collect();
expected.extend((300..310).map(f64::from));
assert_eq!(a.states_for_stage(0), expected.as_slice());
}
#[test]
fn archive_trim_to_window_zero_clears_all_stages() {
let total_fwd = 4;
let mut a = VisitedStatesArchive::new(3, 2, 10, total_fwd);
for t in 0..3 {
a.archive_gathered_states(t, &make_gathered(2, total_fwd, 0.0), total_fwd);
}
a.trim_to_window(0);
for t in 0..3 {
assert_eq!(a.count(t), 0);
assert!(a.states_for_stage(t).is_empty());
}
}
}