use super::adapter::CoherenceCommand;
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct NodeState {
pub node_id: u64,
pub state: Vec<f32>,
pub last_update: u64,
}
#[derive(Debug, Clone)]
pub struct EdgeEnergy {
pub source: u64,
pub target: u64,
pub energy: f32,
pub history: Vec<f32>,
}
impl EdgeEnergy {
pub fn new(source: u64, target: u64, energy: f32) -> Self {
Self {
source,
target,
energy,
history: vec![energy],
}
}
pub fn update(&mut self, energy: f32) {
self.energy = energy;
self.history.push(energy);
if self.history.len() > 10 {
self.history.remove(0);
}
}
pub fn trend(&self) -> f32 {
if self.history.len() < 2 {
return 0.0;
}
let n = self.history.len();
let first_half: f32 = self.history[..n / 2].iter().sum::<f32>() / (n / 2) as f32;
let second_half: f32 = self.history[n / 2..].iter().sum::<f32>() / (n - n / 2) as f32;
second_half - first_half
}
pub fn is_stable(&self, threshold: f32) -> bool {
if self.history.len() < 2 {
return true;
}
let mean: f32 = self.history.iter().sum::<f32>() / self.history.len() as f32;
let variance: f32 = self.history.iter().map(|e| (e - mean).powi(2)).sum::<f32>()
/ self.history.len() as f32;
variance.sqrt() < threshold
}
}
#[derive(Debug, Clone)]
pub struct IncoherentRegion {
pub region_id: u64,
pub nodes: HashSet<u64>,
pub marked_at: u64,
pub active: bool,
}
#[derive(Debug, Clone)]
pub struct Checkpoint {
pub index: u64,
pub total_energy: f32,
pub timestamp: u64,
pub num_edges: usize,
pub num_incoherent: usize,
}
#[derive(Debug)]
pub struct CoherenceStateMachine {
node_states: HashMap<u64, NodeState>,
edge_energies: HashMap<(u64, u64), EdgeEnergy>,
incoherent_regions: HashMap<u64, IncoherentRegion>,
checkpoints: Vec<Checkpoint>,
applied_index: u64,
dimension: usize,
}
impl CoherenceStateMachine {
pub fn new(dimension: usize) -> Self {
Self {
node_states: HashMap::new(),
edge_energies: HashMap::new(),
incoherent_regions: HashMap::new(),
checkpoints: Vec::new(),
applied_index: 0,
dimension,
}
}
pub fn apply(&mut self, command: &CoherenceCommand, index: u64) -> ApplyResult {
self.applied_index = index;
match command {
CoherenceCommand::UpdateEnergy { edge_id, energy } => {
self.apply_update_energy(*edge_id, *energy)
}
CoherenceCommand::SetNodeState { node_id, state } => {
self.apply_set_node_state(*node_id, state.clone())
}
CoherenceCommand::Checkpoint {
total_energy,
timestamp,
} => self.apply_checkpoint(*total_energy, *timestamp),
CoherenceCommand::MarkIncoherent { region_id, nodes } => {
self.apply_mark_incoherent(*region_id, nodes.clone())
}
CoherenceCommand::ClearIncoherent { region_id } => {
self.apply_clear_incoherent(*region_id)
}
}
}
fn apply_update_energy(&mut self, edge_id: (u64, u64), energy: f32) -> ApplyResult {
let edge = self
.edge_energies
.entry(edge_id)
.or_insert_with(|| EdgeEnergy::new(edge_id.0, edge_id.1, 0.0));
let old_energy = edge.energy;
edge.update(energy);
ApplyResult::EnergyUpdated {
edge_id,
old_energy,
new_energy: energy,
}
}
fn apply_set_node_state(&mut self, node_id: u64, state: Vec<f32>) -> ApplyResult {
let truncated_state: Vec<f32> = state.into_iter().take(self.dimension).collect();
let node = self.node_states.entry(node_id).or_insert_with(|| NodeState {
node_id,
state: vec![0.0; self.dimension],
last_update: 0,
});
node.state = truncated_state;
node.last_update = self.applied_index;
ApplyResult::NodeStateSet { node_id }
}
fn apply_checkpoint(&mut self, total_energy: f32, timestamp: u64) -> ApplyResult {
let checkpoint = Checkpoint {
index: self.applied_index,
total_energy,
timestamp,
num_edges: self.edge_energies.len(),
num_incoherent: self.incoherent_regions.values().filter(|r| r.active).count(),
};
self.checkpoints.push(checkpoint.clone());
if self.checkpoints.len() > 100 {
self.checkpoints.remove(0);
}
ApplyResult::CheckpointCreated { checkpoint }
}
fn apply_mark_incoherent(&mut self, region_id: u64, nodes: Vec<u64>) -> ApplyResult {
let region = self
.incoherent_regions
.entry(region_id)
.or_insert_with(|| IncoherentRegion {
region_id,
nodes: HashSet::new(),
marked_at: self.applied_index,
active: false,
});
region.nodes = nodes.into_iter().collect();
region.marked_at = self.applied_index;
region.active = true;
ApplyResult::RegionMarkedIncoherent {
region_id,
node_count: region.nodes.len(),
}
}
fn apply_clear_incoherent(&mut self, region_id: u64) -> ApplyResult {
if let Some(region) = self.incoherent_regions.get_mut(®ion_id) {
region.active = false;
ApplyResult::RegionCleared { region_id }
} else {
ApplyResult::RegionNotFound { region_id }
}
}
pub fn get_node_state(&self, node_id: u64) -> Option<&NodeState> {
self.node_states.get(&node_id)
}
pub fn get_edge_energy(&self, edge_id: (u64, u64)) -> Option<f32> {
self.edge_energies.get(&edge_id).map(|e| e.energy)
}
pub fn total_energy(&self) -> f32 {
self.edge_energies.values().map(|e| e.energy).sum()
}
pub fn num_incoherent_regions(&self) -> usize {
self.incoherent_regions.values().filter(|r| r.active).count()
}
pub fn incoherent_nodes(&self) -> HashSet<u64> {
self.incoherent_regions
.values()
.filter(|r| r.active)
.flat_map(|r| r.nodes.iter().copied())
.collect()
}
pub fn is_node_incoherent(&self, node_id: u64) -> bool {
self.incoherent_regions
.values()
.any(|r| r.active && r.nodes.contains(&node_id))
}
pub fn latest_checkpoint(&self) -> Option<&Checkpoint> {
self.checkpoints.last()
}
pub fn summary(&self) -> StateSummary {
StateSummary {
applied_index: self.applied_index,
num_nodes: self.node_states.len(),
num_edges: self.edge_energies.len(),
total_energy: self.total_energy(),
num_incoherent_regions: self.num_incoherent_regions(),
num_checkpoints: self.checkpoints.len(),
}
}
pub fn snapshot(&self) -> StateSnapshot {
StateSnapshot {
applied_index: self.applied_index,
node_states: self.node_states.clone(),
edge_energies: self.edge_energies.clone(),
incoherent_regions: self.incoherent_regions.clone(),
}
}
pub fn restore(&mut self, snapshot: StateSnapshot) {
self.applied_index = snapshot.applied_index;
self.node_states = snapshot.node_states;
self.edge_energies = snapshot.edge_energies;
self.incoherent_regions = snapshot.incoherent_regions;
}
}
#[derive(Debug, Clone)]
pub enum ApplyResult {
EnergyUpdated {
edge_id: (u64, u64),
old_energy: f32,
new_energy: f32,
},
NodeStateSet { node_id: u64 },
CheckpointCreated { checkpoint: Checkpoint },
RegionMarkedIncoherent { region_id: u64, node_count: usize },
RegionCleared { region_id: u64 },
RegionNotFound { region_id: u64 },
}
#[derive(Debug, Clone)]
pub struct StateSummary {
pub applied_index: u64,
pub num_nodes: usize,
pub num_edges: usize,
pub total_energy: f32,
pub num_incoherent_regions: usize,
pub num_checkpoints: usize,
}
#[derive(Debug, Clone)]
pub struct StateSnapshot {
pub applied_index: u64,
pub node_states: HashMap<u64, NodeState>,
pub edge_energies: HashMap<(u64, u64), EdgeEnergy>,
pub incoherent_regions: HashMap<u64, IncoherentRegion>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_state_machine_creation() {
let sm = CoherenceStateMachine::new(64);
assert_eq!(sm.total_energy(), 0.0);
assert_eq!(sm.num_incoherent_regions(), 0);
}
#[test]
fn test_update_energy() {
let mut sm = CoherenceStateMachine::new(64);
let cmd = CoherenceCommand::UpdateEnergy {
edge_id: (1, 2),
energy: 0.5,
};
sm.apply(&cmd, 1);
assert!((sm.get_edge_energy((1, 2)).unwrap() - 0.5).abs() < 1e-6);
assert!((sm.total_energy() - 0.5).abs() < 1e-6);
}
#[test]
fn test_set_node_state() {
let mut sm = CoherenceStateMachine::new(4);
let cmd = CoherenceCommand::SetNodeState {
node_id: 1,
state: vec![1.0, 2.0, 3.0, 4.0],
};
sm.apply(&cmd, 1);
let state = sm.get_node_state(1).unwrap();
assert_eq!(state.state, vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_mark_incoherent() {
let mut sm = CoherenceStateMachine::new(64);
let cmd = CoherenceCommand::MarkIncoherent {
region_id: 1,
nodes: vec![10, 20, 30],
};
sm.apply(&cmd, 1);
assert_eq!(sm.num_incoherent_regions(), 1);
assert!(sm.is_node_incoherent(10));
assert!(sm.is_node_incoherent(20));
assert!(!sm.is_node_incoherent(40));
}
#[test]
fn test_clear_incoherent() {
let mut sm = CoherenceStateMachine::new(64);
sm.apply(
&CoherenceCommand::MarkIncoherent {
region_id: 1,
nodes: vec![10],
},
1,
);
assert_eq!(sm.num_incoherent_regions(), 1);
sm.apply(&CoherenceCommand::ClearIncoherent { region_id: 1 }, 2);
assert_eq!(sm.num_incoherent_regions(), 0);
}
#[test]
fn test_checkpoint() {
let mut sm = CoherenceStateMachine::new(64);
sm.apply(
&CoherenceCommand::Checkpoint {
total_energy: 1.5,
timestamp: 1000,
},
1,
);
let cp = sm.latest_checkpoint().unwrap();
assert!((cp.total_energy - 1.5).abs() < 1e-6);
assert_eq!(cp.timestamp, 1000);
}
#[test]
fn test_edge_energy_trend() {
let mut edge = EdgeEnergy::new(1, 2, 1.0);
edge.update(1.1);
edge.update(1.2);
edge.update(1.3);
edge.update(1.4);
let trend = edge.trend();
assert!(trend > 0.0, "Trend should be positive for increasing energy");
}
#[test]
fn test_snapshot_restore() {
let mut sm = CoherenceStateMachine::new(64);
sm.apply(
&CoherenceCommand::UpdateEnergy {
edge_id: (1, 2),
energy: 0.5,
},
1,
);
sm.apply(
&CoherenceCommand::SetNodeState {
node_id: 1,
state: vec![1.0; 64],
},
2,
);
let snapshot = sm.snapshot();
let mut sm2 = CoherenceStateMachine::new(64);
sm2.restore(snapshot);
assert!((sm2.get_edge_energy((1, 2)).unwrap() - 0.5).abs() < 1e-6);
assert!(sm2.get_node_state(1).is_some());
}
}