use crate::{Result, TreeBoostError};
use fs4::fs_std::FileExt;
use serde::{Deserialize, Serialize};
use std::fs::{File, OpenOptions};
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::Path;
#[cfg(feature = "mmap")]
use memmap2::Mmap;
pub const TRB_MAGIC: &[u8; 4] = b"TRB1";
pub const FORMAT_VERSION: u32 = 1;
const RKYV_ALIGNMENT: usize = 8;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrbHeader {
pub format_version: u32,
pub model_type: String,
pub created_at: u64,
pub boosting_mode: String,
pub num_features: usize,
pub base_blob_size: u64,
#[serde(default)]
pub metadata: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum UpdateType {
Linear,
Trees,
Preprocessor,
Snapshot,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrbUpdateHeader {
pub update_type: UpdateType,
pub created_at: u64,
pub rows_trained: usize,
#[serde(default)]
pub description: String,
}
#[derive(Debug)]
pub enum TrbSegment {
Base { header: TrbHeader, blob: Vec<u8> },
Update {
header: TrbUpdateHeader,
blob: Vec<u8>,
},
}
pub struct TrbWriter {
file: File,
header: TrbHeader,
}
impl TrbWriter {
pub fn new(path: impl AsRef<Path>, mut header: TrbHeader, base_blob: &[u8]) -> Result<Self> {
let mut file = OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(path.as_ref())?;
file.try_lock_exclusive().map_err(|e| {
TreeBoostError::Serialization(format!("Failed to acquire file lock: {}", e))
})?;
header.base_blob_size = base_blob.len() as u64;
file.write_all(TRB_MAGIC)?;
let header_json = serde_json::to_vec(&header).map_err(|e| {
TreeBoostError::Serialization(format!("Failed to serialize header: {}", e))
})?;
file.write_all(&(header_json.len() as u64).to_le_bytes())?;
file.write_all(&header_json)?;
let current_pos = 4 + 8 + header_json.len();
let padding = alignment_padding(current_pos);
if padding > 0 {
file.write_all(&vec![0u8; padding])?;
}
file.write_all(base_blob)?;
let crc = crc32fast::hash(base_blob);
file.write_all(&crc.to_le_bytes())?;
file.flush()?;
Ok(Self { file, header })
}
pub fn append_update(
&mut self,
update_header: &TrbUpdateHeader,
update_blob: &[u8],
) -> Result<()> {
self.file.seek(SeekFrom::End(0))?;
let header_json = serde_json::to_vec(update_header).map_err(|e| {
TreeBoostError::Serialization(format!("Failed to serialize update header: {}", e))
})?;
let header_section_size = 8 + 8 + header_json.len();
let padding = alignment_padding(header_section_size);
let padded_header_json_len = header_json.len() + padding;
let total_size = 8 + padded_header_json_len + update_blob.len() + 4;
self.file.write_all(&(total_size as u64).to_le_bytes())?;
self.file
.write_all(&(padded_header_json_len as u64).to_le_bytes())?;
self.file.write_all(&header_json)?;
if padding > 0 {
self.file.write_all(&vec![0u8; padding])?;
}
self.file.write_all(update_blob)?;
let crc = crc32fast::hash(update_blob);
self.file.write_all(&crc.to_le_bytes())?;
self.file.flush()?;
Ok(())
}
pub fn header(&self) -> &TrbHeader {
&self.header
}
}
impl Drop for TrbWriter {
fn drop(&mut self) {
let _ = self.file.unlock();
}
}
pub struct TrbReader {
file: File,
header: TrbHeader,
base_blob_offset: u64,
}
impl TrbReader {
pub fn open(path: impl AsRef<Path>) -> Result<Self> {
let mut file = File::open(path.as_ref())?;
file.try_lock_shared().map_err(|e| {
TreeBoostError::Serialization(format!("Failed to acquire file lock: {}", e))
})?;
let mut magic = [0u8; 4];
file.read_exact(&mut magic)?;
if &magic != TRB_MAGIC {
return Err(TreeBoostError::Serialization(format!(
"Invalid TRB magic: expected {:?}, got {:?}",
TRB_MAGIC, magic
)));
}
let mut header_size_bytes = [0u8; 8];
file.read_exact(&mut header_size_bytes)?;
let header_size = u64::from_le_bytes(header_size_bytes) as usize;
let mut header_json = vec![0u8; header_size];
file.read_exact(&mut header_json)?;
let header: TrbHeader = serde_json::from_slice(&header_json)
.map_err(|e| TreeBoostError::Serialization(format!("Failed to parse header: {}", e)))?;
let current_pos = 4 + 8 + header_size;
let padding = alignment_padding(current_pos);
let base_blob_offset = (current_pos + padding) as u64;
Ok(Self {
file,
header,
base_blob_offset,
})
}
pub fn header(&self) -> &TrbHeader {
&self.header
}
pub fn read_base_blob(&mut self) -> Result<Vec<u8>> {
let blob_size = self.header.base_blob_size as usize;
self.file.seek(SeekFrom::Start(self.base_blob_offset))?;
let mut blob = vec![0u8; blob_size];
self.file.read_exact(&mut blob)?;
let mut crc_bytes = [0u8; 4];
self.file.read_exact(&mut crc_bytes)?;
let stored_crc = u32::from_le_bytes(crc_bytes);
let computed_crc = crc32fast::hash(&blob);
if stored_crc != computed_crc {
return Err(TreeBoostError::Serialization(format!(
"Base blob CRC mismatch: stored={:#x}, computed={:#x}",
stored_crc, computed_crc
)));
}
Ok(blob)
}
pub fn iter_updates(&mut self) -> Result<Vec<(TrbUpdateHeader, Vec<u8>)>> {
let mut updates = Vec::new();
let mut pos = self.base_blob_offset + self.header.base_blob_size + 4;
self.file.seek(SeekFrom::Start(pos))?;
let file_size = self.file.seek(SeekFrom::End(0))?;
self.file.seek(SeekFrom::Start(pos))?;
let mut segment_index = 0;
while pos < file_size {
if pos + 8 > file_size {
eprintln!(
"Warning: Incomplete update segment {} at offset {} (truncated total_size)",
segment_index, pos
);
break;
}
let mut total_size_bytes = [0u8; 8];
if self.file.read_exact(&mut total_size_bytes).is_err() {
eprintln!(
"Warning: Failed to read update segment {} at offset {}",
segment_index, pos
);
break;
}
let total_size = u64::from_le_bytes(total_size_bytes) as usize;
if pos + 8 + total_size as u64 > file_size {
eprintln!(
"Warning: Incomplete update segment {} at offset {} (expected {} bytes, have {})",
segment_index, pos, total_size, file_size - pos - 8
);
break;
}
let mut header_size_bytes = [0u8; 8];
self.file.read_exact(&mut header_size_bytes)?;
let header_size = u64::from_le_bytes(header_size_bytes) as usize;
let mut header_json = vec![0u8; header_size];
self.file.read_exact(&mut header_json)?;
let json_end = header_json
.iter()
.rposition(|&b| b != 0)
.map(|i| i + 1)
.unwrap_or(0);
let header_json_trimmed = &header_json[..json_end];
let update_header: TrbUpdateHeader = match serde_json::from_slice(header_json_trimmed) {
Ok(h) => h,
Err(e) => {
eprintln!(
"Warning: Failed to parse update header at segment {}: {}",
segment_index, e
);
pos += 8 + total_size as u64;
self.file.seek(SeekFrom::Start(pos))?;
segment_index += 1;
continue;
}
};
let blob_size = total_size - 8 - header_size - 4;
let mut blob = vec![0u8; blob_size];
self.file.read_exact(&mut blob)?;
let mut crc_bytes = [0u8; 4];
self.file.read_exact(&mut crc_bytes)?;
let stored_crc = u32::from_le_bytes(crc_bytes);
let computed_crc = crc32fast::hash(&blob);
if stored_crc != computed_crc {
eprintln!(
"Warning: Update segment {} CRC mismatch (stored={:#x}, computed={:#x})",
segment_index, stored_crc, computed_crc
);
break;
}
updates.push((update_header, blob));
pos += 8 + total_size as u64;
segment_index += 1;
}
Ok(updates)
}
pub fn load_all_segments(&mut self) -> Result<Vec<TrbSegment>> {
let mut segments = Vec::new();
let base_blob = self.read_base_blob()?;
segments.push(TrbSegment::Base {
header: self.header.clone(),
blob: base_blob,
});
for (update_header, blob) in self.iter_updates()? {
segments.push(TrbSegment::Update {
header: update_header,
blob,
});
}
Ok(segments)
}
}
impl Drop for TrbReader {
fn drop(&mut self) {
let _ = self.file.unlock();
}
}
pub fn open_for_append(path: impl AsRef<Path>) -> Result<TrbWriter> {
let mut file = OpenOptions::new()
.read(true)
.write(true)
.open(path.as_ref())?;
file.try_lock_exclusive().map_err(|e| {
TreeBoostError::Serialization(format!("Failed to acquire file lock: {}", e))
})?;
let mut magic = [0u8; 4];
file.read_exact(&mut magic)?;
if &magic != TRB_MAGIC {
return Err(TreeBoostError::Serialization(format!(
"Invalid TRB magic: expected {:?}, got {:?}",
TRB_MAGIC, magic
)));
}
let mut header_size_bytes = [0u8; 8];
file.read_exact(&mut header_size_bytes)?;
let header_size = u64::from_le_bytes(header_size_bytes) as usize;
let mut header_json = vec![0u8; header_size];
file.read_exact(&mut header_json)?;
let header: TrbHeader = serde_json::from_slice(&header_json)
.map_err(|e| TreeBoostError::Serialization(format!("Failed to parse header: {}", e)))?;
Ok(TrbWriter { file, header })
}
fn alignment_padding(current_pos: usize) -> usize {
let remainder = current_pos % RKYV_ALIGNMENT;
if remainder == 0 {
0
} else {
RKYV_ALIGNMENT - remainder
}
}
#[cfg(feature = "mmap")]
pub struct MmapTrbReader {
mmap: Mmap,
header: TrbHeader,
base_blob_offset: usize,
_file: File,
}
#[cfg(feature = "mmap")]
impl std::fmt::Debug for MmapTrbReader {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MmapTrbReader")
.field("header", &self.header)
.field("base_blob_offset", &self.base_blob_offset)
.field("mapped_size", &self.mmap.len())
.finish()
}
}
#[cfg(feature = "mmap")]
impl MmapTrbReader {
pub fn open(path: impl AsRef<Path>) -> Result<Self> {
let file = File::open(path.as_ref())?;
file.try_lock_shared().map_err(|e| {
TreeBoostError::Serialization(format!("Failed to acquire file lock: {}", e))
})?;
let mmap = unsafe {
Mmap::map(&file).map_err(|e| {
TreeBoostError::Serialization(format!("Failed to memory map file: {}", e))
})?
};
if mmap.len() < 4 {
return Err(TreeBoostError::Serialization(
"File too small for TRB format".to_string(),
));
}
if &mmap[0..4] != TRB_MAGIC {
return Err(TreeBoostError::Serialization(format!(
"Invalid TRB magic: expected {:?}, got {:?}",
TRB_MAGIC,
&mmap[0..4]
)));
}
if mmap.len() < 12 {
return Err(TreeBoostError::Serialization(
"File too small for header size".to_string(),
));
}
let header_size =
u64::from_le_bytes(mmap[4..12].try_into().expect("bounds checked above")) as usize;
if mmap.len() < 12 + header_size {
return Err(TreeBoostError::Serialization(
"File too small for header".to_string(),
));
}
let header: TrbHeader = serde_json::from_slice(&mmap[12..12 + header_size])
.map_err(|e| TreeBoostError::Serialization(format!("Failed to parse header: {}", e)))?;
let current_pos = 4 + 8 + header_size;
let padding = alignment_padding(current_pos);
let base_blob_offset = current_pos + padding;
Ok(Self {
mmap,
header,
base_blob_offset,
_file: file,
})
}
pub fn header(&self) -> &TrbHeader {
&self.header
}
pub fn base_blob_bytes(&self) -> Result<&[u8]> {
let blob_size = self.header.base_blob_size as usize;
let blob_end = self.base_blob_offset + blob_size;
let crc_end = blob_end + 4;
if self.mmap.len() < crc_end {
return Err(TreeBoostError::Serialization(
"File truncated: missing base blob or CRC".to_string(),
));
}
let blob = &self.mmap[self.base_blob_offset..blob_end];
let stored_crc = u32::from_le_bytes(
self.mmap[blob_end..crc_end]
.try_into()
.expect("crc_end bounds checked above"),
);
let computed_crc = crc32fast::hash(blob);
if stored_crc != computed_crc {
return Err(TreeBoostError::Serialization(format!(
"Base blob CRC mismatch: stored={:#x}, computed={:#x}",
stored_crc, computed_crc
)));
}
Ok(blob)
}
pub fn archived_model(&self) -> Result<&rkyv::Archived<crate::model::UniversalModel>> {
let blob = self.base_blob_bytes()?;
debug_assert_eq!(
blob.as_ptr() as usize % 8,
0,
"rkyv blob must be 8-byte aligned"
);
let archived =
unsafe { rkyv::access_unchecked::<rkyv::Archived<crate::model::UniversalModel>>(blob) };
Ok(archived)
}
pub fn load_model(&self) -> Result<crate::model::UniversalModel> {
use rkyv::rancor::Error as RkyvError;
let blob = self.base_blob_bytes()?;
let model: crate::model::UniversalModel =
rkyv::from_bytes::<_, RkyvError>(blob).map_err(|e| {
TreeBoostError::Serialization(format!("Failed to deserialize model: {}", e))
})?;
Ok(model)
}
pub fn mapped_size(&self) -> usize {
self.mmap.len()
}
pub fn base_blob_offset(&self) -> usize {
self.base_blob_offset
}
pub fn iter_updates(&self) -> Result<Vec<(TrbUpdateHeader, &[u8])>> {
let mut updates = Vec::new();
let mut pos = self.base_blob_offset + self.header.base_blob_size as usize + 4;
let file_size = self.mmap.len();
let mut segment_index = 0;
while pos < file_size {
if pos + 8 > file_size {
eprintln!(
"Warning: Incomplete update segment {} at offset {} (truncated total_size)",
segment_index, pos
);
break;
}
let total_size = u64::from_le_bytes(
self.mmap[pos..pos + 8]
.try_into()
.expect("pos+8 bounds checked above"),
) as usize;
if pos + 8 + total_size > file_size {
eprintln!(
"Warning: Incomplete update segment {} at offset {} (expected {} bytes, have {})",
segment_index, pos, total_size, file_size - pos - 8
);
break;
}
let header_size = u64::from_le_bytes(
self.mmap[pos + 8..pos + 16]
.try_into()
.expect("pos+16 within total_size bounds"),
) as usize;
let header_start = pos + 16;
let header_bytes = &self.mmap[header_start..header_start + header_size];
let json_end = header_bytes
.iter()
.rposition(|&b| b != 0)
.map(|i| i + 1)
.unwrap_or(0);
let header_json_trimmed = &header_bytes[..json_end];
let update_header: TrbUpdateHeader = match serde_json::from_slice(header_json_trimmed) {
Ok(h) => h,
Err(e) => {
eprintln!(
"Warning: Failed to parse update header at segment {}: {}",
segment_index, e
);
pos += 8 + total_size;
segment_index += 1;
continue;
}
};
let blob_start = header_start + header_size;
let blob_size = total_size - 8 - header_size - 4;
let blob_end = blob_start + blob_size;
let crc_end = blob_end + 4;
let blob = &self.mmap[blob_start..blob_end];
let stored_crc = u32::from_le_bytes(
self.mmap[blob_end..crc_end]
.try_into()
.expect("crc_end within total_size bounds"),
);
let computed_crc = crc32fast::hash(blob);
if stored_crc != computed_crc {
eprintln!(
"Warning: Update segment {} CRC mismatch (stored={:#x}, computed={:#x})",
segment_index, stored_crc, computed_crc
);
break;
}
updates.push((update_header, blob));
pos += 8 + total_size;
segment_index += 1;
}
Ok(updates)
}
}
#[cfg(feature = "mmap")]
impl Drop for MmapTrbReader {
fn drop(&mut self) {
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::{SystemTime, UNIX_EPOCH};
use tempfile::tempdir;
fn current_timestamp() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
}
fn create_test_header() -> TrbHeader {
TrbHeader {
format_version: FORMAT_VERSION,
model_type: "universal".to_string(),
created_at: current_timestamp(),
boosting_mode: "PureTree".to_string(),
num_features: 10,
base_blob_size: 0, metadata: "Test model".to_string(),
}
}
#[test]
fn test_trb_write_and_read_base() {
let dir = tempdir().unwrap();
let path = dir.path().join("model.trb");
let base_blob = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let header = create_test_header();
let writer = TrbWriter::new(&path, header.clone(), &base_blob).unwrap();
drop(writer);
let mut reader = TrbReader::open(&path).unwrap();
assert_eq!(reader.header().format_version, FORMAT_VERSION);
assert_eq!(reader.header().model_type, "universal");
assert_eq!(reader.header().num_features, 10);
assert_eq!(reader.header().base_blob_size, base_blob.len() as u64);
let loaded_blob = reader.read_base_blob().unwrap();
assert_eq!(loaded_blob, base_blob);
}
#[test]
fn test_trb_append_update() {
let dir = tempdir().unwrap();
let path = dir.path().join("model.trb");
let base_blob = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
let header = create_test_header();
let writer = TrbWriter::new(&path, header, &base_blob).unwrap();
drop(writer);
let initial_size = std::fs::metadata(&path).unwrap().len();
let mut writer = open_for_append(&path).unwrap();
let update_header = TrbUpdateHeader {
update_type: UpdateType::Trees,
created_at: current_timestamp(),
rows_trained: 500,
description: "Update 1".to_string(),
};
let update_blob = vec![10u8, 20, 30, 40];
writer.append_update(&update_header, &update_blob).unwrap();
drop(writer);
let new_size = std::fs::metadata(&path).unwrap().len();
assert!(new_size > initial_size);
let mut reader = TrbReader::open(&path).unwrap();
let segments = reader.load_all_segments().unwrap();
assert_eq!(segments.len(), 2);
assert!(
matches!(&segments[0], TrbSegment::Base { .. }),
"Expected base segment at index 0"
);
if let TrbSegment::Base { blob, .. } = &segments[0] {
assert_eq!(blob, &base_blob);
}
assert!(
matches!(&segments[1], TrbSegment::Update { .. }),
"Expected update segment at index 1"
);
if let TrbSegment::Update { header, blob } = &segments[1] {
assert_eq!(header.update_type, UpdateType::Trees);
assert_eq!(header.rows_trained, 500);
assert_eq!(blob, &update_blob);
}
}
#[test]
fn test_trb_corrupt_recovery() {
let dir = tempdir().unwrap();
let path = dir.path().join("model.trb");
let base_blob = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
let header = create_test_header();
let writer = TrbWriter::new(&path, header, &base_blob).unwrap();
drop(writer);
let mut writer = open_for_append(&path).unwrap();
let update_header = TrbUpdateHeader {
update_type: UpdateType::Trees,
created_at: current_timestamp(),
rows_trained: 500,
description: "Update 1".to_string(),
};
let update_blob = vec![10u8, 20, 30, 40, 50, 60, 70, 80];
writer.append_update(&update_header, &update_blob).unwrap();
drop(writer);
let file_size = std::fs::metadata(&path).unwrap().len();
let file = OpenOptions::new().write(true).open(&path).unwrap();
file.set_len(file_size - 10).unwrap();
drop(file);
let mut reader = TrbReader::open(&path).unwrap();
let base = reader.read_base_blob().unwrap();
assert_eq!(base, base_blob);
let updates = reader.iter_updates().unwrap();
assert!(updates.is_empty(), "Truncated update should be ignored");
}
#[test]
fn test_trb_crc_detects_corruption() {
let dir = tempdir().unwrap();
let path = dir.path().join("model.trb");
let base_blob = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
let header = create_test_header();
let writer = TrbWriter::new(&path, header, &base_blob).unwrap();
drop(writer);
let mut file = OpenOptions::new()
.read(true)
.write(true)
.open(&path)
.unwrap();
file.seek(SeekFrom::Start(4)).unwrap(); let mut header_size_bytes = [0u8; 8];
file.read_exact(&mut header_size_bytes).unwrap();
let header_size = u64::from_le_bytes(header_size_bytes) as usize;
let current_pos = 4 + 8 + header_size;
let padding = alignment_padding(current_pos);
let blob_offset = current_pos + padding;
file.seek(SeekFrom::Start(blob_offset as u64)).unwrap();
let mut byte = [0u8; 1];
file.read_exact(&mut byte).unwrap();
byte[0] ^= 0xFF; file.seek(SeekFrom::Start(blob_offset as u64)).unwrap();
file.write_all(&byte).unwrap();
drop(file);
let mut reader = TrbReader::open(&path).unwrap();
let result = reader.read_base_blob();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("CRC mismatch"));
}
#[test]
fn test_trb_multiple_updates() {
let dir = tempdir().unwrap();
let path = dir.path().join("model.trb");
let base_blob = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
let header = create_test_header();
let writer = TrbWriter::new(&path, header, &base_blob).unwrap();
drop(writer);
for i in 0..5 {
let mut writer = open_for_append(&path).unwrap();
let update_header = TrbUpdateHeader {
update_type: UpdateType::Trees,
created_at: current_timestamp(),
rows_trained: (i + 1) * 100,
description: format!("Update {}", i + 1),
};
let update_blob = vec![(i + 10) as u8; 8];
writer.append_update(&update_header, &update_blob).unwrap();
drop(writer);
}
let mut reader = TrbReader::open(&path).unwrap();
let segments = reader.load_all_segments().unwrap();
assert_eq!(segments.len(), 6);
for (i, segment) in segments.iter().enumerate().skip(1) {
assert!(
matches!(segment, TrbSegment::Update { .. }),
"Expected update segment at index {}",
i
);
if let TrbSegment::Update { header, blob } = segment {
assert_eq!(header.rows_trained, i * 100);
assert_eq!(blob, &vec![(i + 9) as u8; 8]);
}
}
}
#[test]
fn test_trb_update_crc_validation() {
let dir = tempdir().unwrap();
let path = dir.path().join("model.trb");
let base_blob = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
let header = create_test_header();
let writer = TrbWriter::new(&path, header, &base_blob).unwrap();
drop(writer);
for i in 0..2 {
let mut writer = open_for_append(&path).unwrap();
let update_header = TrbUpdateHeader {
update_type: UpdateType::Trees,
created_at: current_timestamp(),
rows_trained: (i + 1) * 100,
description: format!("U{}", i + 1), };
let update_blob = vec![(i + 10) as u8; 128];
writer.append_update(&update_header, &update_blob).unwrap();
drop(writer);
}
let reader = TrbReader::open(&path).unwrap();
let base_end = reader.base_blob_offset + reader.header.base_blob_size + 4;
drop(reader);
let mut file = OpenOptions::new()
.read(true)
.write(true)
.open(&path)
.unwrap();
file.seek(SeekFrom::Start(base_end)).unwrap();
let mut total_size_bytes = [0u8; 8];
file.read_exact(&mut total_size_bytes).unwrap();
let mut header_size_bytes = [0u8; 8];
file.read_exact(&mut header_size_bytes).unwrap();
let header_size = u64::from_le_bytes(header_size_bytes) as u64;
let blob_start = base_end + 8 + 8 + header_size;
let corrupt_offset = blob_start + 64;
file.seek(SeekFrom::Start(corrupt_offset)).unwrap();
let mut byte = [0u8; 1];
file.read_exact(&mut byte).unwrap();
byte[0] ^= 0xFF;
file.seek(SeekFrom::Start(corrupt_offset)).unwrap();
file.write_all(&byte).unwrap();
drop(file);
let mut reader = TrbReader::open(&path).unwrap();
let base = reader.read_base_blob().unwrap();
assert_eq!(base, base_blob);
let updates = reader.iter_updates().unwrap();
assert!(
updates.is_empty(),
"Corrupted update should break the chain"
);
}
#[test]
fn test_trb_unknown_json_fields_ignored() {
let dir = tempdir().unwrap();
let path = dir.path().join("model.trb");
let base_blob = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
let mut file = File::create(&path).unwrap();
file.write_all(TRB_MAGIC).unwrap();
let header_json = serde_json::json!({
"format_version": FORMAT_VERSION,
"model_type": "universal",
"created_at": current_timestamp(),
"boosting_mode": "PureTree",
"num_features": 10,
"base_blob_size": base_blob.len(),
"metadata": "Test",
"future_field": "some_value", "another_future": 42
});
let header_bytes = serde_json::to_vec(&header_json).unwrap();
file.write_all(&(header_bytes.len() as u64).to_le_bytes())
.unwrap();
file.write_all(&header_bytes).unwrap();
let current_pos = 4 + 8 + header_bytes.len();
let padding = alignment_padding(current_pos);
if padding > 0 {
file.write_all(&vec![0u8; padding]).unwrap();
}
file.write_all(&base_blob).unwrap();
let crc = crc32fast::hash(&base_blob);
file.write_all(&crc.to_le_bytes()).unwrap();
drop(file);
let mut reader = TrbReader::open(&path).unwrap();
assert_eq!(reader.header().num_features, 10);
let blob = reader.read_base_blob().unwrap();
assert_eq!(blob, base_blob);
}
#[test]
fn test_trb_rkyv_alignment() {
let dir = tempdir().unwrap();
let path = dir.path().join("model.trb");
let header = TrbHeader {
format_version: FORMAT_VERSION,
model_type: "u".to_string(), created_at: 12345,
boosting_mode: "P".to_string(),
num_features: 1,
base_blob_size: 0,
metadata: "".to_string(),
};
let base_blob = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
let writer = TrbWriter::new(&path, header, &base_blob).unwrap();
drop(writer);
let mut reader = TrbReader::open(&path).unwrap();
assert_eq!(
reader.base_blob_offset % 8,
0,
"Base blob should be 8-byte aligned"
);
let blob = reader.read_base_blob().unwrap();
assert_eq!(blob, base_blob);
}
#[cfg(feature = "mmap")]
mod mmap_tests {
use super::*;
#[test]
fn test_mmap_reader_basic() {
let dir = tempdir().unwrap();
let path = dir.path().join("model.trb");
let base_blob = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
let header = create_test_header();
let writer = TrbWriter::new(&path, header.clone(), &base_blob).unwrap();
drop(writer);
let reader = MmapTrbReader::open(&path).unwrap();
assert_eq!(reader.header().format_version, FORMAT_VERSION);
assert_eq!(reader.header().model_type, "universal");
assert_eq!(reader.header().num_features, 10);
let blob = reader.base_blob_bytes().unwrap();
assert_eq!(blob, base_blob.as_slice());
assert!(reader.mapped_size() > 0);
assert_eq!(reader.base_blob_offset() % 8, 0);
}
#[test]
fn test_mmap_reader_crc_validation() {
let dir = tempdir().unwrap();
let path = dir.path().join("model.trb");
let base_blob = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
let header = create_test_header();
let writer = TrbWriter::new(&path, header, &base_blob).unwrap();
drop(writer);
let mut file = OpenOptions::new()
.read(true)
.write(true)
.open(&path)
.unwrap();
file.seek(SeekFrom::Start(4)).unwrap();
let mut header_size_bytes = [0u8; 8];
file.read_exact(&mut header_size_bytes).unwrap();
let header_size = u64::from_le_bytes(header_size_bytes) as usize;
let current_pos = 4 + 8 + header_size;
let padding = alignment_padding(current_pos);
let blob_offset = current_pos + padding;
file.seek(SeekFrom::Start(blob_offset as u64)).unwrap();
let mut byte = [0u8; 1];
file.read_exact(&mut byte).unwrap();
byte[0] ^= 0xFF;
file.seek(SeekFrom::Start(blob_offset as u64)).unwrap();
file.write_all(&byte).unwrap();
drop(file);
let reader = MmapTrbReader::open(&path).unwrap();
let result = reader.base_blob_bytes();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("CRC mismatch"));
}
#[test]
fn test_mmap_reader_with_updates() {
let dir = tempdir().unwrap();
let path = dir.path().join("model.trb");
let base_blob = vec![1u8, 2, 3, 4, 5, 6, 7, 8];
let header = create_test_header();
let writer = TrbWriter::new(&path, header, &base_blob).unwrap();
drop(writer);
for i in 0..3 {
let mut writer = open_for_append(&path).unwrap();
let update_header = TrbUpdateHeader {
update_type: UpdateType::Trees,
created_at: current_timestamp(),
rows_trained: (i + 1) * 100,
description: format!("Update {}", i + 1),
};
let update_blob = vec![(i + 10) as u8; 16];
writer.append_update(&update_header, &update_blob).unwrap();
drop(writer);
}
let reader = MmapTrbReader::open(&path).unwrap();
let blob = reader.base_blob_bytes().unwrap();
assert_eq!(blob, base_blob.as_slice());
let updates = reader.iter_updates().unwrap();
assert_eq!(updates.len(), 3);
for (i, (header, blob)) in updates.iter().enumerate() {
assert_eq!(header.update_type, UpdateType::Trees);
assert_eq!(header.rows_trained, (i + 1) * 100);
assert_eq!(blob.len(), 16);
assert_eq!(blob[0], (i + 10) as u8);
}
}
#[test]
fn test_mmap_reader_invalid_magic() {
let dir = tempdir().unwrap();
let path = dir.path().join("bad.trb");
std::fs::write(&path, b"BAD!invalid data").unwrap();
let result = MmapTrbReader::open(&path);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Invalid TRB magic"));
}
#[test]
fn test_mmap_reader_truncated_file() {
let dir = tempdir().unwrap();
let path = dir.path().join("truncated.trb");
std::fs::write(&path, TRB_MAGIC).unwrap();
let result = MmapTrbReader::open(&path);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("too small"));
}
#[test]
fn test_mmap_vs_standard_reader_equivalence() {
let dir = tempdir().unwrap();
let path = dir.path().join("model.trb");
let base_blob = vec![1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
let header = create_test_header();
let writer = TrbWriter::new(&path, header, &base_blob).unwrap();
drop(writer);
let mut writer = open_for_append(&path).unwrap();
let update_header = TrbUpdateHeader {
update_type: UpdateType::Trees,
created_at: current_timestamp(),
rows_trained: 500,
description: "Test update".to_string(),
};
let update_blob = vec![20u8; 32];
writer.append_update(&update_header, &update_blob).unwrap();
drop(writer);
let mut std_reader = TrbReader::open(&path).unwrap();
let std_base = std_reader.read_base_blob().unwrap();
let std_updates = std_reader.iter_updates().unwrap();
let mmap_reader = MmapTrbReader::open(&path).unwrap();
let mmap_base = mmap_reader.base_blob_bytes().unwrap();
let mmap_updates = mmap_reader.iter_updates().unwrap();
assert_eq!(std_base, mmap_base);
assert_eq!(std_updates.len(), mmap_updates.len());
for ((std_hdr, std_blob), (mmap_hdr, mmap_blob)) in
std_updates.iter().zip(mmap_updates.iter())
{
assert_eq!(std_hdr.update_type, mmap_hdr.update_type);
assert_eq!(std_hdr.rows_trained, mmap_hdr.rows_trained);
assert_eq!(std_blob.as_slice(), *mmap_blob);
}
}
}
}