Skip to main content

rust_mlp/
serde_model.rs

1//! Model serialization/deserialization (feature: `serde`).
2//!
3//! This module defines a versioned, stable on-disk format for `Mlp`.
4//!
5//! Design notes:
6//! - We do NOT directly serialize internal `Mlp`/`Layer` structs, to keep the
7//!   file format stable even if internal representation changes.
8//! - All deserialization validates dimensions, parameter lengths, and that
9//!   all parameters are finite.
10
11#[cfg(feature = "serde")]
12use serde::{Deserialize, Serialize};
13
14use crate::{Activation, Error, Layer, Mlp, Result};
15
16#[cfg(feature = "serde")]
17use std::path::Path;
18
19pub const MODEL_FORMAT_VERSION: u32 = 1;
20
21#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
22#[derive(Debug, Clone, PartialEq)]
23pub struct SerializedMlp {
24    pub format_version: u32,
25    pub layers: Vec<SerializedLayer>,
26}
27
28#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
29#[derive(Debug, Clone, PartialEq)]
30pub struct SerializedLayer {
31    pub in_dim: usize,
32    pub out_dim: usize,
33    pub activation: SerializedActivation,
34    /// Row-major (out_dim, in_dim).
35    pub weights: Vec<f32>,
36    pub biases: Vec<f32>,
37}
38
39#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
40#[cfg_attr(feature = "serde", serde(tag = "kind", rename_all = "snake_case"))]
41#[derive(Debug, Clone, Copy, PartialEq)]
42pub enum SerializedActivation {
43    Tanh,
44    Relu,
45    LeakyRelu { alpha: f32 },
46    Sigmoid,
47    Identity,
48}
49
50impl From<Activation> for SerializedActivation {
51    fn from(value: Activation) -> Self {
52        match value {
53            Activation::Tanh => SerializedActivation::Tanh,
54            Activation::ReLU => SerializedActivation::Relu,
55            Activation::LeakyReLU { alpha } => SerializedActivation::LeakyRelu { alpha },
56            Activation::Sigmoid => SerializedActivation::Sigmoid,
57            Activation::Identity => SerializedActivation::Identity,
58        }
59    }
60}
61
62impl SerializedActivation {
63    fn into_activation(self) -> Activation {
64        match self {
65            SerializedActivation::Tanh => Activation::Tanh,
66            SerializedActivation::Relu => Activation::ReLU,
67            SerializedActivation::LeakyRelu { alpha } => Activation::LeakyReLU { alpha },
68            SerializedActivation::Sigmoid => Activation::Sigmoid,
69            SerializedActivation::Identity => Activation::Identity,
70        }
71    }
72}
73
74impl SerializedMlp {
75    pub fn validate(&self) -> Result<()> {
76        if self.format_version != MODEL_FORMAT_VERSION {
77            return Err(Error::InvalidData(format!(
78                "unsupported model format_version {}; expected {}",
79                self.format_version, MODEL_FORMAT_VERSION
80            )));
81        }
82        if self.layers.is_empty() {
83            return Err(Error::InvalidData(
84                "serialized model must have at least one layer".to_owned(),
85            ));
86        }
87
88        for (i, layer) in self.layers.iter().enumerate() {
89            layer.validate()?;
90
91            if i > 0 {
92                let prev_out = self.layers[i - 1].out_dim;
93                if layer.in_dim != prev_out {
94                    return Err(Error::InvalidData(format!(
95                        "layer {i} in_dim {} does not match previous out_dim {}",
96                        layer.in_dim, prev_out
97                    )));
98                }
99            }
100        }
101
102        Ok(())
103    }
104}
105
106impl SerializedLayer {
107    fn validate(&self) -> Result<()> {
108        if self.in_dim == 0 || self.out_dim == 0 {
109            return Err(Error::InvalidData(format!(
110                "layer dims must be > 0, got in_dim={} out_dim={}",
111                self.in_dim, self.out_dim
112            )));
113        }
114
115        let expected_w = self
116            .in_dim
117            .checked_mul(self.out_dim)
118            .ok_or_else(|| Error::InvalidData("layer weight shape overflow".to_owned()))?;
119        if self.weights.len() != expected_w {
120            return Err(Error::InvalidData(format!(
121                "weights length {} does not match out_dim * in_dim ({} * {})",
122                self.weights.len(),
123                self.out_dim,
124                self.in_dim
125            )));
126        }
127        if self.biases.len() != self.out_dim {
128            return Err(Error::InvalidData(format!(
129                "biases length {} does not match out_dim {}",
130                self.biases.len(),
131                self.out_dim
132            )));
133        }
134
135        let act = self.activation.into_activation();
136        act.validate()
137            .map_err(|e| Error::InvalidData(format!("invalid activation: {e}")))?;
138
139        if self.weights.iter().any(|v| !v.is_finite()) {
140            return Err(Error::InvalidData(
141                "weights must contain only finite values".to_owned(),
142            ));
143        }
144        if self.biases.iter().any(|v| !v.is_finite()) {
145            return Err(Error::InvalidData(
146                "biases must contain only finite values".to_owned(),
147            ));
148        }
149
150        Ok(())
151    }
152}
153
154impl From<&Mlp> for SerializedMlp {
155    fn from(model: &Mlp) -> Self {
156        let mut layers = Vec::with_capacity(model.num_layers());
157        for i in 0..model.num_layers() {
158            let layer = model.layer(i).expect("layer idx must be valid");
159            layers.push(SerializedLayer::from(layer));
160        }
161        Self {
162            format_version: MODEL_FORMAT_VERSION,
163            layers,
164        }
165    }
166}
167
168impl From<&Layer> for SerializedLayer {
169    fn from(layer: &Layer) -> Self {
170        Self {
171            in_dim: layer.in_dim(),
172            out_dim: layer.out_dim(),
173            activation: SerializedActivation::from(layer.activation()),
174            weights: layer.weights().to_vec(),
175            biases: layer.biases().to_vec(),
176        }
177    }
178}
179
180impl TryFrom<SerializedMlp> for Mlp {
181    type Error = Error;
182
183    fn try_from(value: SerializedMlp) -> std::result::Result<Self, Self::Error> {
184        value.validate()?;
185
186        let mut layers = Vec::with_capacity(value.layers.len());
187        for (i, layer) in value.layers.into_iter().enumerate() {
188            let act = layer.activation.into_activation();
189
190            // Layer::from_parts performs shape validation and finiteness checks.
191            let l = Layer::from_parts(
192                layer.in_dim,
193                layer.out_dim,
194                act,
195                layer.weights,
196                layer.biases,
197            )
198            .map_err(|e| Error::InvalidData(format!("layer {i} invalid: {e}")))?;
199            layers.push(l);
200        }
201
202        Ok(Mlp::from_layers(layers))
203    }
204}
205
206#[cfg(feature = "serde")]
207impl Mlp {
208    /// Serialize the model to a pretty-printed JSON string.
209    pub fn to_json_string_pretty(&self) -> Result<String> {
210        let ser = SerializedMlp::from(self);
211        serde_json::to_string_pretty(&ser)
212            .map_err(|e| Error::InvalidData(format!("failed to serialize model: {e}")))
213    }
214
215    /// Serialize the model to a compact JSON string.
216    pub fn to_json_string(&self) -> Result<String> {
217        let ser = SerializedMlp::from(self);
218        serde_json::to_string(&ser)
219            .map_err(|e| Error::InvalidData(format!("failed to serialize model: {e}")))
220    }
221
222    /// Parse a model from a JSON string.
223    pub fn from_json_str(s: &str) -> Result<Self> {
224        let ser: SerializedMlp = serde_json::from_str(s)
225            .map_err(|e| Error::InvalidData(format!("failed to parse model json: {e}")))?;
226        ser.try_into()
227    }
228
229    /// Save the model to a JSON file (pretty-printed).
230    pub fn save_json<P: AsRef<Path>>(&self, path: P) -> Result<()> {
231        let s = self.to_json_string_pretty()?;
232        let p = path.as_ref();
233        std::fs::write(p, s)
234            .map_err(|e| Error::InvalidData(format!("failed to write {}: {e}", p.display())))?;
235        Ok(())
236    }
237
238    /// Load a model from a JSON file.
239    pub fn load_json<P: AsRef<Path>>(path: P) -> Result<Self> {
240        let p = path.as_ref();
241        let s = std::fs::read_to_string(p)
242            .map_err(|e| Error::InvalidData(format!("failed to read {}: {e}", p.display())))?;
243        Self::from_json_str(&s)
244    }
245}
246
247#[cfg(all(test, feature = "serde"))]
248mod tests {
249    use super::*;
250
251    #[test]
252    fn golden_json_is_stable_and_roundtrips() {
253        let l1 = Layer::from_parts(
254            2,
255            3,
256            Activation::Tanh,
257            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
258            vec![0.1, 0.2, 0.3],
259        )
260        .unwrap();
261        let l2 =
262            Layer::from_parts(3, 1, Activation::Identity, vec![7.0, 8.0, 9.0], vec![0.4]).unwrap();
263
264        let mlp = Mlp::from_layers(vec![l1, l2]);
265        let json = mlp.to_json_string_pretty().unwrap();
266
267        let golden = include_str!(concat!(
268            env!("CARGO_MANIFEST_DIR"),
269            "/tests/golden/mlp_v1.json"
270        ))
271        .trim_end();
272        assert_eq!(json, golden);
273
274        // Round-trip via JSON.
275        let loaded = Mlp::from_json_str(golden).unwrap();
276        let json2 = loaded.to_json_string_pretty().unwrap();
277        assert_eq!(json2, golden);
278    }
279
280    #[test]
281    fn rejects_unknown_version() {
282        let bad = r#"{"format_version":999,"layers":[]}"#;
283        let err = Mlp::from_json_str(bad).unwrap_err();
284        assert!(format!("{err}").contains("format_version"));
285    }
286}