optirs_core/neuromorphic/
spike_based.rs

1// Spike-Based Optimization Algorithms
2//
3// This module implements optimization algorithms that operate on spike trains
4// and temporal spike patterns, designed for neuromorphic computing platforms.
5
6use super::{
7    EventPriority, MembraneDynamicsConfig, NeuromorphicEvent, NeuromorphicMetrics, PlasticityModel,
8    STDPConfig, Spike, SpikeTrain,
9};
10
11// SciRS2 Integration - CRITICAL for neuromorphic computing
12use scirs2_neural::activations_minimal::Activation;
13use scirs2_neural::layers::Layer;
14use scirs2_stats::distributions;
15
16use crate::error::Result;
17use crate::optimizers::Optimizer;
18use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, DataMut, Dimension};
19use scirs2_core::numeric::Float;
20use scirs2_core::random::{thread_rng, Rng};
21use std::collections::{HashMap, VecDeque};
22use std::fmt::Debug;
23use std::time::Instant;
24
25/// Spike-based optimization configuration
26#[derive(Debug, Clone)]
27pub struct SpikingConfig<T: Float + Debug + Send + Sync + 'static> {
28    /// Simulation time step (ms)
29    pub time_step: T,
30
31    /// Total simulation time (ms)
32    pub simulation_time: T,
33
34    /// Encoding method for input data
35    pub encoding_method: SpikeEncodingMethod,
36
37    /// Decoding method for output spikes
38    pub decoding_method: SpikeDecodingMethod,
39
40    /// Spike train learning rate
41    pub spike_learning_rate: T,
42
43    /// Temporal window for spike correlation (ms)
44    pub temporal_window: T,
45
46    /// Enable lateral inhibition
47    pub lateral_inhibition: bool,
48
49    /// Homeostatic scaling parameters
50    pub homeostatic_config: HomeostaticConfig<T>,
51
52    /// Noise parameters for spike generation
53    pub noise_config: SpikeNoiseConfig<T>,
54}
55
56/// Spike encoding methods for converting continuous values to spike trains
57#[derive(Debug, Clone, Copy)]
58pub enum SpikeEncodingMethod {
59    /// Rate coding (firing rate proportional to value)
60    RateCoding,
61
62    /// Temporal coding (spike time proportional to value)
63    TemporalCoding,
64
65    /// Population vector coding
66    PopulationVectorCoding,
67
68    /// Sparse coding
69    SparseCoding,
70
71    /// Phase coding
72    PhaseCoding,
73
74    /// Burst coding
75    BurstCoding,
76
77    /// Rank order coding
78    RankOrderCoding,
79}
80
81/// Spike decoding methods for converting spike trains to continuous values
82#[derive(Debug, Clone, Copy)]
83pub enum SpikeDecodingMethod {
84    /// Rate decoding (spike count in time window)
85    RateDecoding,
86
87    /// Temporal decoding (first spike time)
88    TemporalDecoding,
89
90    /// Population vector decoding
91    PopulationVectorDecoding,
92
93    /// Weighted spike count
94    WeightedSpikeCount,
95
96    /// Moving average filter
97    MovingAverageFilter,
98
99    /// Exponential decay filter
100    ExponentialDecayFilter,
101}
102
103/// Homeostatic plasticity configuration
104#[derive(Debug, Clone)]
105pub struct HomeostaticConfig<T: Float + Debug + Send + Sync + 'static> {
106    /// Enable homeostatic scaling
107    pub enable_homeostatic_scaling: bool,
108
109    /// Target firing rate (Hz)
110    pub target_firing_rate: T,
111
112    /// Scaling time constant (ms)
113    pub scaling_time_constant: T,
114
115    /// Scaling factor
116    pub scaling_factor: T,
117
118    /// Enable intrinsic plasticity
119    pub enable_intrinsic_plasticity: bool,
120
121    /// Threshold adaptation rate
122    pub threshold_adaptation_rate: T,
123}
124
125/// Spike noise configuration
126#[derive(Debug, Clone)]
127pub struct SpikeNoiseConfig<T: Float + Debug + Send + Sync + 'static> {
128    /// Background firing rate (Hz)
129    pub background_rate: T,
130
131    /// Jitter standard deviation (ms)
132    pub jitter_std: T,
133
134    /// Enable Poisson noise
135    pub poisson_noise: bool,
136
137    /// Noise amplitude
138    pub noise_amplitude: T,
139
140    /// Correlation noise
141    pub correlation_noise: T,
142}
143
144impl<T: Float + Debug + Send + Sync + 'static> Default for SpikingConfig<T> {
145    fn default() -> Self {
146        Self {
147            time_step: T::from(0.1).unwrap_or_else(|| T::zero()),
148            simulation_time: T::from(1000.0).unwrap_or_else(|| T::zero()),
149            encoding_method: SpikeEncodingMethod::RateCoding,
150            decoding_method: SpikeDecodingMethod::RateDecoding,
151            spike_learning_rate: T::from(0.01).unwrap_or_else(|| T::zero()),
152            temporal_window: T::from(20.0).unwrap_or_else(|| T::zero()),
153            lateral_inhibition: false,
154            homeostatic_config: HomeostaticConfig::default(),
155            noise_config: SpikeNoiseConfig::default(),
156        }
157    }
158}
159
160impl<T: Float + Debug + Send + Sync + 'static> Default for HomeostaticConfig<T> {
161    fn default() -> Self {
162        Self {
163            enable_homeostatic_scaling: false,
164            target_firing_rate: T::from(10.0).unwrap_or_else(|| T::zero()),
165            scaling_time_constant: T::from(1000.0).unwrap_or_else(|| T::zero()),
166            scaling_factor: T::from(0.01).unwrap_or_else(|| T::zero()),
167            enable_intrinsic_plasticity: false,
168            threshold_adaptation_rate: T::from(0.001).unwrap_or_else(|| T::zero()),
169        }
170    }
171}
172
173impl<T: Float + Debug + Send + Sync + 'static> Default for SpikeNoiseConfig<T> {
174    fn default() -> Self {
175        Self {
176            background_rate: T::from(1.0).unwrap_or_else(|| T::zero()),
177            jitter_std: T::from(0.5).unwrap_or_else(|| T::zero()),
178            poisson_noise: false,
179            noise_amplitude: T::from(0.1).unwrap_or_else(|| T::zero()),
180            correlation_noise: T::zero(),
181        }
182    }
183}
184
185/// Spike-based optimizer
186pub struct SpikingOptimizer<
187    T: Float + Debug + Send + Sync + scirs2_core::ndarray::ScalarOperand + 'static,
188> {
189    /// Configuration
190    config: SpikingConfig<T>,
191
192    /// STDP configuration
193    stdp_config: STDPConfig<T>,
194
195    /// Membrane dynamics configuration
196    membrane_config: MembraneDynamicsConfig<T>,
197
198    /// Current simulation time
199    current_time: T,
200
201    /// Spike trains for each neuron
202    spike_trains: HashMap<usize, SpikeTrain<T>>,
203
204    /// Current membrane potentials
205    membrane_potentials: Array1<T>,
206
207    /// Synaptic weights
208    synaptic_weights: Array2<T>,
209
210    /// Last spike times for each neuron
211    last_spike_times: Array1<T>,
212
213    /// Refractory state
214    refractory_until: Array1<T>,
215
216    /// Homeostatic scaling factors
217    homeostatic_scales: Array1<T>,
218
219    /// Spike buffer for temporal processing
220    spike_buffer: VecDeque<Spike<T>>,
221
222    /// Performance metrics
223    metrics: NeuromorphicMetrics<T>,
224
225    /// Plasticity model
226    plasticity_model: PlasticityModel,
227}
228
229impl<
230        T: Float
231            + Debug
232            + Send
233            + Sync
234            + scirs2_core::ndarray::ScalarOperand
235            + 'static
236            + std::iter::Sum,
237    > SpikingOptimizer<T>
238{
239    /// Create a new spiking optimizer
240    pub fn new(
241        config: SpikingConfig<T>,
242        stdp_config: STDPConfig<T>,
243        membrane_config: MembraneDynamicsConfig<T>,
244        num_neurons: usize,
245    ) -> Self {
246        let resting_potential = membrane_config.resting_potential;
247        Self {
248            config,
249            stdp_config,
250            membrane_config,
251            current_time: T::zero(),
252            spike_trains: HashMap::new(),
253            membrane_potentials: Array1::from_elem(num_neurons, resting_potential),
254            synaptic_weights: Array2::ones((num_neurons, num_neurons))
255                * T::from(0.1).unwrap_or_else(|| T::zero()),
256            last_spike_times: Array1::from_elem(
257                num_neurons,
258                T::from(-1000.0).unwrap_or_else(|| T::zero()),
259            ),
260            refractory_until: Array1::zeros(num_neurons),
261            homeostatic_scales: Array1::ones(num_neurons),
262            spike_buffer: VecDeque::new(),
263            metrics: NeuromorphicMetrics::default(),
264            plasticity_model: PlasticityModel::STDP,
265        }
266    }
267
268    /// Encode continuous input as spike trains
269    pub fn encode_input(&self, input: &Array1<T>) -> Result<Vec<SpikeTrain<T>>> {
270        let mut spike_trains = Vec::new();
271
272        for (neuron_id, &value) in input.iter().enumerate() {
273            let spike_train = match self.config.encoding_method {
274                SpikeEncodingMethod::RateCoding => self.rate_encode(neuron_id, value)?,
275                SpikeEncodingMethod::TemporalCoding => self.temporal_encode(neuron_id, value)?,
276                SpikeEncodingMethod::PopulationVectorCoding => {
277                    self.population_vector_encode(neuron_id, value)?
278                }
279                SpikeEncodingMethod::SparseCoding => self.sparse_encode(neuron_id, value)?,
280                _ => {
281                    // Fallback to rate coding
282                    self.rate_encode(neuron_id, value)?
283                }
284            };
285
286            spike_trains.push(spike_train);
287        }
288
289        Ok(spike_trains)
290    }
291
292    /// Rate encoding: firing rate proportional to input value
293    fn rate_encode(&self, neuron_id: usize, value: T) -> Result<SpikeTrain<T>> {
294        let max_rate = T::from(100.0).unwrap_or_else(|| T::zero()); // 100 Hz max
295        let firing_rate = value.abs() * max_rate;
296
297        let mut spike_times = Vec::new();
298        let dt = self.config.time_step;
299        let total_time = self.config.simulation_time;
300
301        let mut time = T::zero();
302        while time < total_time {
303            // Poisson process: probability of spike in dt
304            let spike_prob = firing_rate * dt / T::from(1000.0).unwrap_or_else(|| T::zero());
305
306            if thread_rng().random::<f64>() < spike_prob.to_f64().unwrap_or(0.0) {
307                spike_times.push(time);
308            }
309
310            time = time + dt;
311        }
312
313        Ok(SpikeTrain::new(neuron_id, spike_times))
314    }
315
316    /// Temporal encoding: spike time inversely proportional to input value
317    fn temporal_encode(&self, neuron_id: usize, value: T) -> Result<SpikeTrain<T>> {
318        let max_delay = T::from(20.0).unwrap_or_else(|| T::zero()); // 20 ms max delay
319        let spike_time = if value > T::zero() {
320            max_delay * (T::one() - value.min(T::one()))
321        } else {
322            max_delay // No spike for negative values
323        };
324
325        let spike_times = if spike_time < max_delay {
326            vec![spike_time]
327        } else {
328            Vec::new()
329        };
330
331        Ok(SpikeTrain::new(neuron_id, spike_times))
332    }
333
334    /// Population vector encoding
335    fn population_vector_encode(&self, neuron_id: usize, value: T) -> Result<SpikeTrain<T>> {
336        // Simplified population vector encoding
337        self.rate_encode(neuron_id, value)
338    }
339
340    /// Sparse encoding: only strong inputs generate spikes
341    fn sparse_encode(&self, neuron_id: usize, value: T) -> Result<SpikeTrain<T>> {
342        let threshold = T::from(0.5).unwrap_or_else(|| T::zero());
343
344        if value.abs() > threshold {
345            self.rate_encode(neuron_id, value)
346        } else {
347            Ok(SpikeTrain::new(neuron_id, Vec::new()))
348        }
349    }
350
351    /// Decode spike trains to continuous output
352    pub fn decode_output(&self, spike_trains: &[SpikeTrain<T>]) -> Result<Array1<T>> {
353        let mut output = Array1::zeros(spike_trains.len());
354
355        for (i, spike_train) in spike_trains.iter().enumerate() {
356            output[i] = match self.config.decoding_method {
357                SpikeDecodingMethod::RateDecoding => self.rate_decode(spike_train)?,
358                SpikeDecodingMethod::TemporalDecoding => self.temporal_decode(spike_train)?,
359                SpikeDecodingMethod::WeightedSpikeCount => {
360                    self.weighted_spike_count_decode(spike_train)?
361                }
362                _ => {
363                    // Fallback to rate decoding
364                    self.rate_decode(spike_train)?
365                }
366            };
367        }
368
369        Ok(output)
370    }
371
372    /// Rate decoding: spike count normalized by time window
373    fn rate_decode(&self, spike_train: &SpikeTrain<T>) -> Result<T> {
374        let window_duration = self.config.temporal_window;
375        let spike_count = T::from(spike_train.spike_count).unwrap_or_else(|| T::zero());
376        let rate = spike_count / (window_duration / T::from(1000.0).unwrap_or_else(|| T::zero()));
377        Ok(rate / T::from(100.0).unwrap_or_else(|| T::zero())) // Normalize by max expected rate
378    }
379
380    /// Temporal decoding: use first spike time
381    fn temporal_decode(&self, spike_train: &SpikeTrain<T>) -> Result<T> {
382        if spike_train.spike_times.is_empty() {
383            Ok(T::zero())
384        } else {
385            let first_spike = spike_train.spike_times[0];
386            let max_delay = T::from(20.0).unwrap_or_else(|| T::zero());
387            Ok(T::one() - (first_spike / max_delay).min(T::one()))
388        }
389    }
390
391    /// Weighted spike count decoding
392    fn weighted_spike_count_decode(&self, spike_train: &SpikeTrain<T>) -> Result<T> {
393        if spike_train.spike_times.is_empty() {
394            return Ok(T::zero());
395        }
396
397        let mut weighted_sum = T::zero();
398        let current_time = self.current_time;
399
400        for &spike_time in &spike_train.spike_times {
401            let time_diff = current_time - spike_time;
402            let weight = (-time_diff / T::from(10.0).unwrap_or_else(|| T::zero())).exp(); // Exponential decay
403            weighted_sum = weighted_sum + weight;
404        }
405
406        Ok(weighted_sum)
407    }
408
409    /// Simulate membrane dynamics for one time step
410    pub fn simulate_step(&mut self, input_spikes: &[Spike<T>]) -> Result<Vec<Spike<T>>> {
411        let mut output_spikes = Vec::new();
412        let dt = self.config.time_step;
413
414        // Process input _spikes
415        for spike in input_spikes {
416            self.process_input_spike(spike)?;
417        }
418
419        // Update membrane potentials
420        for neuron_id in 0..self.membrane_potentials.len() {
421            if self.current_time >= self.refractory_until[neuron_id] {
422                self.update_membrane_potential(neuron_id, dt)?;
423
424                // Check for spike threshold
425                if self.membrane_potentials[neuron_id] >= self.membrane_config.threshold_potential {
426                    let spike = self.generate_spike(neuron_id)?;
427                    output_spikes.push(spike);
428                }
429            }
430        }
431
432        // Apply plasticity updates
433        self.update_plasticity(&output_spikes)?;
434
435        // Update homeostatic mechanisms
436        if self.config.homeostatic_config.enable_homeostatic_scaling {
437            self.update_homeostatic_scaling()?;
438        }
439
440        self.current_time = self.current_time + dt;
441
442        Ok(output_spikes)
443    }
444
445    /// Process an input spike
446    fn process_input_spike(&mut self, spike: &Spike<T>) -> Result<()> {
447        let target_neuron = spike.postsynaptic_id.unwrap_or(spike.neuron_id);
448
449        if target_neuron < self.membrane_potentials.len() {
450            // Add synaptic current
451            let synaptic_current = spike.weight * spike.amplitude;
452            self.membrane_potentials[target_neuron] =
453                self.membrane_potentials[target_neuron] + synaptic_current;
454        }
455
456        Ok(())
457    }
458
459    /// Update membrane potential using leaky integrate-and-fire model
460    fn update_membrane_potential(&mut self, neuron_id: usize, dt: T) -> Result<()> {
461        let v = self.membrane_potentials[neuron_id];
462        let v_rest = self.membrane_config.resting_potential;
463        let tau = self.membrane_config.tau_membrane;
464
465        // Leaky integration: dV/dt = (V_rest - V) / tau
466        let dv_dt = (v_rest - v) / tau;
467        let new_v = v + dv_dt * dt;
468
469        self.membrane_potentials[neuron_id] = new_v;
470
471        Ok(())
472    }
473
474    /// Generate a spike when threshold is reached
475    fn generate_spike(&mut self, neuron_id: usize) -> Result<Spike<T>> {
476        // Reset membrane potential
477        self.membrane_potentials[neuron_id] = self.membrane_config.reset_potential;
478
479        // Set refractory period
480        self.refractory_until[neuron_id] =
481            self.current_time + self.membrane_config.refractory_period;
482
483        // Update last spike time
484        self.last_spike_times[neuron_id] = self.current_time;
485
486        // Create spike
487        let spike = Spike {
488            neuron_id,
489            time: self.current_time,
490            amplitude: T::from(1.0).unwrap_or_else(|| T::zero()),
491            width: Some(T::from(1.0).unwrap_or_else(|| T::zero())),
492            weight: T::one(),
493            presynaptic_id: None,
494            postsynaptic_id: None,
495        };
496
497        // Update spike train
498        if let Some(spike_train) = self.spike_trains.get_mut(&neuron_id) {
499            spike_train.spike_times.push(self.current_time);
500            spike_train.spike_count += 1;
501        } else {
502            let spike_train = SpikeTrain::new(neuron_id, vec![self.current_time]);
503            self.spike_trains.insert(neuron_id, spike_train);
504        }
505
506        // Update metrics
507        self.metrics.total_spikes += 1;
508
509        Ok(spike)
510    }
511
512    /// Update synaptic plasticity
513    fn update_plasticity(&mut self, output_spikes: &[Spike<T>]) -> Result<()> {
514        match self.plasticity_model {
515            PlasticityModel::STDP => {
516                self.update_stdp(output_spikes)?;
517            }
518            PlasticityModel::Hebbian => {
519                self.update_hebbian(output_spikes)?;
520            }
521            _ => {
522                // Default to STDP
523                self.update_stdp(output_spikes)?;
524            }
525        }
526
527        Ok(())
528    }
529
530    /// Update STDP (Spike Timing Dependent Plasticity)
531    fn update_stdp(&mut self, output_spikes: &[Spike<T>]) -> Result<()> {
532        for spike in output_spikes {
533            let post_id = spike.neuron_id;
534            let post_time = spike.time;
535
536            // Check all presynaptic connections
537            for pre_id in 0..self.last_spike_times.len() {
538                if pre_id != post_id {
539                    let pre_time = self.last_spike_times[pre_id];
540
541                    if pre_time > T::from(-1000.0).unwrap_or_else(|| T::zero()) {
542                        // Valid spike time
543                        let dt = post_time - pre_time;
544                        let weight_change = self.compute_stdp_update(dt);
545
546                        // Update synaptic weight
547                        self.synaptic_weights[[pre_id, post_id]] =
548                            (self.synaptic_weights[[pre_id, post_id]] + weight_change)
549                                .max(self.stdp_config.weight_min)
550                                .min(self.stdp_config.weight_max);
551                    }
552                }
553            }
554        }
555
556        Ok(())
557    }
558
559    /// Compute STDP weight update
560    fn compute_stdp_update(&self, dt: T) -> T {
561        if dt > T::zero() {
562            // Post-before-pre: LTP (potentiation)
563            let exp_arg = -dt / self.stdp_config.tau_pot;
564            self.stdp_config.learning_rate_pot * exp_arg.exp()
565        } else {
566            // Pre-before-post: LTD (depression)
567            let exp_arg = dt / self.stdp_config.tau_dep;
568            -self.stdp_config.learning_rate_dep * exp_arg.exp()
569        }
570    }
571
572    /// Update Hebbian plasticity
573    fn update_hebbian(&mut self, output_spikes: &[Spike<T>]) -> Result<()> {
574        // Simplified Hebbian learning
575        for spike in output_spikes {
576            let post_id = spike.neuron_id;
577
578            for pre_id in 0..self.membrane_potentials.len() {
579                if pre_id != post_id {
580                    let pre_activity =
581                        self.membrane_potentials[pre_id] / self.membrane_config.threshold_potential;
582
583                    let weight_change = self.stdp_config.learning_rate_pot * pre_activity;
584
585                    self.synaptic_weights[[pre_id, post_id]] =
586                        (self.synaptic_weights[[pre_id, post_id]] + weight_change)
587                            .max(self.stdp_config.weight_min)
588                            .min(self.stdp_config.weight_max);
589                }
590            }
591        }
592
593        Ok(())
594    }
595
596    /// Update homeostatic scaling
597    fn update_homeostatic_scaling(&mut self) -> Result<()> {
598        let target_rate = self.config.homeostatic_config.target_firing_rate;
599        let time_constant = self.config.homeostatic_config.scaling_time_constant;
600        let dt = self.config.time_step;
601
602        for neuron_id in 0..self.homeostatic_scales.len() {
603            if let Some(spike_train) = self.spike_trains.get(&neuron_id) {
604                let current_rate = spike_train.firing_rate;
605                let rate_error = target_rate - current_rate;
606
607                // Exponential approach to target
608                let scale_change = rate_error * dt / time_constant;
609                self.homeostatic_scales[neuron_id] =
610                    self.homeostatic_scales[neuron_id] + scale_change;
611
612                // Apply scaling to synaptic weights
613                for pre_id in 0..self.synaptic_weights.nrows() {
614                    self.synaptic_weights[[pre_id, neuron_id]] = self.synaptic_weights
615                        [[pre_id, neuron_id]]
616                        * self.homeostatic_scales[neuron_id];
617                }
618            }
619        }
620
621        Ok(())
622    }
623
624    /// Get current neuromorphic metrics
625    pub fn get_metrics(&self) -> &NeuromorphicMetrics<T> {
626        &self.metrics
627    }
628
629    /// Reset the optimizer state
630    pub fn reset(&mut self) {
631        self.current_time = T::zero();
632        self.membrane_potentials
633            .fill(self.membrane_config.resting_potential);
634        self.last_spike_times
635            .fill(T::from(-1000.0).unwrap_or_else(|| T::zero()));
636        self.refractory_until.fill(T::zero());
637        self.spike_trains.clear();
638        self.spike_buffer.clear();
639        self.metrics = NeuromorphicMetrics::default();
640    }
641}
642
643/// Spike train optimizer for temporal pattern learning
644pub struct SpikeTrainOptimizer<
645    T: Float + Debug + scirs2_core::ndarray::ScalarOperand + std::fmt::Debug + Send + Sync,
646> {
647    /// Configuration
648    config: SpikingConfig<T>,
649
650    /// Spike pattern templates
651    pattern_templates: Vec<SpikePattern<T>>,
652
653    /// Pattern matching threshold
654    matching_threshold: T,
655
656    /// Learning rate for pattern adaptation
657    pattern_learning_rate: T,
658
659    /// Temporal kernel for pattern comparison
660    temporal_kernel: TemporalKernel<T>,
661}
662
663/// Spike pattern template
664#[derive(Debug, Clone)]
665pub struct SpikePattern<T: Float + Debug + Send + Sync + 'static> {
666    /// Pattern ID
667    pub pattern_id: usize,
668
669    /// Spike times relative to pattern start
670    pub relative_spike_times: Vec<T>,
671
672    /// Pattern duration
673    pub duration: T,
674
675    /// Pattern weight/importance
676    pub weight: T,
677
678    /// Number of times pattern was observed
679    pub observation_count: usize,
680}
681
682/// Temporal kernel for pattern matching
683#[derive(Debug, Clone)]
684pub struct TemporalKernel<T: Float + Debug + Send + Sync + 'static> {
685    /// Kernel type
686    pub kernel_type: TemporalKernelType,
687
688    /// Kernel width (ms)
689    pub width: T,
690
691    /// Kernel parameters
692    pub parameters: Vec<T>,
693}
694
695/// Types of temporal kernels
696#[derive(Debug, Clone, Copy)]
697pub enum TemporalKernelType {
698    /// Gaussian kernel
699    Gaussian,
700
701    /// Exponential kernel
702    Exponential,
703
704    /// Alpha function kernel
705    Alpha,
706
707    /// Rectangular kernel
708    Rectangular,
709}
710
711impl<T: Float + Debug + Send + Sync + scirs2_core::ndarray::ScalarOperand + std::fmt::Debug>
712    SpikeTrainOptimizer<T>
713{
714    /// Create a new spike train optimizer
715    pub fn new(config: SpikingConfig<T>) -> Self {
716        Self {
717            config,
718            pattern_templates: Vec::new(),
719            matching_threshold: T::from(0.8).unwrap_or_else(|| T::zero()),
720            pattern_learning_rate: T::from(0.1).unwrap_or_else(|| T::zero()),
721            temporal_kernel: TemporalKernel {
722                kernel_type: TemporalKernelType::Gaussian,
723                width: T::from(5.0).unwrap_or_else(|| T::zero()),
724                parameters: vec![T::one()],
725            },
726        }
727    }
728
729    /// Learn spike patterns from training data
730    pub fn learn_patterns(&mut self, spike_trains: &[SpikeTrain<T>]) -> Result<()> {
731        for spike_train in spike_trains {
732            self.extract_and_learn_patterns(spike_train)?;
733        }
734
735        Ok(())
736    }
737
738    /// Extract patterns from a spike train
739    fn extract_and_learn_patterns(&mut self, spike_train: &SpikeTrain<T>) -> Result<()> {
740        let window_size = T::from(50.0).unwrap_or_else(|| T::zero()); // 50 ms windows
741        let step_size = T::from(10.0).unwrap_or_else(|| T::zero()); // 10 ms steps
742
743        let mut window_start = T::zero();
744
745        while window_start < spike_train.duration {
746            let window_end = window_start + window_size;
747
748            // Extract spikes in current window
749            let window_spikes: Vec<T> = spike_train
750                .spike_times
751                .iter()
752                .filter(|&&t| t >= window_start && t < window_end)
753                .map(|&t| t - window_start) // Make relative to window start
754                .collect();
755
756            if !window_spikes.is_empty() {
757                let pattern = SpikePattern {
758                    pattern_id: self.pattern_templates.len(),
759                    relative_spike_times: window_spikes,
760                    duration: window_size,
761                    weight: T::one(),
762                    observation_count: 1,
763                };
764
765                // Check if similar pattern exists
766                if let Some(similar_pattern_id) = self.find_similar_pattern(&pattern) {
767                    self.update_pattern(similar_pattern_id, &pattern)?;
768                } else {
769                    self.pattern_templates.push(pattern);
770                }
771            }
772
773            window_start = window_start + step_size;
774        }
775
776        Ok(())
777    }
778
779    /// Find similar existing pattern
780    fn find_similar_pattern(&self, new_pattern: &SpikePattern<T>) -> Option<usize> {
781        for (i, existing_pattern) in self.pattern_templates.iter().enumerate() {
782            let similarity = self.compute_pattern_similarity(new_pattern, existing_pattern);
783            if similarity > self.matching_threshold {
784                return Some(i);
785            }
786        }
787
788        None
789    }
790
791    /// Compute similarity between two spike patterns
792    fn compute_pattern_similarity(
793        &self,
794        pattern1: &SpikePattern<T>,
795        pattern2: &SpikePattern<T>,
796    ) -> T {
797        // Use Victor-Purpura distance or similar metric
798        let max_spikes = pattern1
799            .relative_spike_times
800            .len()
801            .max(pattern2.relative_spike_times.len());
802        if max_spikes == 0 {
803            return T::one();
804        }
805
806        // Simplified similarity based on spike count and timing
807        let count_diff = (pattern1.relative_spike_times.len() as i32
808            - pattern2.relative_spike_times.len() as i32)
809            .abs() as f64;
810        let count_similarity =
811            T::one() - T::from(count_diff / max_spikes as f64).unwrap_or_else(|| T::zero());
812
813        // Add temporal similarity if both patterns have spikes
814        if !pattern1.relative_spike_times.is_empty() && !pattern2.relative_spike_times.is_empty() {
815            let temporal_similarity = self.compute_temporal_similarity(
816                &pattern1.relative_spike_times,
817                &pattern2.relative_spike_times,
818            );
819            (count_similarity + temporal_similarity) / T::from(2.0).unwrap_or_else(|| T::zero())
820        } else {
821            count_similarity
822        }
823    }
824
825    /// Compute temporal similarity between spike time sequences
826    fn compute_temporal_similarity(&self, spikes1: &[T], spikes2: &[T]) -> T {
827        // Use cross-correlation or DTW-like measure
828        let mut max_correlation = T::zero();
829        let max_shift = T::from(10.0).unwrap_or_else(|| T::zero()); // 10 ms max shift
830        let shift_step = T::from(1.0).unwrap_or_else(|| T::zero());
831
832        let mut shift = -max_shift;
833        while shift <= max_shift {
834            let correlation = self.compute_spike_correlation(spikes1, spikes2, shift);
835            max_correlation = max_correlation.max(correlation);
836            shift = shift + shift_step;
837        }
838
839        max_correlation
840    }
841
842    /// Compute spike correlation with time shift
843    fn compute_spike_correlation(&self, spikes1: &[T], spikes2: &[T], shift: T) -> T {
844        let mut correlation = T::zero();
845        let kernel_width = self.temporal_kernel.width;
846
847        for &t1 in spikes1 {
848            for &t2 in spikes2 {
849                let dt = (t1 - (t2 + shift)).abs();
850                let kernel_value = (-dt * dt
851                    / (T::from(2.0).unwrap_or_else(|| T::zero()) * kernel_width * kernel_width))
852                    .exp();
853                correlation = correlation + kernel_value;
854            }
855        }
856
857        // Normalize by number of spike pairs
858        if !spikes1.is_empty() && !spikes2.is_empty() {
859            correlation / T::from(spikes1.len() * spikes2.len()).unwrap()
860        } else {
861            T::zero()
862        }
863    }
864
865    /// Update existing pattern with new observation
866    fn update_pattern(&mut self, pattern_id: usize, new_pattern: &SpikePattern<T>) -> Result<()> {
867        if let Some(existing_pattern) = self.pattern_templates.get_mut(pattern_id) {
868            // Update _pattern using exponential moving average
869            let alpha = self.pattern_learning_rate;
870
871            // Update spike times (simplified)
872            if existing_pattern.relative_spike_times.len() == new_pattern.relative_spike_times.len()
873            {
874                for (existing_time, &new_time) in existing_pattern
875                    .relative_spike_times
876                    .iter_mut()
877                    .zip(new_pattern.relative_spike_times.iter())
878                {
879                    *existing_time = *existing_time * (T::one() - alpha) + new_time * alpha;
880                }
881            }
882
883            existing_pattern.observation_count += 1;
884            existing_pattern.weight =
885                existing_pattern.weight * (T::one() - alpha) + new_pattern.weight * alpha;
886        }
887
888        Ok(())
889    }
890
891    /// Recognize patterns in new spike train
892    pub fn recognize_patterns(&self, spike_train: &SpikeTrain<T>) -> Result<Vec<(usize, T, T)>> {
893        let mut recognized_patterns = Vec::new();
894        let window_size = T::from(50.0).unwrap_or_else(|| T::zero());
895        let step_size = T::from(5.0).unwrap_or_else(|| T::zero());
896
897        let mut window_start = T::zero();
898
899        while window_start < spike_train.duration {
900            let window_end = window_start + window_size;
901
902            let window_spikes: Vec<T> = spike_train
903                .spike_times
904                .iter()
905                .filter(|&&t| t >= window_start && t < window_end)
906                .map(|&t| t - window_start)
907                .collect();
908
909            if !window_spikes.is_empty() {
910                let test_pattern = SpikePattern {
911                    pattern_id: 0,
912                    relative_spike_times: window_spikes,
913                    duration: window_size,
914                    weight: T::one(),
915                    observation_count: 1,
916                };
917
918                // Find best matching pattern
919                let mut best_match = (0, T::zero());
920                for (i, template) in self.pattern_templates.iter().enumerate() {
921                    let similarity = self.compute_pattern_similarity(&test_pattern, template);
922                    if similarity > best_match.1 {
923                        best_match = (i, similarity);
924                    }
925                }
926
927                if best_match.1 > self.matching_threshold {
928                    recognized_patterns.push((best_match.0, window_start, best_match.1));
929                }
930            }
931
932            window_start = window_start + step_size;
933        }
934
935        Ok(recognized_patterns)
936    }
937
938    /// Get learned patterns
939    pub fn get_patterns(&self) -> &[SpikePattern<T>] {
940        &self.pattern_templates
941    }
942}