use super::{SparseBitVector, SparseProjection};
use crate::{NervousSystemError, Result};
#[derive(Debug, Clone)]
pub struct DentateGyrus {
projection: SparseProjection,
k: usize,
output_dim: usize,
}
impl DentateGyrus {
pub fn new(input_dim: usize, output_dim: usize, k: usize, seed: u64) -> Self {
if k == 0 {
panic!("k must be > 0");
}
if k > output_dim {
panic!("k cannot exceed output_dim");
}
let projection = SparseProjection::new(input_dim, output_dim, 0.15, seed)
.expect("Failed to create sparse projection");
Self {
projection,
k,
output_dim,
}
}
pub fn encode(&self, input: &[f32]) -> SparseBitVector {
let projected = self.projection.project(input).expect("Projection failed");
self.k_winners_take_all(&projected)
}
pub fn encode_dense(&self, input: &[f32]) -> Vec<f32> {
let projected = self.projection.project(input).expect("Projection failed");
let sparse = self.k_winners_take_all(&projected);
let mut dense = vec![0.0; self.output_dim];
for &idx in &sparse.indices {
dense[idx as usize] = projected[idx as usize];
}
dense
}
fn k_winners_take_all(&self, activations: &[f32]) -> SparseBitVector {
let mut indexed: Vec<(usize, f32)> = activations
.iter()
.enumerate()
.map(|(i, &v)| (i, v))
.collect();
indexed.select_nth_unstable_by(self.k, |a, b| {
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
});
let mut top_k_indices: Vec<u16> =
indexed[..self.k].iter().map(|(i, _)| *i as u16).collect();
top_k_indices.sort_unstable();
SparseBitVector::from_indices(top_k_indices, self.output_dim as u16)
}
pub fn input_dim(&self) -> usize {
self.projection.input_dim()
}
pub fn output_dim(&self) -> usize {
self.output_dim
}
pub fn k(&self) -> usize {
self.k
}
pub fn sparsity(&self) -> f32 {
self.k as f32 / self.output_dim as f32
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dentate_gyrus_creation() {
let dg = DentateGyrus::new(128, 10000, 200, 42);
assert_eq!(dg.input_dim(), 128);
assert_eq!(dg.output_dim(), 10000);
assert_eq!(dg.k(), 200);
assert_eq!(dg.sparsity(), 0.02); }
#[test]
#[should_panic(expected = "k must be > 0")]
fn test_invalid_k_zero() {
DentateGyrus::new(128, 10000, 0, 42);
}
#[test]
#[should_panic(expected = "k cannot exceed output_dim")]
fn test_invalid_k_too_large() {
DentateGyrus::new(128, 100, 200, 42);
}
#[test]
fn test_encode_produces_sparse_output() {
let dg = DentateGyrus::new(128, 10000, 200, 42);
let input: Vec<f32> = (0..128).map(|i| (i as f32).sin()).collect();
let sparse = dg.encode(&input);
assert_eq!(sparse.count(), 200, "Should have exactly k active neurons");
assert_eq!(sparse.capacity(), 10000);
}
#[test]
fn test_encode_deterministic() {
let dg = DentateGyrus::new(128, 10000, 200, 42);
let input: Vec<f32> = (0..128).map(|i| (i as f32).sin()).collect();
let sparse1 = dg.encode(&input);
let sparse2 = dg.encode(&input);
assert_eq!(sparse1, sparse2, "Same input should produce same encoding");
}
#[test]
fn test_encode_dense_has_k_nonzeros() {
let dg = DentateGyrus::new(128, 10000, 200, 42);
let input: Vec<f32> = (0..128).map(|i| (i as f32).sin()).collect();
let dense = dg.encode_dense(&input);
let nonzero_count = dense.iter().filter(|&&x| x != 0.0).count();
assert_eq!(
nonzero_count, 200,
"Should have exactly k non-zero elements"
);
}
#[test]
fn test_different_inputs_produce_different_outputs() {
let dg = DentateGyrus::new(128, 10000, 200, 42);
let input1: Vec<f32> = (0..128).map(|i| (i as f32).sin()).collect();
let input2: Vec<f32> = (0..128).map(|i| (i as f32).cos()).collect();
let sparse1 = dg.encode(&input1);
let sparse2 = dg.encode(&input2);
assert_ne!(
sparse1, sparse2,
"Different inputs should produce different encodings"
);
}
#[test]
fn test_pattern_separation_property() {
let dg = DentateGyrus::new(128, 10000, 200, 42);
let mut input1 = vec![0.0; 128];
let mut input2 = vec![0.0; 128];
for i in 0..120 {
input1[i] = 1.0;
input2[i] = 1.0;
}
input1[125] = 1.0;
input2[126] = 1.0;
let sparse1 = dg.encode(&input1);
let sparse2 = dg.encode(&input2);
let input_overlap = 120.0 / 128.0; let output_similarity = sparse1.jaccard_similarity(&sparse2);
assert!(
output_similarity < input_overlap,
"Output similarity ({}) should be less than input overlap ({})",
output_similarity,
input_overlap
);
}
#[test]
fn test_sparsity_levels() {
let cases = vec![
(10000, 200, 0.02), (10000, 300, 0.03), (10000, 500, 0.05), ];
for (output_dim, k, expected_sparsity) in cases {
let dg = DentateGyrus::new(128, output_dim, k, 42);
assert_eq!(dg.sparsity(), expected_sparsity);
let input: Vec<f32> = (0..128).map(|i| (i as f32).sin()).collect();
let sparse = dg.encode(&input);
assert_eq!(sparse.count(), k);
}
}
#[test]
fn test_zero_input() {
let dg = DentateGyrus::new(128, 10000, 200, 42);
let input = vec![0.0; 128];
let sparse = dg.encode(&input);
assert_eq!(sparse.count(), 200);
}
#[test]
fn test_encode_performance_target() {
let dg = DentateGyrus::new(512, 10000, 200, 42);
let input: Vec<f32> = (0..512).map(|i| (i as f32).sin()).collect();
let start = std::time::Instant::now();
let iterations = 100;
for _ in 0..iterations {
let _ = dg.encode(&input);
}
let elapsed = start.elapsed();
let avg_time = elapsed / iterations;
println!("Average encoding time: {:?}", avg_time);
assert!(
avg_time.as_secs() < 2,
"Average encoding time ({:?}) exceeds 2s target",
avg_time
);
}
}