use super::{
network::{LayerConfig, NetworkConfig, SpikingNetwork},
neuron::{LIFNeuron, NeuronConfig},
SimTime, Spike, Vector,
};
use crate::graph::{DynamicGraph, VertexId};
use std::f64::consts::PI;
#[derive(Debug, Clone)]
pub struct CPGConfig {
pub num_phases: usize,
pub frequency: f64,
pub coupling: f64,
pub stability_threshold: f64,
pub dt: f64,
pub transition_threshold: f64,
}
impl Default for CPGConfig {
fn default() -> Self {
Self {
num_phases: 4,
frequency: 10.0, coupling: 0.3,
stability_threshold: 0.1,
dt: 1.0,
transition_threshold: 0.8,
}
}
}
#[derive(Debug, Clone)]
pub struct OscillatorNeuron {
pub id: usize,
pub phase: f64,
pub omega: f64,
pub amplitude: f64,
activity: f64,
}
impl OscillatorNeuron {
pub fn new(id: usize, frequency_hz: f64, phase_offset: f64) -> Self {
let omega = 2.0 * PI * frequency_hz / 1000.0;
Self {
id,
phase: phase_offset,
omega,
amplitude: 1.0,
activity: (phase_offset).cos(),
}
}
pub fn integrate(&mut self, dt: f64, coupling_input: f64) {
let d_phase = self.omega + coupling_input;
self.phase += d_phase * dt;
while self.phase >= 2.0 * PI {
self.phase -= 2.0 * PI;
}
while self.phase < 0.0 {
self.phase += 2.0 * PI;
}
self.activity = self.amplitude * self.phase.cos();
}
pub fn activity(&self) -> f64 {
self.activity
}
pub fn reset(&mut self, phase: f64) {
self.phase = phase;
self.activity = (phase).cos();
}
}
#[derive(Clone)]
pub struct PhaseTopology {
pub phase_id: usize,
pub graph: DynamicGraph,
pub expected_mincut: f64,
entry_points: Vec<VertexId>,
}
impl PhaseTopology {
pub fn new(phase_id: usize) -> Self {
Self {
phase_id,
graph: DynamicGraph::new(),
expected_mincut: 0.0,
entry_points: Vec::new(),
}
}
pub fn from_graph(phase_id: usize, base: &DynamicGraph, modulation: f64) -> Self {
let graph = base.clone();
let phase_factor = (phase_id as f64 * PI / 2.0).sin().abs() + 0.5;
for edge in graph.edges() {
let new_weight = edge.weight * phase_factor * (1.0 + modulation);
let _ = graph.update_edge_weight(edge.source, edge.target, new_weight);
}
let expected_mincut = graph.edges().iter().map(|e| e.weight).sum::<f64>()
/ graph.num_vertices().max(1) as f64;
Self {
phase_id,
graph,
expected_mincut,
entry_points: Vec::new(),
}
}
pub fn entry_points(&self) -> &[VertexId] {
&self.entry_points
}
pub fn update_entry_points(&mut self) {
let mut degrees: Vec<_> = self
.graph
.vertices()
.iter()
.map(|&v| (v, self.graph.degree(v)))
.collect();
degrees.sort_by_key(|(_, d)| std::cmp::Reverse(*d));
self.entry_points = degrees.iter().take(5).map(|(v, _)| *v).collect();
}
pub fn expected_mincut(&self) -> f64 {
self.expected_mincut
}
}
pub struct TimeCrystalCPG {
oscillators: Vec<OscillatorNeuron>,
coupling: Vec<Vec<f64>>,
phase_topologies: Vec<PhaseTopology>,
current_phase: usize,
config: CPGConfig,
time: SimTime,
phase_history: Vec<usize>,
active_graph: DynamicGraph,
}
impl TimeCrystalCPG {
pub fn new(base_graph: DynamicGraph, config: CPGConfig) -> Self {
let n = config.num_phases;
let oscillators: Vec<_> = (0..n)
.map(|i| {
let phase_offset = 2.0 * PI * i as f64 / n as f64;
OscillatorNeuron::new(i, config.frequency, phase_offset)
})
.collect();
let mut coupling = vec![vec![0.0; n]; n];
for i in 0..n {
let prev = (i + n - 1) % n;
let next = (i + 1) % n;
coupling[i][prev] = config.coupling;
coupling[i][next] = config.coupling;
}
let phase_topologies: Vec<_> = (0..n)
.map(|i| {
let modulation = 0.1 * i as f64;
PhaseTopology::from_graph(i, &base_graph, modulation)
})
.collect();
Self {
oscillators,
coupling,
phase_topologies,
current_phase: 0,
config,
time: 0.0,
phase_history: Vec::new(),
active_graph: base_graph,
}
}
pub fn tick(&mut self) -> Option<usize> {
let dt = self.config.dt;
self.time += dt;
let n = self.oscillators.len();
let mut coupling_inputs = vec![0.0; n];
for i in 0..n {
for j in 0..n {
if i != j && self.coupling[i][j] != 0.0 {
let phase_diff = self.oscillators[j].phase - self.oscillators[i].phase;
coupling_inputs[i] += self.coupling[i][j] * phase_diff.sin();
}
}
}
for (i, osc) in self.oscillators.iter_mut().enumerate() {
osc.integrate(dt, coupling_inputs[i]);
}
let winner = self
.oscillators
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
a.activity()
.partial_cmp(&b.activity())
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(i, _)| i)
.unwrap_or(0);
let transition = if winner != self.current_phase {
let old_phase = self.current_phase;
self.transition_topology(old_phase, winner);
self.current_phase = winner;
self.phase_history.push(winner);
if self.phase_history.len() > 1000 {
self.phase_history.remove(0);
}
Some(winner)
} else {
None
};
if let Some(topology) = self.phase_topologies.get(self.current_phase) {
let actual_mincut = self.estimate_mincut();
if (topology.expected_mincut - actual_mincut).abs()
> self.config.stability_threshold * topology.expected_mincut
{
self.repair_crystal();
}
}
transition
}
fn transition_topology(&mut self, from: usize, to: usize) {
if let Some(to_topo) = self.phase_topologies.get(to) {
self.active_graph = to_topo.graph.clone();
}
}
fn estimate_mincut(&self) -> f64 {
let n = self.active_graph.num_vertices();
if n == 0 {
return 0.0;
}
self.active_graph
.vertices()
.iter()
.map(|&v| self.active_graph.degree(v) as f64)
.fold(f64::INFINITY, f64::min)
}
fn repair_crystal(&mut self) {
let n = self.oscillators.len();
for (i, osc) in self.oscillators.iter_mut().enumerate() {
let target_phase = 2.0 * PI * i as f64 / n as f64;
osc.reset(target_phase);
}
if let Some(topology) = self.phase_topologies.get(self.current_phase) {
self.active_graph = topology.graph.clone();
}
}
pub fn current_phase(&self) -> usize {
self.current_phase
}
pub fn phases(&self) -> Vec<f64> {
self.oscillators.iter().map(|o| o.phase).collect()
}
pub fn activities(&self) -> Vec<f64> {
self.oscillators.iter().map(|o| o.activity()).collect()
}
pub fn active_graph(&self) -> &DynamicGraph {
&self.active_graph
}
pub fn phase_aware_entry_points(&self) -> Vec<VertexId> {
self.phase_topologies
.get(self.current_phase)
.map(|t| t.entry_points().to_vec())
.unwrap_or_default()
}
pub fn is_stable(&self) -> bool {
if self.phase_history.len() < self.config.num_phases * 2 {
return false;
}
let period = self.config.num_phases;
let recent: Vec<_> = self.phase_history.iter().rev().take(period * 2).collect();
for i in 0..period {
if recent.get(i) != recent.get(i + period) {
return false;
}
}
true
}
pub fn periodicity(&self) -> usize {
if self.phase_history.len() < 10 {
return 0;
}
for period in 1..=self.config.num_phases {
let mut is_periodic = true;
for i in 0..(self.phase_history.len() - period) {
if self.phase_history[i] != self.phase_history[i + period] {
is_periodic = false;
break;
}
}
if is_periodic {
return period;
}
}
0
}
pub fn run(&mut self, duration: f64) -> Vec<usize> {
let steps = (duration / self.config.dt) as usize;
let mut transitions = Vec::new();
for _ in 0..steps {
if let Some(new_phase) = self.tick() {
transitions.push(new_phase);
}
}
transitions
}
pub fn reset(&mut self) {
let n = self.oscillators.len();
for (i, osc) in self.oscillators.iter_mut().enumerate() {
let phase_offset = 2.0 * PI * i as f64 / n as f64;
osc.reset(phase_offset);
}
self.current_phase = 0;
self.time = 0.0;
self.phase_history.clear();
if let Some(topology) = self.phase_topologies.get(0) {
self.active_graph = topology.graph.clone();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_oscillator_neuron() {
let mut osc = OscillatorNeuron::new(0, 10.0, 0.0);
let initial_activity = osc.activity();
for _ in 0..100 {
osc.integrate(1.0, 0.0);
}
assert!(osc.activity() != initial_activity || osc.phase != 0.0);
}
#[test]
fn test_phase_topology() {
let graph = DynamicGraph::new();
graph.insert_edge(0, 1, 1.0).unwrap();
graph.insert_edge(1, 2, 1.0).unwrap();
let topology = PhaseTopology::from_graph(0, &graph, 0.1);
assert_eq!(topology.phase_id, 0);
assert!(topology.expected_mincut >= 0.0);
}
#[test]
fn test_time_crystal_cpg() {
let graph = DynamicGraph::new();
for i in 0..10 {
graph.insert_edge(i, (i + 1) % 10, 1.0).unwrap();
}
let config = CPGConfig::default();
let mut cpg = TimeCrystalCPG::new(graph, config);
let transitions = cpg.run(1000.0);
assert!(cpg.time > 0.0);
}
#[test]
fn test_phase_aware_entry() {
let graph = DynamicGraph::new();
for i in 0..5 {
for j in (i + 1)..5 {
graph.insert_edge(i, j, 1.0).unwrap();
}
}
let mut config = CPGConfig::default();
config.num_phases = 2;
let cpg = TimeCrystalCPG::new(graph, config);
let _entry_points = cpg.phase_aware_entry_points();
}
}