#[cfg(feature = "alloc")]
use crate::math;
#[cfg(feature = "alloc")]
use alloc::vec;
#[cfg(feature = "alloc")]
use alloc::vec::Vec;
#[cfg(feature = "alloc")]
const GATE_EPS: f64 = 1e-15;
#[cfg(feature = "alloc")]
pub struct GateHead {
weights: Vec<f64>,
bias: f64,
}
#[cfg(feature = "alloc")]
impl GateHead {
pub fn new(d_in: usize) -> Self {
Self {
weights: vec![0.0; d_in],
bias: 0.0,
}
}
#[inline]
pub fn forward(&self, x: &[f64]) -> f64 {
debug_assert_eq!(
x.len(),
self.weights.len(),
"GateHead: input len {} != weight len {}",
x.len(),
self.weights.len()
);
let mut logit = self.bias;
for (&w, &xi) in self.weights.iter().zip(x.iter()) {
logit += w * xi;
}
math::sigmoid(logit).clamp(GATE_EPS, 1.0 - GATE_EPS)
}
pub fn update(&mut self, x: &[f64], target_gate: f64, lr: f64) {
debug_assert_eq!(x.len(), self.weights.len());
let g = self.forward(x);
let err = g - target_gate;
let grad_scale = err * g * (1.0 - g);
let scaled_lr = lr * grad_scale;
for (w, &xi) in self.weights.iter_mut().zip(x.iter()) {
*w -= scaled_lr * xi;
}
self.bias -= scaled_lr;
}
}
#[cfg(all(test, feature = "alloc"))]
mod tests {
use super::*;
#[test]
fn forward_returns_value_in_unit_interval() {
let head = GateHead::new(4);
let g = head.forward(&[1.0, -2.0, 3.0, -4.0]);
assert!(g > 0.0 && g < 1.0, "gate output {g} must be in (0, 1)");
assert!((g - 0.5).abs() < 1e-12, "zero-init gate should equal 0.5");
}
#[test]
fn update_reduces_gate_target_distance() {
let x = [1.0, 0.5, -0.5, 0.0];
let target = 0.9;
let lr = 0.5;
let mut head = GateHead::new(4);
let initial_dist = (head.forward(&x) - target).abs();
for _ in 0..200 {
head.update(&x, target, lr);
}
let final_dist = (head.forward(&x) - target).abs();
assert!(
final_dist < initial_dist,
"gate should converge toward target: initial_dist={initial_dist}, final_dist={final_dist}"
);
}
#[test]
fn forward_bounded_on_extreme_inputs() {
let mut head = GateHead::new(3);
for _ in 0..1000 {
head.update(&[100.0, -100.0, 50.0], 1.0, 0.1);
}
let g = head.forward(&[1000.0, -1000.0, 500.0]);
assert!(g > 0.0 && g < 1.0, "gate {g} must stay in (0, 1)");
}
#[test]
fn update_to_current_gate_value_is_zero_gradient() {
let mut head = GateHead::new(2);
head.update(&[1.0, 2.0], 0.5, 1.0);
for &w in &head.weights {
assert!(
w.abs() < 1e-15,
"weight should not change when gate == target"
);
}
}
}