use eenn::FunctionRegistry;
use eenn::models::{NeuronDef, StageDef};
#[test]
fn neuron_serialization_roundtrip_and_rehydrate() {
let mut reg = FunctionRegistry::empty();
reg.register_fn("relu", eenn::relu, "ReLU");
reg.register_factory("linear", |params: &[f32]| {
let w = params.first().cloned().unwrap_or(1.0);
let b = params.get(1).cloned().unwrap_or(0.0);
std::sync::Arc::new(move |x: f32| w * x + b)
});
let def = NeuronDef::new(
vec![
StageDef::Named("relu".to_string()),
StageDef::Factory {
name: "linear".to_string(),
params: vec![2.0, 1.0],
},
],
StageDef::Bias { b: 0.5 },
);
let bytes = def.to_bytes().expect("serialize");
let def2 = NeuronDef::from_bytes(&bytes).expect("deserialize");
let neuron = def2.to_neuron(®).expect("rehydrate");
let out = neuron.forward(-1.0);
assert!((out - 1.5).abs() < 1e-6, "unexpected output: {}", out);
}