Skip to main content

irithyll_core/snn/
spike_encoding.rs

1//! Delta spike encoding for converting continuous features to spike trains.
2//!
3//! The delta encoder compares each feature's current value against its previous
4//! value and emits UP/DOWN spikes when the absolute change exceeds a threshold.
5//! Each raw feature maps to two encoding channels (positive and negative change),
6//! so N raw features produce 2*N spike channels.
7//!
8//! # Encoding Rule
9//!
10//! ```text
11//! spike_pos[i] = 1 if (x[i] - x_prev[i]) >  threshold[i]
12//! spike_neg[i] = 1 if (x[i] - x_prev[i]) < -threshold[i]
13//! ```
14//!
15//! The output buffer has layout: `[pos_0, neg_0, pos_1, neg_1, ..., pos_{N-1}, neg_{N-1}]`
16//! where even indices are positive (UP) spikes and odd indices are negative (DOWN) spikes.
17
18use alloc::vec;
19use alloc::vec::Vec;
20
21/// Delta spike encoder for fixed-point inputs.
22///
23/// Converts continuous-valued features (Q1.14 i16) into binary spike trains
24/// by detecting temporal changes exceeding per-feature thresholds.
25///
26/// # Layout
27///
28/// For `n_features` input features, the output has `2 * n_features` channels:
29/// - Channel `2*i` = positive spike (feature increased)
30/// - Channel `2*i + 1` = negative spike (feature decreased)
31pub struct DeltaEncoderFixed {
32    /// Previous input values for delta computation.
33    prev: Vec<i16>,
34    /// Per-feature threshold for spike emission.
35    threshold: Vec<i16>,
36    /// Whether we have seen at least one input (first input produces no spikes).
37    initialized: bool,
38}
39
40impl DeltaEncoderFixed {
41    /// Create a new delta encoder with a uniform threshold for all features.
42    ///
43    /// # Arguments
44    ///
45    /// * `n_features` -- number of raw input features
46    /// * `threshold` -- uniform spike threshold in Q1.14
47    pub fn new(n_features: usize, threshold: i16) -> Self {
48        Self {
49            prev: vec![0; n_features],
50            threshold: vec![threshold; n_features],
51            initialized: false,
52        }
53    }
54
55    /// Create a new delta encoder with per-feature thresholds.
56    ///
57    /// # Arguments
58    ///
59    /// * `thresholds` -- one threshold per input feature, in Q1.14
60    pub fn new_per_feature(thresholds: Vec<i16>) -> Self {
61        let n = thresholds.len();
62        Self {
63            prev: vec![0; n],
64            threshold: thresholds,
65            initialized: false,
66        }
67    }
68
69    /// Number of raw input features.
70    #[inline]
71    pub fn n_features(&self) -> usize {
72        self.prev.len()
73    }
74
75    /// Number of output spike channels (always `2 * n_features`).
76    #[inline]
77    pub fn n_output_channels(&self) -> usize {
78        self.prev.len() * 2
79    }
80
81    /// Encode an input vector into a spike buffer.
82    ///
83    /// The `out_spikes` buffer must have length `>= 2 * n_features`. It will be
84    /// filled with 0s and 1s: even indices for positive spikes, odd for negative.
85    ///
86    /// On the very first call, no spikes are emitted (the encoder needs a
87    /// previous value to compute deltas). The previous values are stored for
88    /// the next call.
89    ///
90    /// # Panics
91    ///
92    /// Panics if `input.len() != n_features` or `out_spikes.len() < 2 * n_features`.
93    pub fn encode(&mut self, input: &[i16], out_spikes: &mut [u8]) {
94        let n = self.prev.len();
95        assert_eq!(
96            input.len(),
97            n,
98            "input length {} does not match encoder n_features {}",
99            input.len(),
100            n
101        );
102        assert!(
103            out_spikes.len() >= 2 * n,
104            "out_spikes length {} too small for {} channels",
105            out_spikes.len(),
106            2 * n
107        );
108
109        if !self.initialized {
110            // First call: store input, emit no spikes
111            self.prev.copy_from_slice(input);
112            for spike in out_spikes[..2 * n].iter_mut() {
113                *spike = 0;
114            }
115            self.initialized = true;
116            return;
117        }
118
119        for i in 0..n {
120            let delta = input[i] as i32 - self.prev[i] as i32;
121            let thr = self.threshold[i] as i32;
122
123            // Positive spike: feature increased past threshold
124            out_spikes[2 * i] = if delta > thr { 1 } else { 0 };
125            // Negative spike: feature decreased past threshold
126            out_spikes[2 * i + 1] = if delta < -thr { 1 } else { 0 };
127        }
128
129        // Update previous values
130        self.prev.copy_from_slice(input);
131    }
132
133    /// Reset the encoder to its initial state.
134    ///
135    /// The next call to `encode` will produce no spikes (warmup step).
136    pub fn reset(&mut self) {
137        for v in self.prev.iter_mut() {
138            *v = 0;
139        }
140        self.initialized = false;
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147
148    #[test]
149    fn first_call_produces_no_spikes() {
150        let mut enc = DeltaEncoderFixed::new(3, 100);
151        let input = [500_i16, -200, 1000];
152        let mut spikes = vec![0u8; 6];
153        enc.encode(&input, &mut spikes);
154        assert!(
155            spikes.iter().all(|&s| s == 0),
156            "first call should produce no spikes"
157        );
158    }
159
160    #[test]
161    fn positive_spike_on_increase() {
162        let mut enc = DeltaEncoderFixed::new(2, 100);
163        let mut spikes = vec![0u8; 4];
164
165        // First call: baseline
166        enc.encode(&[0, 0], &mut spikes);
167
168        // Second call: feature 0 increases by 200 (> threshold 100)
169        enc.encode(&[200, 0], &mut spikes);
170        assert_eq!(spikes[0], 1, "feature 0 should have positive spike");
171        assert_eq!(spikes[1], 0, "feature 0 should not have negative spike");
172        assert_eq!(spikes[2], 0, "feature 1 should have no positive spike");
173        assert_eq!(spikes[3], 0, "feature 1 should have no negative spike");
174    }
175
176    #[test]
177    fn negative_spike_on_decrease() {
178        let mut enc = DeltaEncoderFixed::new(2, 100);
179        let mut spikes = vec![0u8; 4];
180
181        enc.encode(&[500, 500], &mut spikes);
182        enc.encode(&[500, 200], &mut spikes); // feature 1 decreases by 300
183
184        assert_eq!(spikes[0], 0, "feature 0 pos should be 0");
185        assert_eq!(spikes[1], 0, "feature 0 neg should be 0");
186        assert_eq!(spikes[2], 0, "feature 1 pos should be 0");
187        assert_eq!(spikes[3], 1, "feature 1 should have negative spike");
188    }
189
190    #[test]
191    fn no_spike_within_threshold() {
192        let mut enc = DeltaEncoderFixed::new(1, 500);
193        let mut spikes = vec![0u8; 2];
194
195        enc.encode(&[1000], &mut spikes);
196        // Change of 100 is within threshold of 500
197        enc.encode(&[1100], &mut spikes);
198
199        assert_eq!(spikes[0], 0, "should not spike for small increase");
200        assert_eq!(spikes[1], 0, "should not spike for small change");
201    }
202
203    #[test]
204    fn both_spikes_in_same_step() {
205        let mut enc = DeltaEncoderFixed::new(2, 100);
206        let mut spikes = vec![0u8; 4];
207
208        enc.encode(&[0, 1000], &mut spikes);
209        // Feature 0 increases, feature 1 decreases
210        enc.encode(&[500, 500], &mut spikes);
211
212        assert_eq!(spikes[0], 1, "feature 0 should have positive spike");
213        assert_eq!(spikes[1], 0, "feature 0 should not have negative spike");
214        assert_eq!(spikes[2], 0, "feature 1 should not have positive spike");
215        assert_eq!(spikes[3], 1, "feature 1 should have negative spike");
216    }
217
218    #[test]
219    fn per_feature_thresholds() {
220        let thresholds = vec![100_i16, 1000];
221        let mut enc = DeltaEncoderFixed::new_per_feature(thresholds);
222        let mut spikes = vec![0u8; 4];
223
224        enc.encode(&[0, 0], &mut spikes);
225        // Feature 0: change 200 > threshold 100 -> spike
226        // Feature 1: change 200 < threshold 1000 -> no spike
227        enc.encode(&[200, 200], &mut spikes);
228
229        assert_eq!(spikes[0], 1, "feature 0 should spike (200 > 100)");
230        assert_eq!(spikes[2], 0, "feature 1 should not spike (200 < 1000)");
231    }
232
233    #[test]
234    fn reset_clears_state() {
235        let mut enc = DeltaEncoderFixed::new(2, 100);
236        let mut spikes = vec![0u8; 4];
237
238        enc.encode(&[1000, 2000], &mut spikes);
239        enc.encode(&[2000, 3000], &mut spikes);
240        assert!(
241            spikes.iter().any(|&s| s == 1),
242            "should have spikes before reset"
243        );
244
245        enc.reset();
246        // After reset, first call should produce no spikes
247        enc.encode(&[5000, 5000], &mut spikes);
248        assert!(
249            spikes.iter().all(|&s| s == 0),
250            "after reset, first call should produce no spikes"
251        );
252    }
253
254    #[test]
255    fn sequential_encoding_uses_updated_prev() {
256        let mut enc = DeltaEncoderFixed::new(1, 100);
257        let mut spikes = vec![0u8; 2];
258
259        enc.encode(&[0], &mut spikes);
260        enc.encode(&[500], &mut spikes); // delta = 500
261        assert_eq!(spikes[0], 1, "first increase should spike");
262
263        enc.encode(&[600], &mut spikes); // delta = 100, NOT > threshold
264        assert_eq!(spikes[0], 0, "small subsequent change should not spike");
265
266        enc.encode(&[1200], &mut spikes); // delta = 600 from 600
267        assert_eq!(spikes[0], 1, "large subsequent change should spike");
268    }
269
270    #[test]
271    fn output_channel_count() {
272        let enc = DeltaEncoderFixed::new(5, 100);
273        assert_eq!(enc.n_features(), 5);
274        assert_eq!(enc.n_output_channels(), 10);
275    }
276}