use crate::variables::{VarId, core::Var};
use crate::variables::domain::{
sparse_set::SparseSetState,
};
#[derive(Clone, Debug)]
pub enum DomainSnapshot {
IntDomain(SparseSetState),
FloatDomain(FloatDomainState),
}
#[derive(Clone, Debug)]
pub struct FloatDomainState {
pub min: f64,
pub max: f64,
}
#[derive(Clone, Debug)]
pub struct TrailEntry {
pub var_id: VarId,
pub old_state: DomainSnapshot,
}
#[derive(Clone, Debug)]
pub struct Trail {
entries: Vec<TrailEntry>,
checkpoints: Vec<usize>,
}
impl Trail {
pub fn new() -> Self {
Trail {
entries: Vec::with_capacity(1024), checkpoints: Vec::with_capacity(64), }
}
pub fn with_capacity(entries_capacity: usize, checkpoints_capacity: usize) -> Self {
Trail {
entries: Vec::with_capacity(entries_capacity),
checkpoints: Vec::with_capacity(checkpoints_capacity),
}
}
#[inline]
pub fn push_checkpoint(&mut self) -> usize {
let checkpoint = self.entries.len();
self.checkpoints.push(checkpoint);
checkpoint
}
#[inline]
pub fn push_change(&mut self, var_id: VarId, old_state: DomainSnapshot) {
self.entries.push(TrailEntry { var_id, old_state });
}
#[inline]
pub fn level(&self) -> usize {
self.entries.len()
}
#[inline]
pub fn checkpoint_depth(&self) -> usize {
self.checkpoints.len()
}
pub fn pop_checkpoint(&mut self) -> Option<impl Iterator<Item = TrailEntry> + '_> {
let checkpoint = self.checkpoints.pop()?;
let drain_start = checkpoint;
let drain_end = self.entries.len();
Some(self.entries.drain(drain_start..drain_end).rev())
}
pub fn clear(&mut self) {
self.entries.clear();
self.checkpoints.clear();
}
pub fn memory_bytes(&self) -> usize {
let entries_capacity = self.entries.capacity() * 32;
let checkpoints_capacity = self.checkpoints.capacity() * 8;
entries_capacity + checkpoints_capacity + 16 }
}
impl Default for Trail {
fn default() -> Self {
Self::new()
}
}
pub trait VarTrail {
fn save_snapshot(&self) -> DomainSnapshot;
fn restore_snapshot(&mut self, snapshot: &DomainSnapshot);
}
impl VarTrail for Var {
fn save_snapshot(&self) -> DomainSnapshot {
match self {
Var::VarI(sparse_set) => {
DomainSnapshot::IntDomain(sparse_set.save_state())
}
Var::VarF(interval) => {
DomainSnapshot::FloatDomain(FloatDomainState {
min: interval.min,
max: interval.max,
})
}
}
}
fn restore_snapshot(&mut self, snapshot: &DomainSnapshot) {
match (self, snapshot) {
(Var::VarI(sparse_set), DomainSnapshot::IntDomain(state)) => {
sparse_set.restore_state(state);
}
(Var::VarF(interval), DomainSnapshot::FloatDomain(state)) => {
interval.min = state.min;
interval.max = state.max;
}
_ => {
debug_assert!(false, "Domain snapshot type mismatch");
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::variables::core::Vars;
use crate::variables::Val;
#[test]
fn test_trail_basic() {
let mut trail = Trail::new();
assert_eq!(trail.level(), 0);
assert_eq!(trail.checkpoint_depth(), 0);
trail.push_checkpoint();
assert_eq!(trail.checkpoint_depth(), 1);
let var_id = VarId::from_index(0);
let snapshot = DomainSnapshot::FloatDomain(FloatDomainState {
min: 0.0,
max: 10.0,
});
trail.push_change(var_id, snapshot.clone());
assert_eq!(trail.level(), 1);
let changes: Vec<_> = trail.pop_checkpoint().unwrap().collect();
assert_eq!(changes.len(), 1);
assert_eq!(trail.level(), 0);
assert_eq!(trail.checkpoint_depth(), 0);
}
#[test]
fn test_trail_nested_checkpoints() {
let mut trail = Trail::new();
trail.push_checkpoint();
let var0 = VarId::from_index(0);
trail.push_change(var0, DomainSnapshot::FloatDomain(FloatDomainState {
min: 0.0,
max: 10.0,
}));
trail.push_checkpoint();
let var1 = VarId::from_index(1);
trail.push_change(var1, DomainSnapshot::FloatDomain(FloatDomainState {
min: 0.0,
max: 5.0,
}));
assert_eq!(trail.level(), 2);
assert_eq!(trail.checkpoint_depth(), 2);
let changes: Vec<_> = trail.pop_checkpoint().unwrap().collect();
assert_eq!(changes.len(), 1);
assert_eq!(trail.level(), 1);
let changes: Vec<_> = trail.pop_checkpoint().unwrap().collect();
assert_eq!(changes.len(), 1);
assert_eq!(trail.level(), 0);
}
#[test]
fn test_var_snapshot_restore() {
let mut vars = Vars::new();
let var_id = vars.new_var_with_bounds(Val::ValI(0), Val::ValI(10));
let var = &mut vars[var_id];
let snapshot = var.save_snapshot();
if let Var::VarI(sparse_set) = var {
sparse_set.remove(5);
sparse_set.remove(6);
assert_eq!(sparse_set.size(), 9);
}
var.restore_snapshot(&snapshot);
if let Var::VarI(sparse_set) = var {
assert_eq!(sparse_set.size(), 11);
assert!(sparse_set.contains(5));
assert!(sparse_set.contains(6));
}
}
#[test]
fn test_trail_memory_estimate() {
let trail = Trail::with_capacity(100, 10);
let mem = trail.memory_bytes();
assert!(mem >= 3000 && mem <= 4000);
}
}