irithyll/
serde_support.rs1use crate::error::{IrithyllError, Result};
8use serde::{Deserialize, Serialize};
9
10#[cfg(feature = "serde-json")]
18pub fn to_json<T: serde::Serialize>(value: &T) -> Result<String> {
19 serde_json::to_string(value).map_err(|e| IrithyllError::Serialization(e.to_string()))
20}
21
22#[cfg(feature = "serde-json")]
30pub fn to_json_pretty<T: serde::Serialize>(value: &T) -> Result<String> {
31 serde_json::to_string_pretty(value).map_err(|e| IrithyllError::Serialization(e.to_string()))
32}
33
34#[cfg(feature = "serde-json")]
42pub fn from_json<T: serde::de::DeserializeOwned>(json: &str) -> Result<T> {
43 serde_json::from_str(json).map_err(|e| IrithyllError::Serialization(e.to_string()))
44}
45
46#[cfg(feature = "serde-json")]
54pub fn to_json_bytes<T: serde::Serialize>(value: &T) -> Result<Vec<u8>> {
55 serde_json::to_vec(value).map_err(|e| IrithyllError::Serialization(e.to_string()))
56}
57
58#[cfg(feature = "serde-json")]
66pub fn from_json_bytes<T: serde::de::DeserializeOwned>(bytes: &[u8]) -> Result<T> {
67 serde_json::from_slice(bytes).map_err(|e| IrithyllError::Serialization(e.to_string()))
68}
69
70#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
75use crate::ensemble::config::SGBTConfig;
76#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
77use crate::ensemble::SGBT;
78#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
79use crate::loss::Loss;
80
81#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
83pub use crate::loss::LossType;
84
85#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct TreeSnapshot {
93 pub feature_idx: Vec<u32>,
94 pub threshold: Vec<f64>,
95 pub left: Vec<u32>,
96 pub right: Vec<u32>,
97 pub leaf_value: Vec<f64>,
98 pub is_leaf: Vec<bool>,
99 pub depth: Vec<u16>,
100 pub sample_count: Vec<u64>,
101 pub n_features: Option<usize>,
102 pub samples_seen: u64,
103 pub rng_state: u64,
104 #[serde(default)]
106 pub categorical_mask: Vec<Option<u64>>,
107}
108
109#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct StepSnapshot {
113 pub tree: TreeSnapshot,
114 pub alternate_tree: Option<TreeSnapshot>,
115 #[serde(default)]
117 pub drift_state: Option<crate::drift::state::DriftDetectorState>,
118 #[serde(default)]
120 pub alt_drift_state: Option<crate::drift::state::DriftDetectorState>,
121}
122
123#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct ModelState {
130 pub config: SGBTConfig,
131 pub loss_type: LossType,
132 pub base_prediction: f64,
133 pub base_initialized: bool,
134 pub initial_targets: Vec<f64>,
135 pub initial_target_count: usize,
136 pub samples_seen: u64,
137 pub rng_state: u64,
138 pub steps: Vec<StepSnapshot>,
139 #[serde(default)]
141 pub rolling_mean_error: f64,
142 #[serde(default)]
144 pub contribution_ewma: Vec<f64>,
145 #[serde(default)]
147 pub low_contrib_count: Vec<u64>,
148}
149
150#[cfg(feature = "serde-json")]
161pub fn save_model<L: Loss>(model: &SGBT<L>) -> Result<String> {
162 let state = model.to_model_state()?;
163 to_json_pretty(&state)
164}
165
166#[cfg(feature = "serde-json")]
170pub fn save_model_with<L: Loss>(model: &SGBT<L>, loss_type: LossType) -> Result<String> {
171 let state = model.to_model_state_with(loss_type);
172 to_json_pretty(&state)
173}
174
175#[cfg(feature = "serde-json")]
185pub fn load_model(json: &str) -> Result<crate::ensemble::DynSGBT> {
186 let state: ModelState = from_json(json)?;
187 Ok(SGBT::from_model_state(state))
188}
189
190#[cfg(feature = "serde-bincode")]
202pub fn to_bincode<T: serde::Serialize>(value: &T) -> Result<Vec<u8>> {
203 bincode::serde::encode_to_vec(value, bincode::config::standard())
204 .map_err(|e| IrithyllError::Serialization(e.to_string()))
205}
206
207#[cfg(feature = "serde-bincode")]
215pub fn from_bincode<T: serde::de::DeserializeOwned>(bytes: &[u8]) -> Result<T> {
216 let (val, _) = bincode::serde::decode_from_slice(bytes, bincode::config::standard())
217 .map_err(|e| IrithyllError::Serialization(e.to_string()))?;
218 Ok(val)
219}
220
221#[cfg(feature = "serde-bincode")]
230pub fn save_model_bincode<L: Loss>(model: &SGBT<L>) -> Result<Vec<u8>> {
231 let state = model.to_model_state()?;
232 to_bincode(&state)
233}
234
235#[cfg(feature = "serde-bincode")]
244pub fn load_model_bincode(bytes: &[u8]) -> Result<crate::ensemble::DynSGBT> {
245 let state: ModelState = from_bincode(bytes)?;
246 Ok(SGBT::from_model_state(state))
247}
248
249#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
255#[derive(Debug, Clone, Serialize, Deserialize)]
256pub struct DistributionalModelState {
257 pub config: SGBTConfig,
258 pub location_steps: Vec<StepSnapshot>,
259 pub scale_steps: Vec<StepSnapshot>,
260 pub location_base: f64,
261 pub scale_base: f64,
262 pub base_initialized: bool,
263 pub initial_targets: Vec<f64>,
264 pub initial_target_count: usize,
265 pub samples_seen: u64,
266 pub rng_state: u64,
267 pub uncertainty_modulated_lr: bool,
268 pub rolling_sigma_mean: f64,
269}
270
271#[cfg(feature = "serde-json")]
273pub fn save_distributional_model(state: &DistributionalModelState) -> Result<String> {
274 to_json_pretty(state)
275}
276
277#[cfg(feature = "serde-json")]
279pub fn load_distributional_model(json: &str) -> Result<DistributionalModelState> {
280 from_json(json)
281}
282
283#[cfg(feature = "serde-bincode")]
285pub fn save_distributional_model_bincode(state: &DistributionalModelState) -> Result<Vec<u8>> {
286 to_bincode(state)
287}
288
289#[cfg(feature = "serde-bincode")]
291pub fn load_distributional_model_bincode(bytes: &[u8]) -> Result<DistributionalModelState> {
292 from_bincode(bytes)
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298 use crate::sample::Sample;
299
300 #[cfg(feature = "serde-json")]
301 #[test]
302 fn json_round_trip_sample() {
303 let sample = Sample::new(vec![1.0, 2.0, 3.0], 4.0);
304 let json = to_json(&sample).unwrap();
305 let restored: Sample = from_json(&json).unwrap();
306 assert_eq!(restored.features, sample.features);
307 assert!((restored.target - sample.target).abs() < f64::EPSILON);
308 }
309
310 #[cfg(feature = "serde-json")]
311 #[test]
312 fn json_pretty_round_trip() {
313 let sample = Sample::weighted(vec![1.0], 2.0, 0.5);
314 let json = to_json_pretty(&sample).unwrap();
315 assert!(json.contains('\n'));
316 let restored: Sample = from_json(&json).unwrap();
317 assert!((restored.weight - 0.5).abs() < f64::EPSILON);
318 }
319
320 #[cfg(feature = "serde-json")]
321 #[test]
322 fn json_bytes_round_trip() {
323 let sample = Sample::new(vec![10.0, 20.0], 30.0);
324 let bytes = to_json_bytes(&sample).unwrap();
325 let restored: Sample = from_json_bytes(&bytes).unwrap();
326 assert_eq!(restored.features, sample.features);
327 }
328
329 #[cfg(feature = "serde-json")]
330 #[test]
331 fn json_invalid_input_returns_error() {
332 let result = from_json::<Sample>("not valid json");
333 assert!(result.is_err());
334 match result.unwrap_err() {
335 IrithyllError::Serialization(msg) => {
336 assert!(!msg.is_empty());
337 }
338 other => panic!("expected Serialization error, got {:?}", other),
339 }
340 }
341
342 #[cfg(feature = "serde-json")]
343 #[test]
344 fn json_batch_samples() {
345 let samples = vec![Sample::new(vec![1.0], 2.0), Sample::new(vec![3.0], 4.0)];
346 let json = to_json(&samples).unwrap();
347 let restored: Vec<Sample> = from_json(&json).unwrap();
348 assert_eq!(restored.len(), 2);
349 }
350}