#[cfg(feature = "temporal")]
use ruvector_attention::{ScaledDotProductAttention, Attention};
#[cfg(feature = "temporal")]
use ruvector_verified::{
ProofEnvironment,
proof_store::create_attestation,
gated::{route_proof, ProofKind},
};
#[cfg(feature = "temporal")]
use crate::config::TemporalConfig;
#[cfg(feature = "temporal")]
use crate::error::{GraphTransformerError, Result};
#[cfg(feature = "temporal")]
use crate::proof_gated::ProofGate;
#[cfg(feature = "temporal")]
#[derive(Debug, Clone)]
pub enum MaskStrategy {
Strict,
TimeWindow {
window_size: f64,
},
Topological,
}
#[cfg(feature = "temporal")]
#[derive(Debug, Clone, PartialEq)]
pub enum EdgeEventType {
Add,
Remove,
UpdateWeight(f32),
}
#[cfg(feature = "temporal")]
#[derive(Debug, Clone)]
pub struct TemporalEdgeEvent {
pub source: usize,
pub target: usize,
pub timestamp: f64,
pub event_type: EdgeEventType,
}
#[cfg(feature = "temporal")]
#[derive(Debug)]
pub struct TemporalAttentionResult {
pub output: Vec<Vec<f32>>,
pub attention_weights: Vec<Vec<f32>>,
}
#[cfg(feature = "temporal")]
pub struct CausalGraphTransformer {
config: TemporalConfig,
attention: ScaledDotProductAttention,
dim: usize,
mask_strategy: MaskStrategy,
discount: f32,
env: ProofEnvironment,
}
#[cfg(feature = "temporal")]
impl CausalGraphTransformer {
pub fn new(dim: usize, config: TemporalConfig) -> Self {
let attention = ScaledDotProductAttention::new(dim);
Self {
config,
attention,
dim,
mask_strategy: MaskStrategy::Strict,
discount: 0.9,
env: ProofEnvironment::new(),
}
}
pub fn with_strategy(
dim: usize,
config: TemporalConfig,
mask_strategy: MaskStrategy,
discount: f32,
) -> Self {
let attention = ScaledDotProductAttention::new(dim);
Self {
config,
attention,
dim,
mask_strategy,
discount: discount.clamp(0.0, 1.0),
env: ProofEnvironment::new(),
}
}
pub fn forward(
&mut self,
features: &[Vec<f32>],
timestamps: &[f64],
edges: &[(usize, usize)],
) -> Result<ProofGate<TemporalAttentionResult>> {
let n = features.len();
if n != timestamps.len() {
return Err(GraphTransformerError::DimensionMismatch {
expected: n,
actual: timestamps.len(),
});
}
if n == 0 {
let result = TemporalAttentionResult {
output: Vec::new(),
attention_weights: Vec::new(),
};
return Ok(ProofGate::new(result));
}
let feat_dim = features[0].len();
if feat_dim != self.dim {
return Err(GraphTransformerError::DimensionMismatch {
expected: self.dim,
actual: feat_dim,
});
}
let decision = route_proof(
ProofKind::DimensionEquality {
expected: self.dim as u32,
actual: feat_dim as u32,
},
&self.env,
);
let _proof_id = ruvector_verified::gated::verify_tiered(
&mut self.env,
self.dim as u32,
feat_dim as u32,
decision.tier,
)?;
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
for &(src, tgt) in edges {
if src < n && tgt < n {
adj[tgt].push(src); }
}
let mut outputs = Vec::with_capacity(n);
let mut all_weights = Vec::with_capacity(n);
for i in 0..n {
let t_i = timestamps[i];
let candidates = self.causal_candidates(i, &adj[i], timestamps, t_i);
if candidates.is_empty() {
outputs.push(features[i].clone());
let mut row = vec![0.0f32; n];
row[i] = 1.0;
all_weights.push(row);
continue;
}
let query = &features[i];
let keys: Vec<&[f32]> = candidates.iter().map(|&j| features[j].as_slice()).collect();
let decay: Vec<f32> = candidates.iter().map(|&j| {
let dt = (t_i - timestamps[j]) as f32;
self.discount.powf(dt.max(0.0))
}).collect();
let scaled_keys: Vec<Vec<f32>> = keys.iter()
.zip(decay.iter())
.map(|(k, &w)| k.iter().map(|&x| x * w).collect())
.collect();
let scaled_refs: Vec<&[f32]> = scaled_keys.iter().map(|k| k.as_slice()).collect();
let values: Vec<&[f32]> = keys.clone();
let out = self.attention.compute(query, &scaled_refs, &values)
.map_err(GraphTransformerError::Attention)?;
let mut row = vec![0.0f32; n];
for (idx, &j) in candidates.iter().enumerate() {
row[j] = decay[idx];
}
outputs.push(out);
all_weights.push(row);
}
let result = TemporalAttentionResult {
output: outputs,
attention_weights: all_weights,
};
let attestation_proof = self.env.alloc_term();
self.env.stats.proofs_verified += 1;
let _attestation = create_attestation(&self.env, attestation_proof);
Ok(ProofGate::new(result))
}
fn causal_candidates(
&self,
i: usize,
neighbors: &[usize],
timestamps: &[f64],
t_i: f64,
) -> Vec<usize> {
let mut cands = Vec::new();
cands.push(i);
for &j in neighbors {
if j == i {
continue;
}
let t_j = timestamps[j];
let valid = match &self.mask_strategy {
MaskStrategy::Strict => t_j <= t_i,
MaskStrategy::TimeWindow { window_size } => {
t_j <= t_i && (t_i - t_j) <= *window_size
}
MaskStrategy::Topological => {
t_j <= t_i
}
};
if valid {
cands.push(j);
}
}
cands
}
pub fn temporal_attention(
&self,
sequence: &[Vec<f32>],
) -> Result<TemporalAttentionResult> {
let t = sequence.len();
if t == 0 {
return Ok(TemporalAttentionResult {
output: Vec::new(),
attention_weights: Vec::new(),
});
}
let dim = sequence[0].len();
if dim != self.dim {
return Err(GraphTransformerError::DimensionMismatch {
expected: self.dim,
actual: dim,
});
}
let mut outputs = Vec::with_capacity(t);
let mut all_weights = Vec::with_capacity(t);
for i in 0..t {
let max_lag = self.config.max_lag.min(i + 1);
let start = if i >= max_lag { i - max_lag + 1 } else { 0 };
let query = &sequence[i];
let keys: Vec<&[f32]> = (start..=i)
.map(|j| sequence[j].as_slice())
.collect();
let values: Vec<&[f32]> = keys.clone();
let decay_weights: Vec<f32> = (start..=i)
.map(|j| {
let dt = (i - j) as f32;
self.config.decay_rate.powf(dt)
})
.collect();
let scaled_keys: Vec<Vec<f32>> = keys.iter()
.zip(decay_weights.iter())
.map(|(k, &w)| k.iter().map(|&x| x * w).collect())
.collect();
let scaled_refs: Vec<&[f32]> = scaled_keys.iter()
.map(|k| k.as_slice())
.collect();
let out = self.attention.compute(query, &scaled_refs, &values)
.map_err(GraphTransformerError::Attention)?;
let mut step_weights = vec![0.0f32; t];
for (idx, j) in (start..=i).enumerate() {
step_weights[j] = decay_weights[idx];
}
outputs.push(out);
all_weights.push(step_weights);
}
Ok(TemporalAttentionResult {
output: outputs,
attention_weights: all_weights,
})
}
pub fn granger_causality(
&self,
time_series: &[Vec<f32>],
source: usize,
target: usize,
) -> Result<GrangerCausalityResult> {
let t = time_series.len();
let lags = self.config.granger_lags.min(t.saturating_sub(1));
if lags == 0 || t < lags + 1 {
return Ok(GrangerCausalityResult {
source,
target,
f_statistic: 0.0,
is_causal: false,
lags,
});
}
if source >= time_series[0].len() || target >= time_series[0].len() {
return Err(GraphTransformerError::Config(format!(
"node index out of bounds: source={}, target={}, dim={}",
source, target, time_series[0].len(),
)));
}
let rss_restricted = compute_var_rss(time_series, target, &[target], lags);
let rss_unrestricted = compute_var_rss(time_series, target, &[target, source], lags);
let n = (t - lags) as f32;
let p_restricted = lags as f32;
let p_unrestricted = 2.0 * lags as f32;
let df_diff = p_unrestricted - p_restricted;
let df_denom = n - p_unrestricted;
let f_stat = if rss_unrestricted > 1e-10 && df_denom > 0.0 && df_diff > 0.0 {
let raw = ((rss_restricted - rss_unrestricted) / df_diff)
/ (rss_unrestricted / df_denom);
if raw.is_finite() { raw.max(0.0) } else { 0.0 }
} else {
0.0
};
let is_causal = f_stat > 3.84;
Ok(GrangerCausalityResult {
source,
target,
f_statistic: f_stat,
is_causal,
lags,
})
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn verify_causal_ordering(&self, weights: &[Vec<f32>]) -> bool {
for (i, row) in weights.iter().enumerate() {
for (j, &w) in row.iter().enumerate() {
if j > i && w.abs() > 1e-8 {
return false; }
}
}
true
}
}
#[cfg(feature = "temporal")]
pub struct BatchModeToken {
_private: (),
}
#[cfg(feature = "temporal")]
impl BatchModeToken {
pub fn new_batch(window_size: usize) -> Option<Self> {
if window_size > 0 {
Some(BatchModeToken { _private: () })
} else {
None
}
}
}
#[cfg(feature = "temporal")]
pub struct RetrocausalAttention {
dim: usize,
gate_weights: Vec<f32>,
env: ProofEnvironment,
}
#[cfg(feature = "temporal")]
#[derive(Debug)]
pub struct SmoothedOutput {
pub features: Vec<Vec<f32>>,
pub forward_features: Vec<Vec<f32>>,
pub backward_features: Vec<Vec<f32>>,
}
#[cfg(feature = "temporal")]
impl RetrocausalAttention {
pub fn new(dim: usize) -> Self {
let gate_weights = vec![0.5; dim];
Self {
dim,
gate_weights,
env: ProofEnvironment::new(),
}
}
pub fn with_gate(dim: usize, gate_weights: Vec<f32>) -> Self {
assert_eq!(gate_weights.len(), dim);
Self {
dim,
gate_weights,
env: ProofEnvironment::new(),
}
}
pub fn forward(
&mut self,
features: &[Vec<f32>],
timestamps: &[f64],
_batch_token: &BatchModeToken,
) -> Result<ProofGate<SmoothedOutput>> {
let n = features.len();
if n == 0 {
return Ok(ProofGate::new(SmoothedOutput {
features: Vec::new(),
forward_features: Vec::new(),
backward_features: Vec::new(),
}));
}
let feat_dim = features[0].len();
if feat_dim != self.dim {
return Err(GraphTransformerError::DimensionMismatch {
expected: self.dim,
actual: feat_dim,
});
}
let _decision = route_proof(ProofKind::Reflexivity, &self.env);
self.env.stats.proofs_verified += 1;
let _proof_id = self.env.alloc_term();
let forward_feats = self.causal_pass(features, timestamps, true);
let backward_feats = self.causal_pass(features, timestamps, false);
let mut smoothed = Vec::with_capacity(n);
for i in 0..n {
let mut combined = vec![0.0f32; feat_dim];
for d in 0..feat_dim {
let g = self.gate_weights[d];
combined[d] = g * forward_feats[i][d] + (1.0 - g) * backward_feats[i][d];
}
smoothed.push(combined);
}
let output = SmoothedOutput {
features: smoothed,
forward_features: forward_feats,
backward_features: backward_feats,
};
Ok(ProofGate::new(output))
}
fn causal_pass(
&self,
features: &[Vec<f32>],
timestamps: &[f64],
forward: bool,
) -> Vec<Vec<f32>> {
let n = features.len();
let dim = if n > 0 { features[0].len() } else { 0 };
let mut output = Vec::with_capacity(n);
for i in 0..n {
let t_i = timestamps[i];
let mut sum = vec![0.0f32; dim];
let mut count = 0u32;
for j in 0..n {
let valid = if forward {
timestamps[j] <= t_i
} else {
timestamps[j] >= t_i
};
if valid {
for d in 0..dim {
sum[d] += features[j][d];
}
count += 1;
}
}
if count > 0 {
for d in 0..dim {
sum[d] /= count as f32;
}
}
output.push(sum);
}
output
}
}
#[cfg(feature = "temporal")]
pub struct ContinuousTimeODE {
dim: usize,
atol: f64,
rtol: f64,
max_steps: usize,
env: ProofEnvironment,
}
#[cfg(feature = "temporal")]
#[derive(Debug)]
pub struct OdeOutput {
pub features: Vec<Vec<f32>>,
pub steps_taken: usize,
pub max_error: f64,
pub event_times: Vec<f64>,
}
#[cfg(feature = "temporal")]
impl ContinuousTimeODE {
pub fn new(dim: usize, atol: f64, rtol: f64, max_steps: usize) -> Self {
Self {
dim,
atol,
rtol,
max_steps,
env: ProofEnvironment::new(),
}
}
pub fn integrate(
&mut self,
features: &[Vec<f32>],
t_start: f64,
t_end: f64,
edge_events: &[TemporalEdgeEvent],
) -> Result<ProofGate<OdeOutput>> {
let n = features.len();
if n == 0 {
return Ok(ProofGate::new(OdeOutput {
features: Vec::new(),
steps_taken: 0,
max_error: 0.0,
event_times: Vec::new(),
}));
}
let feat_dim = features[0].len();
if feat_dim != self.dim {
return Err(GraphTransformerError::DimensionMismatch {
expected: self.dim,
actual: feat_dim,
});
}
let mut sorted_events: Vec<&TemporalEdgeEvent> = edge_events
.iter()
.filter(|e| e.timestamp >= t_start && e.timestamp <= t_end)
.collect();
sorted_events.sort_by(|a, b| a.timestamp.partial_cmp(&b.timestamp).unwrap());
let mut state: Vec<Vec<f32>> = features.to_vec();
let mut t = t_start;
let mut steps = 0usize;
let mut max_error = 0.0f64;
let mut event_times = Vec::new();
let mut event_idx = 0;
let mut adj: Vec<Vec<(usize, f32)>> = vec![Vec::new(); n];
while t < t_end && steps < self.max_steps {
let t_next_event = if event_idx < sorted_events.len() {
sorted_events[event_idx].timestamp
} else {
t_end
};
let t_step_end = t_next_event.min(t_end);
if t_step_end > t {
let (new_state, error) = self.dormand_prince_step(&state, &adj, t, t_step_end);
max_error = max_error.max(error);
state = new_state;
t = t_step_end;
steps += 1;
}
while event_idx < sorted_events.len()
&& (sorted_events[event_idx].timestamp - t).abs() < 1e-12
{
let ev = sorted_events[event_idx];
event_times.push(ev.timestamp);
match &ev.event_type {
EdgeEventType::Add => {
if ev.source < n && ev.target < n {
adj[ev.target].push((ev.source, 1.0));
}
}
EdgeEventType::Remove => {
if ev.target < n {
adj[ev.target].retain(|&(s, _)| s != ev.source);
}
}
EdgeEventType::UpdateWeight(w) => {
if ev.target < n {
for edge in adj[ev.target].iter_mut() {
if edge.0 == ev.source {
edge.1 = *w;
}
}
}
}
}
event_idx += 1;
}
}
if t < t_end && steps < self.max_steps {
let (new_state, error) = self.dormand_prince_step(&state, &adj, t, t_end);
max_error = max_error.max(error);
state = new_state;
steps += 1;
}
let y_scale: f64 = state.iter()
.flat_map(|row| row.iter())
.map(|&v| (v as f64).abs())
.fold(0.0f64, f64::max)
.max(1.0); let error_bound = self.atol + self.rtol * y_scale;
let error_ok = max_error <= error_bound;
if !error_ok {
return Err(GraphTransformerError::NumericalError(format!(
"ODE integration error {} exceeds tolerance (bound={}, atol={}, rtol={})",
max_error, error_bound, self.atol, self.rtol,
)));
}
let _proof_id = self.env.alloc_term();
self.env.stats.proofs_verified += 1;
let output = OdeOutput {
features: state,
steps_taken: steps,
max_error,
event_times,
};
Ok(ProofGate::new(output))
}
fn dormand_prince_step(
&self,
state: &[Vec<f32>],
adj: &[Vec<(usize, f32)>],
_t: f64,
_t_end: f64,
) -> (Vec<Vec<f32>>, f64) {
let n = state.len();
let dim = if n > 0 { state[0].len() } else { 0 };
let mut k1: Vec<Vec<f32>> = Vec::with_capacity(n);
for i in 0..n {
let mut dh = vec![0.0f32; dim];
let neighbors = &adj[i];
if neighbors.is_empty() {
k1.push(dh);
continue;
}
let mut total_weight = 0.0f32;
for &(j, w) in neighbors {
total_weight += w;
for d in 0..dim {
dh[d] += w * state[j][d];
}
}
if total_weight > 0.0 {
for d in 0..dim {
dh[d] = dh[d] / total_weight - state[i][d];
}
}
k1.push(dh);
}
let h = 1.0f32;
let mut y1: Vec<Vec<f32>> = Vec::with_capacity(n);
for i in 0..n {
let mut row = vec![0.0f32; dim];
for d in 0..dim {
row[d] = state[i][d] + h * k1[i][d];
}
y1.push(row);
}
let mut k2: Vec<Vec<f32>> = Vec::with_capacity(n);
for i in 0..n {
let mut dh = vec![0.0f32; dim];
let neighbors = &adj[i];
if neighbors.is_empty() {
k2.push(dh);
continue;
}
let mut total_weight = 0.0f32;
for &(j, w) in neighbors {
total_weight += w;
for d in 0..dim {
dh[d] += w * y1[j][d];
}
}
if total_weight > 0.0 {
for d in 0..dim {
dh[d] = dh[d] / total_weight - y1[i][d];
}
}
k2.push(dh);
}
let mut y_final: Vec<Vec<f32>> = Vec::with_capacity(n);
let mut max_err = 0.0f64;
for i in 0..n {
let mut row = vec![0.0f32; dim];
for d in 0..dim {
row[d] = state[i][d] + 0.5 * h * (k1[i][d] + k2[i][d]);
let err = (y1[i][d] - row[d]).abs() as f64;
if err > max_err {
max_err = err;
}
}
y_final.push(row);
}
(y_final, max_err)
}
}
#[cfg(feature = "temporal")]
#[derive(Debug, Clone)]
pub struct GrangerCausalityResult {
pub source: usize,
pub target: usize,
pub f_statistic: f32,
pub is_causal: bool,
pub lags: usize,
}
#[cfg(feature = "temporal")]
#[derive(Debug, Clone)]
pub struct GrangerEdge {
pub source: usize,
pub target: usize,
pub weight: f64,
}
#[cfg(feature = "temporal")]
#[derive(Debug, Clone)]
pub struct GrangerGraph {
pub num_nodes: usize,
pub edges: Vec<GrangerEdge>,
pub is_acyclic: bool,
pub topological_order: Vec<usize>,
}
#[cfg(feature = "temporal")]
#[derive(Debug, Clone)]
pub struct AttentionSnapshot {
pub weights: Vec<Vec<f32>>,
pub timestamp: f64,
}
#[cfg(feature = "temporal")]
pub struct GrangerCausalityExtractor {
threshold: f64,
min_window: usize,
env: ProofEnvironment,
}
#[cfg(feature = "temporal")]
impl GrangerCausalityExtractor {
pub fn new(threshold: f64, min_window: usize) -> Self {
Self {
threshold,
min_window,
env: ProofEnvironment::new(),
}
}
pub fn extract(
&mut self,
attention_history: &[AttentionSnapshot],
) -> Result<ProofGate<GrangerGraph>> {
if attention_history.len() < self.min_window {
return Err(GraphTransformerError::Config(format!(
"attention history length {} < min_window {}",
attention_history.len(),
self.min_window,
)));
}
let num_nodes = if !attention_history.is_empty() && !attention_history[0].weights.is_empty()
{
attention_history[0].weights.len()
} else {
0
};
let mut avg_weights = vec![vec![0.0f64; num_nodes]; num_nodes];
let count = attention_history.len() as f64;
for snapshot in attention_history {
for (i, row) in snapshot.weights.iter().enumerate() {
for (j, &w) in row.iter().enumerate() {
if i < num_nodes && j < num_nodes {
avg_weights[i][j] += w as f64 / count;
}
}
}
}
let mut edges = Vec::new();
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); num_nodes];
for i in 0..num_nodes {
for j in 0..num_nodes {
if i != j && avg_weights[i][j] > self.threshold {
edges.push(GrangerEdge {
source: i,
target: j,
weight: avg_weights[i][j],
});
adj[i].push(j);
}
}
}
let (is_acyclic, topo_order) = topological_sort(num_nodes, &adj);
if is_acyclic {
let _proof_id = self.env.alloc_term();
self.env.stats.proofs_verified += 1;
}
let graph = GrangerGraph {
num_nodes,
edges,
is_acyclic,
topological_order: topo_order,
};
Ok(ProofGate::new(graph))
}
}
#[cfg(feature = "temporal")]
fn topological_sort(num_nodes: usize, adj: &[Vec<usize>]) -> (bool, Vec<usize>) {
let mut in_degree = vec![0usize; num_nodes];
for neighbors in adj.iter() {
for &v in neighbors {
if v < num_nodes {
in_degree[v] += 1;
}
}
}
let mut queue: Vec<usize> = (0..num_nodes).filter(|&i| in_degree[i] == 0).collect();
let mut order = Vec::with_capacity(num_nodes);
while let Some(u) = queue.pop() {
order.push(u);
for &v in &adj[u] {
if v < num_nodes {
in_degree[v] -= 1;
if in_degree[v] == 0 {
queue.push(v);
}
}
}
}
let is_acyclic = order.len() == num_nodes;
(is_acyclic, order)
}
#[cfg(feature = "temporal")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StorageTier {
Hot,
Warm,
Cold,
}
#[cfg(feature = "temporal")]
#[derive(Debug, Clone)]
struct DeltaEntry {
timestamp: f64,
base: Option<Vec<f32>>,
delta: Vec<(usize, f32)>,
tier: StorageTier,
}
#[cfg(feature = "temporal")]
pub struct TemporalEmbeddingStore {
dim: usize,
chains: Vec<Vec<DeltaEntry>>,
warm_threshold: f64,
cold_threshold: f64,
}
#[cfg(feature = "temporal")]
impl TemporalEmbeddingStore {
pub fn new(dim: usize, num_nodes: usize, warm_threshold: f64, cold_threshold: f64) -> Self {
Self {
dim,
chains: vec![Vec::new(); num_nodes],
warm_threshold,
cold_threshold,
}
}
pub fn store(&mut self, node: usize, time: f64, embedding: &[f32]) {
if node >= self.chains.len() {
self.chains.resize(node + 1, Vec::new());
}
let is_first = self.chains[node].is_empty();
if is_first {
self.chains[node].push(DeltaEntry {
timestamp: time,
base: Some(embedding.to_vec()),
delta: Vec::new(),
tier: StorageTier::Hot,
});
} else {
let prev = self.reconstruct_latest(node);
let delta: Vec<(usize, f32)> = embedding
.iter()
.enumerate()
.filter_map(|(i, &v)| {
let diff = v - prev.as_ref().map_or(0.0, |p| p[i]);
if diff.abs() > 1e-8 {
Some((i, diff))
} else {
None
}
})
.collect();
let is_base = delta.len() > self.dim / 2;
self.chains[node].push(DeltaEntry {
timestamp: time,
base: if is_base { Some(embedding.to_vec()) } else { None },
delta: if is_base { Vec::new() } else { delta },
tier: StorageTier::Hot,
});
}
}
pub fn retrieve(&self, node: usize, time: f64) -> Option<Vec<f32>> {
if node >= self.chains.len() {
return None;
}
let chain = &self.chains[node];
if chain.is_empty() {
return None;
}
let target_idx = chain
.iter()
.rposition(|e| e.timestamp <= time)?;
let base_idx = (0..=target_idx)
.rev()
.find(|&i| chain[i].base.is_some())?;
let mut embedding = chain[base_idx].base.as_ref().unwrap().clone();
for i in (base_idx + 1)..=target_idx {
if let Some(ref base) = chain[i].base {
embedding = base.clone();
} else {
for &(dim_idx, diff) in &chain[i].delta {
if dim_idx < embedding.len() {
embedding[dim_idx] += diff;
}
}
}
}
Some(embedding)
}
pub fn compact(&mut self, current_time: f64) {
for chain in &mut self.chains {
for entry in chain.iter_mut() {
let age = current_time - entry.timestamp;
if age > self.cold_threshold {
entry.tier = StorageTier::Cold;
} else if age > self.warm_threshold {
entry.tier = StorageTier::Warm;
}
}
}
}
pub fn chain_length(&self, node: usize) -> usize {
if node < self.chains.len() {
self.chains[node].len()
} else {
0
}
}
fn reconstruct_latest(&self, node: usize) -> Option<Vec<f32>> {
if node >= self.chains.len() {
return None;
}
let chain = &self.chains[node];
if chain.is_empty() {
return None;
}
self.retrieve(node, chain.last().unwrap().timestamp)
}
}
#[cfg(feature = "temporal")]
fn compute_var_rss(
time_series: &[Vec<f32>],
target: usize,
predictors: &[usize],
lags: usize,
) -> f32 {
let t = time_series.len();
if t <= lags {
return 0.0;
}
let mut rss = 0.0f32;
for i in lags..t {
let actual = time_series[i][target];
let mut predicted = 0.0f32;
let mut count = 0;
for &pred in predictors {
for lag in 1..=lags {
predicted += time_series[i - lag][pred];
count += 1;
}
}
if count > 0 {
predicted /= count as f32;
}
let residual = actual - predicted;
rss += residual * residual;
}
rss
}
#[cfg(test)]
#[cfg(feature = "temporal")]
mod tests {
use super::*;
#[test]
fn test_causal_temporal_attention() {
let config = TemporalConfig {
decay_rate: 0.9,
max_lag: 5,
granger_lags: 3,
};
let transformer = CausalGraphTransformer::new(4, config);
let sequence = vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0],
vec![0.0, 0.0, 0.0, 1.0],
];
let result = transformer.temporal_attention(&sequence).unwrap();
assert_eq!(result.output.len(), 4);
assert_eq!(result.attention_weights.len(), 4);
assert!(transformer.verify_causal_ordering(&result.attention_weights));
}
#[test]
fn test_causal_ordering_verification() {
let config = TemporalConfig::default();
let transformer = CausalGraphTransformer::new(4, config);
let causal_weights = vec![
vec![1.0, 0.0, 0.0],
vec![0.5, 0.5, 0.0],
vec![0.3, 0.3, 0.4],
];
assert!(transformer.verify_causal_ordering(&causal_weights));
let non_causal = vec![
vec![0.5, 0.5, 0.0], vec![0.5, 0.5, 0.0],
vec![0.3, 0.3, 0.4],
];
assert!(!transformer.verify_causal_ordering(&non_causal));
}
#[test]
fn test_granger_causality() {
let config = TemporalConfig {
decay_rate: 0.9,
max_lag: 5,
granger_lags: 2,
};
let transformer = CausalGraphTransformer::new(4, config);
let mut series = Vec::new();
for t in 0..20 {
let x = (t as f32 * 0.1).sin();
let y = if t > 0 { (((t - 1) as f32) * 0.1).sin() * 0.8 } else { 0.0 };
series.push(vec![x, y, 0.0, 0.0]);
}
let result = transformer.granger_causality(&series, 0, 1).unwrap();
assert_eq!(result.source, 0);
assert_eq!(result.target, 1);
assert_eq!(result.lags, 2);
assert!(result.f_statistic >= 0.0);
}
#[test]
fn test_temporal_attention_empty() {
let config = TemporalConfig::default();
let transformer = CausalGraphTransformer::new(4, config);
let result = transformer.temporal_attention(&[]).unwrap();
assert!(result.output.is_empty());
}
#[test]
fn test_temporal_attention_single_step() {
let config = TemporalConfig::default();
let transformer = CausalGraphTransformer::new(4, config);
let sequence = vec![vec![1.0, 2.0, 3.0, 4.0]];
let result = transformer.temporal_attention(&sequence).unwrap();
assert_eq!(result.output.len(), 1);
assert_eq!(result.output[0].len(), 4);
}
#[test]
fn test_causal_no_future_leakage() {
let config = TemporalConfig {
decay_rate: 0.9,
max_lag: 10,
granger_lags: 3,
};
let mut transformer = CausalGraphTransformer::with_strategy(
4,
config,
MaskStrategy::Strict,
0.9,
);
let features = vec![
vec![1.0, 0.0, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0], vec![0.0, 0.0, 1.0, 0.0], vec![0.0, 0.0, 0.0, 1.0], ];
let timestamps = vec![0.0, 1.0, 2.0, 3.0];
let edges: Vec<(usize, usize)> = vec![
(0, 1), (0, 2), (0, 3),
(1, 0), (1, 2), (1, 3),
(2, 0), (2, 1), (2, 3),
(3, 0), (3, 1), (3, 2),
];
let result = transformer.forward(&features, ×tamps, &edges).unwrap();
let weights = &result.read().attention_weights;
assert!(
weights[1][2].abs() < 1e-8,
"node 1 (t=1) leaked to node 2 (t=2): weight={}",
weights[1][2]
);
assert!(
weights[1][3].abs() < 1e-8,
"node 1 (t=1) leaked to node 3 (t=3): weight={}",
weights[1][3]
);
assert!(weights[0][1].abs() < 1e-8, "node 0 (t=0) leaked to node 1 (t=1)");
assert!(weights[0][2].abs() < 1e-8, "node 0 (t=0) leaked to node 2 (t=2)");
assert!(weights[0][3].abs() < 1e-8, "node 0 (t=0) leaked to node 3 (t=3)");
assert!(weights[3][3].abs() > 1e-8, "node 3 must see itself");
}
#[test]
fn test_causal_time_window() {
let config = TemporalConfig {
decay_rate: 0.9,
max_lag: 10,
granger_lags: 3,
};
let mut transformer = CausalGraphTransformer::with_strategy(
2,
config,
MaskStrategy::TimeWindow { window_size: 1.5 },
0.9,
);
let features = vec![
vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0], vec![0.5, 0.5], ];
let timestamps = vec![0.0, 1.0, 2.0, 3.0];
let edges: Vec<(usize, usize)> = vec![
(0, 1), (0, 2), (0, 3),
(1, 2), (1, 3),
(2, 3),
];
let result = transformer.forward(&features, ×tamps, &edges).unwrap();
let weights = &result.read().attention_weights;
assert!(weights[3][0].abs() < 1e-8, "node 3 should not see node 0 (outside window)");
assert!(weights[3][1].abs() < 1e-8, "node 3 should not see node 1 (outside window)");
}
#[test]
fn test_retrocausal_requires_batch_token() {
let mut retro = RetrocausalAttention::new(4);
let features = vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0],
];
let timestamps = vec![0.0, 1.0, 2.0];
assert!(BatchModeToken::new_batch(0).is_none());
let token = BatchModeToken::new_batch(3).expect("should create batch token");
let result = retro.forward(&features, ×tamps, &token);
assert!(result.is_ok());
let gate = result.unwrap();
let output = gate.read();
assert_eq!(output.features.len(), 3);
assert_eq!(output.forward_features.len(), 3);
assert_eq!(output.backward_features.len(), 3);
assert_eq!(output.features[0].len(), 4);
}
#[test]
fn test_retrocausal_bidirectional() {
let mut retro = RetrocausalAttention::new(2);
let features = vec![
vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0], ];
let timestamps = vec![0.0, 1.0, 2.0];
let token = BatchModeToken::new_batch(3).unwrap();
let result = retro.forward(&features, ×tamps, &token).unwrap();
let output = result.read();
assert_ne!(output.forward_features[0], output.backward_features[0]);
}
#[test]
fn test_ode_integration_3_events() {
let mut ode = ContinuousTimeODE::new(2, 1.0, 0.5, 100);
let features = vec![
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![0.5, 0.5],
];
let events = vec![
TemporalEdgeEvent {
source: 0,
target: 1,
timestamp: 0.5,
event_type: EdgeEventType::Add,
},
TemporalEdgeEvent {
source: 1,
target: 2,
timestamp: 1.0,
event_type: EdgeEventType::Add,
},
TemporalEdgeEvent {
source: 0,
target: 2,
timestamp: 1.5,
event_type: EdgeEventType::UpdateWeight(0.5),
},
];
let result = ode.integrate(&features, 0.0, 2.0, &events);
assert!(result.is_ok(), "ODE integration should succeed");
let gate = result.unwrap();
let output = gate.read();
assert_eq!(output.features.len(), 3);
assert_eq!(output.features[0].len(), 2);
assert!(output.steps_taken > 0, "should take at least one step");
assert_eq!(output.event_times.len(), 3, "should process 3 events");
}
#[test]
fn test_ode_empty() {
let mut ode = ContinuousTimeODE::new(2, 1e-3, 1e-3, 100);
let result = ode.integrate(&[], 0.0, 1.0, &[]).unwrap();
assert!(result.read().features.is_empty());
}
#[test]
fn test_granger_extract_dag_acyclic() {
let mut extractor = GrangerCausalityExtractor::new(0.1, 2);
let snapshot1 = AttentionSnapshot {
weights: vec![
vec![0.0, 0.4, 0.0],
vec![0.0, 0.0, 0.5],
vec![0.0, 0.0, 0.0],
],
timestamp: 0.0,
};
let snapshot2 = AttentionSnapshot {
weights: vec![
vec![0.0, 0.6, 0.0],
vec![0.0, 0.0, 0.3],
vec![0.0, 0.0, 0.0],
],
timestamp: 1.0,
};
let snapshot3 = AttentionSnapshot {
weights: vec![
vec![0.0, 0.5, 0.0],
vec![0.0, 0.0, 0.4],
vec![0.0, 0.0, 0.0],
],
timestamp: 2.0,
};
let result = extractor.extract(&[snapshot1, snapshot2, snapshot3]);
assert!(result.is_ok());
let gate = result.unwrap();
let graph = gate.read();
assert_eq!(graph.num_nodes, 3);
assert!(graph.is_acyclic, "graph should be acyclic");
assert_eq!(graph.topological_order.len(), 3);
assert!(graph.edges.len() >= 2, "should have at least 2 edges");
let has_01 = graph.edges.iter().any(|e| e.source == 0 && e.target == 1);
let has_12 = graph.edges.iter().any(|e| e.source == 1 && e.target == 2);
assert!(has_01, "should have edge 0->1");
assert!(has_12, "should have edge 1->2");
let has_10 = graph.edges.iter().any(|e| e.source == 1 && e.target == 0);
let has_20 = graph.edges.iter().any(|e| e.source == 2 && e.target == 0);
let has_21 = graph.edges.iter().any(|e| e.source == 2 && e.target == 1);
assert!(!has_10, "should not have edge 1->0");
assert!(!has_20, "should not have edge 2->0");
assert!(!has_21, "should not have edge 2->1");
}
#[test]
fn test_granger_too_few_snapshots() {
let mut extractor = GrangerCausalityExtractor::new(0.1, 5);
let snapshot = AttentionSnapshot {
weights: vec![vec![1.0]],
timestamp: 0.0,
};
let result = extractor.extract(&[snapshot]);
assert!(result.is_err());
}
#[test]
fn test_temporal_store_retrieve() {
let mut store = TemporalEmbeddingStore::new(4, 3, 10.0, 100.0);
store.store(0, 0.0, &[1.0, 0.0, 0.0, 0.0]);
store.store(0, 1.0, &[1.0, 0.1, 0.0, 0.0]); store.store(0, 2.0, &[1.0, 0.1, 0.2, 0.0]); store.store(0, 3.0, &[0.0, 0.0, 0.0, 1.0]);
assert_eq!(store.chain_length(0), 4);
let emb0 = store.retrieve(0, 0.0).expect("should find t=0");
assert!((emb0[0] - 1.0).abs() < 1e-6);
assert!((emb0[1] - 0.0).abs() < 1e-6);
let emb1 = store.retrieve(0, 1.0).expect("should find t=1");
assert!((emb1[0] - 1.0).abs() < 1e-6);
assert!((emb1[1] - 0.1).abs() < 1e-6);
let emb2 = store.retrieve(0, 2.0).expect("should find t=2");
assert!((emb2[2] - 0.2).abs() < 1e-6);
let emb3 = store.retrieve(0, 3.0).expect("should find t=3");
assert!((emb3[3] - 1.0).abs() < 1e-6);
assert!((emb3[0] - 0.0).abs() < 1e-6);
let emb_half = store.retrieve(0, 0.5).expect("should find entry <= 0.5");
assert!((emb_half[0] - 1.0).abs() < 1e-6);
assert!(store.retrieve(0, -1.0).is_none());
assert!(store.retrieve(99, 0.0).is_none());
}
#[test]
fn test_temporal_store_compact() {
let mut store = TemporalEmbeddingStore::new(2, 1, 5.0, 20.0);
store.store(0, 0.0, &[1.0, 0.0]);
store.store(0, 10.0, &[0.0, 1.0]);
store.store(0, 25.0, &[0.5, 0.5]);
store.compact(30.0);
let emb = store.retrieve(0, 25.0).expect("should still retrieve after compaction");
assert!((emb[0] - 0.5).abs() < 1e-6);
}
#[test]
fn test_temporal_edge_event() {
let event = TemporalEdgeEvent {
source: 0,
target: 1,
timestamp: 42.0,
event_type: EdgeEventType::Add,
};
assert_eq!(event.source, 0);
assert_eq!(event.target, 1);
assert!((event.timestamp - 42.0).abs() < 1e-10);
assert_eq!(event.event_type, EdgeEventType::Add);
let update = TemporalEdgeEvent {
source: 2,
target: 3,
timestamp: 99.0,
event_type: EdgeEventType::UpdateWeight(0.75),
};
assert_eq!(update.event_type, EdgeEventType::UpdateWeight(0.75));
let remove = TemporalEdgeEvent {
source: 0,
target: 1,
timestamp: 100.0,
event_type: EdgeEventType::Remove,
};
assert_eq!(remove.event_type, EdgeEventType::Remove);
}
}