use candle_core::{Device, Tensor};
use candle_nn::optim::{AdamW, Optimizer, ParamsAdamW};
use noos::errors::NousResult;
use noos::inference::cognitive_gate::CognitiveGateConfig;
use noos::inference::cognitive_model::CognitiveModel;
use noos::inference::mamba::{CognitiveMambaWithGate, HfTokenizer, MambaConfig};
use noos::inference::model::LocalModel;
use noos::inference::tokenizer::NousTokenizer;
use noos::types::intervention::DeltaModulation;
const SAMPLE_TEXT: &str = "\
The brain processes information through billions of interconnected neurons. \
Each neuron communicates via electrochemical signals, transmitting data across \
synapses at speeds of up to 120 meters per second. The prefrontal cortex handles \
executive functions like planning and decision-making. How do you feel about that? \
I think it's fascinating how the brain can reorganize itself through neuroplasticity. \
The hippocampus plays a crucial role in forming new memories. Scientists have \
discovered that sleep is essential for memory consolidation. During deep sleep, \
the brain replays experiences and strengthens important connections. What patterns \
do you notice in your own thinking? The amygdala processes emotions rapidly, \
sometimes before conscious awareness. This fast pathway evolved as a survival \
mechanism. Fear responses can be triggered in milliseconds. Meanwhile, the \
default mode network activates during rest and mind-wandering. Creative insights \
often emerge from this wandering state. The locus coeruleus modulates attention \
through norepinephrine release. Phasic bursts sharpen focus on specific stimuli, \
while tonic activity enables broad environmental monitoring.";
fn main() -> NousResult<()> {
println!("=== Nous Tầng 3: CognitiveGate Training ===\n");
let model_id = "state-spaces/mamba-130m-hf";
let learning_rate = 1e-3;
let num_steps = 50;
let seq_len = 8;
let config = MambaConfig::mamba_130m();
let gate_config = CognitiveGateConfig::from_mamba_config(&config);
println!("Model: {model_id}");
println!("Gate position: layer {} / {}", gate_config.gate_position, config.n_layer);
println!("Gate cognitive dim: {}", gate_config.cognitive_dim);
println!("Learning rate: {learning_rate}");
println!("Steps: {num_steps}, seq_len: {seq_len}\n");
println!("Loading model (this may download ~500MB on first run)...");
let (mut model, gate_varmap) =
CognitiveMambaWithGate::from_pretrained_with_gate(model_id, config.clone(), gate_config)?;
let tokenizer = HfTokenizer::from_pretrained(model_id)?;
let gate_param_count: usize = gate_varmap
.data()
.lock()
.map_err(|e| noos::NousError::Internal(format!("Lock error: {e}")))?
.values()
.map(|v| v.elem_count())
.sum();
println!("Gate params: {gate_param_count} (trainable)");
println!("Base model: frozen\n");
let tokens = tokenizer.encode(SAMPLE_TEXT, false)?;
println!("Tokenized: {} tokens from sample text\n", tokens.len());
if tokens.len() < seq_len + 1 {
return Err(noos::NousError::Internal(
"Sample text too short for training".into(),
));
}
let num_sequences = (tokens.len() - 1) / seq_len;
let gate_vars = gate_varmap.all_vars();
let params = ParamsAdamW {
lr: learning_rate,
..ParamsAdamW::default()
};
let mut optimizer = AdamW::new(gate_vars, params)
.map_err(|e| noos::NousError::Internal(format!("Optimizer init error: {e}")))?;
println!("── Training ──\n");
println!("{:>5} {:>10} {:>10} {:>12}", "Step", "Loss", "Alpha", "Delta Gain");
println!("{}", "-".repeat(45));
for step in 0..num_steps {
let seq_idx = step % num_sequences;
let start = seq_idx * seq_len;
let input_tokens = &tokens[start..start + seq_len];
let target_tokens: Vec<u32> = tokens[start + 1..start + seq_len + 1].to_vec();
model.reset_cache();
let logits = model
.forward_train(input_tokens)
.map_err(|e| noos::NousError::Internal(format!("Forward error: {e}")))?;
let targets = Tensor::new(target_tokens.as_slice(), &Device::Cpu)
.map_err(|e| noos::NousError::Internal(format!("Target tensor error: {e}")))?;
let loss = candle_nn::loss::cross_entropy(&logits, &targets)
.map_err(|e| noos::NousError::Internal(format!("Loss error: {e}")))?;
let loss_val = loss
.to_scalar::<f32>()
.map_err(|e| noos::NousError::Internal(format!("Loss scalar error: {e}")))?;
optimizer
.backward_step(&loss)
.map_err(|e| noos::NousError::Internal(format!("Backward step error: {e}")))?;
model.reset_cache();
let dm = DeltaModulation::default();
let inspect = model.forward_cognitive(&[input_tokens[0]], 0, &dm)?;
let gate_alpha = inspect.gate_alpha.unwrap_or(0.0);
let delta_gain = inspect.gate_delta_gain.unwrap_or(1.0);
if step % 5 == 0 || step == num_steps - 1 {
println!(
"{:>5} {:>10.4} {:>10.4} {:>12.4}",
step, loss_val, gate_alpha, delta_gain
);
}
}
println!("\n── Training complete ──\n");
model.reset_cache();
let dm = DeltaModulation::default();
let final_result = model.forward_cognitive(&[tokens[0]], 0, &dm)?;
let final_alpha = final_result.gate_alpha.unwrap_or(0.0);
let final_delta_gain = final_result.gate_delta_gain.unwrap_or(1.0);
println!("Final gate_alpha: {final_alpha:.4} (init was ~0.05)");
println!("Final delta_gain: {final_delta_gain:.4} (range [0.5, 2.0])");
if (final_alpha - 0.047_f64).abs() > 0.01 {
println!("\nGate alpha has moved from initialization → gate is LEARNING.");
} else {
println!("\nGate alpha near initial value — may need more steps or different data.");
}
let gate_changed = gate_varmap
.data()
.lock()
.map_err(|e| noos::NousError::Internal(format!("Lock error: {e}")))?
.iter()
.filter(|(name, _)| name.contains("w_gate"))
.any(|(_, var)| {
var.flatten_all()
.and_then(|t| t.to_vec1::<f32>())
.map(|bias| bias.iter().any(|v| (v - (-3.0f32)).abs() > 0.01))
.unwrap_or(false)
});
if gate_changed {
println!("W_gate bias moved from -3.0 → gate learning confirmed.\n");
}
println!("=== Phase 3.4: Principle Verification ===\n");
let test_prompts = vec![
("Factual", "The hippocampus is a brain region involved in memory"),
("Emotional", "I feel terrified and my heart is racing with fear"),
("Question", "How do neurons communicate across synapses"),
("Creative", "Imagine a world where thoughts flow like rivers of light"),
];
println!("{:<12} {:>10} {:>12} Principle check", "Type", "Alpha", "Delta Gain");
println!("{}", "-".repeat(55));
let mut alphas = Vec::new();
let mut deltas = Vec::new();
for (label, prompt) in &test_prompts {
model.reset_cache();
let prompt_tokens = tokenizer.encode(prompt, false)?;
let dm = DeltaModulation::default();
let n = prompt_tokens.len().min(8);
let result = model.forward_cognitive(&prompt_tokens[..n], 0, &dm)?;
let alpha = result.gate_alpha.unwrap_or(0.0);
let delta = result.gate_delta_gain.unwrap_or(1.0);
alphas.push(alpha);
deltas.push(delta);
println!("{:<12} {:>10.4} {:>12.4}", label, alpha, delta);
}
let alpha_range = alphas.iter().cloned().fold(f64::INFINITY, f64::min);
let alpha_max = alphas.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let alpha_spread = alpha_max - alpha_range;
let delta_min = deltas.iter().cloned().fold(f64::INFINITY, f64::min);
let delta_max = deltas.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let delta_spread = delta_max - delta_min;
println!("\n── Principle Verification ──\n");
if alpha_spread > 0.001 {
println!("P4 (Affect): PASS — alpha varies across contexts (spread: {alpha_spread:.4})");
} else {
println!("P4 (Affect): WEAK — alpha barely varies (spread: {alpha_spread:.6}). More training needed.");
}
if delta_spread > 0.001 {
println!("P5 (Classification): PASS — delta_gain differs by text type (spread: {delta_spread:.4})");
} else {
println!("P5 (Classification): WEAK — delta_gain uniform (spread: {delta_spread:.6}). More training needed.");
}
let non_unity = deltas.iter().any(|d| (d - 1.0).abs() > 0.05);
if non_unity {
println!("P7 (Multi-Timescale): PASS — gate produces non-unity delta_gain (modulating SSM timescale)");
} else {
println!("P7 (Multi-Timescale): WEAK — delta_gain near 1.0. Gate not yet modulating timescale.");
}
println!("\n=== Tầng 3 Complete: Phase 3.1-3.4 ===");
Ok(())
}