1use std::io::{Read, Seek, Write};
15
16use diskann_wide::{LoHi, SplitJoin};
17use thiserror::Error;
18
19use crate::views::{Matrix, MatrixView};
20
21pub fn read_bin<T>(reader: &mut (impl Read + Seek)) -> Result<Matrix<T>, ReadBinError>
25where
26 T: bytemuck::Pod,
27{
28 let metadata = Metadata::read(reader)?;
29 let (npoints, ndims) = (metadata.npoints(), metadata.ndims());
30 let type_size = std::mem::size_of::<T>();
31
32 let expected_bytes = npoints
33 .checked_mul(ndims)
34 .and_then(|n| n.checked_mul(type_size))
35 .ok_or(ReadBinError::Overflow {
36 npoints: metadata.npoints_u32(),
37 ndims: metadata.ndims_u32(),
38 type_size,
39 })?;
40
41 let data_start = reader.stream_position()?;
42 let end = reader.seek(std::io::SeekFrom::End(0))?;
43 let available = end - data_start;
44 reader.seek(std::io::SeekFrom::Start(data_start))?;
45
46 if available < expected_bytes as u64 {
47 return Err(ReadBinError::SizeMismatch {
48 expected: expected_bytes as u64,
49 available,
50 npoints: metadata.npoints_u32(),
51 ndims: metadata.ndims_u32(),
52 type_size,
53 });
54 }
55
56 let mut data = Matrix::new(<T as bytemuck::Zeroable>::zeroed(), npoints, ndims);
57
58 reader.read_exact(bytemuck::must_cast_slice_mut::<T, u8>(data.as_mut_slice()))?;
59 Ok(data)
60}
61
62pub fn write_bin<T>(data: MatrixView<'_, T>, writer: &mut impl Write) -> Result<usize, SaveBinError>
66where
67 T: bytemuck::Pod,
68{
69 let metadata =
70 Metadata::new(data.nrows(), data.ncols()).map_err(|_| SaveBinError::DimensionOverflow {
71 nrows: data.nrows(),
72 ncols: data.ncols(),
73 })?;
74 let bytes = metadata.write(writer)?;
75 writer.write_all(bytemuck::must_cast_slice::<T, u8>(data.as_slice()))?;
76 Ok(bytes + std::mem::size_of_val(data.as_slice()))
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub struct Metadata {
82 npoints: u32,
83 ndims: u32,
84}
85
86impl Metadata {
87 pub fn new<T, U>(npoints: T, ndims: U) -> Result<Self, MetadataError<T::Error, U::Error>>
89 where
90 T: TryInto<u32>,
91 U: TryInto<u32>,
92 {
93 Ok(Self {
94 npoints: npoints.try_into().map_err(MetadataError::NumPoints)?,
95 ndims: ndims.try_into().map_err(MetadataError::Dim)?,
96 })
97 }
98
99 pub fn npoints(&self) -> usize {
101 self.npoints as usize
102 }
103
104 pub fn npoints_u32(&self) -> u32 {
106 self.npoints
107 }
108
109 pub fn ndims(&self) -> usize {
111 self.ndims as usize
112 }
113
114 pub fn ndims_u32(&self) -> u32 {
116 self.ndims
117 }
118
119 pub fn into_dims(&self) -> (usize, usize) {
121 (self.npoints(), self.ndims())
122 }
123
124 pub fn read<R>(reader: &mut R) -> std::io::Result<Self>
126 where
127 R: Read,
128 {
129 let mut bytes = [0u8; 8];
130 reader.read_exact(&mut bytes)?;
131
132 let LoHi {
133 lo: npts_bytes,
134 hi: ndims_bytes,
135 } = bytes.split();
136
137 let npoints = u32::from_le_bytes(npts_bytes);
138 let ndims = u32::from_le_bytes(ndims_bytes);
139 Ok(Metadata { npoints, ndims })
140 }
141
142 pub fn write<W>(&self, writer: &mut W) -> std::io::Result<usize>
144 where
145 W: Write,
146 {
147 let bytes: [u8; 8] = LoHi::new(self.npoints.to_le_bytes(), self.ndims.to_le_bytes()).join();
148 writer.write_all(&bytes)?;
149 Ok(2 * std::mem::size_of::<u32>())
150 }
151}
152
153#[derive(Debug, Error)]
154pub enum MetadataError<T, U> {
155 #[error("num points conversion")]
156 NumPoints(#[source] T),
157 #[error("dim conversion")]
158 Dim(#[source] U),
159}
160
161#[derive(Debug, Error)]
163pub enum ReadBinError {
164 #[error(
166 "binary data too short: header declares {npoints} points × {ndims} dims × {type_size} bytes = \
167 {expected} bytes, but only {available} bytes available"
168 )]
169 SizeMismatch {
170 expected: u64,
171 available: u64,
172 npoints: u32,
173 ndims: u32,
174 type_size: usize,
175 },
176
177 #[error(
179 "header dimensions overflow: {npoints} points × {ndims} dims × {type_size} bytes overflows"
180 )]
181 Overflow {
182 npoints: u32,
183 ndims: u32,
184 type_size: usize,
185 },
186
187 #[error(transparent)]
189 Io(#[from] std::io::Error),
190}
191
192#[derive(Debug, Error)]
194pub enum SaveBinError {
195 #[error("dimensions overflow u32: {nrows} rows × {ncols} cols")]
197 DimensionOverflow { nrows: usize, ncols: usize },
198
199 #[error(transparent)]
201 Io(#[from] std::io::Error),
202}
203
204#[cfg(test)]
209mod tests {
210 use std::io::Cursor;
211
212 use crate::views::Init;
213
214 use super::*;
215
216 #[test]
217 fn round_trip_f32() {
218 let mut counter = 1.0f32;
219 let matrix = Matrix::<f32>::new(
220 Init(|| {
221 let v = counter;
222 counter += 1.0;
223 v
224 }),
225 3,
226 4,
227 );
228
229 assert_eq!(
230 matrix.as_slice(),
231 &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]
232 );
233
234 let mut buf = Vec::new();
235 let written = write_bin(matrix.as_view(), &mut buf).unwrap();
236 assert_eq!(written, 8 + 3 * 4 * 4);
237
238 let mut cursor = Cursor::new(&buf);
239 let loaded = read_bin::<f32>(&mut cursor).unwrap();
240 assert_eq!(loaded.nrows(), 3);
241 assert_eq!(loaded.ncols(), 4);
242 assert_eq!(loaded.as_slice(), matrix.as_slice());
243 }
244
245 #[test]
246 fn read_bin_size_mismatch() {
247 let mut buf = Vec::new();
249 let metadata = Metadata::new(10u32, 4u32).unwrap();
250 metadata.write(&mut buf).unwrap();
251 buf.extend_from_slice(&[0u8; 8]);
252
253 let mut cursor = Cursor::new(&buf);
254 let err = read_bin::<f32>(&mut cursor).unwrap_err();
255
256 match err {
257 ReadBinError::SizeMismatch {
258 expected,
259 available,
260 npoints,
261 ndims,
262 type_size,
263 } => {
264 assert_eq!(expected, 10 * 4 * 4);
265 assert_eq!(available, 8);
266 assert_eq!(npoints, 10);
267 assert_eq!(ndims, 4);
268 assert_eq!(type_size, 4);
269 }
270 other => panic!("expected SizeMismatch, got: {other}"),
271 }
272 }
273
274 #[test]
275 fn read_bin_overflow() {
276 let mut buf = Vec::new();
278 buf.extend_from_slice(&u32::MAX.to_le_bytes());
279 buf.extend_from_slice(&u32::MAX.to_le_bytes());
280
281 let mut cursor = Cursor::new(&buf);
282 let err = read_bin::<f32>(&mut cursor).unwrap_err();
283
284 match err {
285 ReadBinError::Overflow {
286 npoints,
287 ndims,
288 type_size,
289 } => {
290 assert_eq!(npoints, u32::MAX);
291 assert_eq!(ndims, u32::MAX);
292 assert_eq!(type_size, 4);
293 }
294 other => panic!("expected Overflow, got: {other}"),
295 }
296 }
297
298 #[test]
299 fn read_bin_error_message_is_informative() {
300 let mut buf = Vec::new();
301 let metadata = Metadata::new(100u32, 32u32).unwrap();
302 metadata.write(&mut buf).unwrap();
303 let mut cursor = Cursor::new(&buf);
306 let err = read_bin::<f32>(&mut cursor).unwrap_err();
307 let msg = err.to_string();
308
309 assert!(msg.contains("100 points"), "missing npoints: {msg}");
310 assert!(msg.contains("32 dims"), "missing ndims: {msg}");
311 assert!(msg.contains("12800 bytes"), "missing expected: {msg}");
312 assert!(
313 msg.contains("0 bytes available"),
314 "missing available: {msg}"
315 );
316 }
317
318 #[test]
319 fn metadata_read_write_round_trip() {
320 let mut buf = Vec::new();
321 let metadata = Metadata::new(200u32, 128u32).unwrap();
322 metadata.write(&mut buf).unwrap();
323
324 let mut cursor = Cursor::new(&buf);
325 let loaded = Metadata::read(&mut cursor).unwrap();
326 assert_eq!(loaded, metadata);
327 }
328}