1use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8use crate::neuron::NeuronId;
9use crate::spike_train::SpikeTrain;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct PopulationActivity {
14 pub population_id: String,
16 pub activities: HashMap<NeuronId, f64>,
18 pub mean_activity: f64,
20 pub sparsity: f64,
22 pub population_vector: Vec<f64>,
24}
25
26impl PopulationActivity {
27 pub fn from_spike_trains(
29 population_id: String,
30 trains: &[SpikeTrain],
31 window: std::time::Duration,
32 max_rate: f64,
33 ) -> Self {
34 let mut activities = HashMap::new();
35 let mut total_activity = 0.0;
36 let mut active_count = 0;
37
38 for train in trains {
39 let rate = train.firing_rate(window);
40 let activity = (rate / max_rate).min(1.0);
41 activities.insert(train.neuron_id.clone(), activity);
42 total_activity += activity;
43
44 if activity > 0.1 {
45 active_count += 1;
46 }
47 }
48
49 let n = trains.len() as f64;
50 let mean_activity = if n > 0.0 { total_activity / n } else { 0.0 };
51 let sparsity = if n > 0.0 {
52 1.0 - (active_count as f64 / n)
53 } else {
54 1.0
55 };
56
57 Self {
58 population_id,
59 activities,
60 mean_activity,
61 sparsity,
62 population_vector: Vec::new(),
63 }
64 }
65
66 pub fn get_activity(&self, neuron_id: &NeuronId) -> f64 {
68 *self.activities.get(neuron_id).unwrap_or(&0.0)
69 }
70
71 pub fn top_active(&self, k: usize) -> Vec<(NeuronId, f64)> {
73 let mut sorted: Vec<_> = self.activities.iter().collect();
74 sorted.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
75 sorted
76 .into_iter()
77 .take(k)
78 .map(|(id, act)| (id.clone(), *act))
79 .collect()
80 }
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct SparseCode {
86 pub dimension: usize,
88 pub active: HashMap<usize, f64>,
90 pub target_sparsity: f64,
92}
93
94impl SparseCode {
95 pub fn new(dimension: usize, target_sparsity: f64) -> Self {
97 Self {
98 dimension,
99 active: HashMap::new(),
100 target_sparsity,
101 }
102 }
103
104 pub fn from_dense(values: &[f64], target_sparsity: f64) -> Self {
106 let dimension = values.len();
107 let k = ((1.0 - target_sparsity) * dimension as f64).ceil() as usize;
108
109 let mut indexed: Vec<_> = values.iter().enumerate().collect();
111 indexed.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
112
113 let active: HashMap<usize, f64> = indexed
114 .into_iter()
115 .take(k)
116 .filter(|(_, &v)| v > 0.0)
117 .map(|(i, &v)| (i, v))
118 .collect();
119
120 Self {
121 dimension,
122 active,
123 target_sparsity,
124 }
125 }
126
127 pub fn to_dense(&self) -> Vec<f64> {
129 let mut dense = vec![0.0; self.dimension];
130 for (&idx, &val) in &self.active {
131 if idx < self.dimension {
132 dense[idx] = val;
133 }
134 }
135 dense
136 }
137
138 pub fn sparsity(&self) -> f64 {
140 1.0 - (self.active.len() as f64 / self.dimension as f64)
141 }
142
143 pub fn l1_norm(&self) -> f64 {
145 self.active.values().map(|v| v.abs()).sum()
146 }
147
148 pub fn l2_norm(&self) -> f64 {
150 self.active.values().map(|v| v * v).sum::<f64>().sqrt()
151 }
152
153 pub fn dot(&self, other: &SparseCode) -> f64 {
155 let mut sum = 0.0;
156 for (&idx, &val) in &self.active {
157 if let Some(&other_val) = other.active.get(&idx) {
158 sum += val * other_val;
159 }
160 }
161 sum
162 }
163
164 pub fn cosine_similarity(&self, other: &SparseCode) -> f64 {
166 let dot = self.dot(other);
167 let norm_self = self.l2_norm();
168 let norm_other = other.l2_norm();
169
170 if norm_self == 0.0 || norm_other == 0.0 {
171 return 0.0;
172 }
173
174 dot / (norm_self * norm_other)
175 }
176
177 pub fn add(&self, other: &SparseCode) -> SparseCode {
179 let mut result = self.clone();
180
181 for (&idx, &val) in &other.active {
182 *result.active.entry(idx).or_insert(0.0) += val;
183 }
184
185 result
186 }
187
188 pub fn scale(&self, factor: f64) -> SparseCode {
190 let mut result = self.clone();
191 for val in result.active.values_mut() {
192 *val *= factor;
193 }
194 result
195 }
196}
197
198#[derive(Debug, Clone)]
200pub struct NeuralPopulation {
201 pub id: String,
203 pub neuron_ids: Vec<NeuronId>,
205 pub tuning_centers: HashMap<NeuronId, Vec<f64>>,
207 pub tuning_width: f64,
209}
210
211impl NeuralPopulation {
212 pub fn new(id: String, neuron_ids: Vec<NeuronId>) -> Self {
214 Self {
215 id,
216 neuron_ids,
217 tuning_centers: HashMap::new(),
218 tuning_width: 1.0,
219 }
220 }
221
222 pub fn with_uniform_tuning(
224 id: String,
225 size: usize,
226 stimulus_dim: usize,
227 stimulus_range: (f64, f64),
228 ) -> Self {
229 let mut neuron_ids = Vec::new();
230 let mut tuning_centers = HashMap::new();
231
232 for i in 0..size {
233 let neuron_id = format!("{}_{}", id, i);
234
235 let t = i as f64 / (size - 1).max(1) as f64;
237 let center: Vec<f64> = (0..stimulus_dim)
238 .map(|_| stimulus_range.0 + t * (stimulus_range.1 - stimulus_range.0))
239 .collect();
240
241 tuning_centers.insert(neuron_id.clone(), center);
242 neuron_ids.push(neuron_id);
243 }
244
245 Self {
246 id,
247 neuron_ids,
248 tuning_centers,
249 tuning_width: (stimulus_range.1 - stimulus_range.0) / size as f64,
250 }
251 }
252
253 pub fn encode(&self, stimulus: &[f64]) -> HashMap<NeuronId, f64> {
255 let mut activities = HashMap::new();
256
257 for (neuron_id, center) in &self.tuning_centers {
258 let dist_sq: f64 = stimulus
260 .iter()
261 .zip(center.iter())
262 .map(|(s, c)| (s - c).powi(2))
263 .sum();
264
265 let activity = (-dist_sq / (2.0 * self.tuning_width * self.tuning_width)).exp();
267 activities.insert(neuron_id.clone(), activity);
268 }
269
270 activities
271 }
272
273 pub fn decode(&self, activities: &HashMap<NeuronId, f64>) -> Vec<f64> {
275 if self.tuning_centers.is_empty() {
276 return Vec::new();
277 }
278
279 let dim = self.tuning_centers.values().next().map(|v| v.len()).unwrap_or(0);
280 let mut weighted_sum = vec![0.0; dim];
281 let mut total_weight = 0.0;
282
283 for (neuron_id, center) in &self.tuning_centers {
284 let activity = *activities.get(neuron_id).unwrap_or(&0.0);
285 total_weight += activity;
286
287 for (i, &c) in center.iter().enumerate() {
288 weighted_sum[i] += activity * c;
289 }
290 }
291
292 if total_weight > 0.0 {
293 weighted_sum.iter().map(|&s| s / total_weight).collect()
294 } else {
295 vec![0.0; dim]
296 }
297 }
298
299 pub fn size(&self) -> usize {
301 self.neuron_ids.len()
302 }
303}
304
305#[derive(Debug, Clone)]
307pub struct WinnerTakeAll {
308 pub k: usize,
310 pub inhibition: f64,
312}
313
314impl WinnerTakeAll {
315 pub fn new(k: usize, inhibition: f64) -> Self {
316 Self { k, inhibition }
317 }
318
319 pub fn apply(&self, activities: &mut HashMap<NeuronId, f64>) {
321 let mut sorted: Vec<_> = activities.iter().collect();
323 sorted.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
324
325 let winners: std::collections::HashSet<_> =
326 sorted.iter().take(self.k).map(|(id, _)| (*id).clone()).collect();
327
328 for (id, activity) in activities.iter_mut() {
330 if !winners.contains(id) {
331 *activity *= 1.0 - self.inhibition;
332 }
333 }
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 #[test]
342 fn test_sparse_code_creation() {
343 let dense = vec![0.1, 0.9, 0.2, 0.8, 0.3];
344 let sparse = SparseCode::from_dense(&dense, 0.6);
345
346 assert!(sparse.sparsity() >= 0.4); assert!(sparse.active.contains_key(&1)); assert!(sparse.active.contains_key(&3)); }
350
351 #[test]
352 fn test_sparse_code_operations() {
353 let a = SparseCode::from_dense(&[1.0, 0.0, 1.0, 0.0], 0.5);
354 let b = SparseCode::from_dense(&[1.0, 1.0, 0.0, 0.0], 0.5);
355
356 let dot = a.dot(&b);
357 assert!((dot - 1.0).abs() < 0.01); }
359
360 #[test]
361 fn test_neural_population_encoding() {
362 let pop = NeuralPopulation::with_uniform_tuning(
363 "test".to_string(),
364 10,
365 1,
366 (0.0, 1.0),
367 );
368
369 let activities = pop.encode(&[0.5]);
370
371 let max_activity = activities.values().cloned().fold(0.0, f64::max);
373 assert!(max_activity > 0.0);
374 }
375
376 #[test]
377 fn test_neural_population_decoding() {
378 let pop = NeuralPopulation::with_uniform_tuning(
379 "test".to_string(),
380 10,
381 1,
382 (0.0, 1.0),
383 );
384
385 let stimulus = vec![0.5];
386 let activities = pop.encode(&stimulus);
387 let decoded = pop.decode(&activities);
388
389 assert!((decoded[0] - 0.5).abs() < 0.1);
390 }
391
392 #[test]
393 fn test_winner_take_all() {
394 let mut activities: HashMap<NeuronId, f64> = HashMap::new();
395 activities.insert("n1".to_string(), 0.9);
396 activities.insert("n2".to_string(), 0.8);
397 activities.insert("n3".to_string(), 0.3);
398 activities.insert("n4".to_string(), 0.2);
399
400 let wta = WinnerTakeAll::new(2, 0.9);
401 wta.apply(&mut activities);
402
403 assert!(*activities.get("n1").unwrap() > 0.8);
404 assert!(*activities.get("n2").unwrap() > 0.7);
405 assert!(*activities.get("n3").unwrap() < 0.1);
406 assert!(*activities.get("n4").unwrap() < 0.1);
407 }
408}