use crate::booster::GBDTModel;
use crate::{Result, TreeBoostError};
use std::fs::File;
use std::io::{BufReader, BufWriter};
use std::path::Path;
pub fn save_model_bincode(model: &GBDTModel, path: impl AsRef<Path>) -> Result<()> {
let file = File::create(path)?;
let writer = BufWriter::new(file);
bincode::serialize_into(writer, model).map_err(|e| {
TreeBoostError::Serialization(format!("Failed to serialize bincode: {}", e))
})?;
Ok(())
}
pub fn load_model_bincode(path: impl AsRef<Path>) -> Result<GBDTModel> {
let file = File::open(path)?;
let reader = BufReader::new(file);
let model: GBDTModel = bincode::deserialize_from(reader).map_err(|e| {
TreeBoostError::Serialization(format!("Failed to deserialize bincode: {}", e))
})?;
Ok(model)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::booster::GBDTConfig;
use crate::dataset::{BinnedDataset, FeatureInfo, FeatureType};
use tempfile::tempdir;
fn create_test_dataset() -> BinnedDataset {
let num_rows = 100;
let num_features = 2;
let features: Vec<u8> = (0..num_rows * num_features)
.map(|i| (i % 256) as u8)
.collect();
let targets: Vec<f32> = (0..num_rows).map(|i| (i as f32) * 0.1).collect();
let feature_info = vec![
FeatureInfo {
name: "f0".to_string(),
feature_type: FeatureType::Numeric,
num_bins: 255,
bin_boundaries: vec![],
},
FeatureInfo {
name: "f1".to_string(),
feature_type: FeatureType::Numeric,
num_bins: 255,
bin_boundaries: vec![],
},
];
BinnedDataset::new(num_rows, features, targets, feature_info)
}
#[test]
fn test_save_load_model_bincode() {
let dataset = create_test_dataset();
let config = GBDTConfig::new().with_num_rounds(5).with_max_depth(3);
let model = GBDTModel::train_binned(&dataset, config).unwrap();
let dir = tempdir().unwrap();
let path = dir.path().join("model.bincode");
save_model_bincode(&model, &path).unwrap();
let loaded = load_model_bincode(&path).unwrap();
assert_eq!(loaded.num_trees(), model.num_trees());
assert_eq!(loaded.base_prediction(), model.base_prediction());
let orig_preds = model.predict(&dataset);
let loaded_preds = loaded.predict(&dataset);
for (a, b) in orig_preds.iter().zip(loaded_preds.iter()) {
assert!((a - b).abs() < 1e-6);
}
}
}