use std::io::{BufWriter, Write};
use std::path::Path;
use crate::inference_trace::save_tensor::{write_tensor_file, WHOLE_MODEL_LAYER};
use crate::inference_trace::save_tensor_paths::ensure_layer_dir;
use crate::inference_trace::save_tensor_plan::SaveTensorPlan;
use crate::inference_trace::save_tensor_stage::SaveTensorStage;
pub fn maybe_save_stage(
plan: Option<&SaveTensorPlan>,
stage: SaveTensorStage,
layer: u32,
values: &[f32],
) -> std::io::Result<()> {
let Some(plan) = plan else {
return Ok(());
};
if !plan.should_save(stage, layer) {
return Ok(());
}
write_stage_file(&plan.output_dir, stage, layer, values)
}
pub fn write_stage_file(
output_dir: &Path,
stage: SaveTensorStage,
layer: u32,
values: &[f32],
) -> std::io::Result<()> {
let dir_layer = if stage.is_per_layer() {
layer
} else {
WHOLE_MODEL_LAYER
};
ensure_layer_dir(output_dir, dir_layer)?;
let path = output_path_for(output_dir, stage, layer);
let file = std::fs::File::create(&path)?;
let mut writer = BufWriter::new(file);
write_tensor_file(&mut writer, dir_layer, values)?;
writer.flush()?;
Ok(())
}
fn output_path_for(
output_dir: &Path,
stage: SaveTensorStage,
layer: u32,
) -> std::path::PathBuf {
use crate::inference_trace::save_tensor_paths::output_path;
let header_layer = if stage.is_per_layer() {
layer
} else {
WHOLE_MODEL_LAYER
};
output_path(output_dir, header_layer, stage.canonical_name())
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
fn read_back(path: &Path) -> Vec<u8> {
std::fs::read(path).expect("file exists")
}
fn parse_header_bytes(bytes: &[u8]) -> (u32, u32) {
assert!(bytes.len() >= 12, "header is 12 bytes");
assert_eq!(&bytes[0..4], b"APRT", "magic must be APRT");
let layer = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
let dim = u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]);
(layer, dim)
}
#[test]
fn maybe_save_with_none_plan_is_noop() {
let tmp = tempfile::tempdir().unwrap();
let values = vec![1.0_f32, 2.0, 3.0];
let r = maybe_save_stage(None, SaveTensorStage::Embedding, 0, &values);
assert!(r.is_ok());
let entries: Vec<_> = std::fs::read_dir(tmp.path()).unwrap().collect();
assert!(entries.is_empty(), "None plan must not create any files");
}
#[test]
fn maybe_save_with_unselected_stage_is_noop() {
let tmp = tempfile::tempdir().unwrap();
let plan = SaveTensorPlan::from_cli(
"embedding",
"0..1",
tmp.path().to_path_buf(),
)
.unwrap();
let values = vec![9.9_f32];
let r = maybe_save_stage(Some(&plan), SaveTensorStage::FfnGate, 0, &values);
assert!(r.is_ok());
assert!(!tmp.path().join("layer-0").join("ffn_gate.bin").exists());
}
#[test]
fn maybe_save_per_layer_writes_to_layer_subdir() {
let tmp = tempfile::tempdir().unwrap();
let plan = SaveTensorPlan::from_cli(
"embedding",
"0..1",
tmp.path().to_path_buf(),
)
.unwrap();
let values = vec![1.5_f32, 2.5, 3.5, -4.0];
maybe_save_stage(Some(&plan), SaveTensorStage::Embedding, 0, &values).unwrap();
let path = tmp.path().join("layer-0").join("embedding.bin");
assert!(path.exists(), "per-layer stage lands in layer-N/<stage>.bin");
let bytes = read_back(&path);
let (layer, dim) = parse_header_bytes(&bytes);
assert_eq!(layer, 0);
assert_eq!(dim, values.len() as u32);
}
#[test]
fn maybe_save_whole_model_writes_to_root_with_sentinel() {
let tmp = tempfile::tempdir().unwrap();
let plan = SaveTensorPlan::from_cli(
"lm_head",
"0..1",
tmp.path().to_path_buf(),
)
.unwrap();
let values = vec![0.1_f32, 0.2, 0.3];
maybe_save_stage(Some(&plan), SaveTensorStage::LmHead, 42, &values).unwrap();
let path = tmp.path().join("lm_head.bin");
assert!(path.exists(), "whole-model stage lands at <root>/<stage>.bin");
assert!(!tmp.path().join("layer-42").exists());
let bytes = read_back(&path);
let (layer, dim) = parse_header_bytes(&bytes);
assert_eq!(
layer, WHOLE_MODEL_LAYER,
"whole-model header carries the 0xFFFFFFFF sentinel regardless of caller's `layer` arg"
);
assert_eq!(dim, values.len() as u32);
}
#[test]
fn maybe_save_layer_filter_excludes_out_of_range() {
let tmp = tempfile::tempdir().unwrap();
let plan = SaveTensorPlan::from_cli(
"ffn_gate",
"0..2",
tmp.path().to_path_buf(),
)
.unwrap();
let values = vec![1.0_f32];
maybe_save_stage(Some(&plan), SaveTensorStage::FfnGate, 5, &values).unwrap();
assert!(!tmp.path().join("layer-5").join("ffn_gate.bin").exists());
maybe_save_stage(Some(&plan), SaveTensorStage::FfnGate, 1, &values).unwrap();
assert!(tmp.path().join("layer-1").join("ffn_gate.bin").exists());
}
#[test]
fn maybe_save_writes_bytes_verbatim_including_nans() {
let tmp = tempfile::tempdir().unwrap();
let plan = SaveTensorPlan::from_cli(
"embedding",
"0..1",
tmp.path().to_path_buf(),
)
.unwrap();
let values = vec![1.0_f32, f32::NAN, -0.0, f32::INFINITY];
maybe_save_stage(Some(&plan), SaveTensorStage::Embedding, 0, &values).unwrap();
let bytes = read_back(&tmp.path().join("layer-0").join("embedding.bin"));
let body = &bytes[12..];
assert_eq!(body.len(), values.len() * 4);
for (i, expected) in values.iter().enumerate() {
let read = f32::from_le_bytes([
body[i * 4],
body[i * 4 + 1],
body[i * 4 + 2],
body[i * 4 + 3],
]);
assert_eq!(
read.to_bits(),
expected.to_bits(),
"bit-identical round-trip required at index {i}"
);
}
}
#[test]
fn write_stage_file_creates_missing_parent_dirs() {
let tmp = tempfile::tempdir().unwrap();
let nested: PathBuf = tmp.path().join("a").join("b").join("c");
let values = vec![1.0_f32];
write_stage_file(&nested, SaveTensorStage::Embedding, 7, &values).unwrap();
assert!(nested.join("layer-7").join("embedding.bin").exists());
}
}