use crate::error::{IrithyllError, Result};
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde-json")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-json")))]
pub fn to_json<T: serde::Serialize>(value: &T) -> Result<String> {
serde_json::to_string(value).map_err(|e| IrithyllError::Serialization(e.to_string()))
}
#[cfg(feature = "serde-json")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-json")))]
pub fn to_json_pretty<T: serde::Serialize>(value: &T) -> Result<String> {
serde_json::to_string_pretty(value).map_err(|e| IrithyllError::Serialization(e.to_string()))
}
#[cfg(feature = "serde-json")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-json")))]
pub fn from_json<T: serde::de::DeserializeOwned>(json: &str) -> Result<T> {
serde_json::from_str(json).map_err(|e| IrithyllError::Serialization(e.to_string()))
}
#[cfg(feature = "serde-json")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-json")))]
pub fn to_json_bytes<T: serde::Serialize>(value: &T) -> Result<Vec<u8>> {
serde_json::to_vec(value).map_err(|e| IrithyllError::Serialization(e.to_string()))
}
#[cfg(feature = "serde-json")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-json")))]
pub fn from_json_bytes<T: serde::de::DeserializeOwned>(bytes: &[u8]) -> Result<T> {
serde_json::from_slice(bytes).map_err(|e| IrithyllError::Serialization(e.to_string()))
}
#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
use crate::ensemble::config::SGBTConfig;
#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
use crate::ensemble::SGBT;
#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
use crate::loss::Loss;
#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
pub use crate::loss::LossType;
#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "serde-json", feature = "serde-bincode")))
)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TreeSnapshot {
pub feature_idx: Vec<u32>,
pub threshold: Vec<f64>,
pub left: Vec<u32>,
pub right: Vec<u32>,
pub leaf_value: Vec<f64>,
pub is_leaf: Vec<bool>,
pub depth: Vec<u16>,
pub sample_count: Vec<u64>,
pub n_features: Option<usize>,
pub samples_seen: u64,
pub rng_state: u64,
#[serde(default)]
pub categorical_mask: Vec<Option<u64>>,
}
#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "serde-json", feature = "serde-bincode")))
)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StepSnapshot {
pub tree: TreeSnapshot,
pub alternate_tree: Option<TreeSnapshot>,
#[serde(default)]
pub drift_state: Option<crate::drift::state::DriftDetectorState>,
#[serde(default)]
pub alt_drift_state: Option<crate::drift::state::DriftDetectorState>,
}
#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "serde-json", feature = "serde-bincode")))
)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelState {
pub config: SGBTConfig,
pub loss_type: LossType,
pub base_prediction: f64,
pub base_initialized: bool,
pub initial_targets: Vec<f64>,
pub initial_target_count: usize,
pub samples_seen: u64,
pub rng_state: u64,
pub steps: Vec<StepSnapshot>,
#[serde(default)]
pub rolling_mean_error: f64,
#[serde(default)]
pub contribution_ewma: Vec<f64>,
#[serde(default)]
pub low_contrib_count: Vec<u64>,
#[serde(default)]
pub rolling_contribution_sigma: f64,
}
#[cfg(feature = "serde-json")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-json")))]
pub fn save_model<L: Loss>(model: &SGBT<L>) -> Result<String> {
let state = model.to_model_state()?;
to_json_pretty(&state)
}
#[cfg(feature = "serde-json")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-json")))]
pub fn save_model_with<L: Loss>(model: &SGBT<L>, loss_type: LossType) -> Result<String> {
let state = model.to_model_state_with(loss_type);
to_json_pretty(&state)
}
#[cfg(feature = "serde-json")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-json")))]
pub fn load_model(json: &str) -> Result<crate::ensemble::DynSGBT> {
let state: ModelState = from_json(json)?;
Ok(SGBT::from_model_state(state))
}
#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "serde-json", feature = "serde-bincode")))
)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MulticlassModelState {
pub n_classes: usize,
pub committees: Vec<ModelState>,
pub samples_seen: u64,
}
#[cfg(feature = "serde-json")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-json")))]
pub fn save_multiclass_model(state: &MulticlassModelState) -> Result<String> {
to_json_pretty(state)
}
#[cfg(feature = "serde-json")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-json")))]
pub fn load_multiclass_model(json: &str) -> Result<MulticlassModelState> {
from_json(json)
}
#[cfg(feature = "serde-bincode")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-bincode")))]
pub fn save_multiclass_model_bincode(state: &MulticlassModelState) -> Result<Vec<u8>> {
to_bincode(state)
}
#[cfg(feature = "serde-bincode")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-bincode")))]
pub fn load_multiclass_model_bincode(bytes: &[u8]) -> Result<MulticlassModelState> {
from_bincode(bytes)
}
#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "serde-json", feature = "serde-bincode")))
)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BaggedModelState {
pub n_bags: usize,
pub bags: Vec<ModelState>,
pub samples_seen: u64,
pub rng_state: u64,
pub seed: u64,
}
#[cfg(feature = "serde-json")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-json")))]
pub fn save_bagged_model(state: &BaggedModelState) -> Result<String> {
to_json_pretty(state)
}
#[cfg(feature = "serde-json")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-json")))]
pub fn load_bagged_model(json: &str) -> Result<BaggedModelState> {
from_json(json)
}
#[cfg(feature = "serde-bincode")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-bincode")))]
pub fn save_bagged_model_bincode(state: &BaggedModelState) -> Result<Vec<u8>> {
to_bincode(state)
}
#[cfg(feature = "serde-bincode")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-bincode")))]
pub fn load_bagged_model_bincode(bytes: &[u8]) -> Result<BaggedModelState> {
from_bincode(bytes)
}
#[cfg(feature = "serde-bincode")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-bincode")))]
pub fn to_bincode<T: serde::Serialize>(value: &T) -> Result<Vec<u8>> {
bincode::serde::encode_to_vec(value, bincode::config::standard())
.map_err(|e| IrithyllError::Serialization(e.to_string()))
}
#[cfg(feature = "serde-bincode")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-bincode")))]
pub fn from_bincode<T: serde::de::DeserializeOwned>(bytes: &[u8]) -> Result<T> {
let (val, _) = bincode::serde::decode_from_slice(bytes, bincode::config::standard())
.map_err(|e| IrithyllError::Serialization(e.to_string()))?;
Ok(val)
}
#[cfg(feature = "serde-bincode")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-bincode")))]
pub fn save_model_bincode<L: Loss>(model: &SGBT<L>) -> Result<Vec<u8>> {
let state = model.to_model_state()?;
to_bincode(&state)
}
#[cfg(feature = "serde-bincode")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-bincode")))]
pub fn load_model_bincode(bytes: &[u8]) -> Result<crate::ensemble::DynSGBT> {
let state: ModelState = from_bincode(bytes)?;
Ok(SGBT::from_model_state(state))
}
#[cfg(any(feature = "serde-json", feature = "serde-bincode"))]
#[cfg_attr(
docsrs,
doc(cfg(any(feature = "serde-json", feature = "serde-bincode")))
)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DistributionalModelState {
pub config: SGBTConfig,
pub location_steps: Vec<StepSnapshot>,
pub scale_steps: Vec<StepSnapshot>,
pub location_base: f64,
pub scale_base: f64,
pub base_initialized: bool,
pub initial_targets: Vec<f64>,
pub initial_target_count: usize,
pub samples_seen: u64,
pub rng_state: u64,
pub uncertainty_modulated_lr: bool,
pub rolling_sigma_mean: f64,
#[serde(default = "default_ewma_sq_err")]
pub ewma_sq_err: f64,
#[serde(default)]
pub rolling_honest_sigma_mean: f64,
}
fn default_ewma_sq_err() -> f64 {
1.0
}
#[cfg(feature = "serde-json")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-json")))]
pub fn save_distributional_model(state: &DistributionalModelState) -> Result<String> {
to_json_pretty(state)
}
#[cfg(feature = "serde-json")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-json")))]
pub fn load_distributional_model(json: &str) -> Result<DistributionalModelState> {
from_json(json)
}
#[cfg(feature = "serde-bincode")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-bincode")))]
pub fn save_distributional_model_bincode(state: &DistributionalModelState) -> Result<Vec<u8>> {
to_bincode(state)
}
#[cfg(feature = "serde-bincode")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde-bincode")))]
pub fn load_distributional_model_bincode(bytes: &[u8]) -> Result<DistributionalModelState> {
from_bincode(bytes)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sample::Sample;
#[cfg(feature = "serde-json")]
#[test]
fn json_round_trip_sample() {
let sample = Sample::new(vec![1.0, 2.0, 3.0], 4.0);
let json = to_json(&sample).unwrap();
let restored: Sample = from_json(&json).unwrap();
assert_eq!(restored.features, sample.features);
assert!((restored.target - sample.target).abs() < f64::EPSILON);
}
#[cfg(feature = "serde-json")]
#[test]
fn json_pretty_round_trip() {
let sample = Sample::weighted(vec![1.0], 2.0, 0.5);
let json = to_json_pretty(&sample).unwrap();
assert!(json.contains('\n'));
let restored: Sample = from_json(&json).unwrap();
assert!((restored.weight - 0.5).abs() < f64::EPSILON);
}
#[cfg(feature = "serde-json")]
#[test]
fn json_bytes_round_trip() {
let sample = Sample::new(vec![10.0, 20.0], 30.0);
let bytes = to_json_bytes(&sample).unwrap();
let restored: Sample = from_json_bytes(&bytes).unwrap();
assert_eq!(restored.features, sample.features);
}
#[cfg(feature = "serde-json")]
#[test]
fn json_invalid_input_returns_error() {
let result = from_json::<Sample>("not valid json");
assert!(result.is_err());
match result.unwrap_err() {
IrithyllError::Serialization(msg) => {
assert!(!msg.is_empty());
}
other => panic!("expected Serialization error, got {:?}", other),
}
}
#[cfg(feature = "serde-json")]
#[test]
fn json_batch_samples() {
let samples = vec![Sample::new(vec![1.0], 2.0), Sample::new(vec![3.0], 4.0)];
let json = to_json(&samples).unwrap();
let restored: Vec<Sample> = from_json(&json).unwrap();
assert_eq!(restored.len(), 2);
}
#[cfg(feature = "serde-json")]
#[test]
fn multiclass_model_json_roundtrip() {
use crate::ensemble::multiclass::MulticlassSGBT;
use crate::SGBTConfig;
let config = SGBTConfig::builder()
.n_steps(5)
.learning_rate(0.1)
.grace_period(10)
.max_depth(3)
.initial_target_count(5)
.build()
.unwrap();
let mut model = MulticlassSGBT::new(config, 3).unwrap();
for i in 0..60 {
let x = i as f64 * 0.1;
let class = (i % 3) as f64;
model.train_one(&Sample::new(vec![x, x * 2.0], class));
}
let state = model.to_multiclass_state();
let json = save_multiclass_model(&state).unwrap();
let loaded_state = load_multiclass_model(&json).unwrap();
let restored = MulticlassSGBT::from_multiclass_state(loaded_state);
let test_features = vec![vec![0.5, 1.0], vec![1.0, 2.0], vec![2.0, 4.0]];
for features in &test_features {
let orig_proba = model.predict_proba(features);
let rest_proba = restored.predict_proba(features);
assert_eq!(
orig_proba.len(),
rest_proba.len(),
"probability vector lengths should match"
);
for (c, (o, r)) in orig_proba.iter().zip(rest_proba.iter()).enumerate() {
assert!(
(o - r).abs() < 1e-10,
"multiclass JSON round-trip mismatch at class {}: {} vs {}",
c,
o,
r
);
}
}
assert_eq!(model.n_classes(), restored.n_classes());
assert_eq!(model.n_samples_seen(), restored.n_samples_seen());
}
#[cfg(feature = "serde-json")]
#[test]
fn bagged_model_json_roundtrip() {
use crate::ensemble::bagged::BaggedSGBT;
use crate::loss::squared::SquaredLoss;
use crate::SGBTConfig;
let config = SGBTConfig::builder()
.n_steps(5)
.learning_rate(0.1)
.grace_period(10)
.initial_target_count(5)
.build()
.unwrap();
let mut model = BaggedSGBT::new(config, 3).unwrap();
for i in 0..100 {
let x = i as f64 * 0.1;
model.train_one(&Sample::new(vec![x], x * 2.0 + 1.0));
}
let state = model.to_bagged_state().unwrap();
let json = save_bagged_model(&state).unwrap();
let loaded_state = load_bagged_model(&json).unwrap();
let restored = BaggedSGBT::from_bagged_state(loaded_state, SquaredLoss);
let test_points = [0.5, 1.0, 2.0, 3.0];
for &x in &test_points {
let orig = model.predict(&[x]);
let rest = restored.predict(&[x]);
assert!(
(orig - rest).abs() < 1e-10,
"bagged JSON round-trip mismatch at x={}: {} vs {}",
x,
orig,
rest
);
}
assert_eq!(model.n_bags(), restored.n_bags());
assert_eq!(model.n_samples_seen(), restored.n_samples_seen());
}
#[cfg(feature = "serde-json")]
#[test]
fn distributional_model_json_roundtrip() {
use crate::ensemble::distributional::DistributionalSGBT;
use crate::SGBTConfig;
let config = SGBTConfig::builder()
.n_steps(5)
.learning_rate(0.1)
.grace_period(10)
.max_depth(3)
.initial_target_count(10)
.build()
.unwrap();
let mut model = DistributionalSGBT::new(config);
for i in 0..100 {
let x = i as f64 * 0.1;
model.train_one(&(vec![x], x.sin()));
}
let state = model.to_distributional_state();
let json = save_distributional_model(&state).unwrap();
let loaded_state = load_distributional_model(&json).unwrap();
let restored = DistributionalSGBT::from_distributional_state(loaded_state);
let test_points = [0.5, 1.0, 2.0, 3.0];
for &x in &test_points {
let orig = model.predict(&[x]);
let rest = restored.predict(&[x]);
assert!(
(orig.mu - rest.mu).abs() < 1e-10,
"distributional JSON round-trip mu mismatch at x={}: {} vs {}",
x,
orig.mu,
rest.mu
);
assert!(
(orig.sigma - rest.sigma).abs() < 1e-10,
"distributional JSON round-trip sigma mismatch at x={}: {} vs {}",
x,
orig.sigma,
rest.sigma
);
}
assert_eq!(model.n_samples_seen(), restored.n_samples_seen());
}
}