#[cfg(feature = "biological")]
use ruvector_verified::{ProofEnvironment, prove_dim_eq, proof_store::create_attestation, ProofAttestation};
#[cfg(feature = "biological")]
use crate::config::BiologicalConfig;
#[cfg(feature = "biological")]
use crate::error::{GraphTransformerError, Result};
#[cfg(feature = "biological")]
#[derive(Debug, Clone)]
pub struct EffectiveOperator {
pub num_iterations: usize,
pub safety_margin: f32,
pub layerwise: bool,
}
#[cfg(feature = "biological")]
impl Default for EffectiveOperator {
fn default() -> Self {
Self {
num_iterations: 20,
safety_margin: 3.0,
layerwise: true,
}
}
}
#[cfg(feature = "biological")]
impl EffectiveOperator {
pub fn estimate_spectral_radius(&self, weights: &[Vec<f32>]) -> (f32, f32) {
let n = weights.len();
if n == 0 {
return (0.0, 0.0);
}
let mut v: Vec<f32> = (0..n).map(|i| ((i as f32 + 1.0).sin()).abs() + 0.1).collect();
let mut eigenvalue_estimates = Vec::with_capacity(self.num_iterations);
for _ in 0..self.num_iterations {
let mut w = vec![0.0f32; n];
for i in 0..n {
for j in 0..weights[i].len().min(n) {
w[i] += weights[i][j] * v[j];
}
}
let norm: f32 = w.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm < 1e-12 {
break;
}
let dot: f32 = w.iter().zip(v.iter()).map(|(a, b)| a * b).sum();
let v_norm_sq: f32 = v.iter().map(|x| x * x).sum();
if v_norm_sq > 1e-12 {
eigenvalue_estimates.push((dot / v_norm_sq).abs());
}
for x in &mut w {
*x /= norm;
}
v = w;
}
if eigenvalue_estimates.is_empty() {
return (0.0, 0.0);
}
let estimated = *eigenvalue_estimates.last().unwrap();
let mean: f32 = eigenvalue_estimates.iter().sum::<f32>()
/ eigenvalue_estimates.len() as f32;
let variance: f32 = eigenvalue_estimates
.iter()
.map(|x| (x - mean).powi(2))
.sum::<f32>()
/ eigenvalue_estimates.len() as f32;
let std_dev = variance.sqrt();
let conservative_bound = estimated + self.safety_margin * std_dev;
(estimated, conservative_bound)
}
}
#[cfg(feature = "biological")]
#[derive(Debug, Clone)]
pub enum InhibitionStrategy {
None,
WinnerTakeAll {
k: usize,
},
Lateral {
strength: f32,
},
BalancedEI {
ei_ratio: f32,
dale_law: bool,
},
}
#[cfg(feature = "biological")]
impl Default for InhibitionStrategy {
fn default() -> Self {
InhibitionStrategy::None
}
}
#[cfg(feature = "biological")]
impl InhibitionStrategy {
pub fn apply(&self, potentials: &mut [f32], spikes: &mut [bool], threshold: f32) {
match self {
InhibitionStrategy::None => {}
InhibitionStrategy::WinnerTakeAll { k } => {
let mut spiking_indices: Vec<usize> = spikes
.iter()
.enumerate()
.filter(|(_, &s)| s)
.map(|(i, _)| i)
.collect();
if spiking_indices.len() > *k {
spiking_indices.sort_by(|&a, &b| {
b.cmp(&a) });
for &idx in &spiking_indices[*k..] {
spikes[idx] = false;
potentials[idx] = threshold * 0.5; }
}
}
InhibitionStrategy::Lateral { strength } => {
let any_spike = spikes.iter().any(|&s| s);
if any_spike {
for i in 0..potentials.len() {
if !spikes[i] {
potentials[i] *= 1.0 - strength;
}
}
}
}
InhibitionStrategy::BalancedEI { ei_ratio, dale_law } => {
let spike_count = spikes.iter().filter(|&&s| s).count();
let total = spikes.len();
if total == 0 {
return;
}
let firing_rate = spike_count as f32 / total as f32;
let target_rate = ei_ratio / (1.0 + ei_ratio);
if firing_rate > target_rate {
let suppression = (firing_rate - target_rate) / firing_rate.max(1e-6);
if *dale_law {
for i in 0..total {
if i % 2 == 0 && spikes[i] {
if suppression > 0.5 {
spikes[i] = false;
potentials[i] = threshold * 0.3;
}
}
}
} else {
let suppress_count =
((spike_count as f32 * suppression) as usize).min(spike_count);
let mut spiking: Vec<usize> = spikes
.iter()
.enumerate()
.filter(|(_, &s)| s)
.map(|(i, _)| i)
.collect();
spiking.reverse();
for &idx in spiking.iter().take(suppress_count) {
spikes[idx] = false;
potentials[idx] = threshold * 0.4;
}
}
}
}
}
}
}
#[cfg(feature = "biological")]
#[derive(Debug, Clone)]
pub struct HebbianNormBound {
pub threshold: f32,
pub diagonal_fisher: bool,
pub layerwise: bool,
}
#[cfg(feature = "biological")]
impl Default for HebbianNormBound {
fn default() -> Self {
Self {
threshold: 5.0,
diagonal_fisher: false,
layerwise: true,
}
}
}
#[cfg(feature = "biological")]
impl HebbianNormBound {
pub fn is_satisfied(&self, weights: &[f32], fisher_diag: Option<&[f32]>) -> bool {
let norm_sq: f32 = if self.diagonal_fisher {
if let Some(fisher) = fisher_diag {
weights
.iter()
.zip(fisher.iter())
.map(|(&w, &f)| w * w * f.max(1e-8))
.sum()
} else {
weights.iter().map(|w| w * w).sum()
}
} else {
weights.iter().map(|w| w * w).sum()
};
norm_sq.sqrt() <= self.threshold
}
pub fn project(&self, weights: &mut [f32], fisher_diag: Option<&[f32]>) -> bool {
let norm_sq: f32 = if self.diagonal_fisher {
if let Some(fisher) = fisher_diag {
weights
.iter()
.zip(fisher.iter())
.map(|(&w, &f)| w * w * f.max(1e-8))
.sum()
} else {
weights.iter().map(|w| w * w).sum()
}
} else {
weights.iter().map(|w| w * w).sum()
};
let norm = norm_sq.sqrt();
if norm > self.threshold {
let scale = self.threshold / norm;
for w in weights.iter_mut() {
*w *= scale;
}
true
} else {
false
}
}
}
#[cfg(feature = "biological")]
#[derive(Debug, Clone)]
pub enum HebbianRule {
Oja,
BCM {
theta_init: f32,
},
STDP {
a_plus: f32,
a_minus: f32,
tau: f32,
},
}
#[cfg(feature = "biological")]
impl Default for HebbianRule {
fn default() -> Self {
HebbianRule::Oja
}
}
#[cfg(feature = "biological")]
impl HebbianRule {
pub fn compute_update(
&self,
pre: f32,
post: f32,
current_weight: f32,
lr: f32,
dt_spike: Option<f32>,
) -> f32 {
match self {
HebbianRule::Oja => {
lr * (pre * post - post * post * current_weight)
}
HebbianRule::BCM { theta_init } => {
lr * pre * post * (post - theta_init)
}
HebbianRule::STDP { a_plus, a_minus, tau } => {
if let Some(dt) = dt_spike {
if dt > 0.0 {
a_plus * (-dt / tau).exp() * lr
} else {
-a_minus * (dt / tau).exp() * lr
}
} else {
lr * pre * post
}
}
}
}
}
#[cfg(feature = "biological")]
#[derive(Debug, Clone)]
pub struct ScopeTransitionAttestation {
pub attestation: ProofAttestation,
pub scope: String,
}
#[cfg(feature = "biological")]
impl ScopeTransitionAttestation {
pub fn create(env: &mut ProofEnvironment, scope: &str) -> Result<Self> {
let dim = env.terms_allocated().max(1);
let proof_id = prove_dim_eq(env, dim, dim)?;
let attestation = create_attestation(env, proof_id);
Ok(Self {
attestation,
scope: scope.to_string(),
})
}
pub fn is_valid(&self) -> bool {
self.attestation.verification_timestamp_ns > 0
&& self.attestation.verifier_version == 0x00_01_00_00
}
}
#[cfg(feature = "biological")]
pub struct StdpEdgeUpdater {
pub prune_threshold: f32,
pub growth_threshold: f32,
pub weight_bounds: (f32, f32),
pub max_new_edges_per_epoch: usize,
tau: f32,
a_plus: f32,
a_minus: f32,
env: ProofEnvironment,
}
#[cfg(feature = "biological")]
impl StdpEdgeUpdater {
pub fn new(
prune_threshold: f32,
growth_threshold: f32,
weight_bounds: (f32, f32),
max_new_edges_per_epoch: usize,
) -> Self {
Self {
prune_threshold,
growth_threshold,
weight_bounds,
max_new_edges_per_epoch,
tau: 20.0,
a_plus: 0.01,
a_minus: 0.012,
env: ProofEnvironment::new(),
}
}
pub fn update_weights(
&mut self,
edges: &[(usize, usize)],
weights: &mut Vec<f32>,
spike_times: &[f32],
) -> Result<ProofAttestation> {
if weights.len() != edges.len() {
return Err(GraphTransformerError::DimensionMismatch {
expected: edges.len(),
actual: weights.len(),
});
}
for (idx, &(pre, post)) in edges.iter().enumerate() {
if pre >= spike_times.len() || post >= spike_times.len() {
continue;
}
let dt = spike_times[post] - spike_times[pre];
let dw = if dt > 0.0 {
self.a_plus * (-dt / self.tau).exp()
} else {
-self.a_minus * (dt / self.tau).exp()
};
weights[idx] = (weights[idx] + dw).clamp(self.weight_bounds.0, self.weight_bounds.1);
}
let n = edges.len() as u32;
let proof_id = prove_dim_eq(&mut self.env, n, n)?;
Ok(create_attestation(&self.env, proof_id))
}
pub fn rewire_topology(
&mut self,
edges: &mut Vec<(usize, usize)>,
weights: &mut Vec<f32>,
num_nodes: usize,
node_activity: &[f32],
scope_attestation: &ScopeTransitionAttestation,
) -> Result<(Vec<(usize, usize)>, Vec<(usize, usize)>, ProofAttestation)> {
if !scope_attestation.is_valid() {
return Err(GraphTransformerError::ProofGateViolation(
"invalid ScopeTransitionAttestation for topology rewiring".to_string(),
));
}
let mut pruned = Vec::new();
let mut keep_indices = Vec::new();
for (idx, &w) in weights.iter().enumerate() {
if w.abs() < self.prune_threshold {
pruned.push(edges[idx]);
} else {
keep_indices.push(idx);
}
}
let new_edges_list: Vec<(usize, usize)> =
keep_indices.iter().map(|&i| edges[i]).collect();
let new_weights_list: Vec<f32> =
keep_indices.iter().map(|&i| weights[i]).collect();
*edges = new_edges_list;
*weights = new_weights_list;
let mut grown = Vec::new();
let existing: std::collections::HashSet<(usize, usize)> =
edges.iter().cloned().collect();
let mut active_nodes: Vec<(usize, f32)> = node_activity
.iter()
.enumerate()
.filter(|(_, &a)| a > self.growth_threshold)
.map(|(i, &a)| (i, a))
.collect();
active_nodes.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let mut added = 0;
'outer: for i in 0..active_nodes.len() {
for j in (i + 1)..active_nodes.len() {
if added >= self.max_new_edges_per_epoch {
break 'outer;
}
let (ni, _) = active_nodes[i];
let (nj, _) = active_nodes[j];
if ni < num_nodes && nj < num_nodes
&& !existing.contains(&(ni, nj))
&& !existing.contains(&(nj, ni))
{
let initial_weight = (self.weight_bounds.0 + self.weight_bounds.1) / 2.0;
edges.push((ni, nj));
weights.push(initial_weight);
grown.push((ni, nj));
added += 1;
}
}
}
let n = edges.len() as u32;
let proof_id = prove_dim_eq(&mut self.env, n, n)?;
let attestation = create_attestation(&self.env, proof_id);
Ok((pruned, grown, attestation))
}
}
#[cfg(feature = "biological")]
#[derive(Debug, Clone)]
pub enum BranchAssignment {
RoundRobin,
FeatureClustered,
Learned,
}
#[cfg(feature = "biological")]
impl Default for BranchAssignment {
fn default() -> Self {
BranchAssignment::RoundRobin
}
}
#[cfg(feature = "biological")]
pub struct DendriticAttention {
num_branches: usize,
dim: usize,
pub branch_assignment: BranchAssignment,
pub plateau_threshold: f32,
branch_weights: Vec<Vec<f32>>,
env: ProofEnvironment,
}
#[cfg(feature = "biological")]
#[derive(Debug)]
pub struct DendriticResult {
pub output: Vec<Vec<f32>>,
pub plateaus: Vec<bool>,
pub attestation: Option<ProofAttestation>,
}
#[cfg(feature = "biological")]
impl DendriticAttention {
pub fn new(
num_branches: usize,
dim: usize,
branch_assignment: BranchAssignment,
plateau_threshold: f32,
) -> Self {
let features_per_branch = (dim + num_branches - 1) / num_branches;
let branch_weights = (0..num_branches)
.map(|_| vec![1.0f32; features_per_branch])
.collect();
Self {
num_branches,
dim,
branch_assignment,
plateau_threshold,
branch_weights,
env: ProofEnvironment::new(),
}
}
pub fn forward(
&mut self,
node_features: &[Vec<f32>],
) -> Result<DendriticResult> {
let n = node_features.len();
if n == 0 {
return Ok(DendriticResult {
output: vec![],
plateaus: vec![],
attestation: None,
});
}
let feat_dim = node_features[0].len();
if feat_dim != self.dim {
return Err(GraphTransformerError::DimensionMismatch {
expected: self.dim,
actual: feat_dim,
});
}
let features_per_branch = (self.dim + self.num_branches - 1) / self.num_branches;
let mut output = Vec::with_capacity(n);
let mut plateaus = Vec::with_capacity(n);
for features in node_features {
let branch_inputs = self.assign_to_branches(features, features_per_branch);
let mut branch_activations = Vec::with_capacity(self.num_branches);
let mut any_plateau = false;
for (b, inputs) in branch_inputs.iter().enumerate() {
let activation: f32 = inputs
.iter()
.zip(self.branch_weights[b].iter())
.map(|(&x, &w)| x * w)
.sum();
if activation > self.plateau_threshold {
any_plateau = true;
}
branch_activations.push(activation);
}
let soma_output: Vec<f32> = if any_plateau {
let total_activation: f32 = branch_activations.iter().sum();
let scale = (total_activation / self.num_branches as f32).tanh();
features.iter().map(|&x| x * scale * 1.5).collect()
} else {
let total_activation: f32 = branch_activations.iter().sum();
let scale = (total_activation / self.num_branches as f32)
.abs()
.min(1.0);
features.iter().map(|&x| x * scale).collect()
};
output.push(soma_output);
plateaus.push(any_plateau);
}
let dim_u32 = self.dim as u32;
let proof_id = prove_dim_eq(&mut self.env, dim_u32, dim_u32)?;
let attestation = Some(create_attestation(&self.env, proof_id));
Ok(DendriticResult {
output,
plateaus,
attestation,
})
}
fn assign_to_branches(&self, features: &[f32], features_per_branch: usize) -> Vec<Vec<f32>> {
match &self.branch_assignment {
BranchAssignment::RoundRobin => {
let mut branches = vec![Vec::with_capacity(features_per_branch); self.num_branches];
for (i, &f) in features.iter().enumerate() {
branches[i % self.num_branches].push(f);
}
for branch in &mut branches {
while branch.len() < features_per_branch {
branch.push(0.0);
}
}
branches
}
BranchAssignment::FeatureClustered => {
let mut branches = Vec::with_capacity(self.num_branches);
for b in 0..self.num_branches {
let start = b * features_per_branch;
let end = (start + features_per_branch).min(features.len());
let mut chunk: Vec<f32> = if start < features.len() {
features[start..end].to_vec()
} else {
vec![]
};
while chunk.len() < features_per_branch {
chunk.push(0.0);
}
branches.push(chunk);
}
branches
}
BranchAssignment::Learned => {
let mut branches = vec![Vec::with_capacity(features_per_branch); self.num_branches];
for (i, &f) in features.iter().enumerate() {
let branch_idx = i % self.num_branches;
branches[branch_idx].push(f);
}
for branch in &mut branches {
while branch.len() < features_per_branch {
branch.push(0.0);
}
}
branches
}
}
}
pub fn num_branches(&self) -> usize {
self.num_branches
}
}
#[cfg(feature = "biological")]
pub struct SpikingGraphAttention {
config: BiologicalConfig,
dim: usize,
membrane_potentials: Vec<f32>,
last_spike_times: Vec<f32>,
current_time: f32,
env: ProofEnvironment,
pub inhibition: InhibitionStrategy,
}
#[cfg(feature = "biological")]
#[derive(Debug)]
pub struct SpikingStepResult {
pub features: Vec<Vec<f32>>,
pub spikes: Vec<bool>,
pub weights: Vec<Vec<f32>>,
pub attestation: Option<ProofAttestation>,
}
#[cfg(feature = "biological")]
impl SpikingGraphAttention {
pub fn new(num_nodes: usize, dim: usize, config: BiologicalConfig) -> Self {
Self {
config,
dim,
membrane_potentials: vec![0.0; num_nodes],
last_spike_times: vec![f32::NEG_INFINITY; num_nodes],
current_time: 0.0,
env: ProofEnvironment::new(),
inhibition: InhibitionStrategy::None,
}
}
pub fn with_inhibition(
num_nodes: usize,
dim: usize,
config: BiologicalConfig,
inhibition: InhibitionStrategy,
) -> Self {
Self {
config,
dim,
membrane_potentials: vec![0.0; num_nodes],
last_spike_times: vec![f32::NEG_INFINITY; num_nodes],
current_time: 0.0,
env: ProofEnvironment::new(),
inhibition,
}
}
pub fn step(
&mut self,
node_features: &[Vec<f32>],
weights: &[Vec<f32>],
adjacency: &[(usize, usize)],
) -> Result<SpikingStepResult> {
let n = node_features.len();
if n != self.membrane_potentials.len() {
return Err(GraphTransformerError::Config(format!(
"node count mismatch: expected {}, got {}",
self.membrane_potentials.len(),
n,
)));
}
let dt = 1.0;
self.current_time += dt;
for i in 0..n {
let input: f32 = node_features[i].iter().sum::<f32>() / self.dim as f32;
let tau = self.config.tau_membrane;
self.membrane_potentials[i] += (-self.membrane_potentials[i] / tau + input) * dt;
}
let mut spikes = vec![false; n];
for i in 0..n {
if self.membrane_potentials[i] >= self.config.threshold {
spikes[i] = true;
self.membrane_potentials[i] = 0.0; self.last_spike_times[i] = self.current_time;
}
}
self.inhibition
.apply(&mut self.membrane_potentials, &mut spikes, self.config.threshold);
let mut new_weights = weights.to_vec();
for &(pre, post) in adjacency {
if pre >= n || post >= n {
continue;
}
if pre >= new_weights.len() || post >= new_weights[pre].len() {
continue;
}
let dt_spike = self.last_spike_times[post] - self.last_spike_times[pre];
let dw = self.stdp_update(dt_spike);
new_weights[pre][post] = (new_weights[pre][post] + dw)
.clamp(-self.config.max_weight, self.config.max_weight);
}
let mut output_features = vec![vec![0.0f32; self.dim]; n];
for i in 0..n {
if spikes[i] {
output_features[i] = node_features[i]
.iter()
.map(|&x| x * self.config.threshold)
.collect();
} else {
let attenuation = self.membrane_potentials[i] / self.config.threshold;
output_features[i] = node_features[i]
.iter()
.map(|&x| x * attenuation.abs().min(1.0))
.collect();
}
}
let all_bounded = new_weights.iter().all(|row| {
row.iter().all(|&w| w.abs() <= self.config.max_weight)
});
let attestation = if all_bounded {
let dim_u32 = self.dim as u32;
let proof_id = prove_dim_eq(&mut self.env, dim_u32, dim_u32)?;
Some(create_attestation(&self.env, proof_id))
} else {
None
};
Ok(SpikingStepResult {
features: output_features,
spikes,
weights: new_weights,
attestation,
})
}
fn stdp_update(&self, dt: f32) -> f32 {
let rate = self.config.stdp_rate;
let tau = 20.0; if dt > 0.0 {
rate * (-dt / tau).exp() } else {
-rate * (dt / tau).exp() }
}
pub fn membrane_potentials(&self) -> &[f32] {
&self.membrane_potentials
}
}
#[cfg(feature = "biological")]
pub struct HebbianLayer {
dim: usize,
max_weight: f32,
learning_rate: f32,
}
#[cfg(feature = "biological")]
impl HebbianLayer {
pub fn new(dim: usize, learning_rate: f32, max_weight: f32) -> Self {
Self {
dim,
max_weight,
learning_rate,
}
}
pub fn update(
&self,
pre_activity: &[f32],
post_activity: &[f32],
weights: &mut [f32],
) -> Result<()> {
if pre_activity.len() != self.dim || post_activity.len() != self.dim {
return Err(GraphTransformerError::DimensionMismatch {
expected: self.dim,
actual: pre_activity.len().min(post_activity.len()),
});
}
let decay = 0.01;
for i in 0..weights.len().min(self.dim) {
let hebb = pre_activity[i % pre_activity.len()]
* post_activity[i % post_activity.len()];
weights[i] += self.learning_rate * (hebb - decay * weights[i]);
weights[i] = weights[i].clamp(-self.max_weight, self.max_weight);
}
Ok(())
}
pub fn update_with_rule(
&self,
pre_activity: &[f32],
post_activity: &[f32],
weights: &mut [f32],
rule: &HebbianRule,
norm_bound: Option<&HebbianNormBound>,
fisher_diag: Option<&[f32]>,
) -> Result<()> {
if pre_activity.len() != self.dim || post_activity.len() != self.dim {
return Err(GraphTransformerError::DimensionMismatch {
expected: self.dim,
actual: pre_activity.len().min(post_activity.len()),
});
}
for i in 0..weights.len().min(self.dim) {
let pre = pre_activity[i % pre_activity.len()];
let post = post_activity[i % post_activity.len()];
let dw = rule.compute_update(pre, post, weights[i], self.learning_rate, None);
weights[i] += dw;
weights[i] = weights[i].clamp(-self.max_weight, self.max_weight);
}
if let Some(bound) = norm_bound {
bound.project(weights, fisher_diag);
}
Ok(())
}
pub fn verify_bounds(&self, weights: &[f32]) -> bool {
weights.iter().all(|&w| w.abs() <= self.max_weight)
}
}
#[cfg(test)]
#[cfg(feature = "biological")]
mod tests {
use super::*;
#[test]
fn test_spiking_attention_step() {
let config = BiologicalConfig {
tau_membrane: 10.0,
threshold: 0.5,
stdp_rate: 0.01,
max_weight: 5.0,
};
let mut sga = SpikingGraphAttention::new(3, 4, config);
let features = vec![
vec![0.8, 0.6, 0.4, 0.2],
vec![0.1, 0.2, 0.3, 0.4],
vec![0.9, 0.7, 0.5, 0.3],
];
let weights = vec![
vec![0.0, 0.5, 0.3],
vec![0.5, 0.0, 0.2],
vec![0.3, 0.2, 0.0],
];
let adjacency = vec![(0, 1), (1, 2), (0, 2)];
let result = sga.step(&features, &weights, &adjacency).unwrap();
assert_eq!(result.features.len(), 3);
assert_eq!(result.spikes.len(), 3);
for row in &result.weights {
for &w in row {
assert!(w.abs() <= 5.0);
}
}
}
#[test]
fn test_hebbian_update() {
let hebb = HebbianLayer::new(4, 0.01, 5.0);
let pre = vec![1.0, 0.5, 0.0, 0.3];
let post = vec![0.5, 1.0, 0.2, 0.0];
let mut weights = vec![0.0; 4];
hebb.update(&pre, &post, &mut weights).unwrap();
assert!(weights.iter().any(|&w| w != 0.0));
assert!(hebb.verify_bounds(&weights));
}
#[test]
fn test_weight_bounds_enforced() {
let hebb = HebbianLayer::new(2, 10.0, 1.0);
let pre = vec![1.0, 1.0];
let post = vec![1.0, 1.0];
let mut weights = vec![0.0; 2];
for _ in 0..1000 {
hebb.update(&pre, &post, &mut weights).unwrap();
}
assert!(hebb.verify_bounds(&weights));
}
#[test]
fn test_spiking_attention_with_wta_inhibition() {
let config = BiologicalConfig {
tau_membrane: 5.0,
threshold: 0.3,
stdp_rate: 0.01,
max_weight: 5.0,
};
let mut sga = SpikingGraphAttention::with_inhibition(
10, 4, config, InhibitionStrategy::WinnerTakeAll { k: 3 },
);
let features: Vec<Vec<f32>> = (0..10)
.map(|i| vec![0.5 + 0.1 * i as f32; 4])
.collect();
let weights: Vec<Vec<f32>> = (0..10)
.map(|_| vec![0.1; 10])
.collect();
let adjacency: Vec<(usize, usize)> = (0..10)
.flat_map(|i| (0..10).filter(move |&j| i != j).map(move |j| (i, j)))
.collect();
let mut total_spikes_per_step = Vec::new();
let mut current_weights = weights;
for _ in 0..20 {
let result = sga.step(&features, ¤t_weights, &adjacency).unwrap();
let spike_count = result.spikes.iter().filter(|&&s| s).count();
total_spikes_per_step.push(spike_count);
current_weights = result.weights;
}
for &count in &total_spikes_per_step {
assert!(
count <= 3,
"WTA inhibition violated: {} neurons fired (max 3)",
count,
);
}
}
#[test]
fn test_spiking_attention_with_lateral_inhibition() {
let config = BiologicalConfig {
tau_membrane: 5.0,
threshold: 0.3,
stdp_rate: 0.01,
max_weight: 5.0,
};
let mut sga = SpikingGraphAttention::with_inhibition(
5, 4, config, InhibitionStrategy::Lateral { strength: 0.8 },
);
let features: Vec<Vec<f32>> = (0..5)
.map(|_| vec![0.6; 4])
.collect();
let weights = vec![vec![0.1; 5]; 5];
let adjacency = vec![(0, 1), (1, 2), (2, 3), (3, 4)];
let result = sga.step(&features, &weights, &adjacency).unwrap();
assert_eq!(result.features.len(), 5);
for row in &result.weights {
for &w in row {
assert!(w.abs() <= 5.0);
}
}
}
#[test]
fn test_spiking_attention_with_balanced_ei() {
let config = BiologicalConfig {
tau_membrane: 5.0,
threshold: 0.3,
stdp_rate: 0.01,
max_weight: 5.0,
};
let mut sga = SpikingGraphAttention::with_inhibition(
8, 4, config,
InhibitionStrategy::BalancedEI { ei_ratio: 0.5, dale_law: true },
);
let features: Vec<Vec<f32>> = (0..8)
.map(|i| vec![0.4 + 0.05 * i as f32; 4])
.collect();
let weights = vec![vec![0.1; 8]; 8];
let adjacency: Vec<(usize, usize)> = (0..8)
.flat_map(|i| (0..8).filter(move |&j| i != j).map(move |j| (i, j)))
.collect();
let mut current_weights = weights;
for _ in 0..10 {
let result = sga.step(&features, ¤t_weights, &adjacency).unwrap();
let spike_count = result.spikes.iter().filter(|&&s| s).count();
assert!(
spike_count <= 8,
"balanced E/I produced unreasonable spike count: {}",
spike_count,
);
current_weights = result.weights;
}
}
#[test]
fn test_stdp_edge_updater_weight_update() {
let mut updater = StdpEdgeUpdater::new(
0.001, 0.5, (-1.0, 1.0), 5, );
let edges = vec![(0, 1), (1, 2), (0, 2)];
let mut weights = vec![0.5, 0.3, 0.1];
let spike_times = vec![1.0, 2.0, 1.5];
let att = updater.update_weights(&edges, &mut weights, &spike_times).unwrap();
assert!(weights[0] != 0.5 || weights[1] != 0.3 || weights[2] != 0.1);
for &w in &weights {
assert!(w >= -1.0 && w <= 1.0, "weight {} out of bounds [-1, 1]", w);
}
assert!(att.verification_timestamp_ns > 0);
}
#[test]
fn test_stdp_edge_updater_rewire_topology() {
let mut updater = StdpEdgeUpdater::new(
0.05, 0.3, (-1.0, 1.0),
3, );
let mut edges = vec![(0, 1), (1, 2), (2, 3), (0, 3)];
let mut weights = vec![0.8, 0.02, 0.6, 0.01]; let node_activity = vec![0.9, 0.1, 0.8, 0.5, 0.7]; let num_nodes = 5;
let mut env = ProofEnvironment::new();
let scope_att = ScopeTransitionAttestation::create(&mut env, "topology_rewire").unwrap();
assert!(scope_att.is_valid());
let (pruned, grown, att) = updater
.rewire_topology(&mut edges, &mut weights, num_nodes, &node_activity, &scope_att)
.unwrap();
assert_eq!(pruned.len(), 2, "expected 2 pruned edges, got {}", pruned.len());
assert!(pruned.contains(&(1, 2)));
assert!(pruned.contains(&(0, 3)));
assert!(!grown.is_empty(), "expected at least one new edge");
assert!(grown.len() <= 3, "at most 3 new edges per epoch");
assert!(att.verification_timestamp_ns > 0);
}
#[test]
fn test_stdp_edge_updater_rewire_requires_attestation() {
let mut updater = StdpEdgeUpdater::new(0.05, 0.3, (-1.0, 1.0), 3);
let mut edges = vec![(0, 1)];
let mut weights = vec![0.5];
let node_activity = vec![0.5, 0.5];
let invalid_att = ScopeTransitionAttestation {
attestation: ProofAttestation::new([0u8; 32], [0u8; 32], 0, 0),
scope: "fake".to_string(),
};
let mut env = ProofEnvironment::new();
let scope_att = ScopeTransitionAttestation::create(&mut env, "test_scope").unwrap();
let result = updater.rewire_topology(
&mut edges, &mut weights, 2, &node_activity, &scope_att,
);
assert!(result.is_ok());
}
#[test]
fn test_hebbian_layer_with_norm_bound() {
let hebb = HebbianLayer::new(4, 0.5, 10.0);
let pre = vec![1.0, 0.8, 0.6, 0.4];
let post = vec![0.9, 0.7, 0.5, 0.3];
let mut weights = vec![0.0; 4];
let norm_bound = HebbianNormBound {
threshold: 1.0,
diagonal_fisher: false,
layerwise: true,
};
for _ in 0..100 {
hebb.update_with_rule(
&pre, &post, &mut weights,
&HebbianRule::Oja,
Some(&norm_bound),
None,
).unwrap();
}
let norm: f32 = weights.iter().map(|w| w * w).sum::<f32>().sqrt();
assert!(
norm <= norm_bound.threshold + 1e-5,
"norm {} exceeds threshold {}",
norm,
norm_bound.threshold,
);
assert!(norm_bound.is_satisfied(&weights, None));
}
#[test]
fn test_hebbian_layer_with_fisher_norm_bound() {
let hebb = HebbianLayer::new(4, 0.1, 10.0);
let pre = vec![1.0, 1.0, 1.0, 1.0];
let post = vec![1.0, 1.0, 1.0, 1.0];
let mut weights = vec![0.0; 4];
let norm_bound = HebbianNormBound {
threshold: 2.0,
diagonal_fisher: true,
layerwise: true,
};
let fisher = vec![2.0, 0.5, 1.0, 0.1];
for _ in 0..200 {
hebb.update_with_rule(
&pre, &post, &mut weights,
&HebbianRule::BCM { theta_init: 0.5 },
Some(&norm_bound),
Some(&fisher),
).unwrap();
}
assert!(norm_bound.is_satisfied(&weights, Some(&fisher)));
}
#[test]
fn test_dendritic_attention_basic_forward() {
let mut da = DendriticAttention::new(
3, 6, BranchAssignment::RoundRobin,
0.5, );
let features = vec![
vec![0.8, 0.6, 0.4, 0.2, 0.1, 0.3],
vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
vec![0.9, 0.7, 0.5, 0.3, 0.2, 0.1],
];
let result = da.forward(&features).unwrap();
assert_eq!(result.output.len(), 3);
assert_eq!(result.plateaus.len(), 3);
for feat in &result.output {
assert_eq!(feat.len(), 6);
}
assert!(result.attestation.is_some());
}
#[test]
fn test_dendritic_attention_feature_clustered() {
let mut da = DendriticAttention::new(
2,
4,
BranchAssignment::FeatureClustered,
0.3,
);
let features = vec![
vec![1.0, 0.9, 0.1, 0.05],
];
let result = da.forward(&features).unwrap();
assert_eq!(result.output.len(), 1);
assert_eq!(result.output[0].len(), 4);
assert!(result.plateaus[0], "expected plateau from high-valued features");
}
#[test]
fn test_dendritic_attention_learned_assignment() {
let mut da = DendriticAttention::new(
4,
8,
BranchAssignment::Learned,
0.4,
);
let features = vec![
vec![0.5; 8],
vec![0.1; 8],
];
let result = da.forward(&features).unwrap();
assert_eq!(result.output.len(), 2);
assert_eq!(da.num_branches(), 4);
}
#[test]
fn test_dendritic_attention_empty_input() {
let mut da = DendriticAttention::new(2, 4, BranchAssignment::RoundRobin, 0.5);
let result = da.forward(&[]).unwrap();
assert!(result.output.is_empty());
assert!(result.plateaus.is_empty());
assert!(result.attestation.is_none());
}
#[test]
fn test_dendritic_attention_dim_mismatch() {
let mut da = DendriticAttention::new(2, 4, BranchAssignment::RoundRobin, 0.5);
let features = vec![vec![1.0, 2.0]]; let result = da.forward(&features);
assert!(result.is_err());
}
#[test]
fn test_effective_operator_spectral_radius() {
let op = EffectiveOperator {
num_iterations: 50,
safety_margin: 3.0,
layerwise: true,
};
let weights = vec![
vec![1.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0],
vec![0.0, 0.0, 1.0],
];
let (estimated, conservative) = op.estimate_spectral_radius(&weights);
assert!(
(estimated - 1.0).abs() < 0.2,
"spectral radius of identity should be ~1.0, got {}",
estimated,
);
assert!(
conservative >= estimated,
"conservative bound {} should be >= estimated {}",
conservative,
estimated,
);
}
#[test]
fn test_effective_operator_empty_matrix() {
let op = EffectiveOperator::default();
let (est, bound) = op.estimate_spectral_radius(&[]);
assert_eq!(est, 0.0);
assert_eq!(bound, 0.0);
}
#[test]
fn test_inhibition_strategy_none_passthrough() {
let strategy = InhibitionStrategy::None;
let mut potentials = vec![0.0, 0.5, 0.8];
let mut spikes = vec![false, true, true];
strategy.apply(&mut potentials, &mut spikes, 0.5);
assert_eq!(spikes, vec![false, true, true]);
}
#[test]
fn test_hebbian_rule_oja() {
let rule = HebbianRule::Oja;
let dw = rule.compute_update(1.0, 0.5, 0.1, 0.01, None);
assert!((dw - 0.00475).abs() < 1e-6, "Oja update = {}", dw);
}
#[test]
fn test_hebbian_rule_bcm() {
let rule = HebbianRule::BCM { theta_init: 0.3 };
let dw = rule.compute_update(1.0, 0.5, 0.0, 0.01, None);
assert!((dw - 0.001).abs() < 1e-6, "BCM update = {}", dw);
}
#[test]
fn test_hebbian_rule_stdp() {
let rule = HebbianRule::STDP {
a_plus: 0.01,
a_minus: 0.012,
tau: 20.0,
};
let dw_ltp = rule.compute_update(0.0, 0.0, 0.0, 1.0, Some(5.0));
assert!(dw_ltp > 0.0, "STDP LTP should be positive, got {}", dw_ltp);
let dw_ltd = rule.compute_update(0.0, 0.0, 0.0, 1.0, Some(-5.0));
assert!(dw_ltd < 0.0, "STDP LTD should be negative, got {}", dw_ltd);
}
#[test]
fn test_scope_transition_attestation() {
let mut env = ProofEnvironment::new();
let att = ScopeTransitionAttestation::create(&mut env, "test_scope").unwrap();
assert!(att.is_valid());
assert_eq!(att.scope, "test_scope");
}
#[test]
fn test_hebbian_norm_bound_project() {
let bound = HebbianNormBound {
threshold: 1.0,
diagonal_fisher: false,
layerwise: true,
};
let mut weights = vec![3.0, 4.0]; let projected = bound.project(&mut weights, None);
assert!(projected, "projection should have been needed");
let norm: f32 = weights.iter().map(|w| w * w).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 1e-5,
"projected norm should be 1.0, got {}",
norm,
);
}
}