1use std::io::{self, Read, Write};
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43#[repr(u32)]
44pub enum DType {
45 F32 = 0,
47 F64 = 1,
49 I32 = 2,
51 I64 = 3,
53 U8 = 4,
55}
56
57impl DType {
58 #[must_use]
60 pub const fn size(&self) -> usize {
61 match self {
62 Self::F32 => 4,
63 Self::F64 => 8,
64 Self::I32 => 4,
65 Self::I64 => 8,
66 Self::U8 => 1,
67 }
68 }
69
70 #[must_use]
72 pub const fn from_u32(v: u32) -> Option<Self> {
73 match v {
74 0 => Some(Self::F32),
75 1 => Some(Self::F64),
76 2 => Some(Self::I32),
77 3 => Some(Self::I64),
78 4 => Some(Self::U8),
79 _ => None,
80 }
81 }
82}
83
84#[derive(Debug, Clone)]
86pub struct Tensor {
87 pub name: String,
89 pub dtype: DType,
91 pub shape: Vec<u32>,
93 pub data: Vec<u8>,
95}
96
97impl Tensor {
98 #[must_use]
100 pub fn new(name: impl Into<String>, dtype: DType, shape: Vec<u32>, data: Vec<u8>) -> Self {
101 Self {
102 name: name.into(),
103 dtype,
104 shape,
105 data,
106 }
107 }
108
109 #[must_use]
111 pub fn numel(&self) -> usize {
112 self.shape.iter().map(|&d| d as usize).product()
113 }
114
115 #[must_use]
117 pub fn expected_size(&self) -> usize {
118 self.numel() * self.dtype.size()
119 }
120
121 #[must_use]
123 pub fn is_valid(&self) -> bool {
124 self.data.len() == self.expected_size()
125 }
126
127 pub fn to_f32_vec(&self) -> Option<Vec<f32>> {
129 if self.dtype != DType::F32 {
130 return None;
131 }
132 let floats: Vec<f32> = self
133 .data
134 .chunks_exact(4)
135 .map(|chunk| {
136 let arr: [u8; 4] = chunk.try_into().expect("chunk size");
137 f32::from_le_bytes(arr)
138 })
139 .collect();
140 Some(floats)
141 }
142
143 #[must_use]
145 pub fn from_f32(name: impl Into<String>, shape: Vec<u32>, data: &[f32]) -> Self {
146 let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
147 Self::new(name, DType::F32, shape, bytes)
148 }
149}
150
151#[derive(Debug, Clone)]
153pub struct AldDataset {
154 pub version: u32,
156 pub tensors: Vec<Tensor>,
158}
159
160const ALD_MAGIC: &[u8; 4] = b"ALD\0";
162
163const ALD_VERSION: u32 = 1;
165
166impl AldDataset {
167 #[must_use]
169 pub fn new() -> Self {
170 Self {
171 version: ALD_VERSION,
172 tensors: Vec::new(),
173 }
174 }
175
176 pub fn add_tensor(&mut self, tensor: Tensor) {
178 self.tensors.push(tensor);
179 }
180
181 #[must_use]
183 pub fn get(&self, name: &str) -> Option<&Tensor> {
184 self.tensors.iter().find(|t| t.name == name)
185 }
186
187 pub fn load(data: &[u8]) -> Result<Self, FormatError> {
193 let mut cursor = io::Cursor::new(data);
194 Self::read_from(&mut cursor)
195 }
196
197 pub fn read_from<R: Read>(reader: &mut R) -> Result<Self, FormatError> {
203 let mut magic = [0u8; 4];
205 reader.read_exact(&mut magic)?;
206 if &magic != ALD_MAGIC {
207 return Err(FormatError::InvalidMagic);
208 }
209
210 let version = read_u32(reader)?;
212 if version > ALD_VERSION {
213 return Err(FormatError::UnsupportedVersion(version));
214 }
215
216 let num_tensors = read_u32(reader)?;
218 let mut tensors = Vec::with_capacity(num_tensors as usize);
219
220 for _ in 0..num_tensors {
221 let tensor = read_tensor(reader)?;
222 tensors.push(tensor);
223 }
224
225 Ok(Self { version, tensors })
226 }
227
228 #[must_use]
230 pub fn save(&self) -> Vec<u8> {
231 let mut data = Vec::new();
232 self.write_to(&mut data).expect("write to vec");
233 data
234 }
235
236 pub fn write_to<W: Write>(&self, writer: &mut W) -> io::Result<()> {
242 writer.write_all(ALD_MAGIC)?;
244
245 write_u32(writer, self.version)?;
247
248 write_u32(writer, self.tensors.len() as u32)?;
250
251 for tensor in &self.tensors {
253 write_tensor(writer, tensor)?;
254 }
255
256 Ok(())
257 }
258}
259
260impl Default for AldDataset {
261 fn default() -> Self {
262 Self::new()
263 }
264}
265
266#[derive(Debug, Clone)]
268pub struct AprModel {
269 pub version: u32,
271 pub model_type: String,
273 pub layers: Vec<ModelLayer>,
275 pub metadata: std::collections::HashMap<String, String>,
277}
278
279#[derive(Debug, Clone)]
281pub struct ModelLayer {
282 pub layer_type: String,
284 pub parameters: Vec<Tensor>,
286}
287
288const APR_MAGIC: &[u8; 4] = b"APR\0";
290
291const APR_VERSION: u32 = 1;
293
294impl AprModel {
295 #[must_use]
297 pub fn new(model_type: impl Into<String>) -> Self {
298 Self {
299 version: APR_VERSION,
300 model_type: model_type.into(),
301 layers: Vec::new(),
302 metadata: std::collections::HashMap::new(),
303 }
304 }
305
306 pub fn add_layer(&mut self, layer: ModelLayer) {
308 self.layers.push(layer);
309 }
310
311 #[must_use]
313 pub fn param_count(&self) -> usize {
314 self.layers
315 .iter()
316 .flat_map(|l| &l.parameters)
317 .map(Tensor::numel)
318 .sum()
319 }
320
321 pub fn load(data: &[u8]) -> Result<Self, FormatError> {
327 let mut cursor = io::Cursor::new(data);
328 Self::read_from(&mut cursor)
329 }
330
331 pub fn read_from<R: Read>(reader: &mut R) -> Result<Self, FormatError> {
337 let mut magic = [0u8; 4];
339 reader.read_exact(&mut magic)?;
340 if &magic != APR_MAGIC {
341 return Err(FormatError::InvalidMagic);
342 }
343
344 let version = read_u32(reader)?;
346 if version > APR_VERSION {
347 return Err(FormatError::UnsupportedVersion(version));
348 }
349
350 let model_type = read_string(reader)?;
352
353 let num_layers = read_u32(reader)?;
355 let mut layers = Vec::with_capacity(num_layers as usize);
356
357 for _ in 0..num_layers {
358 let layer_type = read_string(reader)?;
359 let num_params = read_u32(reader)?;
360 let mut parameters = Vec::with_capacity(num_params as usize);
361
362 for _ in 0..num_params {
363 let tensor = read_tensor(reader)?;
364 parameters.push(tensor);
365 }
366
367 layers.push(ModelLayer {
368 layer_type,
369 parameters,
370 });
371 }
372
373 let mut metadata = std::collections::HashMap::new();
375 while let Ok(key) = read_string(reader) {
376 if let Ok(value) = read_string(reader) {
377 metadata.insert(key, value);
378 } else {
379 break;
380 }
381 }
382
383 Ok(Self {
384 version,
385 model_type,
386 layers,
387 metadata,
388 })
389 }
390
391 #[must_use]
393 pub fn save(&self) -> Vec<u8> {
394 let mut data = Vec::new();
395 self.write_to(&mut data).expect("write to vec");
396 data
397 }
398
399 pub fn write_to<W: Write>(&self, writer: &mut W) -> io::Result<()> {
405 writer.write_all(APR_MAGIC)?;
407
408 write_u32(writer, self.version)?;
410
411 write_string(writer, &self.model_type)?;
413
414 write_u32(writer, self.layers.len() as u32)?;
416 for layer in &self.layers {
417 write_string(writer, &layer.layer_type)?;
418 write_u32(writer, layer.parameters.len() as u32)?;
419 for param in &layer.parameters {
420 write_tensor(writer, param)?;
421 }
422 }
423
424 for (key, value) in &self.metadata {
426 write_string(writer, key)?;
427 write_string(writer, value)?;
428 }
429
430 Ok(())
431 }
432}
433
434#[derive(Debug, Clone, PartialEq)]
436pub enum FormatError {
437 InvalidMagic,
439 UnsupportedVersion(u32),
441 InvalidDType(u32),
443 TruncatedData,
445 IoError(String),
447}
448
449impl std::fmt::Display for FormatError {
450 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
451 match self {
452 Self::InvalidMagic => write!(f, "Invalid file magic bytes"),
453 Self::UnsupportedVersion(v) => write!(f, "Unsupported format version: {v}"),
454 Self::InvalidDType(d) => write!(f, "Invalid dtype: {d}"),
455 Self::TruncatedData => write!(f, "Truncated data"),
456 Self::IoError(e) => write!(f, "IO error: {e}"),
457 }
458 }
459}
460
461impl std::error::Error for FormatError {}
462
463impl From<io::Error> for FormatError {
464 fn from(e: io::Error) -> Self {
465 if e.kind() == io::ErrorKind::UnexpectedEof {
466 Self::TruncatedData
467 } else {
468 Self::IoError(e.to_string())
469 }
470 }
471}
472
473fn read_u32<R: Read>(reader: &mut R) -> Result<u32, FormatError> {
478 let mut buf = [0u8; 4];
479 reader.read_exact(&mut buf)?;
480 Ok(u32::from_le_bytes(buf))
481}
482
483fn write_u32<W: Write>(writer: &mut W, v: u32) -> io::Result<()> {
484 writer.write_all(&v.to_le_bytes())
485}
486
487fn read_string<R: Read>(reader: &mut R) -> Result<String, FormatError> {
488 let len = read_u32(reader)? as usize;
489 let mut buf = vec![0u8; len];
490 reader.read_exact(&mut buf)?;
491 String::from_utf8(buf).map_err(|e| FormatError::IoError(e.to_string()))
492}
493
494fn write_string<W: Write>(writer: &mut W, s: &str) -> io::Result<()> {
495 write_u32(writer, s.len() as u32)?;
496 writer.write_all(s.as_bytes())
497}
498
499fn read_tensor<R: Read>(reader: &mut R) -> Result<Tensor, FormatError> {
500 let name = read_string(reader)?;
501 let dtype_u32 = read_u32(reader)?;
502 let dtype = DType::from_u32(dtype_u32).ok_or(FormatError::InvalidDType(dtype_u32))?;
503
504 let num_dims = read_u32(reader)? as usize;
505 let mut shape = Vec::with_capacity(num_dims);
506 for _ in 0..num_dims {
507 shape.push(read_u32(reader)?);
508 }
509
510 let numel: usize = shape.iter().map(|&d| d as usize).product();
511 let data_size = numel * dtype.size();
512 let mut data = vec![0u8; data_size];
513 reader.read_exact(&mut data)?;
514
515 Ok(Tensor {
516 name,
517 dtype,
518 shape,
519 data,
520 })
521}
522
523fn write_tensor<W: Write>(writer: &mut W, tensor: &Tensor) -> io::Result<()> {
524 write_string(writer, &tensor.name)?;
525 write_u32(writer, tensor.dtype as u32)?;
526 write_u32(writer, tensor.shape.len() as u32)?;
527 for &dim in &tensor.shape {
528 write_u32(writer, dim)?;
529 }
530 writer.write_all(&tensor.data)
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536
537 #[test]
542 fn test_dtype_size() {
543 assert_eq!(DType::F32.size(), 4);
544 assert_eq!(DType::F64.size(), 8);
545 assert_eq!(DType::I32.size(), 4);
546 assert_eq!(DType::I64.size(), 8);
547 assert_eq!(DType::U8.size(), 1);
548 }
549
550 #[test]
551 fn test_dtype_from_u32() {
552 assert_eq!(DType::from_u32(0), Some(DType::F32));
553 assert_eq!(DType::from_u32(1), Some(DType::F64));
554 assert_eq!(DType::from_u32(2), Some(DType::I32));
555 assert_eq!(DType::from_u32(3), Some(DType::I64));
556 assert_eq!(DType::from_u32(4), Some(DType::U8));
557 assert_eq!(DType::from_u32(5), None);
558 }
559
560 #[test]
565 fn test_tensor_numel() {
566 let t = Tensor::new("test", DType::F32, vec![2, 3, 4], vec![0; 96]);
567 assert_eq!(t.numel(), 24);
568 }
569
570 #[test]
571 fn test_tensor_expected_size() {
572 let t = Tensor::new("test", DType::F32, vec![2, 3], vec![]);
573 assert_eq!(t.expected_size(), 24); }
575
576 #[test]
577 fn test_tensor_is_valid() {
578 let valid = Tensor::new("test", DType::F32, vec![2, 3], vec![0; 24]);
579 assert!(valid.is_valid());
580
581 let invalid = Tensor::new("test", DType::F32, vec![2, 3], vec![0; 10]);
582 assert!(!invalid.is_valid());
583 }
584
585 #[test]
586 fn test_tensor_from_f32() {
587 let data = [1.0f32, 2.0, 3.0, 4.0];
588 let t = Tensor::from_f32("weights", vec![2, 2], &data);
589
590 assert_eq!(t.name, "weights");
591 assert_eq!(t.dtype, DType::F32);
592 assert_eq!(t.shape, vec![2, 2]);
593 assert_eq!(t.data.len(), 16);
594
595 let vec = t.to_f32_vec().unwrap();
596 assert_eq!(vec, data.to_vec());
597 }
598
599 #[test]
604 fn test_ald_new() {
605 let ds = AldDataset::new();
606 assert_eq!(ds.version, ALD_VERSION);
607 assert!(ds.tensors.is_empty());
608 }
609
610 #[test]
611 fn test_ald_add_get() {
612 let mut ds = AldDataset::new();
613 ds.add_tensor(Tensor::from_f32("x", vec![10], &[0.0; 10]));
614 ds.add_tensor(Tensor::from_f32("y", vec![5], &[0.0; 5]));
615
616 assert!(ds.get("x").is_some());
617 assert!(ds.get("y").is_some());
618 assert!(ds.get("z").is_none());
619 }
620
621 #[test]
622 fn test_ald_roundtrip() {
623 let mut ds = AldDataset::new();
624 ds.add_tensor(Tensor::from_f32("weights", vec![3, 3], &[1.0; 9]));
625 ds.add_tensor(Tensor::from_f32("bias", vec![3], &[0.5; 3]));
626
627 let bytes = ds.save();
628 let loaded = AldDataset::load(&bytes).unwrap();
629
630 assert_eq!(loaded.version, ds.version);
631 assert_eq!(loaded.tensors.len(), 2);
632 assert_eq!(loaded.get("weights").unwrap().shape, vec![3, 3]);
633 assert_eq!(loaded.get("bias").unwrap().shape, vec![3]);
634 }
635
636 #[test]
637 fn test_ald_invalid_magic() {
638 let result = AldDataset::load(b"BAAD");
639 assert!(matches!(result, Err(FormatError::InvalidMagic)));
640 }
641
642 #[test]
643 fn test_ald_truncated() {
644 let result = AldDataset::load(b"ALD\0");
645 assert!(matches!(result, Err(FormatError::TruncatedData)));
646 }
647
648 #[test]
653 fn test_apr_new() {
654 let model = AprModel::new("mlp");
655 assert_eq!(model.version, APR_VERSION);
656 assert_eq!(model.model_type, "mlp");
657 assert!(model.layers.is_empty());
658 }
659
660 #[test]
661 fn test_apr_param_count() {
662 let mut model = AprModel::new("test");
663 model.add_layer(ModelLayer {
664 layer_type: "dense".to_string(),
665 parameters: vec![
666 Tensor::from_f32("w", vec![10, 5], &[0.0; 50]),
667 Tensor::from_f32("b", vec![5], &[0.0; 5]),
668 ],
669 });
670
671 assert_eq!(model.param_count(), 55);
672 }
673
674 #[test]
675 fn test_apr_roundtrip() {
676 let mut model = AprModel::new("classifier");
677 model.add_layer(ModelLayer {
678 layer_type: "dense".to_string(),
679 parameters: vec![
680 Tensor::from_f32("weight", vec![4, 2], &[1.0; 8]),
681 Tensor::from_f32("bias", vec![2], &[0.1, 0.2]),
682 ],
683 });
684 model
685 .metadata
686 .insert("trained_epochs".to_string(), "100".to_string());
687
688 let bytes = model.save();
689 let loaded = AprModel::load(&bytes).unwrap();
690
691 assert_eq!(loaded.model_type, "classifier");
692 assert_eq!(loaded.layers.len(), 1);
693 assert_eq!(loaded.layers[0].layer_type, "dense");
694 assert_eq!(loaded.layers[0].parameters.len(), 2);
695 }
696
697 #[test]
698 fn test_apr_invalid_magic() {
699 let result = AprModel::load(b"NOPE");
700 assert!(matches!(result, Err(FormatError::InvalidMagic)));
701 }
702
703 #[test]
708 fn test_format_error_display() {
709 assert!(FormatError::InvalidMagic.to_string().contains("magic"));
710 assert!(FormatError::UnsupportedVersion(99)
711 .to_string()
712 .contains("99"));
713 assert!(FormatError::InvalidDType(255).to_string().contains("255"));
714 assert!(FormatError::TruncatedData.to_string().contains("Truncated"));
715 }
716}