use ruvector_dag::dag::{OperatorNode, OperatorType, QueryDag};
use ruvector_dag::sona::*;
#[test]
fn test_micro_lora_adaptation() {
let mut lora = MicroLoRA::new(MicroLoRAConfig::default(), 256);
let input = ndarray::Array1::from_vec(vec![0.1; 256]);
let output1 = lora.forward(&input);
let gradient = ndarray::Array1::from_vec(vec![0.01; 256]);
lora.adapt(&gradient, 0.1);
let output2 = lora.forward(&input);
let diff: f32 = output1
.iter()
.zip(output2.iter())
.map(|(a, b)| (a - b).abs())
.sum();
assert!(diff > 0.0, "Output should change after adaptation");
}
#[test]
fn test_trajectory_buffer() {
let buffer = DagTrajectoryBuffer::new(10);
for i in 0..15 {
buffer.push(DagTrajectory::new(
i as u64,
vec![0.1; 256],
"topological".to_string(),
100.0,
150.0,
));
}
assert!(buffer.len() <= 10);
let drained = buffer.drain();
assert!(!drained.is_empty());
assert!(buffer.is_empty());
}
#[test]
fn test_reasoning_bank_clustering() {
let mut bank = DagReasoningBank::new(ReasoningBankConfig {
num_clusters: 5,
pattern_dim: 256,
max_patterns: 100,
similarity_threshold: 0.5,
});
for i in 0..50 {
let pattern: Vec<f32> = (0..256)
.map(|j| ((i * 256 + j) as f32 / 1000.0).sin())
.collect();
bank.store_pattern(pattern, 0.8);
}
assert_eq!(bank.pattern_count(), 50);
bank.recompute_clusters();
let query: Vec<f32> = (0..256).map(|j| (j as f32 / 1000.0).sin()).collect();
let results = bank.query_similar(&query, 5);
assert!(results.len() <= 5);
}
#[test]
fn test_ewc_prevents_forgetting() {
let mut ewc = EwcPlusPlus::new(EwcConfig::default());
let params1 = ndarray::Array1::from_vec(vec![1.0; 256]);
let fisher1 = ndarray::Array1::from_vec(vec![0.1; 256]);
ewc.consolidate(¶ms1, &fisher1);
let penalty0 = ewc.penalty(¶ms1);
assert!(penalty0 < 0.001);
let params2 = ndarray::Array1::from_vec(vec![2.0; 256]);
let penalty1 = ewc.penalty(¶ms2);
assert!(penalty1 > penalty0);
}
#[test]
fn test_trajectory_buffer_ordering() {
let buffer = DagTrajectoryBuffer::new(100);
for i in 0..10 {
buffer.push(DagTrajectory::new(
i as u64,
vec![0.1; 256],
"test".to_string(),
100.0,
150.0,
));
}
let trajectories = buffer.drain();
for (idx, traj) in trajectories.iter().enumerate() {
assert_eq!(traj.query_hash, idx as u64);
}
}
#[test]
fn test_lora_rank_adaptation() {
let config = MicroLoRAConfig {
rank: 8,
alpha: 16.0,
dropout: 0.1,
};
let lora = MicroLoRA::new(config, 256);
let input = ndarray::Array1::from_vec(vec![0.5; 256]);
let output = lora.forward(&input);
assert_eq!(output.len(), 256);
}
#[test]
fn test_reasoning_bank_similarity_threshold() {
let config = ReasoningBankConfig {
num_clusters: 3,
pattern_dim: 64,
max_patterns: 50,
similarity_threshold: 0.9, };
let mut bank = DagReasoningBank::new(config);
let pattern = vec![1.0; 64];
for _ in 0..10 {
bank.store_pattern(pattern.clone(), 0.8);
}
let results = bank.query_similar(&pattern, 5);
assert!(!results.is_empty());
}
#[test]
fn test_ewc_consolidation_updates() {
let mut ewc = EwcPlusPlus::new(EwcConfig {
lambda: 1000.0,
decay: 0.9,
online: true,
});
let params1 = ndarray::Array1::from_vec(vec![1.0; 256]);
let fisher1 = ndarray::Array1::from_vec(vec![0.5; 256]);
ewc.consolidate(¶ms1, &fisher1);
let params2 = ndarray::Array1::from_vec(vec![1.5; 256]);
let fisher2 = ndarray::Array1::from_vec(vec![0.3; 256]);
ewc.consolidate(¶ms2, &fisher2);
let params3 = ndarray::Array1::from_vec(vec![2.0; 256]);
let penalty = ewc.penalty(¶ms3);
assert!(penalty > 0.0);
}
#[test]
fn test_trajectory_buffer_capacity() {
let buffer = DagTrajectoryBuffer::new(5);
for i in 0..10 {
buffer.push(DagTrajectory::new(
i as u64,
vec![0.1; 256],
"test".to_string(),
100.0,
150.0,
));
}
assert_eq!(buffer.len(), 5);
let trajectories = buffer.drain();
assert_eq!(trajectories.len(), 5);
let ids: Vec<u64> = trajectories.iter().map(|t| t.query_hash).collect();
assert!(ids.contains(&5));
assert!(ids.contains(&9));
}
#[test]
fn test_reasoning_bank_cluster_count() {
let config = ReasoningBankConfig {
num_clusters: 4,
pattern_dim: 128,
max_patterns: 100,
similarity_threshold: 0.5,
};
let mut bank = DagReasoningBank::new(config);
for i in 0..20 {
let pattern: Vec<f32> = (0..128).map(|j| ((i + j) as f32 / 10.0).sin()).collect();
bank.store_pattern(pattern, 0.7);
}
bank.recompute_clusters();
assert!(bank.cluster_count() <= 4);
}