1use super::{
7 EventPriority, MembraneDynamicsConfig, NeuromorphicEvent, NeuromorphicMetrics, PlasticityModel,
8 STDPConfig, Spike, SpikeTrain,
9};
10
11use scirs2_neural::activations_minimal::Activation;
13use scirs2_neural::layers::Layer;
14use scirs2_stats::distributions;
15
16use crate::error::Result;
17use crate::optimizers::Optimizer;
18use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, DataMut, Dimension};
19use scirs2_core::numeric::Float;
20use scirs2_core::random::{thread_rng, Rng};
21use std::collections::{HashMap, VecDeque};
22use std::fmt::Debug;
23use std::time::Instant;
24
25#[derive(Debug, Clone)]
27pub struct SpikingConfig<T: Float + Debug + Send + Sync + 'static> {
28 pub time_step: T,
30
31 pub simulation_time: T,
33
34 pub encoding_method: SpikeEncodingMethod,
36
37 pub decoding_method: SpikeDecodingMethod,
39
40 pub spike_learning_rate: T,
42
43 pub temporal_window: T,
45
46 pub lateral_inhibition: bool,
48
49 pub homeostatic_config: HomeostaticConfig<T>,
51
52 pub noise_config: SpikeNoiseConfig<T>,
54}
55
56#[derive(Debug, Clone, Copy)]
58pub enum SpikeEncodingMethod {
59 RateCoding,
61
62 TemporalCoding,
64
65 PopulationVectorCoding,
67
68 SparseCoding,
70
71 PhaseCoding,
73
74 BurstCoding,
76
77 RankOrderCoding,
79}
80
81#[derive(Debug, Clone, Copy)]
83pub enum SpikeDecodingMethod {
84 RateDecoding,
86
87 TemporalDecoding,
89
90 PopulationVectorDecoding,
92
93 WeightedSpikeCount,
95
96 MovingAverageFilter,
98
99 ExponentialDecayFilter,
101}
102
103#[derive(Debug, Clone)]
105pub struct HomeostaticConfig<T: Float + Debug + Send + Sync + 'static> {
106 pub enable_homeostatic_scaling: bool,
108
109 pub target_firing_rate: T,
111
112 pub scaling_time_constant: T,
114
115 pub scaling_factor: T,
117
118 pub enable_intrinsic_plasticity: bool,
120
121 pub threshold_adaptation_rate: T,
123}
124
125#[derive(Debug, Clone)]
127pub struct SpikeNoiseConfig<T: Float + Debug + Send + Sync + 'static> {
128 pub background_rate: T,
130
131 pub jitter_std: T,
133
134 pub poisson_noise: bool,
136
137 pub noise_amplitude: T,
139
140 pub correlation_noise: T,
142}
143
144impl<T: Float + Debug + Send + Sync + 'static> Default for SpikingConfig<T> {
145 fn default() -> Self {
146 Self {
147 time_step: T::from(0.1).unwrap_or_else(|| T::zero()),
148 simulation_time: T::from(1000.0).unwrap_or_else(|| T::zero()),
149 encoding_method: SpikeEncodingMethod::RateCoding,
150 decoding_method: SpikeDecodingMethod::RateDecoding,
151 spike_learning_rate: T::from(0.01).unwrap_or_else(|| T::zero()),
152 temporal_window: T::from(20.0).unwrap_or_else(|| T::zero()),
153 lateral_inhibition: false,
154 homeostatic_config: HomeostaticConfig::default(),
155 noise_config: SpikeNoiseConfig::default(),
156 }
157 }
158}
159
160impl<T: Float + Debug + Send + Sync + 'static> Default for HomeostaticConfig<T> {
161 fn default() -> Self {
162 Self {
163 enable_homeostatic_scaling: false,
164 target_firing_rate: T::from(10.0).unwrap_or_else(|| T::zero()),
165 scaling_time_constant: T::from(1000.0).unwrap_or_else(|| T::zero()),
166 scaling_factor: T::from(0.01).unwrap_or_else(|| T::zero()),
167 enable_intrinsic_plasticity: false,
168 threshold_adaptation_rate: T::from(0.001).unwrap_or_else(|| T::zero()),
169 }
170 }
171}
172
173impl<T: Float + Debug + Send + Sync + 'static> Default for SpikeNoiseConfig<T> {
174 fn default() -> Self {
175 Self {
176 background_rate: T::from(1.0).unwrap_or_else(|| T::zero()),
177 jitter_std: T::from(0.5).unwrap_or_else(|| T::zero()),
178 poisson_noise: false,
179 noise_amplitude: T::from(0.1).unwrap_or_else(|| T::zero()),
180 correlation_noise: T::zero(),
181 }
182 }
183}
184
185pub struct SpikingOptimizer<
187 T: Float + Debug + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
188> {
189 config: SpikingConfig<T>,
191
192 stdp_config: STDPConfig<T>,
194
195 membrane_config: MembraneDynamicsConfig<T>,
197
198 current_time: T,
200
201 spike_trains: HashMap<usize, SpikeTrain<T>>,
203
204 membrane_potentials: Array1<T>,
206
207 synaptic_weights: Array2<T>,
209
210 last_spike_times: Array1<T>,
212
213 refractory_until: Array1<T>,
215
216 homeostatic_scales: Array1<T>,
218
219 spike_buffer: VecDeque<Spike<T>>,
221
222 metrics: NeuromorphicMetrics<T>,
224
225 plasticity_model: PlasticityModel,
227}
228
229impl<
230 T: Float
231 + Debug
232 + Send
233 + Sync
234 + scirs2_core::ndarray::ScalarOperand
235 + 'static
236 + std::iter::Sum,
237 > SpikingOptimizer<T>
238{
239 pub fn new(
241 config: SpikingConfig<T>,
242 stdp_config: STDPConfig<T>,
243 membrane_config: MembraneDynamicsConfig<T>,
244 num_neurons: usize,
245 ) -> Self {
246 let resting_potential = membrane_config.resting_potential;
247 Self {
248 config,
249 stdp_config,
250 membrane_config,
251 current_time: T::zero(),
252 spike_trains: HashMap::new(),
253 membrane_potentials: Array1::from_elem(num_neurons, resting_potential),
254 synaptic_weights: Array2::ones((num_neurons, num_neurons))
255 * T::from(0.1).unwrap_or_else(|| T::zero()),
256 last_spike_times: Array1::from_elem(
257 num_neurons,
258 T::from(-1000.0).unwrap_or_else(|| T::zero()),
259 ),
260 refractory_until: Array1::zeros(num_neurons),
261 homeostatic_scales: Array1::ones(num_neurons),
262 spike_buffer: VecDeque::new(),
263 metrics: NeuromorphicMetrics::default(),
264 plasticity_model: PlasticityModel::STDP,
265 }
266 }
267
268 pub fn encode_input(&self, input: &Array1<T>) -> Result<Vec<SpikeTrain<T>>> {
270 let mut spike_trains = Vec::new();
271
272 for (neuron_id, &value) in input.iter().enumerate() {
273 let spike_train = match self.config.encoding_method {
274 SpikeEncodingMethod::RateCoding => self.rate_encode(neuron_id, value)?,
275 SpikeEncodingMethod::TemporalCoding => self.temporal_encode(neuron_id, value)?,
276 SpikeEncodingMethod::PopulationVectorCoding => {
277 self.population_vector_encode(neuron_id, value)?
278 }
279 SpikeEncodingMethod::SparseCoding => self.sparse_encode(neuron_id, value)?,
280 _ => {
281 self.rate_encode(neuron_id, value)?
283 }
284 };
285
286 spike_trains.push(spike_train);
287 }
288
289 Ok(spike_trains)
290 }
291
292 fn rate_encode(&self, neuron_id: usize, value: T) -> Result<SpikeTrain<T>> {
294 let max_rate = T::from(100.0).unwrap_or_else(|| T::zero()); let firing_rate = value.abs() * max_rate;
296
297 let mut spike_times = Vec::new();
298 let dt = self.config.time_step;
299 let total_time = self.config.simulation_time;
300
301 let mut time = T::zero();
302 while time < total_time {
303 let spike_prob = firing_rate * dt / T::from(1000.0).unwrap_or_else(|| T::zero());
305
306 if thread_rng().random::<f64>() < spike_prob.to_f64().unwrap_or(0.0) {
307 spike_times.push(time);
308 }
309
310 time = time + dt;
311 }
312
313 Ok(SpikeTrain::new(neuron_id, spike_times))
314 }
315
316 fn temporal_encode(&self, neuron_id: usize, value: T) -> Result<SpikeTrain<T>> {
318 let max_delay = T::from(20.0).unwrap_or_else(|| T::zero()); let spike_time = if value > T::zero() {
320 max_delay * (T::one() - value.min(T::one()))
321 } else {
322 max_delay };
324
325 let spike_times = if spike_time < max_delay {
326 vec![spike_time]
327 } else {
328 Vec::new()
329 };
330
331 Ok(SpikeTrain::new(neuron_id, spike_times))
332 }
333
334 fn population_vector_encode(&self, neuron_id: usize, value: T) -> Result<SpikeTrain<T>> {
336 self.rate_encode(neuron_id, value)
338 }
339
340 fn sparse_encode(&self, neuron_id: usize, value: T) -> Result<SpikeTrain<T>> {
342 let threshold = T::from(0.5).unwrap_or_else(|| T::zero());
343
344 if value.abs() > threshold {
345 self.rate_encode(neuron_id, value)
346 } else {
347 Ok(SpikeTrain::new(neuron_id, Vec::new()))
348 }
349 }
350
351 pub fn decode_output(&self, spike_trains: &[SpikeTrain<T>]) -> Result<Array1<T>> {
353 let mut output = Array1::zeros(spike_trains.len());
354
355 for (i, spike_train) in spike_trains.iter().enumerate() {
356 output[i] = match self.config.decoding_method {
357 SpikeDecodingMethod::RateDecoding => self.rate_decode(spike_train)?,
358 SpikeDecodingMethod::TemporalDecoding => self.temporal_decode(spike_train)?,
359 SpikeDecodingMethod::WeightedSpikeCount => {
360 self.weighted_spike_count_decode(spike_train)?
361 }
362 _ => {
363 self.rate_decode(spike_train)?
365 }
366 };
367 }
368
369 Ok(output)
370 }
371
372 fn rate_decode(&self, spike_train: &SpikeTrain<T>) -> Result<T> {
374 let window_duration = self.config.temporal_window;
375 let spike_count = T::from(spike_train.spike_count).unwrap_or_else(|| T::zero());
376 let rate = spike_count / (window_duration / T::from(1000.0).unwrap_or_else(|| T::zero()));
377 Ok(rate / T::from(100.0).unwrap_or_else(|| T::zero())) }
379
380 fn temporal_decode(&self, spike_train: &SpikeTrain<T>) -> Result<T> {
382 if spike_train.spike_times.is_empty() {
383 Ok(T::zero())
384 } else {
385 let first_spike = spike_train.spike_times[0];
386 let max_delay = T::from(20.0).unwrap_or_else(|| T::zero());
387 Ok(T::one() - (first_spike / max_delay).min(T::one()))
388 }
389 }
390
391 fn weighted_spike_count_decode(&self, spike_train: &SpikeTrain<T>) -> Result<T> {
393 if spike_train.spike_times.is_empty() {
394 return Ok(T::zero());
395 }
396
397 let mut weighted_sum = T::zero();
398 let current_time = self.current_time;
399
400 for &spike_time in &spike_train.spike_times {
401 let time_diff = current_time - spike_time;
402 let weight = (-time_diff / T::from(10.0).unwrap_or_else(|| T::zero())).exp(); weighted_sum = weighted_sum + weight;
404 }
405
406 Ok(weighted_sum)
407 }
408
409 pub fn simulate_step(&mut self, input_spikes: &[Spike<T>]) -> Result<Vec<Spike<T>>> {
411 let mut output_spikes = Vec::new();
412 let dt = self.config.time_step;
413
414 for spike in input_spikes {
416 self.process_input_spike(spike)?;
417 }
418
419 for neuron_id in 0..self.membrane_potentials.len() {
421 if self.current_time >= self.refractory_until[neuron_id] {
422 self.update_membrane_potential(neuron_id, dt)?;
423
424 if self.membrane_potentials[neuron_id] >= self.membrane_config.threshold_potential {
426 let spike = self.generate_spike(neuron_id)?;
427 output_spikes.push(spike);
428 }
429 }
430 }
431
432 self.update_plasticity(&output_spikes)?;
434
435 if self.config.homeostatic_config.enable_homeostatic_scaling {
437 self.update_homeostatic_scaling()?;
438 }
439
440 self.current_time = self.current_time + dt;
441
442 Ok(output_spikes)
443 }
444
445 fn process_input_spike(&mut self, spike: &Spike<T>) -> Result<()> {
447 let target_neuron = spike.postsynaptic_id.unwrap_or(spike.neuron_id);
448
449 if target_neuron < self.membrane_potentials.len() {
450 let synaptic_current = spike.weight * spike.amplitude;
452 self.membrane_potentials[target_neuron] =
453 self.membrane_potentials[target_neuron] + synaptic_current;
454 }
455
456 Ok(())
457 }
458
459 fn update_membrane_potential(&mut self, neuron_id: usize, dt: T) -> Result<()> {
461 let v = self.membrane_potentials[neuron_id];
462 let v_rest = self.membrane_config.resting_potential;
463 let tau = self.membrane_config.tau_membrane;
464
465 let dv_dt = (v_rest - v) / tau;
467 let new_v = v + dv_dt * dt;
468
469 self.membrane_potentials[neuron_id] = new_v;
470
471 Ok(())
472 }
473
474 fn generate_spike(&mut self, neuron_id: usize) -> Result<Spike<T>> {
476 self.membrane_potentials[neuron_id] = self.membrane_config.reset_potential;
478
479 self.refractory_until[neuron_id] =
481 self.current_time + self.membrane_config.refractory_period;
482
483 self.last_spike_times[neuron_id] = self.current_time;
485
486 let spike = Spike {
488 neuron_id,
489 time: self.current_time,
490 amplitude: T::from(1.0).unwrap_or_else(|| T::zero()),
491 width: Some(T::from(1.0).unwrap_or_else(|| T::zero())),
492 weight: T::one(),
493 presynaptic_id: None,
494 postsynaptic_id: None,
495 };
496
497 if let Some(spike_train) = self.spike_trains.get_mut(&neuron_id) {
499 spike_train.spike_times.push(self.current_time);
500 spike_train.spike_count += 1;
501 } else {
502 let spike_train = SpikeTrain::new(neuron_id, vec![self.current_time]);
503 self.spike_trains.insert(neuron_id, spike_train);
504 }
505
506 self.metrics.total_spikes += 1;
508
509 Ok(spike)
510 }
511
512 fn update_plasticity(&mut self, output_spikes: &[Spike<T>]) -> Result<()> {
514 match self.plasticity_model {
515 PlasticityModel::STDP => {
516 self.update_stdp(output_spikes)?;
517 }
518 PlasticityModel::Hebbian => {
519 self.update_hebbian(output_spikes)?;
520 }
521 _ => {
522 self.update_stdp(output_spikes)?;
524 }
525 }
526
527 Ok(())
528 }
529
530 fn update_stdp(&mut self, output_spikes: &[Spike<T>]) -> Result<()> {
532 for spike in output_spikes {
533 let post_id = spike.neuron_id;
534 let post_time = spike.time;
535
536 for pre_id in 0..self.last_spike_times.len() {
538 if pre_id != post_id {
539 let pre_time = self.last_spike_times[pre_id];
540
541 if pre_time > T::from(-1000.0).unwrap_or_else(|| T::zero()) {
542 let dt = post_time - pre_time;
544 let weight_change = self.compute_stdp_update(dt);
545
546 self.synaptic_weights[[pre_id, post_id]] =
548 (self.synaptic_weights[[pre_id, post_id]] + weight_change)
549 .max(self.stdp_config.weight_min)
550 .min(self.stdp_config.weight_max);
551 }
552 }
553 }
554 }
555
556 Ok(())
557 }
558
559 fn compute_stdp_update(&self, dt: T) -> T {
561 if dt > T::zero() {
562 let exp_arg = -dt / self.stdp_config.tau_pot;
564 self.stdp_config.learning_rate_pot * exp_arg.exp()
565 } else {
566 let exp_arg = dt / self.stdp_config.tau_dep;
568 -self.stdp_config.learning_rate_dep * exp_arg.exp()
569 }
570 }
571
572 fn update_hebbian(&mut self, output_spikes: &[Spike<T>]) -> Result<()> {
574 for spike in output_spikes {
576 let post_id = spike.neuron_id;
577
578 for pre_id in 0..self.membrane_potentials.len() {
579 if pre_id != post_id {
580 let pre_activity =
581 self.membrane_potentials[pre_id] / self.membrane_config.threshold_potential;
582
583 let weight_change = self.stdp_config.learning_rate_pot * pre_activity;
584
585 self.synaptic_weights[[pre_id, post_id]] =
586 (self.synaptic_weights[[pre_id, post_id]] + weight_change)
587 .max(self.stdp_config.weight_min)
588 .min(self.stdp_config.weight_max);
589 }
590 }
591 }
592
593 Ok(())
594 }
595
596 fn update_homeostatic_scaling(&mut self) -> Result<()> {
598 let target_rate = self.config.homeostatic_config.target_firing_rate;
599 let time_constant = self.config.homeostatic_config.scaling_time_constant;
600 let dt = self.config.time_step;
601
602 for neuron_id in 0..self.homeostatic_scales.len() {
603 if let Some(spike_train) = self.spike_trains.get(&neuron_id) {
604 let current_rate = spike_train.firing_rate;
605 let rate_error = target_rate - current_rate;
606
607 let scale_change = rate_error * dt / time_constant;
609 self.homeostatic_scales[neuron_id] =
610 self.homeostatic_scales[neuron_id] + scale_change;
611
612 for pre_id in 0..self.synaptic_weights.nrows() {
614 self.synaptic_weights[[pre_id, neuron_id]] = self.synaptic_weights
615 [[pre_id, neuron_id]]
616 * self.homeostatic_scales[neuron_id];
617 }
618 }
619 }
620
621 Ok(())
622 }
623
624 pub fn get_metrics(&self) -> &NeuromorphicMetrics<T> {
626 &self.metrics
627 }
628
629 pub fn reset(&mut self) {
631 self.current_time = T::zero();
632 self.membrane_potentials
633 .fill(self.membrane_config.resting_potential);
634 self.last_spike_times
635 .fill(T::from(-1000.0).unwrap_or_else(|| T::zero()));
636 self.refractory_until.fill(T::zero());
637 self.spike_trains.clear();
638 self.spike_buffer.clear();
639 self.metrics = NeuromorphicMetrics::default();
640 }
641}
642
643pub struct SpikeTrainOptimizer<
645 T: Float + Debug + scirs2_core::ndarray::ScalarOperand + std::fmt::Debug + Send + Sync,
646> {
647 config: SpikingConfig<T>,
649
650 pattern_templates: Vec<SpikePattern<T>>,
652
653 matching_threshold: T,
655
656 pattern_learning_rate: T,
658
659 temporal_kernel: TemporalKernel<T>,
661}
662
663#[derive(Debug, Clone)]
665pub struct SpikePattern<T: Float + Debug + Send + Sync + 'static> {
666 pub pattern_id: usize,
668
669 pub relative_spike_times: Vec<T>,
671
672 pub duration: T,
674
675 pub weight: T,
677
678 pub observation_count: usize,
680}
681
682#[derive(Debug, Clone)]
684pub struct TemporalKernel<T: Float + Debug + Send + Sync + 'static> {
685 pub kernel_type: TemporalKernelType,
687
688 pub width: T,
690
691 pub parameters: Vec<T>,
693}
694
695#[derive(Debug, Clone, Copy)]
697pub enum TemporalKernelType {
698 Gaussian,
700
701 Exponential,
703
704 Alpha,
706
707 Rectangular,
709}
710
711impl<T: Float + Debug + Send + Sync + scirs2_core::ndarray::ScalarOperand + std::fmt::Debug>
712 SpikeTrainOptimizer<T>
713{
714 pub fn new(config: SpikingConfig<T>) -> Self {
716 Self {
717 config,
718 pattern_templates: Vec::new(),
719 matching_threshold: T::from(0.8).unwrap_or_else(|| T::zero()),
720 pattern_learning_rate: T::from(0.1).unwrap_or_else(|| T::zero()),
721 temporal_kernel: TemporalKernel {
722 kernel_type: TemporalKernelType::Gaussian,
723 width: T::from(5.0).unwrap_or_else(|| T::zero()),
724 parameters: vec![T::one()],
725 },
726 }
727 }
728
729 pub fn learn_patterns(&mut self, spike_trains: &[SpikeTrain<T>]) -> Result<()> {
731 for spike_train in spike_trains {
732 self.extract_and_learn_patterns(spike_train)?;
733 }
734
735 Ok(())
736 }
737
738 fn extract_and_learn_patterns(&mut self, spike_train: &SpikeTrain<T>) -> Result<()> {
740 let window_size = T::from(50.0).unwrap_or_else(|| T::zero()); let step_size = T::from(10.0).unwrap_or_else(|| T::zero()); let mut window_start = T::zero();
744
745 while window_start < spike_train.duration {
746 let window_end = window_start + window_size;
747
748 let window_spikes: Vec<T> = spike_train
750 .spike_times
751 .iter()
752 .filter(|&&t| t >= window_start && t < window_end)
753 .map(|&t| t - window_start) .collect();
755
756 if !window_spikes.is_empty() {
757 let pattern = SpikePattern {
758 pattern_id: self.pattern_templates.len(),
759 relative_spike_times: window_spikes,
760 duration: window_size,
761 weight: T::one(),
762 observation_count: 1,
763 };
764
765 if let Some(similar_pattern_id) = self.find_similar_pattern(&pattern) {
767 self.update_pattern(similar_pattern_id, &pattern)?;
768 } else {
769 self.pattern_templates.push(pattern);
770 }
771 }
772
773 window_start = window_start + step_size;
774 }
775
776 Ok(())
777 }
778
779 fn find_similar_pattern(&self, new_pattern: &SpikePattern<T>) -> Option<usize> {
781 for (i, existing_pattern) in self.pattern_templates.iter().enumerate() {
782 let similarity = self.compute_pattern_similarity(new_pattern, existing_pattern);
783 if similarity > self.matching_threshold {
784 return Some(i);
785 }
786 }
787
788 None
789 }
790
791 fn compute_pattern_similarity(
793 &self,
794 pattern1: &SpikePattern<T>,
795 pattern2: &SpikePattern<T>,
796 ) -> T {
797 let max_spikes = pattern1
799 .relative_spike_times
800 .len()
801 .max(pattern2.relative_spike_times.len());
802 if max_spikes == 0 {
803 return T::one();
804 }
805
806 let count_diff = (pattern1.relative_spike_times.len() as i32
808 - pattern2.relative_spike_times.len() as i32)
809 .abs() as f64;
810 let count_similarity =
811 T::one() - T::from(count_diff / max_spikes as f64).unwrap_or_else(|| T::zero());
812
813 if !pattern1.relative_spike_times.is_empty() && !pattern2.relative_spike_times.is_empty() {
815 let temporal_similarity = self.compute_temporal_similarity(
816 &pattern1.relative_spike_times,
817 &pattern2.relative_spike_times,
818 );
819 (count_similarity + temporal_similarity) / T::from(2.0).unwrap_or_else(|| T::zero())
820 } else {
821 count_similarity
822 }
823 }
824
825 fn compute_temporal_similarity(&self, spikes1: &[T], spikes2: &[T]) -> T {
827 let mut max_correlation = T::zero();
829 let max_shift = T::from(10.0).unwrap_or_else(|| T::zero()); let shift_step = T::from(1.0).unwrap_or_else(|| T::zero());
831
832 let mut shift = -max_shift;
833 while shift <= max_shift {
834 let correlation = self.compute_spike_correlation(spikes1, spikes2, shift);
835 max_correlation = max_correlation.max(correlation);
836 shift = shift + shift_step;
837 }
838
839 max_correlation
840 }
841
842 fn compute_spike_correlation(&self, spikes1: &[T], spikes2: &[T], shift: T) -> T {
844 let mut correlation = T::zero();
845 let kernel_width = self.temporal_kernel.width;
846
847 for &t1 in spikes1 {
848 for &t2 in spikes2 {
849 let dt = (t1 - (t2 + shift)).abs();
850 let kernel_value = (-dt * dt
851 / (T::from(2.0).unwrap_or_else(|| T::zero()) * kernel_width * kernel_width))
852 .exp();
853 correlation = correlation + kernel_value;
854 }
855 }
856
857 if !spikes1.is_empty() && !spikes2.is_empty() {
859 correlation / T::from(spikes1.len() * spikes2.len()).unwrap()
860 } else {
861 T::zero()
862 }
863 }
864
865 fn update_pattern(&mut self, pattern_id: usize, new_pattern: &SpikePattern<T>) -> Result<()> {
867 if let Some(existing_pattern) = self.pattern_templates.get_mut(pattern_id) {
868 let alpha = self.pattern_learning_rate;
870
871 if existing_pattern.relative_spike_times.len() == new_pattern.relative_spike_times.len()
873 {
874 for (existing_time, &new_time) in existing_pattern
875 .relative_spike_times
876 .iter_mut()
877 .zip(new_pattern.relative_spike_times.iter())
878 {
879 *existing_time = *existing_time * (T::one() - alpha) + new_time * alpha;
880 }
881 }
882
883 existing_pattern.observation_count += 1;
884 existing_pattern.weight =
885 existing_pattern.weight * (T::one() - alpha) + new_pattern.weight * alpha;
886 }
887
888 Ok(())
889 }
890
891 pub fn recognize_patterns(&self, spike_train: &SpikeTrain<T>) -> Result<Vec<(usize, T, T)>> {
893 let mut recognized_patterns = Vec::new();
894 let window_size = T::from(50.0).unwrap_or_else(|| T::zero());
895 let step_size = T::from(5.0).unwrap_or_else(|| T::zero());
896
897 let mut window_start = T::zero();
898
899 while window_start < spike_train.duration {
900 let window_end = window_start + window_size;
901
902 let window_spikes: Vec<T> = spike_train
903 .spike_times
904 .iter()
905 .filter(|&&t| t >= window_start && t < window_end)
906 .map(|&t| t - window_start)
907 .collect();
908
909 if !window_spikes.is_empty() {
910 let test_pattern = SpikePattern {
911 pattern_id: 0,
912 relative_spike_times: window_spikes,
913 duration: window_size,
914 weight: T::one(),
915 observation_count: 1,
916 };
917
918 let mut best_match = (0, T::zero());
920 for (i, template) in self.pattern_templates.iter().enumerate() {
921 let similarity = self.compute_pattern_similarity(&test_pattern, template);
922 if similarity > best_match.1 {
923 best_match = (i, similarity);
924 }
925 }
926
927 if best_match.1 > self.matching_threshold {
928 recognized_patterns.push((best_match.0, window_start, best_match.1));
929 }
930 }
931
932 window_start = window_start + step_size;
933 }
934
935 Ok(recognized_patterns)
936 }
937
938 pub fn get_patterns(&self) -> &[SpikePattern<T>] {
940 &self.pattern_templates
941 }
942}