#[cfg(any(feature = "postcard", feature = "serde_json"))]
use crate::error::{RcfError, RcfResult};
use crate::forest::RandomCutForest;
#[cfg(feature = "serde")]
use crate::thresholded::ThresholdedForest;
pub const PERSISTENCE_VERSION: u32 = 4;
pub const THRESHOLDED_PERSISTENCE_VERSION: u32 = 4;
pub const VERSION_PREFIX_BYTES: usize = 4;
pub const MAX_DESERIALIZE_BYTES: usize = 256 * 1024 * 1024;
pub const MAX_JSON_BYTES: usize = 1024 * 1024 * 1024;
#[cfg(any(feature = "postcard", feature = "serde_json"))]
fn enforce_size_cap(len: usize, max: usize, kind: &'static str) -> RcfResult<()> {
if len > max {
return Err(RcfError::DeserializationFailed(format!(
"{kind} payload {len} byte(s) exceeds cap {max} (caller-controlled OOM guard) — \
use the `*_with_max_size` variant to opt into a larger bound"
)));
}
Ok(())
}
#[cfg(feature = "postcard")]
fn read_version_prefix(bytes: &[u8]) -> RcfResult<u32> {
if bytes.len() < VERSION_PREFIX_BYTES {
return Err(RcfError::DeserializationFailed(format!(
"payload too short: {} byte(s), need at least {VERSION_PREFIX_BYTES}",
bytes.len()
)));
}
let mut v = [0_u8; VERSION_PREFIX_BYTES];
v.copy_from_slice(&bytes[..VERSION_PREFIX_BYTES]);
Ok(u32::from_le_bytes(v))
}
#[cfg(all(feature = "std", any(feature = "postcard", feature = "serde_json")))]
mod atomic {
use std::ffi::OsString;
use std::fs::{File, rename};
use std::io::Write;
use std::path::{Path, PathBuf};
use crate::error::{RcfError, RcfResult};
pub(super) fn tmp_path(path: &Path) -> PathBuf {
let mut s: OsString = path.as_os_str().to_owned();
s.push(".tmp");
PathBuf::from(s)
}
pub(super) fn write_atomic(path: &Path, bytes: &[u8]) -> RcfResult<()> {
let tmp = tmp_path(path);
let mut f = File::create(&tmp)
.map_err(|e| RcfError::SerializationFailed(format!("create {}: {e}", tmp.display())))?;
f.write_all(bytes)
.map_err(|e| RcfError::SerializationFailed(format!("write {}: {e}", tmp.display())))?;
f.sync_all()
.map_err(|e| RcfError::SerializationFailed(format!("fsync {}: {e}", tmp.display())))?;
drop(f);
rename(&tmp, path).map_err(|e| {
RcfError::SerializationFailed(format!(
"rename {} -> {}: {e}",
tmp.display(),
path.display()
))
})?;
Ok(())
}
#[cfg(feature = "postcard")]
pub(super) fn read_all(path: &Path) -> RcfResult<Vec<u8>> {
std::fs::read(path)
.map_err(|e| RcfError::DeserializationFailed(format!("read {}: {e}", path.display())))
}
#[cfg(feature = "serde_json")]
pub(super) fn read_all_string(path: &Path) -> RcfResult<String> {
std::fs::read_to_string(path)
.map_err(|e| RcfError::DeserializationFailed(format!("read {}: {e}", path.display())))
}
}
impl<const D: usize> RandomCutForest<D> {
#[cfg(feature = "postcard")]
pub fn to_bytes(&self) -> RcfResult<Vec<u8>> {
let mut out = Vec::with_capacity(VERSION_PREFIX_BYTES + 4096);
out.extend_from_slice(&PERSISTENCE_VERSION.to_le_bytes());
let payload = postcard::to_allocvec(self)
.map_err(|e| RcfError::SerializationFailed(e.to_string()))?;
out.extend_from_slice(&payload);
Ok(out)
}
#[cfg(feature = "postcard")]
pub fn from_bytes(bytes: &[u8]) -> RcfResult<Self> {
Self::from_bytes_with_max_size(bytes, MAX_DESERIALIZE_BYTES)
}
#[cfg(feature = "postcard")]
pub fn from_bytes_with_max_size(bytes: &[u8], max: usize) -> RcfResult<Self> {
enforce_size_cap(bytes.len(), max, "RandomCutForest postcard")?;
let version = read_version_prefix(bytes)?;
if version != PERSISTENCE_VERSION {
return Err(RcfError::IncompatibleVersion {
found: version,
expected: PERSISTENCE_VERSION,
});
}
let forest: Self = postcard::from_bytes(&bytes[VERSION_PREFIX_BYTES..])
.map_err(|e| RcfError::DeserializationFailed(e.to_string()))?;
Ok(forest)
}
#[cfg(all(feature = "postcard", feature = "std"))]
pub fn to_path(&self, path: impl AsRef<std::path::Path>) -> RcfResult<()> {
let bytes = self.to_bytes()?;
atomic::write_atomic(path.as_ref(), &bytes)
}
#[cfg(all(feature = "postcard", feature = "std"))]
pub fn from_path(path: impl AsRef<std::path::Path>) -> RcfResult<Self> {
let bytes = atomic::read_all(path.as_ref())?;
Self::from_bytes(&bytes)
}
#[cfg(feature = "serde_json")]
pub fn to_json(&self) -> RcfResult<String> {
let envelope = JsonEnvelope {
version: PERSISTENCE_VERSION,
forest: self,
};
serde_json::to_string(&envelope).map_err(|e| RcfError::SerializationFailed(e.to_string()))
}
#[cfg(feature = "serde_json")]
pub fn from_json(json: &str) -> RcfResult<Self> {
Self::from_json_with_max_size(json, MAX_JSON_BYTES)
}
#[cfg(feature = "serde_json")]
pub fn from_json_with_max_size(json: &str, max: usize) -> RcfResult<Self> {
enforce_size_cap(json.len(), max, "RandomCutForest JSON")?;
let envelope: JsonEnvelopeOwned<D> = serde_json::from_str(json)
.map_err(|e| RcfError::DeserializationFailed(e.to_string()))?;
if envelope.version != PERSISTENCE_VERSION {
return Err(RcfError::IncompatibleVersion {
found: envelope.version,
expected: PERSISTENCE_VERSION,
});
}
Ok(envelope.forest)
}
#[cfg(all(feature = "serde_json", feature = "std"))]
pub fn to_json_path(&self, path: impl AsRef<std::path::Path>) -> RcfResult<()> {
let json = self.to_json()?;
atomic::write_atomic(path.as_ref(), json.as_bytes())
}
#[cfg(all(feature = "serde_json", feature = "std"))]
pub fn from_json_path(path: impl AsRef<std::path::Path>) -> RcfResult<Self> {
let json = atomic::read_all_string(path.as_ref())?;
Self::from_json(&json)
}
}
impl<const D: usize> ThresholdedForest<D> {
#[cfg(feature = "postcard")]
pub fn to_bytes(&self) -> RcfResult<Vec<u8>> {
let mut out = Vec::with_capacity(VERSION_PREFIX_BYTES + 4096);
out.extend_from_slice(&THRESHOLDED_PERSISTENCE_VERSION.to_le_bytes());
let payload = postcard::to_allocvec(self)
.map_err(|e| RcfError::SerializationFailed(e.to_string()))?;
out.extend_from_slice(&payload);
Ok(out)
}
#[cfg(feature = "postcard")]
pub fn from_bytes(bytes: &[u8]) -> RcfResult<Self> {
Self::from_bytes_with_max_size(bytes, MAX_DESERIALIZE_BYTES)
}
#[cfg(feature = "postcard")]
pub fn from_bytes_with_max_size(bytes: &[u8], max: usize) -> RcfResult<Self> {
enforce_size_cap(bytes.len(), max, "ThresholdedForest postcard")?;
let version = read_version_prefix(bytes)?;
if version != THRESHOLDED_PERSISTENCE_VERSION {
return Err(RcfError::IncompatibleVersion {
found: version,
expected: THRESHOLDED_PERSISTENCE_VERSION,
});
}
let detector: Self = postcard::from_bytes(&bytes[VERSION_PREFIX_BYTES..])
.map_err(|e| RcfError::DeserializationFailed(e.to_string()))?;
Ok(detector)
}
#[cfg(all(feature = "postcard", feature = "std"))]
pub fn to_path(&self, path: impl AsRef<std::path::Path>) -> RcfResult<()> {
let bytes = self.to_bytes()?;
atomic::write_atomic(path.as_ref(), &bytes)
}
#[cfg(all(feature = "postcard", feature = "std"))]
pub fn from_path(path: impl AsRef<std::path::Path>) -> RcfResult<Self> {
let bytes = atomic::read_all(path.as_ref())?;
Self::from_bytes(&bytes)
}
#[cfg(feature = "serde_json")]
pub fn to_json(&self) -> RcfResult<String> {
let envelope = ThresholdedJsonEnvelope {
version: THRESHOLDED_PERSISTENCE_VERSION,
detector: self,
};
serde_json::to_string(&envelope).map_err(|e| RcfError::SerializationFailed(e.to_string()))
}
#[cfg(feature = "serde_json")]
pub fn from_json(json: &str) -> RcfResult<Self> {
Self::from_json_with_max_size(json, MAX_JSON_BYTES)
}
#[cfg(feature = "serde_json")]
pub fn from_json_with_max_size(json: &str, max: usize) -> RcfResult<Self> {
enforce_size_cap(json.len(), max, "ThresholdedForest JSON")?;
let envelope: ThresholdedJsonEnvelopeOwned<D> = serde_json::from_str(json)
.map_err(|e| RcfError::DeserializationFailed(e.to_string()))?;
if envelope.version != THRESHOLDED_PERSISTENCE_VERSION {
return Err(RcfError::IncompatibleVersion {
found: envelope.version,
expected: THRESHOLDED_PERSISTENCE_VERSION,
});
}
Ok(envelope.detector)
}
#[cfg(all(feature = "serde_json", feature = "std"))]
pub fn to_json_path(&self, path: impl AsRef<std::path::Path>) -> RcfResult<()> {
let json = self.to_json()?;
atomic::write_atomic(path.as_ref(), json.as_bytes())
}
#[cfg(all(feature = "serde_json", feature = "std"))]
pub fn from_json_path(path: impl AsRef<std::path::Path>) -> RcfResult<Self> {
let json = atomic::read_all_string(path.as_ref())?;
Self::from_json(&json)
}
}
#[cfg(feature = "serde_json")]
#[derive(serde::Serialize)]
struct JsonEnvelope<'a, const D: usize> {
version: u32,
forest: &'a RandomCutForest<D>,
}
#[cfg(feature = "serde_json")]
#[derive(serde::Deserialize)]
struct JsonEnvelopeOwned<const D: usize> {
version: u32,
forest: RandomCutForest<D>,
}
#[cfg(feature = "serde_json")]
#[derive(serde::Serialize)]
struct ThresholdedJsonEnvelope<'a, const D: usize> {
version: u32,
detector: &'a ThresholdedForest<D>,
}
#[cfg(feature = "serde_json")]
#[derive(serde::Deserialize)]
struct ThresholdedJsonEnvelopeOwned<const D: usize> {
version: u32,
detector: ThresholdedForest<D>,
}
#[cfg(all(test, feature = "postcard"))]
#[allow(clippy::float_cmp, clippy::cast_precision_loss, clippy::cast_lossless)] mod binary_tests {
use super::*;
use crate::ForestBuilder;
fn trained_forest(seed: u64, updates: usize) -> RandomCutForest<2> {
let mut f = ForestBuilder::<2>::new()
.num_trees(50)
.sample_size(16)
.seed(seed)
.build()
.unwrap();
for i in 0..updates {
#[allow(clippy::cast_precision_loss)]
let v = i as f64 * 0.01;
f.update([v, v + 0.5]).unwrap();
}
f
}
#[test]
fn version_prefix_present() {
let f = trained_forest(2026, 10);
let bytes = f.to_bytes().unwrap();
assert!(bytes.len() >= VERSION_PREFIX_BYTES);
let mut v = [0_u8; 4];
v.copy_from_slice(&bytes[..4]);
assert_eq!(u32::from_le_bytes(v), PERSISTENCE_VERSION);
}
#[test]
fn empty_forest_roundtrip() {
let f = ForestBuilder::<4>::new()
.num_trees(50)
.sample_size(16)
.seed(1)
.build()
.unwrap();
let bytes = f.to_bytes().unwrap();
let back = RandomCutForest::<4>::from_bytes(&bytes).unwrap();
assert_eq!(back.num_trees(), f.num_trees());
assert_eq!(back.sample_size(), f.sample_size());
assert_eq!(back.dimension(), f.dimension());
}
#[test]
fn trained_forest_score_roundtrip() {
let f = trained_forest(7, 200);
let bytes = f.to_bytes().unwrap();
let back = RandomCutForest::<2>::from_bytes(&bytes).unwrap();
let probe = [1.5_f64, 2.0];
let s1: f64 = f.score(&probe).unwrap().into();
let s2: f64 = back.score(&probe).unwrap().into();
assert_eq!(s1, s2);
}
#[test]
fn time_decay_roundtrip() {
let mut f = ForestBuilder::<2>::new()
.num_trees(50)
.sample_size(16)
.time_decay(0.05)
.seed(11)
.build()
.unwrap();
for i in 0..100 {
#[allow(clippy::cast_precision_loss)]
let v = i as f64;
f.update([v, v]).unwrap();
}
let bytes = f.to_bytes().unwrap();
let back = RandomCutForest::<2>::from_bytes(&bytes).unwrap();
assert_eq!(f.config().time_decay, back.config().time_decay);
let probe = [10.0_f64, 10.0];
assert_eq!(
f64::from(f.score(&probe).unwrap()),
f64::from(back.score(&probe).unwrap())
);
}
#[test]
fn truncated_bytes_rejected() {
let bytes = [0_u8; 2];
let err = RandomCutForest::<2>::from_bytes(&bytes).unwrap_err();
assert!(matches!(err, RcfError::DeserializationFailed(_)));
}
#[test]
fn version_mismatch_rejected() {
let f = trained_forest(2026, 5);
let mut bytes = f.to_bytes().unwrap();
let bogus_version = (PERSISTENCE_VERSION + 99).to_le_bytes();
bytes[..VERSION_PREFIX_BYTES].copy_from_slice(&bogus_version);
let err = RandomCutForest::<2>::from_bytes(&bytes).unwrap_err();
match err {
RcfError::IncompatibleVersion { found, expected } => {
assert_eq!(found, PERSISTENCE_VERSION + 99);
assert_eq!(expected, PERSISTENCE_VERSION);
}
other => panic!("expected IncompatibleVersion, got {other:?}"),
}
}
#[test]
fn malformed_payload_rejected() {
let mut bytes = Vec::new();
bytes.extend_from_slice(&PERSISTENCE_VERSION.to_le_bytes());
bytes.extend_from_slice(&[0xFF; 16]);
let err = RandomCutForest::<2>::from_bytes(&bytes).unwrap_err();
assert!(matches!(err, RcfError::DeserializationFailed(_)));
}
#[test]
fn oversize_payload_rejected_by_default_cap() {
let mut bytes = Vec::with_capacity(MAX_DESERIALIZE_BYTES + 16);
bytes.extend_from_slice(&PERSISTENCE_VERSION.to_le_bytes());
bytes.resize(MAX_DESERIALIZE_BYTES + 1, 0xAA);
let err = RandomCutForest::<2>::from_bytes(&bytes).unwrap_err();
assert!(matches!(err, RcfError::DeserializationFailed(_)));
}
#[test]
fn from_bytes_with_max_size_accepts_higher_cap() {
let f = trained_forest(7, 50);
let bytes = f.to_bytes().unwrap();
let back =
RandomCutForest::<2>::from_bytes_with_max_size(&bytes, MAX_DESERIALIZE_BYTES).unwrap();
assert_eq!(back.updates_seen(), f.updates_seen());
}
#[test]
fn from_bytes_with_max_size_rejects_below_payload_size() {
let f = trained_forest(7, 50);
let bytes = f.to_bytes().unwrap();
let too_tight = bytes.len() - 1;
let err = RandomCutForest::<2>::from_bytes_with_max_size(&bytes, too_tight).unwrap_err();
assert!(matches!(err, RcfError::DeserializationFailed(_)));
}
#[test]
fn updates_seen_counter_roundtrips() {
let f = trained_forest(42, 75);
let before = f.updates_seen();
let bytes = f.to_bytes().unwrap();
let back = RandomCutForest::<2>::from_bytes(&bytes).unwrap();
assert_eq!(back.updates_seen(), before);
}
}
#[cfg(all(test, feature = "serde_json"))]
#[allow(clippy::float_cmp, clippy::cast_precision_loss, clippy::cast_lossless)]
mod json_tests {
use super::*;
use crate::ForestBuilder;
fn small_trained() -> RandomCutForest<2> {
let mut f = ForestBuilder::<2>::new()
.num_trees(50)
.sample_size(8)
.seed(2026)
.build()
.unwrap();
for i in 0..30 {
#[allow(clippy::cast_precision_loss)]
let v = i as f64;
f.update([v, v + 1.0]).unwrap();
}
f
}
#[test]
fn json_roundtrip_preserves_score() {
let f = small_trained();
let json = f.to_json().unwrap();
let back = RandomCutForest::<2>::from_json(&json).unwrap();
let probe = [3.0_f64, 4.0];
let s1: f64 = f.score(&probe).unwrap().into();
let s2: f64 = back.score(&probe).unwrap().into();
assert_eq!(s1, s2);
}
#[test]
fn json_envelope_carries_version_field() {
let f = small_trained();
let json = f.to_json().unwrap();
assert!(json.contains("\"version\""));
assert!(json.contains(&format!(":{PERSISTENCE_VERSION}")));
}
#[test]
fn json_version_mismatch_rejected() {
let f = small_trained();
let json = f.to_json().unwrap();
let bogus = json.replace(
&format!("\"version\":{PERSISTENCE_VERSION}"),
&format!("\"version\":{}", PERSISTENCE_VERSION + 99),
);
let err = RandomCutForest::<2>::from_json(&bogus).unwrap_err();
assert!(matches!(err, RcfError::IncompatibleVersion { .. }));
}
#[test]
fn json_malformed_rejected() {
assert!(matches!(
RandomCutForest::<2>::from_json("not json").unwrap_err(),
RcfError::DeserializationFailed(_)
));
}
#[test]
fn json_oversize_payload_rejected_by_default_cap() {
let f = small_trained();
let json = f.to_json().unwrap();
let err = RandomCutForest::<2>::from_json_with_max_size(&json, json.len() - 1).unwrap_err();
assert!(matches!(err, RcfError::DeserializationFailed(_)));
}
#[test]
fn json_with_max_size_round_trips_at_default_cap() {
let f = small_trained();
let json = f.to_json().unwrap();
let back = RandomCutForest::<2>::from_json_with_max_size(&json, MAX_JSON_BYTES).unwrap();
let probe = [3.0_f64, 4.0];
let s1: f64 = f.score(&probe).unwrap().into();
let s2: f64 = back.score(&probe).unwrap().into();
assert_eq!(s1, s2);
}
}