cubek_test_utils/test_tensor/
io.rs1use std::fs::File;
27use std::io::{self, BufReader, BufWriter, Read, Write};
28use std::path::Path;
29
30use cubecl::zspace::{Shape, Strides};
31
32use crate::test_tensor::host_data::{HostData, HostDataVec};
33
34const MAGIC: &[u8; 4] = b"CKHD";
35const VERSION: u8 = 1;
36
37const TAG_F32: u8 = 0;
38const TAG_I32: u8 = 1;
39const TAG_BOOL: u8 = 2;
40
41pub fn write_host_data(path: &Path, data: &HostData) -> io::Result<u64> {
46 let f = File::create(path)?;
47 let mut w = BufWriter::new(f);
48
49 w.write_all(MAGIC)?;
50 w.write_all(&[VERSION])?;
51
52 let (tag, elem_count) = match &data.data {
53 HostDataVec::F32(v) => (TAG_F32, v.len()),
54 HostDataVec::I32(v) => (TAG_I32, v.len()),
55 HostDataVec::Bool(v) => (TAG_BOOL, v.len()),
56 };
57 w.write_all(&[tag])?;
58
59 let rank = data.shape.as_slice().len();
60 w.write_all(&(rank as u32).to_le_bytes())?;
61 for d in data.shape.as_slice() {
62 w.write_all(&(*d as u64).to_le_bytes())?;
63 }
64 let strides_slice: &[usize] = &data.strides;
65 if strides_slice.len() != rank {
66 return Err(io::Error::new(
67 io::ErrorKind::InvalidInput,
68 format!(
69 "strides rank {} != shape rank {}",
70 strides_slice.len(),
71 rank,
72 ),
73 ));
74 }
75 for s in strides_slice {
76 w.write_all(&(*s as u64).to_le_bytes())?;
77 }
78 w.write_all(&(elem_count as u64).to_le_bytes())?;
79
80 match &data.data {
81 HostDataVec::F32(v) => w.write_all(bytemuck::cast_slice(v))?,
82 HostDataVec::I32(v) => w.write_all(bytemuck::cast_slice(v))?,
83 HostDataVec::Bool(v) => {
84 for b in v {
87 w.write_all(&[u8::from(*b)])?;
88 }
89 }
90 }
91
92 w.flush()?;
93 Ok(w.into_inner()
94 .map_err(|e| e.into_error())?
95 .metadata()?
96 .len())
97}
98
99pub fn read_host_data(path: &Path) -> io::Result<HostData> {
105 let f = File::open(path)?;
106 let mut r = BufReader::new(f);
107
108 let mut magic = [0u8; 4];
109 r.read_exact(&mut magic)?;
110 if &magic != MAGIC {
111 return Err(invalid("wrong magic — file is not a HostData blob"));
112 }
113 let version = read_u8(&mut r)?;
114 if version != VERSION {
115 return Err(invalid(format!(
116 "unsupported HostData file version: {version} (expected {VERSION})"
117 )));
118 }
119 let tag = read_u8(&mut r)?;
120 let rank = read_u32(&mut r)? as usize;
121
122 let mut shape_dims = Vec::with_capacity(rank);
123 for _ in 0..rank {
124 shape_dims.push(read_u64(&mut r)? as usize);
125 }
126 let mut stride_dims = Vec::with_capacity(rank);
127 for _ in 0..rank {
128 stride_dims.push(read_u64(&mut r)? as usize);
129 }
130 let elem_count = read_u64(&mut r)? as usize;
131
132 let data = match tag {
133 TAG_F32 => {
134 let mut buf = vec![0u8; elem_count * std::mem::size_of::<f32>()];
135 r.read_exact(&mut buf)?;
136 let mut v = Vec::with_capacity(elem_count);
139 for chunk in buf.chunks_exact(4) {
140 v.push(f32::from_le_bytes(chunk.try_into().unwrap()));
141 }
142 HostDataVec::F32(v)
143 }
144 TAG_I32 => {
145 let mut buf = vec![0u8; elem_count * std::mem::size_of::<i32>()];
146 r.read_exact(&mut buf)?;
147 let mut v = Vec::with_capacity(elem_count);
148 for chunk in buf.chunks_exact(4) {
149 v.push(i32::from_le_bytes(chunk.try_into().unwrap()));
150 }
151 HostDataVec::I32(v)
152 }
153 TAG_BOOL => {
154 let mut buf = vec![0u8; elem_count];
155 r.read_exact(&mut buf)?;
156 HostDataVec::Bool(buf.into_iter().map(|b| b != 0).collect())
157 }
158 other => return Err(invalid(format!("unknown HostData dtype tag: {other}"))),
159 };
160
161 Ok(HostData {
162 data,
163 shape: Shape::from(shape_dims),
164 strides: Strides::new(&stride_dims),
165 })
166}
167
168fn read_u8<R: Read>(r: &mut R) -> io::Result<u8> {
169 let mut b = [0u8; 1];
170 r.read_exact(&mut b)?;
171 Ok(b[0])
172}
173
174fn read_u32<R: Read>(r: &mut R) -> io::Result<u32> {
175 let mut b = [0u8; 4];
176 r.read_exact(&mut b)?;
177 Ok(u32::from_le_bytes(b))
178}
179
180fn read_u64<R: Read>(r: &mut R) -> io::Result<u64> {
181 let mut b = [0u8; 8];
182 r.read_exact(&mut b)?;
183 Ok(u64::from_le_bytes(b))
184}
185
186fn invalid<E: Into<String>>(msg: E) -> io::Error {
187 io::Error::new(io::ErrorKind::InvalidData, msg.into())
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 fn round_trip(label: &str, data: HostData) {
195 let dir =
196 std::env::temp_dir().join(format!("cubek-test-utils-iotest-{}", std::process::id(),));
197 std::fs::create_dir_all(&dir).unwrap();
198 let path = dir.join(format!("blob-{label}.bin"));
199 write_host_data(&path, &data).unwrap();
200 let read_back = read_host_data(&path).unwrap();
201 assert_eq!(data.shape, read_back.shape);
202 assert_eq!(data.strides, read_back.strides);
203 match (&data.data, &read_back.data) {
204 (HostDataVec::F32(a), HostDataVec::F32(b)) => assert_eq!(a, b),
205 (HostDataVec::I32(a), HostDataVec::I32(b)) => assert_eq!(a, b),
206 (HostDataVec::Bool(a), HostDataVec::Bool(b)) => assert_eq!(a, b),
207 _ => panic!("dtype mismatch on round-trip"),
208 }
209 let _ = std::fs::remove_file(&path);
210 }
211
212 #[test]
213 fn round_trip_f32() {
214 round_trip(
215 "f32",
216 HostData {
217 data: HostDataVec::F32(vec![1.0, -2.0, std::f32::consts::PI, 0.5, 0.0]),
218 shape: Shape::from(vec![5]),
219 strides: Strides::new(&[1]),
220 },
221 );
222 }
223
224 #[test]
225 fn round_trip_i32_2d() {
226 round_trip(
227 "i32",
228 HostData {
229 data: HostDataVec::I32(vec![1, 2, 3, 4, 5, 6]),
230 shape: Shape::from(vec![2, 3]),
231 strides: Strides::new(&[3, 1]),
232 },
233 );
234 }
235
236 #[test]
237 fn round_trip_bool() {
238 round_trip(
239 "bool",
240 HostData {
241 data: HostDataVec::Bool(vec![true, false, true, true, false]),
242 shape: Shape::from(vec![5]),
243 strides: Strides::new(&[1]),
244 },
245 );
246 }
247}