1#[cfg(feature = "serde")]
12use serde::{Deserialize, Serialize};
13
14use crate::{Activation, Error, Layer, Mlp, Result};
15
16#[cfg(feature = "serde")]
17use std::path::Path;
18
19pub const MODEL_FORMAT_VERSION: u32 = 1;
20
21#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
22#[derive(Debug, Clone, PartialEq)]
23pub struct SerializedMlp {
24 pub format_version: u32,
25 pub layers: Vec<SerializedLayer>,
26}
27
28#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
29#[derive(Debug, Clone, PartialEq)]
30pub struct SerializedLayer {
31 pub in_dim: usize,
32 pub out_dim: usize,
33 pub activation: SerializedActivation,
34 pub weights: Vec<f32>,
36 pub biases: Vec<f32>,
37}
38
39#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
40#[cfg_attr(feature = "serde", serde(tag = "kind", rename_all = "snake_case"))]
41#[derive(Debug, Clone, Copy, PartialEq)]
42pub enum SerializedActivation {
43 Tanh,
44 Relu,
45 LeakyRelu { alpha: f32 },
46 Sigmoid,
47 Identity,
48}
49
50impl From<Activation> for SerializedActivation {
51 fn from(value: Activation) -> Self {
52 match value {
53 Activation::Tanh => SerializedActivation::Tanh,
54 Activation::ReLU => SerializedActivation::Relu,
55 Activation::LeakyReLU { alpha } => SerializedActivation::LeakyRelu { alpha },
56 Activation::Sigmoid => SerializedActivation::Sigmoid,
57 Activation::Identity => SerializedActivation::Identity,
58 }
59 }
60}
61
62impl SerializedActivation {
63 fn into_activation(self) -> Activation {
64 match self {
65 SerializedActivation::Tanh => Activation::Tanh,
66 SerializedActivation::Relu => Activation::ReLU,
67 SerializedActivation::LeakyRelu { alpha } => Activation::LeakyReLU { alpha },
68 SerializedActivation::Sigmoid => Activation::Sigmoid,
69 SerializedActivation::Identity => Activation::Identity,
70 }
71 }
72}
73
74impl SerializedMlp {
75 pub fn validate(&self) -> Result<()> {
76 if self.format_version != MODEL_FORMAT_VERSION {
77 return Err(Error::InvalidData(format!(
78 "unsupported model format_version {}; expected {}",
79 self.format_version, MODEL_FORMAT_VERSION
80 )));
81 }
82 if self.layers.is_empty() {
83 return Err(Error::InvalidData(
84 "serialized model must have at least one layer".to_owned(),
85 ));
86 }
87
88 for (i, layer) in self.layers.iter().enumerate() {
89 layer.validate()?;
90
91 if i > 0 {
92 let prev_out = self.layers[i - 1].out_dim;
93 if layer.in_dim != prev_out {
94 return Err(Error::InvalidData(format!(
95 "layer {i} in_dim {} does not match previous out_dim {}",
96 layer.in_dim, prev_out
97 )));
98 }
99 }
100 }
101
102 Ok(())
103 }
104}
105
106impl SerializedLayer {
107 fn validate(&self) -> Result<()> {
108 if self.in_dim == 0 || self.out_dim == 0 {
109 return Err(Error::InvalidData(format!(
110 "layer dims must be > 0, got in_dim={} out_dim={}",
111 self.in_dim, self.out_dim
112 )));
113 }
114
115 let expected_w = self
116 .in_dim
117 .checked_mul(self.out_dim)
118 .ok_or_else(|| Error::InvalidData("layer weight shape overflow".to_owned()))?;
119 if self.weights.len() != expected_w {
120 return Err(Error::InvalidData(format!(
121 "weights length {} does not match out_dim * in_dim ({} * {})",
122 self.weights.len(),
123 self.out_dim,
124 self.in_dim
125 )));
126 }
127 if self.biases.len() != self.out_dim {
128 return Err(Error::InvalidData(format!(
129 "biases length {} does not match out_dim {}",
130 self.biases.len(),
131 self.out_dim
132 )));
133 }
134
135 let act = self.activation.into_activation();
136 act.validate()
137 .map_err(|e| Error::InvalidData(format!("invalid activation: {e}")))?;
138
139 if self.weights.iter().any(|v| !v.is_finite()) {
140 return Err(Error::InvalidData(
141 "weights must contain only finite values".to_owned(),
142 ));
143 }
144 if self.biases.iter().any(|v| !v.is_finite()) {
145 return Err(Error::InvalidData(
146 "biases must contain only finite values".to_owned(),
147 ));
148 }
149
150 Ok(())
151 }
152}
153
154impl From<&Mlp> for SerializedMlp {
155 fn from(model: &Mlp) -> Self {
156 let mut layers = Vec::with_capacity(model.num_layers());
157 for i in 0..model.num_layers() {
158 let layer = model.layer(i).expect("layer idx must be valid");
159 layers.push(SerializedLayer::from(layer));
160 }
161 Self {
162 format_version: MODEL_FORMAT_VERSION,
163 layers,
164 }
165 }
166}
167
168impl From<&Layer> for SerializedLayer {
169 fn from(layer: &Layer) -> Self {
170 Self {
171 in_dim: layer.in_dim(),
172 out_dim: layer.out_dim(),
173 activation: SerializedActivation::from(layer.activation()),
174 weights: layer.weights().to_vec(),
175 biases: layer.biases().to_vec(),
176 }
177 }
178}
179
180impl TryFrom<SerializedMlp> for Mlp {
181 type Error = Error;
182
183 fn try_from(value: SerializedMlp) -> std::result::Result<Self, Self::Error> {
184 value.validate()?;
185
186 let mut layers = Vec::with_capacity(value.layers.len());
187 for (i, layer) in value.layers.into_iter().enumerate() {
188 let act = layer.activation.into_activation();
189
190 let l = Layer::from_parts(
192 layer.in_dim,
193 layer.out_dim,
194 act,
195 layer.weights,
196 layer.biases,
197 )
198 .map_err(|e| Error::InvalidData(format!("layer {i} invalid: {e}")))?;
199 layers.push(l);
200 }
201
202 Ok(Mlp::from_layers(layers))
203 }
204}
205
206#[cfg(feature = "serde")]
207impl Mlp {
208 pub fn to_json_string_pretty(&self) -> Result<String> {
210 let ser = SerializedMlp::from(self);
211 serde_json::to_string_pretty(&ser)
212 .map_err(|e| Error::InvalidData(format!("failed to serialize model: {e}")))
213 }
214
215 pub fn to_json_string(&self) -> Result<String> {
217 let ser = SerializedMlp::from(self);
218 serde_json::to_string(&ser)
219 .map_err(|e| Error::InvalidData(format!("failed to serialize model: {e}")))
220 }
221
222 pub fn from_json_str(s: &str) -> Result<Self> {
224 let ser: SerializedMlp = serde_json::from_str(s)
225 .map_err(|e| Error::InvalidData(format!("failed to parse model json: {e}")))?;
226 ser.try_into()
227 }
228
229 pub fn save_json<P: AsRef<Path>>(&self, path: P) -> Result<()> {
231 let s = self.to_json_string_pretty()?;
232 let p = path.as_ref();
233 std::fs::write(p, s)
234 .map_err(|e| Error::InvalidData(format!("failed to write {}: {e}", p.display())))?;
235 Ok(())
236 }
237
238 pub fn load_json<P: AsRef<Path>>(path: P) -> Result<Self> {
240 let p = path.as_ref();
241 let s = std::fs::read_to_string(p)
242 .map_err(|e| Error::InvalidData(format!("failed to read {}: {e}", p.display())))?;
243 Self::from_json_str(&s)
244 }
245}
246
247#[cfg(all(test, feature = "serde"))]
248mod tests {
249 use super::*;
250
251 #[test]
252 fn golden_json_is_stable_and_roundtrips() {
253 let l1 = Layer::from_parts(
254 2,
255 3,
256 Activation::Tanh,
257 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
258 vec![0.1, 0.2, 0.3],
259 )
260 .unwrap();
261 let l2 =
262 Layer::from_parts(3, 1, Activation::Identity, vec![7.0, 8.0, 9.0], vec![0.4]).unwrap();
263
264 let mlp = Mlp::from_layers(vec![l1, l2]);
265 let json = mlp.to_json_string_pretty().unwrap();
266
267 let golden = include_str!(concat!(
268 env!("CARGO_MANIFEST_DIR"),
269 "/tests/golden/mlp_v1.json"
270 ))
271 .trim_end();
272 assert_eq!(json, golden);
273
274 let loaded = Mlp::from_json_str(golden).unwrap();
276 let json2 = loaded.to_json_string_pretty().unwrap();
277 assert_eq!(json2, golden);
278 }
279
280 #[test]
281 fn rejects_unknown_version() {
282 let bad = r#"{"format_version":999,"layers":[]}"#;
283 let err = Mlp::from_json_str(bad).unwrap_err();
284 assert!(format!("{err}").contains("format_version"));
285 }
286}