use crate::apr_transformer::{AprTransformer, ForwardTrace};
use crate::error::{RealizarError, Result};
use crate::inference_trace::save_tensor::WHOLE_MODEL_LAYER;
use crate::inference_trace::save_tensor_emit::maybe_save_stage;
use crate::inference_trace::save_tensor_plan::SaveTensorPlan;
use crate::inference_trace::save_tensor_stage::SaveTensorStage;
#[derive(Debug, thiserror::Error)]
pub enum SaveTensorEmitError {
#[error("save-tensor: failed to create output dir: {0}")]
CreateDir(std::io::Error),
#[error("save-tensor: failed to write tensor: {0}")]
Write(std::io::Error),
#[error("save-tensor: failed to flush: {0}")]
Flush(std::io::Error),
}
impl AprTransformer {
pub fn forward_traced_with_save_tensor(
&self,
token_ids: &[u32],
plan: &SaveTensorPlan,
) -> Result<ForwardTrace> {
self.forward_traced_with_plan(token_ids, Some(plan))
}
}
#[cfg(test)]
mod traced_save_tensor_step2_tests {
use super::{SaveTensorPlan, SaveTensorStage, WHOLE_MODEL_LAYER};
use crate::inference_trace::save_tensor::write_tensor_file;
use crate::inference_trace::save_tensor_paths::ensure_layer_dir;
use std::io::{BufWriter, Write};
fn simulate_step2_lm_head_branch(
plan: &SaveTensorPlan,
logits: &[f32],
) -> std::io::Result<Option<std::path::PathBuf>> {
if !plan.should_save(SaveTensorStage::LmHead, WHOLE_MODEL_LAYER) {
return Ok(None);
}
let path = plan.stage_path(SaveTensorStage::LmHead, WHOLE_MODEL_LAYER);
ensure_layer_dir(&plan.output_dir, WHOLE_MODEL_LAYER)?;
let file = std::fs::File::create(&path)?;
let mut writer = BufWriter::new(file);
write_tensor_file(&mut writer, WHOLE_MODEL_LAYER, logits)?;
writer.flush()?;
Ok(Some(path))
}
#[test]
fn step2_lm_head_writes_to_output_root_not_per_layer_dir() {
let tmp = tempfile::tempdir().unwrap();
let plan = SaveTensorPlan::from_cli("lm_head", "0..1", tmp.path().to_path_buf()).unwrap();
let logits = vec![1.5_f32, 2.5, -3.5, 4.5];
let written = simulate_step2_lm_head_branch(&plan, &logits)
.unwrap()
.expect("LmHead is selected — should write");
assert_eq!(written, tmp.path().join("lm_head.bin"));
assert!(written.exists());
assert!(!tmp.path().join("layer-0").join("lm_head.bin").exists());
}
#[test]
fn step2_lm_head_header_uses_whole_model_sentinel() {
let tmp = tempfile::tempdir().unwrap();
let plan = SaveTensorPlan::from_cli("lm_head", "0..1", tmp.path().to_path_buf()).unwrap();
let logits = vec![0.1_f32, 0.2, 0.3, 0.4, 0.5];
let path = simulate_step2_lm_head_branch(&plan, &logits)
.unwrap()
.unwrap();
let bytes = std::fs::read(&path).unwrap();
assert!(bytes.len() >= 12, "header is 12 bytes minimum");
assert_eq!(&bytes[0..4], b"APRT", "magic must be APRT");
let layer = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
assert_eq!(
layer, WHOLE_MODEL_LAYER,
"whole-model stages must use WHOLE_MODEL_LAYER sentinel"
);
let dim = u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]);
assert_eq!(dim, logits.len() as u32);
assert_eq!(
bytes.len(),
12 + logits.len() * 4,
"total file = 12-byte header + 4 bytes per f32"
);
}
#[test]
fn step2_lm_head_skipped_when_plan_does_not_select_it() {
let tmp = tempfile::tempdir().unwrap();
let plan = SaveTensorPlan::from_cli("embedding", "0..1", tmp.path().to_path_buf()).unwrap();
let logits = vec![9.9_f32];
let written = simulate_step2_lm_head_branch(&plan, &logits).unwrap();
assert!(
written.is_none(),
"LmHead branch must be a no-op when not selected"
);
assert!(!tmp.path().join("lm_head.bin").exists());
}
#[test]
fn step2_lm_head_writes_logits_bytes_verbatim() {
let tmp = tempfile::tempdir().unwrap();
let plan = SaveTensorPlan::from_cli("lm_head", "0..1", tmp.path().to_path_buf()).unwrap();
let logits = vec![1.0_f32, f32::NAN, -0.0, f32::INFINITY];
let path = simulate_step2_lm_head_branch(&plan, &logits)
.unwrap()
.unwrap();
let bytes = std::fs::read(&path).unwrap();
let body = &bytes[12..];
assert_eq!(body.len(), logits.len() * 4);
for (i, expected) in logits.iter().enumerate() {
let read = f32::from_le_bytes([
body[i * 4],
body[i * 4 + 1],
body[i * 4 + 2],
body[i * 4 + 3],
]);
assert_eq!(
read.to_bits(),
expected.to_bits(),
"bit-identical round-trip required at index {i}"
);
}
}
}