Skip to main content

irithyll/
serde_support.rs

1//! Model serialization and deserialization support.
2//!
3//! Provides JSON serialization (default feature) and optional bincode
4//! serialization for model persistence. Currently exposes utility functions
5//! for serializing/deserializing generic `Serialize`/`Deserialize` types.
6
7use crate::error::{IrithyllError, Result};
8use serde::{Deserialize, Serialize};
9
10/// Serialize a value to a JSON string.
11///
12/// Requires the `serde-json` feature (enabled by default).
13///
14/// # Errors
15///
16/// Returns [`IrithyllError::Serialization`] if serialization fails.
17#[cfg(feature = "serde-json")]
18pub fn to_json<T: serde::Serialize>(value: &T) -> Result<String> {
19    serde_json::to_string(value).map_err(|e| IrithyllError::Serialization(e.to_string()))
20}
21
22/// Serialize a value to a pretty-printed JSON string.
23///
24/// Requires the `serde-json` feature (enabled by default).
25///
26/// # Errors
27///
28/// Returns [`IrithyllError::Serialization`] if serialization fails.
29#[cfg(feature = "serde-json")]
30pub fn to_json_pretty<T: serde::Serialize>(value: &T) -> Result<String> {
31    serde_json::to_string_pretty(value).map_err(|e| IrithyllError::Serialization(e.to_string()))
32}
33
34/// Deserialize a value from a JSON string.
35///
36/// Requires the `serde-json` feature (enabled by default).
37///
38/// # Errors
39///
40/// Returns [`IrithyllError::Serialization`] if deserialization fails.
41#[cfg(feature = "serde-json")]
42pub fn from_json<T: serde::de::DeserializeOwned>(json: &str) -> Result<T> {
43    serde_json::from_str(json).map_err(|e| IrithyllError::Serialization(e.to_string()))
44}
45
46/// Serialize a value to JSON bytes.
47///
48/// Requires the `serde-json` feature (enabled by default).
49///
50/// # Errors
51///
52/// Returns [`IrithyllError::Serialization`] if serialization fails.
53#[cfg(feature = "serde-json")]
54pub fn to_json_bytes<T: serde::Serialize>(value: &T) -> Result<Vec<u8>> {
55    serde_json::to_vec(value).map_err(|e| IrithyllError::Serialization(e.to_string()))
56}
57
58/// Deserialize a value from JSON bytes.
59///
60/// Requires the `serde-json` feature (enabled by default).
61///
62/// # Errors
63///
64/// Returns [`IrithyllError::Serialization`] if deserialization fails.
65#[cfg(feature = "serde-json")]
66pub fn from_json_bytes<T: serde::de::DeserializeOwned>(bytes: &[u8]) -> Result<T> {
67    serde_json::from_slice(bytes).map_err(|e| IrithyllError::Serialization(e.to_string()))
68}
69
70// ---------------------------------------------------------------------------
71// Model checkpoint/restore serialization
72// ---------------------------------------------------------------------------
73
74#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
75use crate::ensemble::config::SGBTConfig;
76#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
77use crate::ensemble::SGBT;
78#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
79use crate::loss::Loss;
80
81// Re-export LossType from the loss module for backwards compatibility.
82#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
83pub use crate::loss::LossType;
84
85/// Serializable snapshot of the tree arena structure.
86///
87/// Captures the minimal state needed to reconstruct a tree for prediction:
88/// node topology, split decisions, and leaf values. Histogram accumulators
89/// are NOT serialized -- they rebuild naturally from continued training.
90#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
91#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct TreeSnapshot {
93    pub feature_idx: Vec<u32>,
94    pub threshold: Vec<f64>,
95    pub left: Vec<u32>,
96    pub right: Vec<u32>,
97    pub leaf_value: Vec<f64>,
98    pub is_leaf: Vec<bool>,
99    pub depth: Vec<u16>,
100    pub sample_count: Vec<u64>,
101    pub n_features: Option<usize>,
102    pub samples_seen: u64,
103    pub rng_state: u64,
104    /// Categorical split bitmasks (v5+). `None` entries are continuous splits.
105    #[serde(default)]
106    pub categorical_mask: Vec<Option<u64>>,
107}
108
109/// Serializable snapshot of a single boosting step.
110#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct StepSnapshot {
113    pub tree: TreeSnapshot,
114    pub alternate_tree: Option<TreeSnapshot>,
115    /// Drift detector accumulated state (preserves warmup across save/load).
116    #[serde(default)]
117    pub drift_state: Option<crate::drift::state::DriftDetectorState>,
118    /// Alternate drift detector state (if an alternate tree is training).
119    #[serde(default)]
120    pub alt_drift_state: Option<crate::drift::state::DriftDetectorState>,
121}
122
123/// Complete serializable state of an SGBT model.
124///
125/// Captures everything needed to reconstruct a trained model for prediction
126/// and continued training. The loss function is stored as a [`LossType`] tag.
127#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct ModelState {
130    pub config: SGBTConfig,
131    pub loss_type: LossType,
132    pub base_prediction: f64,
133    pub base_initialized: bool,
134    pub initial_targets: Vec<f64>,
135    pub initial_target_count: usize,
136    pub samples_seen: u64,
137    pub rng_state: u64,
138    pub steps: Vec<StepSnapshot>,
139    /// Rolling mean absolute error for error-weighted sample importance (v6+).
140    #[serde(default)]
141    pub rolling_mean_error: f64,
142    /// Per-step EWMA of contribution magnitude for quality pruning (v6+).
143    #[serde(default)]
144    pub contribution_ewma: Vec<f64>,
145    /// Per-step consecutive low-contribution count for quality pruning (v6+).
146    #[serde(default)]
147    pub low_contrib_count: Vec<u64>,
148}
149
150/// Save an SGBT model to a JSON string.
151///
152/// Auto-detects the loss type from the model's loss function. For built-in
153/// losses (Squared, Logistic, Huber, Softmax) this works automatically.
154/// For custom losses, use [`save_model_with`] to supply the tag manually.
155///
156/// # Errors
157///
158/// Returns [`IrithyllError::Serialization`] if the loss type cannot be
159/// auto-detected or if serialization fails.
160#[cfg(feature = "serde-json")]
161pub fn save_model<L: Loss>(model: &SGBT<L>) -> Result<String> {
162    let state = model.to_model_state()?;
163    to_json_pretty(&state)
164}
165
166/// Save an SGBT model to a JSON string with an explicit loss type tag.
167///
168/// Use this for custom loss functions that don't implement `loss_type()`.
169#[cfg(feature = "serde-json")]
170pub fn save_model_with<L: Loss>(model: &SGBT<L>, loss_type: LossType) -> Result<String> {
171    let state = model.to_model_state_with(loss_type);
172    to_json_pretty(&state)
173}
174
175/// Load an SGBT model from a JSON string.
176///
177/// Returns a [`DynSGBT`](crate::ensemble::DynSGBT) (`SGBT<Box<dyn Loss>>`)
178/// because the concrete loss type is determined at runtime from the
179/// serialized tag.
180///
181/// # Errors
182///
183/// Returns [`IrithyllError::Serialization`] if deserialization fails.
184#[cfg(feature = "serde-json")]
185pub fn load_model(json: &str) -> Result<crate::ensemble::DynSGBT> {
186    let state: ModelState = from_json(json)?;
187    Ok(SGBT::from_model_state(state))
188}
189
190// ---------------------------------------------------------------------------
191// Bincode serialization (compact binary format)
192// ---------------------------------------------------------------------------
193
194/// Serialize a value to bincode bytes.
195///
196/// Requires the `serde-bincode` feature.
197///
198/// # Errors
199///
200/// Returns [`IrithyllError::Serialization`] if serialization fails.
201#[cfg(feature = "serde-bincode")]
202pub fn to_bincode<T: serde::Serialize>(value: &T) -> Result<Vec<u8>> {
203    bincode::serde::encode_to_vec(value, bincode::config::standard())
204        .map_err(|e| IrithyllError::Serialization(e.to_string()))
205}
206
207/// Deserialize a value from bincode bytes.
208///
209/// Requires the `serde-bincode` feature.
210///
211/// # Errors
212///
213/// Returns [`IrithyllError::Serialization`] if deserialization fails.
214#[cfg(feature = "serde-bincode")]
215pub fn from_bincode<T: serde::de::DeserializeOwned>(bytes: &[u8]) -> Result<T> {
216    let (val, _) = bincode::serde::decode_from_slice(bytes, bincode::config::standard())
217        .map_err(|e| IrithyllError::Serialization(e.to_string()))?;
218    Ok(val)
219}
220
221/// Save an SGBT model to bincode bytes.
222///
223/// Compact binary format -- typically 3-5x smaller than JSON.
224///
225/// # Errors
226///
227/// Returns [`IrithyllError::Serialization`] if the loss type cannot be
228/// auto-detected or if serialization fails.
229#[cfg(feature = "serde-bincode")]
230pub fn save_model_bincode<L: Loss>(model: &SGBT<L>) -> Result<Vec<u8>> {
231    let state = model.to_model_state()?;
232    to_bincode(&state)
233}
234
235/// Load an SGBT model from bincode bytes.
236///
237/// Returns a [`DynSGBT`](crate::ensemble::DynSGBT) because the concrete
238/// loss type is determined at runtime.
239///
240/// # Errors
241///
242/// Returns [`IrithyllError::Serialization`] if deserialization fails.
243#[cfg(feature = "serde-bincode")]
244pub fn load_model_bincode(bytes: &[u8]) -> Result<crate::ensemble::DynSGBT> {
245    let state: ModelState = from_bincode(bytes)?;
246    Ok(SGBT::from_model_state(state))
247}
248
249// ---------------------------------------------------------------------------
250// DistributionalSGBT serialization
251// ---------------------------------------------------------------------------
252
253/// Serializable state for [`DistributionalSGBT`](crate::ensemble::distributional::DistributionalSGBT).
254#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
255#[derive(Debug, Clone, Serialize, Deserialize)]
256pub struct DistributionalModelState {
257    pub config: SGBTConfig,
258    pub location_steps: Vec<StepSnapshot>,
259    pub scale_steps: Vec<StepSnapshot>,
260    pub location_base: f64,
261    pub scale_base: f64,
262    pub base_initialized: bool,
263    pub initial_targets: Vec<f64>,
264    pub initial_target_count: usize,
265    pub samples_seen: u64,
266    pub rng_state: u64,
267    pub uncertainty_modulated_lr: bool,
268    pub rolling_sigma_mean: f64,
269}
270
271/// Serialize a [`DistributionalModelState`] to JSON.
272#[cfg(feature = "serde-json")]
273pub fn save_distributional_model(state: &DistributionalModelState) -> Result<String> {
274    to_json_pretty(state)
275}
276
277/// Deserialize a [`DistributionalModelState`] from JSON.
278#[cfg(feature = "serde-json")]
279pub fn load_distributional_model(json: &str) -> Result<DistributionalModelState> {
280    from_json(json)
281}
282
283/// Serialize a [`DistributionalModelState`] to bincode bytes.
284#[cfg(feature = "serde-bincode")]
285pub fn save_distributional_model_bincode(state: &DistributionalModelState) -> Result<Vec<u8>> {
286    to_bincode(state)
287}
288
289/// Deserialize a [`DistributionalModelState`] from bincode bytes.
290#[cfg(feature = "serde-bincode")]
291pub fn load_distributional_model_bincode(bytes: &[u8]) -> Result<DistributionalModelState> {
292    from_bincode(bytes)
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298    use crate::sample::Sample;
299
300    #[cfg(feature = "serde-json")]
301    #[test]
302    fn json_round_trip_sample() {
303        let sample = Sample::new(vec![1.0, 2.0, 3.0], 4.0);
304        let json = to_json(&sample).unwrap();
305        let restored: Sample = from_json(&json).unwrap();
306        assert_eq!(restored.features, sample.features);
307        assert!((restored.target - sample.target).abs() < f64::EPSILON);
308    }
309
310    #[cfg(feature = "serde-json")]
311    #[test]
312    fn json_pretty_round_trip() {
313        let sample = Sample::weighted(vec![1.0], 2.0, 0.5);
314        let json = to_json_pretty(&sample).unwrap();
315        assert!(json.contains('\n'));
316        let restored: Sample = from_json(&json).unwrap();
317        assert!((restored.weight - 0.5).abs() < f64::EPSILON);
318    }
319
320    #[cfg(feature = "serde-json")]
321    #[test]
322    fn json_bytes_round_trip() {
323        let sample = Sample::new(vec![10.0, 20.0], 30.0);
324        let bytes = to_json_bytes(&sample).unwrap();
325        let restored: Sample = from_json_bytes(&bytes).unwrap();
326        assert_eq!(restored.features, sample.features);
327    }
328
329    #[cfg(feature = "serde-json")]
330    #[test]
331    fn json_invalid_input_returns_error() {
332        let result = from_json::<Sample>("not valid json");
333        assert!(result.is_err());
334        match result.unwrap_err() {
335            IrithyllError::Serialization(msg) => {
336                assert!(!msg.is_empty());
337            }
338            other => panic!("expected Serialization error, got {:?}", other),
339        }
340    }
341
342    #[cfg(feature = "serde-json")]
343    #[test]
344    fn json_batch_samples() {
345        let samples = vec![Sample::new(vec![1.0], 2.0), Sample::new(vec![3.0], 4.0)];
346        let json = to_json(&samples).unwrap();
347        let restored: Vec<Sample> = from_json(&json).unwrap();
348        assert_eq!(restored.len(), 2);
349    }
350}