omega_snn/
population.rs

1//! Neural Population Coding
2//!
3//! Implements population-level representations and sparse coding.
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8use crate::neuron::NeuronId;
9use crate::spike_train::SpikeTrain;
10
11/// Activity state of a neural population
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct PopulationActivity {
14    /// Population identifier
15    pub population_id: String,
16    /// Activity level per neuron (0.0 to 1.0)
17    pub activities: HashMap<NeuronId, f64>,
18    /// Population-level metrics
19    pub mean_activity: f64,
20    /// Sparsity (fraction of neurons with low activity)
21    pub sparsity: f64,
22    /// Population vector (centroid of activity)
23    pub population_vector: Vec<f64>,
24}
25
26impl PopulationActivity {
27    /// Create from spike trains over a time window
28    pub fn from_spike_trains(
29        population_id: String,
30        trains: &[SpikeTrain],
31        window: std::time::Duration,
32        max_rate: f64,
33    ) -> Self {
34        let mut activities = HashMap::new();
35        let mut total_activity = 0.0;
36        let mut active_count = 0;
37
38        for train in trains {
39            let rate = train.firing_rate(window);
40            let activity = (rate / max_rate).min(1.0);
41            activities.insert(train.neuron_id.clone(), activity);
42            total_activity += activity;
43
44            if activity > 0.1 {
45                active_count += 1;
46            }
47        }
48
49        let n = trains.len() as f64;
50        let mean_activity = if n > 0.0 { total_activity / n } else { 0.0 };
51        let sparsity = if n > 0.0 {
52            1.0 - (active_count as f64 / n)
53        } else {
54            1.0
55        };
56
57        Self {
58            population_id,
59            activities,
60            mean_activity,
61            sparsity,
62            population_vector: Vec::new(),
63        }
64    }
65
66    /// Get activity for specific neuron
67    pub fn get_activity(&self, neuron_id: &NeuronId) -> f64 {
68        *self.activities.get(neuron_id).unwrap_or(&0.0)
69    }
70
71    /// Get most active neurons
72    pub fn top_active(&self, k: usize) -> Vec<(NeuronId, f64)> {
73        let mut sorted: Vec<_> = self.activities.iter().collect();
74        sorted.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
75        sorted
76            .into_iter()
77            .take(k)
78            .map(|(id, act)| (id.clone(), *act))
79            .collect()
80    }
81}
82
83/// Sparse code representation
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct SparseCode {
86    /// Dimension of the code
87    pub dimension: usize,
88    /// Active indices and their values
89    pub active: HashMap<usize, f64>,
90    /// Target sparsity level
91    pub target_sparsity: f64,
92}
93
94impl SparseCode {
95    /// Create empty sparse code
96    pub fn new(dimension: usize, target_sparsity: f64) -> Self {
97        Self {
98            dimension,
99            active: HashMap::new(),
100            target_sparsity,
101        }
102    }
103
104    /// Create from dense vector
105    pub fn from_dense(values: &[f64], target_sparsity: f64) -> Self {
106        let dimension = values.len();
107        let k = ((1.0 - target_sparsity) * dimension as f64).ceil() as usize;
108
109        // Get top-k indices
110        let mut indexed: Vec<_> = values.iter().enumerate().collect();
111        indexed.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
112
113        let active: HashMap<usize, f64> = indexed
114            .into_iter()
115            .take(k)
116            .filter(|(_, &v)| v > 0.0)
117            .map(|(i, &v)| (i, v))
118            .collect();
119
120        Self {
121            dimension,
122            active,
123            target_sparsity,
124        }
125    }
126
127    /// Convert to dense vector
128    pub fn to_dense(&self) -> Vec<f64> {
129        let mut dense = vec![0.0; self.dimension];
130        for (&idx, &val) in &self.active {
131            if idx < self.dimension {
132                dense[idx] = val;
133            }
134        }
135        dense
136    }
137
138    /// Get actual sparsity
139    pub fn sparsity(&self) -> f64 {
140        1.0 - (self.active.len() as f64 / self.dimension as f64)
141    }
142
143    /// Get L1 norm
144    pub fn l1_norm(&self) -> f64 {
145        self.active.values().map(|v| v.abs()).sum()
146    }
147
148    /// Get L2 norm
149    pub fn l2_norm(&self) -> f64 {
150        self.active.values().map(|v| v * v).sum::<f64>().sqrt()
151    }
152
153    /// Dot product with another sparse code
154    pub fn dot(&self, other: &SparseCode) -> f64 {
155        let mut sum = 0.0;
156        for (&idx, &val) in &self.active {
157            if let Some(&other_val) = other.active.get(&idx) {
158                sum += val * other_val;
159            }
160        }
161        sum
162    }
163
164    /// Cosine similarity with another sparse code
165    pub fn cosine_similarity(&self, other: &SparseCode) -> f64 {
166        let dot = self.dot(other);
167        let norm_self = self.l2_norm();
168        let norm_other = other.l2_norm();
169
170        if norm_self == 0.0 || norm_other == 0.0 {
171            return 0.0;
172        }
173
174        dot / (norm_self * norm_other)
175    }
176
177    /// Add two sparse codes
178    pub fn add(&self, other: &SparseCode) -> SparseCode {
179        let mut result = self.clone();
180
181        for (&idx, &val) in &other.active {
182            *result.active.entry(idx).or_insert(0.0) += val;
183        }
184
185        result
186    }
187
188    /// Scale sparse code
189    pub fn scale(&self, factor: f64) -> SparseCode {
190        let mut result = self.clone();
191        for val in result.active.values_mut() {
192            *val *= factor;
193        }
194        result
195    }
196}
197
198/// Neural population with encoding/decoding capabilities
199#[derive(Debug, Clone)]
200pub struct NeuralPopulation {
201    /// Population ID
202    pub id: String,
203    /// Neuron IDs in this population
204    pub neuron_ids: Vec<NeuronId>,
205    /// Preferred stimuli for each neuron (tuning curves)
206    pub tuning_centers: HashMap<NeuronId, Vec<f64>>,
207    /// Tuning width (selectivity)
208    pub tuning_width: f64,
209}
210
211impl NeuralPopulation {
212    /// Create a new population
213    pub fn new(id: String, neuron_ids: Vec<NeuronId>) -> Self {
214        Self {
215            id,
216            neuron_ids,
217            tuning_centers: HashMap::new(),
218            tuning_width: 1.0,
219        }
220    }
221
222    /// Create population with uniformly spaced tuning curves
223    pub fn with_uniform_tuning(
224        id: String,
225        size: usize,
226        stimulus_dim: usize,
227        stimulus_range: (f64, f64),
228    ) -> Self {
229        let mut neuron_ids = Vec::new();
230        let mut tuning_centers = HashMap::new();
231
232        for i in 0..size {
233            let neuron_id = format!("{}_{}", id, i);
234
235            // Create tuning center
236            let t = i as f64 / (size - 1).max(1) as f64;
237            let center: Vec<f64> = (0..stimulus_dim)
238                .map(|_| stimulus_range.0 + t * (stimulus_range.1 - stimulus_range.0))
239                .collect();
240
241            tuning_centers.insert(neuron_id.clone(), center);
242            neuron_ids.push(neuron_id);
243        }
244
245        Self {
246            id,
247            neuron_ids,
248            tuning_centers,
249            tuning_width: (stimulus_range.1 - stimulus_range.0) / size as f64,
250        }
251    }
252
253    /// Encode stimulus into population activity
254    pub fn encode(&self, stimulus: &[f64]) -> HashMap<NeuronId, f64> {
255        let mut activities = HashMap::new();
256
257        for (neuron_id, center) in &self.tuning_centers {
258            // Compute distance from tuning center
259            let dist_sq: f64 = stimulus
260                .iter()
261                .zip(center.iter())
262                .map(|(s, c)| (s - c).powi(2))
263                .sum();
264
265            // Gaussian tuning curve
266            let activity = (-dist_sq / (2.0 * self.tuning_width * self.tuning_width)).exp();
267            activities.insert(neuron_id.clone(), activity);
268        }
269
270        activities
271    }
272
273    /// Decode stimulus from population activity (population vector)
274    pub fn decode(&self, activities: &HashMap<NeuronId, f64>) -> Vec<f64> {
275        if self.tuning_centers.is_empty() {
276            return Vec::new();
277        }
278
279        let dim = self.tuning_centers.values().next().map(|v| v.len()).unwrap_or(0);
280        let mut weighted_sum = vec![0.0; dim];
281        let mut total_weight = 0.0;
282
283        for (neuron_id, center) in &self.tuning_centers {
284            let activity = *activities.get(neuron_id).unwrap_or(&0.0);
285            total_weight += activity;
286
287            for (i, &c) in center.iter().enumerate() {
288                weighted_sum[i] += activity * c;
289            }
290        }
291
292        if total_weight > 0.0 {
293            weighted_sum.iter().map(|&s| s / total_weight).collect()
294        } else {
295            vec![0.0; dim]
296        }
297    }
298
299    /// Get population size
300    pub fn size(&self) -> usize {
301        self.neuron_ids.len()
302    }
303}
304
305/// Winner-take-all circuit for competitive inhibition
306#[derive(Debug, Clone)]
307pub struct WinnerTakeAll {
308    /// Number of winners to select
309    pub k: usize,
310    /// Inhibition strength
311    pub inhibition: f64,
312}
313
314impl WinnerTakeAll {
315    pub fn new(k: usize, inhibition: f64) -> Self {
316        Self { k, inhibition }
317    }
318
319    /// Apply WTA to activities
320    pub fn apply(&self, activities: &mut HashMap<NeuronId, f64>) {
321        // Find top-k
322        let mut sorted: Vec<_> = activities.iter().collect();
323        sorted.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
324
325        let winners: std::collections::HashSet<_> =
326            sorted.iter().take(self.k).map(|(id, _)| (*id).clone()).collect();
327
328        // Inhibit non-winners
329        for (id, activity) in activities.iter_mut() {
330            if !winners.contains(id) {
331                *activity *= 1.0 - self.inhibition;
332            }
333        }
334    }
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340
341    #[test]
342    fn test_sparse_code_creation() {
343        let dense = vec![0.1, 0.9, 0.2, 0.8, 0.3];
344        let sparse = SparseCode::from_dense(&dense, 0.6);
345
346        assert!(sparse.sparsity() >= 0.4); // At most 60% active
347        assert!(sparse.active.contains_key(&1)); // 0.9 should be active
348        assert!(sparse.active.contains_key(&3)); // 0.8 should be active
349    }
350
351    #[test]
352    fn test_sparse_code_operations() {
353        let a = SparseCode::from_dense(&[1.0, 0.0, 1.0, 0.0], 0.5);
354        let b = SparseCode::from_dense(&[1.0, 1.0, 0.0, 0.0], 0.5);
355
356        let dot = a.dot(&b);
357        assert!((dot - 1.0).abs() < 0.01); // Only overlap at index 0
358    }
359
360    #[test]
361    fn test_neural_population_encoding() {
362        let pop = NeuralPopulation::with_uniform_tuning(
363            "test".to_string(),
364            10,
365            1,
366            (0.0, 1.0),
367        );
368
369        let activities = pop.encode(&[0.5]);
370
371        // Middle neuron should be most active
372        let max_activity = activities.values().cloned().fold(0.0, f64::max);
373        assert!(max_activity > 0.0);
374    }
375
376    #[test]
377    fn test_neural_population_decoding() {
378        let pop = NeuralPopulation::with_uniform_tuning(
379            "test".to_string(),
380            10,
381            1,
382            (0.0, 1.0),
383        );
384
385        let stimulus = vec![0.5];
386        let activities = pop.encode(&stimulus);
387        let decoded = pop.decode(&activities);
388
389        assert!((decoded[0] - 0.5).abs() < 0.1);
390    }
391
392    #[test]
393    fn test_winner_take_all() {
394        let mut activities: HashMap<NeuronId, f64> = HashMap::new();
395        activities.insert("n1".to_string(), 0.9);
396        activities.insert("n2".to_string(), 0.8);
397        activities.insert("n3".to_string(), 0.3);
398        activities.insert("n4".to_string(), 0.2);
399
400        let wta = WinnerTakeAll::new(2, 0.9);
401        wta.apply(&mut activities);
402
403        assert!(*activities.get("n1").unwrap() > 0.8);
404        assert!(*activities.get("n2").unwrap() > 0.7);
405        assert!(*activities.get("n3").unwrap() < 0.1);
406        assert!(*activities.get("n4").unwrap() < 0.1);
407    }
408}