use candle_core::{Result as CandleResult, Tensor, D};
use candle_nn::{linear_no_bias, Linear, Module, VarBuilder};
use crate::inference::mamba::{MambaConfig, RmsNorm};
fn sigmoid(xs: &Tensor) -> CandleResult<Tensor> {
(xs.neg()?.exp()? + 1.0)?.recip()
}
const DEFAULT_COGNITIVE_DIM: usize = 64;
const GATE_INIT_BIAS: f64 = -3.0;
const DEFAULT_GATE_DEPTH_FRACTION: f64 = 0.50;
pub(crate) const GATE_GAIN_MIN: f64 = 0.5;
pub(crate) const GATE_GAIN_MAX: f64 = 2.0;
#[derive(Debug, Clone)]
pub struct CognitiveGateConfig {
pub cognitive_dim: usize,
pub gate_position: usize,
pub d_model: usize,
pub d_inner: usize,
}
impl CognitiveGateConfig {
pub fn from_mamba_config(config: &MambaConfig) -> Self {
let gate_position =
(config.n_layer as f64 * DEFAULT_GATE_DEPTH_FRACTION).round() as usize;
Self {
cognitive_dim: DEFAULT_COGNITIVE_DIM,
gate_position,
d_model: config.d_model,
d_inner: config.d_inner(),
}
}
pub fn with_position(mut self, position: usize) -> Self {
self.gate_position = position;
self
}
pub fn with_cognitive_dim(mut self, dim: usize) -> Self {
self.cognitive_dim = dim;
self
}
}
pub struct CognitiveGateOutput {
pub modulated: Tensor,
pub delta_gain: f64,
pub gate_alpha: f64,
pub cog_signal: Tensor,
}
pub struct CognitiveGate {
w_read: Linear,
w_delta: Linear,
w_gate: Linear,
w_write: Linear,
norm: RmsNorm,
config: CognitiveGateConfig,
}
impl CognitiveGate {
pub fn new(config: CognitiveGateConfig, vb: VarBuilder) -> CandleResult<Self> {
let d_model = config.d_model;
let d_inner = config.d_inner;
let cog_dim = config.cognitive_dim;
let w_read = linear_no_bias(d_inner, cog_dim, vb.pp("w_read"))?;
let w_delta_weight = vb.pp("w_delta").get_with_hints(
(1, cog_dim),
"weight",
candle_nn::init::DEFAULT_KAIMING_UNIFORM,
)?;
let w_delta_bias = vb.pp("w_delta").get_with_hints(
1,
"bias",
candle_nn::init::Init::Const(0.0),
)?;
let w_delta = Linear::new(w_delta_weight, Some(w_delta_bias));
let w_gate_weight = vb.pp("w_gate").get_with_hints(
(1, cog_dim),
"weight",
candle_nn::init::DEFAULT_KAIMING_UNIFORM,
)?;
let w_gate_bias = vb.pp("w_gate").get_with_hints(
1,
"bias",
candle_nn::init::Init::Const(GATE_INIT_BIAS),
)?;
let w_gate = Linear::new(w_gate_weight, Some(w_gate_bias));
let w_write = linear_no_bias(cog_dim, d_model, vb.pp("w_write"))?;
let norm = RmsNorm::new(d_inner, vb.pp("norm"))?;
Ok(Self {
w_read,
w_delta,
w_gate,
w_write,
norm,
config,
})
}
pub fn forward(
&self,
ssm_state: &Tensor,
hidden_state: &Tensor,
) -> CandleResult<CognitiveGateOutput> {
let normed = self.norm.forward(ssm_state)?;
let cog_signal = self.w_read.forward(&normed)?;
let delta_raw = sigmoid(&self.w_delta.forward(&cog_signal)?)?;
let alpha_raw = sigmoid(&self.w_gate.forward(&cog_signal)?)?;
let cog_contribution = self.w_write.forward(&cog_signal)?;
let one_minus_alpha = (Tensor::ones_like(&alpha_raw)? - &alpha_raw)?;
let modulated = (hidden_state.broadcast_mul(&one_minus_alpha)?
+ cog_contribution.broadcast_mul(&alpha_raw)?)?;
let delta_scalar = delta_raw
.flatten_all()?
.mean(D::Minus1)?
.to_scalar::<f32>()? as f64;
let alpha_scalar = alpha_raw
.flatten_all()?
.mean(D::Minus1)?
.to_scalar::<f32>()? as f64;
let delta_gain = GATE_GAIN_MIN + delta_scalar * (GATE_GAIN_MAX - GATE_GAIN_MIN);
Ok(CognitiveGateOutput {
modulated,
delta_gain,
gate_alpha: alpha_scalar,
cog_signal,
})
}
pub fn position(&self) -> usize {
self.config.gate_position
}
pub fn cognitive_dim(&self) -> usize {
self.config.cognitive_dim
}
pub fn w_read_weights(&self) -> &Linear {
&self.w_read
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use candle_core::{DType, Device};
fn make_test_gate(d_model: usize, d_inner: usize, cognitive_dim: usize) -> CognitiveGate {
let config = CognitiveGateConfig {
cognitive_dim,
gate_position: 2,
d_model,
d_inner,
};
let device = Device::Cpu;
let varmap = candle_nn::VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &device);
CognitiveGate::new(config, vb.pp("gate")).expect("Gate construction should succeed")
}
#[test]
fn gate_passthrough_with_default_init() {
let gate = make_test_gate(32, 64, 8);
let ssm_state = Tensor::ones((1, 64), DType::F32, &Device::Cpu).unwrap();
let hidden = Tensor::ones((1, 32), DType::F32, &Device::Cpu).unwrap();
let output = gate.forward(&ssm_state, &hidden).unwrap();
assert!(
output.gate_alpha < 0.15,
"Gate alpha should be near 0 at init, got {}",
output.gate_alpha
);
let hidden_vec: Vec<f32> = hidden.flatten_all().unwrap().to_vec1().unwrap();
let output_vec: Vec<f32> = output.modulated.flatten_all().unwrap().to_vec1().unwrap();
assert_eq!(hidden_vec.len(), output_vec.len());
let max_diff: f32 = hidden_vec
.iter()
.zip(output_vec.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
max_diff < 2.0,
"Near-passthrough gate should not dramatically change residual, max_diff={}",
max_diff
);
}
#[test]
fn gate_output_shapes_correct() {
let gate = make_test_gate(32, 64, 8);
let ssm_state = Tensor::randn(0f32, 1f32, (1, 64), &Device::Cpu).unwrap();
let hidden = Tensor::randn(0f32, 1f32, (1, 32), &Device::Cpu).unwrap();
let output = gate.forward(&ssm_state, &hidden).unwrap();
assert_eq!(output.modulated.dims(), hidden.dims());
assert_eq!(output.cog_signal.dims(), &[1, 8]);
}
#[test]
fn gate_config_from_mamba_config() {
let mamba_cfg = MambaConfig::mamba_130m();
let gate_cfg = CognitiveGateConfig::from_mamba_config(&mamba_cfg);
assert_eq!(gate_cfg.d_model, 768);
assert_eq!(gate_cfg.d_inner, 1536);
assert_eq!(gate_cfg.cognitive_dim, 64);
assert_eq!(gate_cfg.gate_position, 12);
}
#[test]
fn gate_config_custom_position() {
let mamba_cfg = MambaConfig::mamba_130m();
let gate_cfg = CognitiveGateConfig::from_mamba_config(&mamba_cfg)
.with_position(8)
.with_cognitive_dim(32);
assert_eq!(gate_cfg.gate_position, 8);
assert_eq!(gate_cfg.cognitive_dim, 32);
assert_eq!(gate_cfg.d_model, 768);
assert_eq!(gate_cfg.d_inner, 1536); }
#[test]
fn different_ssm_states_produce_different_readings() {
let gate = make_test_gate(32, 64, 8);
let hidden = Tensor::ones((1, 32), DType::F32, &Device::Cpu).unwrap();
let vals_a: Vec<f32> = (0..64).map(|i| i as f32 / 64.0).collect();
let vals_b: Vec<f32> = (0..64).map(|i| 1.0 - i as f32 / 64.0).collect();
let ssm_a = Tensor::new(&vals_a[..], &Device::Cpu).unwrap().unsqueeze(0).unwrap();
let ssm_b = Tensor::new(&vals_b[..], &Device::Cpu).unwrap().unsqueeze(0).unwrap();
let out_a = gate.forward(&ssm_a, &hidden).unwrap();
let out_b = gate.forward(&ssm_b, &hidden).unwrap();
let sig_a: Vec<f32> = out_a.cog_signal.flatten_all().unwrap().to_vec1().unwrap();
let sig_b: Vec<f32> = out_b.cog_signal.flatten_all().unwrap().to_vec1().unwrap();
let a_nonzero = sig_a.iter().any(|v| v.abs() > 1e-10);
let b_nonzero = sig_b.iter().any(|v| v.abs() > 1e-10);
assert!(a_nonzero, "cog_signal A should be non-zero, got: {:?}", &sig_a[..4.min(sig_a.len())]);
assert!(b_nonzero, "cog_signal B should be non-zero, got: {:?}", &sig_b[..4.min(sig_b.len())]);
let differs = sig_a.iter().zip(sig_b.iter()).any(|(a, b)| (a - b).abs() > 1e-6);
assert!(
differs,
"Different SSM states should produce different cognitive signals. A: {:?}, B: {:?}",
&sig_a[..4.min(sig_a.len())], &sig_b[..4.min(sig_b.len())]
);
}
#[test]
fn gate_delta_gain_in_safe_range() {
let gate = make_test_gate(32, 64, 8);
let hidden = Tensor::ones((1, 32), DType::F32, &Device::Cpu).unwrap();
for scale in [0.01, 0.1, 1.0, 5.0, 10.0] {
let ssm = (Tensor::ones((1, 64), DType::F32, &Device::Cpu).unwrap() * scale).unwrap();
let output = gate.forward(&ssm, &hidden).unwrap();
assert!(
output.delta_gain >= GATE_GAIN_MIN && output.delta_gain <= GATE_GAIN_MAX,
"delta_gain {} outside [{}, {}] for scale {}",
output.delta_gain, GATE_GAIN_MIN, GATE_GAIN_MAX, scale
);
}
}
#[test]
fn gate_alpha_in_unit_range() {
let gate = make_test_gate(32, 64, 8);
let hidden = Tensor::ones((1, 32), DType::F32, &Device::Cpu).unwrap();
for scale in [0.01, 1.0, 10.0] {
let ssm = (Tensor::ones((1, 64), DType::F32, &Device::Cpu).unwrap() * scale).unwrap();
let output = gate.forward(&ssm, &hidden).unwrap();
assert!(
output.gate_alpha >= 0.0 && output.gate_alpha <= 1.0,
"gate_alpha {} not in [0, 1] for scale {}",
output.gate_alpha, scale
);
}
}
#[test]
fn gate_delta_gain_range_constants() {
assert_relative_eq!(GATE_GAIN_MIN + 0.0 * (GATE_GAIN_MAX - GATE_GAIN_MIN), 0.5);
assert_relative_eq!(GATE_GAIN_MIN + 1.0 * (GATE_GAIN_MAX - GATE_GAIN_MIN), 2.0);
assert_relative_eq!(GATE_GAIN_MIN + 0.5 * (GATE_GAIN_MAX - GATE_GAIN_MIN), 1.25);
}
}