use crate::error::ConfigError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[non_exhaustive]
pub enum LearningRule {
#[default]
Stdp,
PpProp,
}
#[derive(Debug, Clone)]
pub struct SpikeNetConfig {
pub n_hidden: usize,
pub n_outputs: usize,
pub alpha: f64,
pub kappa: f64,
pub kappa_out: f64,
pub learning_rate: f64,
pub v_thr: f64,
pub gamma: f64,
pub spike_threshold: f64,
pub seed: u64,
pub weight_init_range: f64,
pub astrocyte: bool,
pub astrocyte_tau: f64,
pub learning_rule: LearningRule,
}
impl Default for SpikeNetConfig {
fn default() -> Self {
Self {
n_hidden: 64,
n_outputs: 1,
alpha: 0.95,
kappa: 0.99,
kappa_out: 0.90,
learning_rate: 0.001,
v_thr: 0.50,
gamma: 0.30,
spike_threshold: 0.05,
seed: 42,
weight_init_range: 0.10,
astrocyte: false,
astrocyte_tau: 1000.0,
learning_rule: LearningRule::Stdp,
}
}
}
impl SpikeNetConfig {
pub fn builder() -> SpikeNetConfigBuilder {
SpikeNetConfigBuilder::new()
}
pub fn validate(&self) -> Result<(), ConfigError> {
if self.n_hidden == 0 {
return Err(ConfigError::out_of_range(
"n_hidden",
"must be > 0",
self.n_hidden,
));
}
if self.n_outputs == 0 {
return Err(ConfigError::out_of_range(
"n_outputs",
"must be > 0",
self.n_outputs,
));
}
if self.alpha < 0.0 || self.alpha > 1.0 {
return Err(ConfigError::out_of_range(
"alpha",
"must be in [0.0, 1.0]",
self.alpha,
));
}
if self.kappa < 0.0 || self.kappa > 1.0 {
return Err(ConfigError::out_of_range(
"kappa",
"must be in [0.0, 1.0]",
self.kappa,
));
}
if self.kappa_out < 0.0 || self.kappa_out > 1.0 {
return Err(ConfigError::out_of_range(
"kappa_out",
"must be in [0.0, 1.0]",
self.kappa_out,
));
}
if self.learning_rate <= 0.0 {
return Err(ConfigError::out_of_range(
"learning_rate",
"must be > 0.0",
self.learning_rate,
));
}
if self.learning_rate > 1.0 {
return Err(ConfigError::out_of_range(
"learning_rate",
"must be <= 1.0",
self.learning_rate,
));
}
if self.v_thr <= 0.0 {
return Err(ConfigError::out_of_range(
"v_thr",
"must be > 0.0",
self.v_thr,
));
}
if self.gamma < 0.0 || self.gamma > 1.0 {
return Err(ConfigError::out_of_range(
"gamma",
"must be in [0.0, 1.0]",
self.gamma,
));
}
if self.spike_threshold <= 0.0 {
return Err(ConfigError::out_of_range(
"spike_threshold",
"must be > 0.0",
self.spike_threshold,
));
}
if self.weight_init_range <= 0.0 {
return Err(ConfigError::out_of_range(
"weight_init_range",
"must be > 0.0",
self.weight_init_range,
));
}
if self.weight_init_range > 1.9 {
return Err(ConfigError::out_of_range(
"weight_init_range",
"must be <= 1.9 (Q1.14 limit)",
self.weight_init_range,
));
}
if self.astrocyte_tau <= 0.0 {
return Err(ConfigError::out_of_range(
"astrocyte_tau",
"must be > 0.0",
self.astrocyte_tau,
));
}
Ok(())
}
}
pub struct SpikeNetConfigBuilder {
config: SpikeNetConfig,
}
impl SpikeNetConfigBuilder {
pub fn new() -> Self {
Self {
config: SpikeNetConfig::default(),
}
}
pub fn n_hidden(mut self, n: usize) -> Self {
self.config.n_hidden = n;
self
}
pub fn n_outputs(mut self, n: usize) -> Self {
self.config.n_outputs = n;
self
}
pub fn alpha(mut self, alpha: f64) -> Self {
self.config.alpha = alpha;
self
}
pub fn kappa(mut self, kappa: f64) -> Self {
self.config.kappa = kappa;
self
}
pub fn kappa_out(mut self, kappa_out: f64) -> Self {
self.config.kappa_out = kappa_out;
self
}
pub fn learning_rate(mut self, lr: f64) -> Self {
self.config.learning_rate = lr;
self
}
pub fn v_thr(mut self, v_thr: f64) -> Self {
self.config.v_thr = v_thr;
self
}
pub fn gamma(mut self, gamma: f64) -> Self {
self.config.gamma = gamma;
self
}
pub fn spike_threshold(mut self, threshold: f64) -> Self {
self.config.spike_threshold = threshold;
self
}
pub fn seed(mut self, seed: u64) -> Self {
self.config.seed = seed;
self
}
pub fn weight_init_range(mut self, range: f64) -> Self {
self.config.weight_init_range = range;
self
}
pub fn astrocyte(mut self, enabled: bool) -> Self {
self.config.astrocyte = enabled;
self
}
pub fn astrocyte_tau(mut self, tau: f64) -> Self {
self.config.astrocyte_tau = tau;
self
}
pub fn learning_rule(mut self, rule: LearningRule) -> Self {
self.config.learning_rule = rule;
self
}
pub fn build(self) -> Result<SpikeNetConfig, ConfigError> {
self.config.validate()?;
Ok(self.config)
}
}
impl Default for SpikeNetConfigBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pp_prop_variant_constructible() {
let config = SpikeNetConfig::builder()
.n_hidden(16)
.learning_rate(0.01)
.learning_rule(LearningRule::PpProp)
.build()
.unwrap();
assert_eq!(
config.learning_rule,
LearningRule::PpProp,
"builder must store PpProp, got {:?}",
config.learning_rule
);
let default_config = SpikeNetConfig::default();
assert_eq!(
default_config.learning_rule,
LearningRule::Stdp,
"default learning_rule must be Stdp, got {:?}",
default_config.learning_rule
);
assert_ne!(
LearningRule::Stdp,
LearningRule::PpProp,
"Stdp and PpProp must be distinct variants"
);
}
#[test]
fn default_config_is_valid() {
let config = SpikeNetConfig::default();
assert!(config.validate().is_ok());
}
#[test]
fn builder_produces_valid_config() {
let config = SpikeNetConfig::builder()
.n_hidden(32)
.n_outputs(2)
.learning_rate(0.005)
.alpha(0.9)
.build()
.unwrap();
assert_eq!(config.n_hidden, 32);
assert_eq!(config.n_outputs, 2);
assert!((config.learning_rate - 0.005).abs() < 1e-10);
assert!((config.alpha - 0.9).abs() < 1e-10);
}
#[test]
fn zero_hidden_rejected() {
let result = SpikeNetConfig::builder().n_hidden(0).build();
assert!(result.is_err());
}
#[test]
fn zero_outputs_rejected() {
let result = SpikeNetConfig::builder().n_outputs(0).build();
assert!(result.is_err());
}
#[test]
fn negative_eta_rejected() {
let result = SpikeNetConfig::builder().learning_rate(-0.01).build();
assert!(result.is_err());
}
#[test]
fn alpha_out_of_range_rejected() {
assert!(SpikeNetConfig::builder().alpha(1.5).build().is_err());
assert!(SpikeNetConfig::builder().alpha(-0.1).build().is_err());
}
#[test]
fn weight_range_too_large_rejected() {
assert!(SpikeNetConfig::builder()
.weight_init_range(2.0)
.build()
.is_err());
}
#[test]
fn all_builder_methods_chain() {
let config = SpikeNetConfig::builder()
.n_hidden(16)
.n_outputs(3)
.alpha(0.85)
.kappa(0.95)
.kappa_out(0.8)
.learning_rate(0.01)
.v_thr(0.3)
.gamma(0.5)
.spike_threshold(0.1)
.seed(999)
.weight_init_range(0.2)
.build()
.unwrap();
assert_eq!(config.n_hidden, 16);
assert_eq!(config.n_outputs, 3);
assert!((config.alpha - 0.85).abs() < 1e-10);
assert!((config.kappa - 0.95).abs() < 1e-10);
assert!((config.kappa_out - 0.8).abs() < 1e-10);
assert!((config.learning_rate - 0.01).abs() < 1e-10);
assert!((config.v_thr - 0.3).abs() < 1e-10);
assert!((config.gamma - 0.5).abs() < 1e-10);
assert!((config.spike_threshold - 0.1).abs() < 1e-10);
assert_eq!(config.seed, 999);
assert!((config.weight_init_range - 0.2).abs() < 1e-10);
}
}