use crate::pipeline::StreamingPreprocessor;
use irithyll_core::snn::astrocyte::AstrocyteMode;
use irithyll_core::snn::lif::{f64_to_q14, Q14_HALF, Q14_ONE};
use irithyll_core::snn::network_fixed::{SpikeNetFixed, SpikeNetFixedConfig};
pub struct SpikePreprocessor {
n_hidden: usize,
n_outputs: usize,
seed: u64,
inner: Option<SpikeNetFixed>,
input_scale: Vec<f64>,
input_max_abs: Vec<f64>,
quantized_buf: Vec<i16>,
n_input: usize,
n_samples: u64,
}
impl SpikePreprocessor {
pub fn new(n_hidden: usize, seed: u64) -> Self {
Self {
n_hidden,
n_outputs: 1,
seed,
inner: None,
input_scale: Vec::new(),
input_max_abs: Vec::new(),
quantized_buf: Vec::new(),
n_input: 0,
n_samples: 0,
}
}
pub fn with_outputs(n_hidden: usize, n_outputs: usize, seed: u64) -> Self {
Self {
n_hidden,
n_outputs,
seed,
inner: None,
input_scale: Vec::new(),
input_max_abs: Vec::new(),
quantized_buf: Vec::new(),
n_input: 0,
n_samples: 0,
}
}
fn initialize(&mut self, n_input: usize) {
self.n_input = n_input;
let config = SpikeNetFixedConfig {
n_input,
n_hidden: self.n_hidden,
n_output: self.n_outputs,
alpha: f64_to_q14(0.95),
kappa: f64_to_q14(0.99),
kappa_out: f64_to_q14(0.9),
eta: 0, v_thr: f64_to_q14(0.5),
gamma: f64_to_q14(0.3),
spike_threshold: f64_to_q14(0.05),
seed: self.seed,
weight_init_range: f64_to_q14(0.1),
use_astrocyte: false,
astrocyte_tau: 1000.0,
astrocyte_mode: AstrocyteMode::WeightMod,
};
self.inner = Some(SpikeNetFixed::new(config));
self.input_scale = vec![Q14_HALF as f64; n_input];
self.input_max_abs = vec![1.0; n_input];
self.quantized_buf = vec![0i16; n_input];
}
fn update_scaling(&mut self, features: &[f64]) {
for (i, &feat) in features.iter().enumerate().take(self.n_input) {
let abs_val = feat.abs();
if abs_val > self.input_max_abs[i] {
self.input_max_abs[i] = abs_val;
}
}
for i in 0..self.n_input {
if self.input_max_abs[i] > 1e-10 {
self.input_scale[i] = Q14_HALF as f64 / self.input_max_abs[i];
}
}
}
fn quantize_input(&mut self, features: &[f64]) {
for (i, &feat) in features.iter().enumerate().take(self.n_input) {
let scaled = feat * self.input_scale[i];
self.quantized_buf[i] = scaled.clamp(i16::MIN as f64, i16::MAX as f64) as i16;
}
}
fn extract_features(&self) -> Vec<f64> {
let inner = self.inner.as_ref().unwrap();
let mut out = Vec::with_capacity(self.n_hidden + self.n_outputs);
for &s in inner.hidden_spikes() {
out.push(s as f64);
}
let scale = 1.0 / Q14_ONE as f64;
for v in inner.predict_all_f64(scale) {
out.push(v);
}
out
}
}
impl StreamingPreprocessor for SpikePreprocessor {
fn update_and_transform(&mut self, features: &[f64]) -> Vec<f64> {
if self.inner.is_none() {
self.initialize(features.len());
}
self.update_scaling(features);
self.quantize_input(features);
let quantized: Vec<i16> = self.quantized_buf.clone();
self.inner.as_mut().unwrap().forward(&quantized);
self.n_samples += 1;
self.extract_features()
}
fn transform(&self, _features: &[f64]) -> Vec<f64> {
if self.inner.is_none() {
return vec![0.0; self.n_hidden + self.n_outputs];
}
self.extract_features()
}
fn output_dim(&self) -> Option<usize> {
if self.inner.is_some() {
Some(self.n_hidden + self.n_outputs)
} else {
None
}
}
fn reset(&mut self) {
if let Some(ref mut net) = self.inner {
net.reset();
}
for v in self.input_max_abs.iter_mut() {
*v = 1.0;
}
for v in self.input_scale.iter_mut() {
*v = Q14_HALF as f64;
}
self.n_samples = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn lazy_initialization_on_first_update() {
let mut pp = SpikePreprocessor::new(8, 42);
assert!(pp.inner.is_none());
assert_eq!(pp.output_dim(), None);
let out = pp.update_and_transform(&[1.0, 2.0, 3.0]);
assert!(pp.inner.is_some());
assert_eq!(pp.output_dim(), Some(9)); assert_eq!(out.len(), 9);
}
#[test]
fn transform_before_init_returns_zeros() {
let pp = SpikePreprocessor::new(4, 42);
let out = pp.transform(&[1.0, 2.0]);
assert_eq!(out.len(), 5); assert!(out.iter().all(|&v| v == 0.0));
}
#[test]
fn output_dim_matches_hidden_plus_output() {
let mut pp = SpikePreprocessor::with_outputs(16, 3, 42);
pp.update_and_transform(&[1.0, 2.0]);
assert_eq!(pp.output_dim(), Some(19)); }
#[test]
fn spike_features_are_binary() {
let mut pp = SpikePreprocessor::new(8, 42);
pp.update_and_transform(&[0.0, 0.0]);
let out = pp.update_and_transform(&[10.0, -10.0]);
for (i, &val) in out.iter().enumerate().take(8) {
assert!(
val == 0.0 || val == 1.0,
"spike feature {} should be 0.0 or 1.0, got {}",
i,
val
);
}
}
#[test]
fn reset_clears_preprocessor_state() {
let mut pp = SpikePreprocessor::new(4, 42);
pp.update_and_transform(&[1.0, 2.0]);
pp.update_and_transform(&[3.0, 4.0]);
assert!(pp.n_samples > 0);
pp.reset();
assert_eq!(pp.n_samples, 0);
}
#[test]
fn multiple_timesteps_produce_different_features() {
let mut pp = SpikePreprocessor::new(16, 42);
let mut outputs = Vec::new();
for step in 0..20 {
let t = step as f64 * 0.5;
let input = [t.sin() * 10.0, t.cos() * 10.0, (t * 2.0).sin() * 5.0];
outputs.push(pp.update_and_transform(&input));
}
let any_diff = outputs
.windows(2)
.any(|w| w[0].iter().zip(&w[1]).any(|(a, b)| (a - b).abs() > 1e-15));
assert!(
any_diff,
"features should change across timesteps with varying oscillating inputs"
);
}
#[test]
fn transform_does_not_advance_state() {
let mut pp = SpikePreprocessor::new(8, 42);
pp.update_and_transform(&[1.0, 2.0]);
pp.update_and_transform(&[3.0, 4.0]);
let t1 = pp.transform(&[5.0, 6.0]);
let t2 = pp.transform(&[5.0, 6.0]);
assert_eq!(t1, t2, "transform should not change state");
}
#[test]
fn works_in_pipeline() {
use crate::learner::StreamingLearner;
use crate::learners::StreamingLinearModel;
use crate::pipeline::Pipeline;
let mut pipeline = Pipeline::builder()
.pipe(SpikePreprocessor::new(8, 42))
.learner(StreamingLinearModel::new(0.01));
pipeline.train(&[1.0, 2.0], 3.0);
pipeline.train(&[4.0, 5.0], 6.0);
let pred = pipeline.predict(&[2.5, 3.5]);
assert!(pred.is_finite(), "pipeline prediction should be finite");
}
}