Skip to main content

cubek_test_utils/test_tensor/
io.rs

1//! On-disk format for [`HostData`].
2//!
3//! Used by the tuner to compare the output of two metal runs (current commit
4//! vs. trusted reference) without materializing both at the same time. Files
5//! are intended to live in a temp directory — there's no forward-compatibility
6//! story beyond a version byte that lets a reader reject an unknown format.
7//!
8//! Layout (little-endian throughout):
9//!
10//! ```text
11//!   offset  size   field
12//!   ------  ----   -----
13//!     0     4      magic = "CKHD"
14//!     4     1      version (currently 1)
15//!     5     1      dtype tag (0=F32, 1=I32, 2=Bool)
16//!     6     4      rank
17//!    10     8*rank shape
18//!    +     8*rank  strides
19//!    +     8       element count
20//!    +     n       packed element bytes
21//! ```
22//!
23//! Booleans are written as one byte each (0/1). The element count is the
24//! length of the packed data array (not `shape.product()` — strides may make
25//! the physical extent larger than the logical one).
26use std::fs::File;
27use std::io::{self, BufReader, BufWriter, Read, Write};
28use std::path::Path;
29
30use cubecl::zspace::{Shape, Strides};
31
32use crate::test_tensor::host_data::{HostData, HostDataVec};
33
34const MAGIC: &[u8; 4] = b"CKHD";
35const VERSION: u8 = 1;
36
37const TAG_F32: u8 = 0;
38const TAG_I32: u8 = 1;
39const TAG_BOOL: u8 = 2;
40
41/// Write `data` to `path` in the binary format documented at the module level.
42///
43/// Truncates any existing file. Returns the number of bytes written so callers
44/// can surface a "wrote N MiB" log line.
45pub fn write_host_data(path: &Path, data: &HostData) -> io::Result<u64> {
46    let f = File::create(path)?;
47    let mut w = BufWriter::new(f);
48
49    w.write_all(MAGIC)?;
50    w.write_all(&[VERSION])?;
51
52    let (tag, elem_count) = match &data.data {
53        HostDataVec::F32(v) => (TAG_F32, v.len()),
54        HostDataVec::I32(v) => (TAG_I32, v.len()),
55        HostDataVec::Bool(v) => (TAG_BOOL, v.len()),
56    };
57    w.write_all(&[tag])?;
58
59    let rank = data.shape.as_slice().len();
60    w.write_all(&(rank as u32).to_le_bytes())?;
61    for d in data.shape.as_slice() {
62        w.write_all(&(*d as u64).to_le_bytes())?;
63    }
64    let strides_slice: &[usize] = &data.strides;
65    if strides_slice.len() != rank {
66        return Err(io::Error::new(
67            io::ErrorKind::InvalidInput,
68            format!(
69                "strides rank {} != shape rank {}",
70                strides_slice.len(),
71                rank,
72            ),
73        ));
74    }
75    for s in strides_slice {
76        w.write_all(&(*s as u64).to_le_bytes())?;
77    }
78    w.write_all(&(elem_count as u64).to_le_bytes())?;
79
80    match &data.data {
81        HostDataVec::F32(v) => w.write_all(bytemuck::cast_slice(v))?,
82        HostDataVec::I32(v) => w.write_all(bytemuck::cast_slice(v))?,
83        HostDataVec::Bool(v) => {
84            // One byte per bool — keeps reads alignment-free and rare enough
85            // not to be worth bit-packing.
86            for b in v {
87                w.write_all(&[u8::from(*b)])?;
88            }
89        }
90    }
91
92    w.flush()?;
93    Ok(w.into_inner()
94        .map_err(|e| e.into_error())?
95        .metadata()?
96        .len())
97}
98
99/// Read a [`HostData`] previously produced by [`write_host_data`].
100///
101/// Errors with `InvalidData` for any header/version/tag mismatch — these
102/// usually mean the file came from a different cubek version and should be
103/// regenerated.
104pub fn read_host_data(path: &Path) -> io::Result<HostData> {
105    let f = File::open(path)?;
106    let mut r = BufReader::new(f);
107
108    let mut magic = [0u8; 4];
109    r.read_exact(&mut magic)?;
110    if &magic != MAGIC {
111        return Err(invalid("wrong magic — file is not a HostData blob"));
112    }
113    let version = read_u8(&mut r)?;
114    if version != VERSION {
115        return Err(invalid(format!(
116            "unsupported HostData file version: {version} (expected {VERSION})"
117        )));
118    }
119    let tag = read_u8(&mut r)?;
120    let rank = read_u32(&mut r)? as usize;
121
122    let mut shape_dims = Vec::with_capacity(rank);
123    for _ in 0..rank {
124        shape_dims.push(read_u64(&mut r)? as usize);
125    }
126    let mut stride_dims = Vec::with_capacity(rank);
127    for _ in 0..rank {
128        stride_dims.push(read_u64(&mut r)? as usize);
129    }
130    let elem_count = read_u64(&mut r)? as usize;
131
132    let data = match tag {
133        TAG_F32 => {
134            let mut buf = vec![0u8; elem_count * std::mem::size_of::<f32>()];
135            r.read_exact(&mut buf)?;
136            // Guaranteed-aligned re-cast: build the Vec<f32> from the byte
137            // chunks rather than transmuting the buffer in place.
138            let mut v = Vec::with_capacity(elem_count);
139            for chunk in buf.chunks_exact(4) {
140                v.push(f32::from_le_bytes(chunk.try_into().unwrap()));
141            }
142            HostDataVec::F32(v)
143        }
144        TAG_I32 => {
145            let mut buf = vec![0u8; elem_count * std::mem::size_of::<i32>()];
146            r.read_exact(&mut buf)?;
147            let mut v = Vec::with_capacity(elem_count);
148            for chunk in buf.chunks_exact(4) {
149                v.push(i32::from_le_bytes(chunk.try_into().unwrap()));
150            }
151            HostDataVec::I32(v)
152        }
153        TAG_BOOL => {
154            let mut buf = vec![0u8; elem_count];
155            r.read_exact(&mut buf)?;
156            HostDataVec::Bool(buf.into_iter().map(|b| b != 0).collect())
157        }
158        other => return Err(invalid(format!("unknown HostData dtype tag: {other}"))),
159    };
160
161    Ok(HostData {
162        data,
163        shape: Shape::from(shape_dims),
164        strides: Strides::new(&stride_dims),
165    })
166}
167
168fn read_u8<R: Read>(r: &mut R) -> io::Result<u8> {
169    let mut b = [0u8; 1];
170    r.read_exact(&mut b)?;
171    Ok(b[0])
172}
173
174fn read_u32<R: Read>(r: &mut R) -> io::Result<u32> {
175    let mut b = [0u8; 4];
176    r.read_exact(&mut b)?;
177    Ok(u32::from_le_bytes(b))
178}
179
180fn read_u64<R: Read>(r: &mut R) -> io::Result<u64> {
181    let mut b = [0u8; 8];
182    r.read_exact(&mut b)?;
183    Ok(u64::from_le_bytes(b))
184}
185
186fn invalid<E: Into<String>>(msg: E) -> io::Error {
187    io::Error::new(io::ErrorKind::InvalidData, msg.into())
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    fn round_trip(label: &str, data: HostData) {
195        let dir =
196            std::env::temp_dir().join(format!("cubek-test-utils-iotest-{}", std::process::id(),));
197        std::fs::create_dir_all(&dir).unwrap();
198        let path = dir.join(format!("blob-{label}.bin"));
199        write_host_data(&path, &data).unwrap();
200        let read_back = read_host_data(&path).unwrap();
201        assert_eq!(data.shape, read_back.shape);
202        assert_eq!(data.strides, read_back.strides);
203        match (&data.data, &read_back.data) {
204            (HostDataVec::F32(a), HostDataVec::F32(b)) => assert_eq!(a, b),
205            (HostDataVec::I32(a), HostDataVec::I32(b)) => assert_eq!(a, b),
206            (HostDataVec::Bool(a), HostDataVec::Bool(b)) => assert_eq!(a, b),
207            _ => panic!("dtype mismatch on round-trip"),
208        }
209        let _ = std::fs::remove_file(&path);
210    }
211
212    #[test]
213    fn round_trip_f32() {
214        round_trip(
215            "f32",
216            HostData {
217                data: HostDataVec::F32(vec![1.0, -2.0, std::f32::consts::PI, 0.5, 0.0]),
218                shape: Shape::from(vec![5]),
219                strides: Strides::new(&[1]),
220            },
221        );
222    }
223
224    #[test]
225    fn round_trip_i32_2d() {
226        round_trip(
227            "i32",
228            HostData {
229                data: HostDataVec::I32(vec![1, 2, 3, 4, 5, 6]),
230                shape: Shape::from(vec![2, 3]),
231                strides: Strides::new(&[3, 1]),
232            },
233        );
234    }
235
236    #[test]
237    fn round_trip_bool() {
238        round_trip(
239            "bool",
240            HostData {
241                data: HostDataVec::Bool(vec![true, false, true, true, false]),
242                shape: Shape::from(vec![5]),
243                strides: Strides::new(&[1]),
244            },
245        );
246    }
247}