Skip to main content

presentar_yaml/
formats.rs

1//! File format loaders for Aprender (.apr) and Alimentar (.ald) files.
2//!
3//! # File Formats
4//!
5//! ## Alimentar Dataset (.ald)
6//!
7//! Binary format for tensor datasets:
8//! ```text
9//! [4 bytes] Magic: "ALD\0"
10//! [4 bytes] Version (u32 LE)
11//! [4 bytes] Num tensors (u32 LE)
12//! For each tensor:
13//!   [4 bytes] Name length (u32 LE)
14//!   [N bytes] Name (UTF-8)
15//!   [4 bytes] Dtype (u32 LE): 0=f32, 1=f64, 2=i32, 3=i64, 4=u8
16//!   [4 bytes] Num dimensions (u32 LE)
17//!   [4*D bytes] Shape (D x u32 LE)
18//!   [N bytes] Data (raw bytes)
19//! ```
20//!
21//! ## Aprender Model (.apr)
22//!
23//! Binary format for trained models:
24//! ```text
25//! [4 bytes] Magic: "APR\0"
26//! [4 bytes] Version (u32 LE)
27//! [4 bytes] Model type length (u32 LE)
28//! [N bytes] Model type (UTF-8)
29//! [4 bytes] Num layers (u32 LE)
30//! For each layer:
31//!   [4 bytes] Layer type length (u32 LE)
32//!   [N bytes] Layer type (UTF-8)
33//!   [4 bytes] Num parameters (u32 LE)
34//!   For each parameter:
35//!     [Tensor data as in ALD]
36//! [Metadata section]
37//! ```
38
39use std::io::{self, Read, Write};
40
41/// Data type for tensor elements.
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43#[repr(u32)]
44pub enum DType {
45    /// 32-bit float
46    F32 = 0,
47    /// 64-bit float
48    F64 = 1,
49    /// 32-bit signed integer
50    I32 = 2,
51    /// 64-bit signed integer
52    I64 = 3,
53    /// 8-bit unsigned integer
54    U8 = 4,
55}
56
57impl DType {
58    /// Get byte size of one element.
59    #[must_use]
60    pub const fn size(&self) -> usize {
61        match self {
62            Self::F32 => 4,
63            Self::F64 => 8,
64            Self::I32 => 4,
65            Self::I64 => 8,
66            Self::U8 => 1,
67        }
68    }
69
70    /// Parse from u32.
71    #[must_use]
72    pub const fn from_u32(v: u32) -> Option<Self> {
73        match v {
74            0 => Some(Self::F32),
75            1 => Some(Self::F64),
76            2 => Some(Self::I32),
77            3 => Some(Self::I64),
78            4 => Some(Self::U8),
79            _ => None,
80        }
81    }
82}
83
84/// A tensor with shape and data.
85#[derive(Debug, Clone)]
86pub struct Tensor {
87    /// Tensor name
88    pub name: String,
89    /// Data type
90    pub dtype: DType,
91    /// Shape dimensions
92    pub shape: Vec<u32>,
93    /// Raw data bytes
94    pub data: Vec<u8>,
95}
96
97impl Tensor {
98    /// Create a new tensor.
99    #[must_use]
100    pub fn new(name: impl Into<String>, dtype: DType, shape: Vec<u32>, data: Vec<u8>) -> Self {
101        Self {
102            name: name.into(),
103            dtype,
104            shape,
105            data,
106        }
107    }
108
109    /// Get number of elements.
110    #[must_use]
111    pub fn numel(&self) -> usize {
112        self.shape.iter().map(|&d| d as usize).product()
113    }
114
115    /// Get expected data size in bytes.
116    #[must_use]
117    pub fn expected_size(&self) -> usize {
118        self.numel() * self.dtype.size()
119    }
120
121    /// Validate tensor data size.
122    #[must_use]
123    pub fn is_valid(&self) -> bool {
124        self.data.len() == self.expected_size()
125    }
126
127    /// Get data as f32 vector (if dtype is F32).
128    pub fn to_f32_vec(&self) -> Option<Vec<f32>> {
129        if self.dtype != DType::F32 {
130            return None;
131        }
132        let floats: Vec<f32> = self
133            .data
134            .chunks_exact(4)
135            .map(|chunk| {
136                let arr: [u8; 4] = chunk.try_into().expect("chunk size");
137                f32::from_le_bytes(arr)
138            })
139            .collect();
140        Some(floats)
141    }
142
143    /// Create f32 tensor from slice.
144    #[must_use]
145    pub fn from_f32(name: impl Into<String>, shape: Vec<u32>, data: &[f32]) -> Self {
146        let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
147        Self::new(name, DType::F32, shape, bytes)
148    }
149}
150
151/// Alimentar dataset (.ald file).
152#[derive(Debug, Clone)]
153pub struct AldDataset {
154    /// Format version
155    pub version: u32,
156    /// Tensors in the dataset
157    pub tensors: Vec<Tensor>,
158}
159
160/// Magic bytes for .ald files.
161const ALD_MAGIC: &[u8; 4] = b"ALD\0";
162
163/// Current .ald format version.
164const ALD_VERSION: u32 = 1;
165
166impl AldDataset {
167    /// Create a new empty dataset.
168    #[must_use]
169    pub fn new() -> Self {
170        Self {
171            version: ALD_VERSION,
172            tensors: Vec::new(),
173        }
174    }
175
176    /// Add a tensor to the dataset.
177    pub fn add_tensor(&mut self, tensor: Tensor) {
178        self.tensors.push(tensor);
179    }
180
181    /// Get tensor by name.
182    #[must_use]
183    pub fn get(&self, name: &str) -> Option<&Tensor> {
184        self.tensors.iter().find(|t| t.name == name)
185    }
186
187    /// Load from bytes.
188    ///
189    /// # Errors
190    ///
191    /// Returns error if the format is invalid.
192    pub fn load(data: &[u8]) -> Result<Self, FormatError> {
193        let mut cursor = io::Cursor::new(data);
194        Self::read_from(&mut cursor)
195    }
196
197    /// Read from a reader.
198    ///
199    /// # Errors
200    ///
201    /// Returns error if the format is invalid.
202    pub fn read_from<R: Read>(reader: &mut R) -> Result<Self, FormatError> {
203        // Read magic
204        let mut magic = [0u8; 4];
205        reader.read_exact(&mut magic)?;
206        if &magic != ALD_MAGIC {
207            return Err(FormatError::InvalidMagic);
208        }
209
210        // Read version
211        let version = read_u32(reader)?;
212        if version > ALD_VERSION {
213            return Err(FormatError::UnsupportedVersion(version));
214        }
215
216        // Read tensor count
217        let num_tensors = read_u32(reader)?;
218        let mut tensors = Vec::with_capacity(num_tensors as usize);
219
220        for _ in 0..num_tensors {
221            let tensor = read_tensor(reader)?;
222            tensors.push(tensor);
223        }
224
225        Ok(Self { version, tensors })
226    }
227
228    /// Write to bytes.
229    #[must_use]
230    pub fn save(&self) -> Vec<u8> {
231        let mut data = Vec::new();
232        self.write_to(&mut data).expect("write to vec");
233        data
234    }
235
236    /// Write to a writer.
237    ///
238    /// # Errors
239    ///
240    /// Returns error if writing fails.
241    pub fn write_to<W: Write>(&self, writer: &mut W) -> io::Result<()> {
242        // Write magic
243        writer.write_all(ALD_MAGIC)?;
244
245        // Write version
246        write_u32(writer, self.version)?;
247
248        // Write tensor count
249        write_u32(writer, self.tensors.len() as u32)?;
250
251        // Write each tensor
252        for tensor in &self.tensors {
253            write_tensor(writer, tensor)?;
254        }
255
256        Ok(())
257    }
258}
259
260impl Default for AldDataset {
261    fn default() -> Self {
262        Self::new()
263    }
264}
265
266/// Aprender model (.apr file).
267#[derive(Debug, Clone)]
268pub struct AprModel {
269    /// Format version
270    pub version: u32,
271    /// Model type (e.g., "linear", "mlp", "transformer")
272    pub model_type: String,
273    /// Model layers
274    pub layers: Vec<ModelLayer>,
275    /// Model metadata
276    pub metadata: std::collections::HashMap<String, String>,
277}
278
279/// A model layer with parameters.
280#[derive(Debug, Clone)]
281pub struct ModelLayer {
282    /// Layer type (e.g., "dense", "conv2d", "attention")
283    pub layer_type: String,
284    /// Layer parameters (weights, biases)
285    pub parameters: Vec<Tensor>,
286}
287
288/// Magic bytes for .apr files.
289const APR_MAGIC: &[u8; 4] = b"APR\0";
290
291/// Current .apr format version.
292const APR_VERSION: u32 = 1;
293
294impl AprModel {
295    /// Create a new model.
296    #[must_use]
297    pub fn new(model_type: impl Into<String>) -> Self {
298        Self {
299            version: APR_VERSION,
300            model_type: model_type.into(),
301            layers: Vec::new(),
302            metadata: std::collections::HashMap::new(),
303        }
304    }
305
306    /// Add a layer to the model.
307    pub fn add_layer(&mut self, layer: ModelLayer) {
308        self.layers.push(layer);
309    }
310
311    /// Get total parameter count.
312    #[must_use]
313    pub fn param_count(&self) -> usize {
314        self.layers
315            .iter()
316            .flat_map(|l| &l.parameters)
317            .map(Tensor::numel)
318            .sum()
319    }
320
321    /// Load from bytes.
322    ///
323    /// # Errors
324    ///
325    /// Returns error if the format is invalid.
326    pub fn load(data: &[u8]) -> Result<Self, FormatError> {
327        let mut cursor = io::Cursor::new(data);
328        Self::read_from(&mut cursor)
329    }
330
331    /// Read from a reader.
332    ///
333    /// # Errors
334    ///
335    /// Returns error if the format is invalid.
336    pub fn read_from<R: Read>(reader: &mut R) -> Result<Self, FormatError> {
337        // Read magic
338        let mut magic = [0u8; 4];
339        reader.read_exact(&mut magic)?;
340        if &magic != APR_MAGIC {
341            return Err(FormatError::InvalidMagic);
342        }
343
344        // Read version
345        let version = read_u32(reader)?;
346        if version > APR_VERSION {
347            return Err(FormatError::UnsupportedVersion(version));
348        }
349
350        // Read model type
351        let model_type = read_string(reader)?;
352
353        // Read layers
354        let num_layers = read_u32(reader)?;
355        let mut layers = Vec::with_capacity(num_layers as usize);
356
357        for _ in 0..num_layers {
358            let layer_type = read_string(reader)?;
359            let num_params = read_u32(reader)?;
360            let mut parameters = Vec::with_capacity(num_params as usize);
361
362            for _ in 0..num_params {
363                let tensor = read_tensor(reader)?;
364                parameters.push(tensor);
365            }
366
367            layers.push(ModelLayer {
368                layer_type,
369                parameters,
370            });
371        }
372
373        // Read metadata (optional, remaining bytes)
374        let mut metadata = std::collections::HashMap::new();
375        while let Ok(key) = read_string(reader) {
376            if let Ok(value) = read_string(reader) {
377                metadata.insert(key, value);
378            } else {
379                break;
380            }
381        }
382
383        Ok(Self {
384            version,
385            model_type,
386            layers,
387            metadata,
388        })
389    }
390
391    /// Write to bytes.
392    #[must_use]
393    pub fn save(&self) -> Vec<u8> {
394        let mut data = Vec::new();
395        self.write_to(&mut data).expect("write to vec");
396        data
397    }
398
399    /// Write to a writer.
400    ///
401    /// # Errors
402    ///
403    /// Returns error if writing fails.
404    pub fn write_to<W: Write>(&self, writer: &mut W) -> io::Result<()> {
405        // Write magic
406        writer.write_all(APR_MAGIC)?;
407
408        // Write version
409        write_u32(writer, self.version)?;
410
411        // Write model type
412        write_string(writer, &self.model_type)?;
413
414        // Write layers
415        write_u32(writer, self.layers.len() as u32)?;
416        for layer in &self.layers {
417            write_string(writer, &layer.layer_type)?;
418            write_u32(writer, layer.parameters.len() as u32)?;
419            for param in &layer.parameters {
420                write_tensor(writer, param)?;
421            }
422        }
423
424        // Write metadata
425        for (key, value) in &self.metadata {
426            write_string(writer, key)?;
427            write_string(writer, value)?;
428        }
429
430        Ok(())
431    }
432}
433
434/// Format parsing error.
435#[derive(Debug, Clone, PartialEq)]
436pub enum FormatError {
437    /// Invalid magic bytes
438    InvalidMagic,
439    /// Unsupported format version
440    UnsupportedVersion(u32),
441    /// Invalid data type
442    InvalidDType(u32),
443    /// Truncated data
444    TruncatedData,
445    /// IO error
446    IoError(String),
447}
448
449impl std::fmt::Display for FormatError {
450    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
451        match self {
452            Self::InvalidMagic => write!(f, "Invalid file magic bytes"),
453            Self::UnsupportedVersion(v) => write!(f, "Unsupported format version: {v}"),
454            Self::InvalidDType(d) => write!(f, "Invalid dtype: {d}"),
455            Self::TruncatedData => write!(f, "Truncated data"),
456            Self::IoError(e) => write!(f, "IO error: {e}"),
457        }
458    }
459}
460
461impl std::error::Error for FormatError {}
462
463impl From<io::Error> for FormatError {
464    fn from(e: io::Error) -> Self {
465        if e.kind() == io::ErrorKind::UnexpectedEof {
466            Self::TruncatedData
467        } else {
468            Self::IoError(e.to_string())
469        }
470    }
471}
472
473// =============================================================================
474// Helper functions
475// =============================================================================
476
477fn read_u32<R: Read>(reader: &mut R) -> Result<u32, FormatError> {
478    let mut buf = [0u8; 4];
479    reader.read_exact(&mut buf)?;
480    Ok(u32::from_le_bytes(buf))
481}
482
483fn write_u32<W: Write>(writer: &mut W, v: u32) -> io::Result<()> {
484    writer.write_all(&v.to_le_bytes())
485}
486
487fn read_string<R: Read>(reader: &mut R) -> Result<String, FormatError> {
488    let len = read_u32(reader)? as usize;
489    let mut buf = vec![0u8; len];
490    reader.read_exact(&mut buf)?;
491    String::from_utf8(buf).map_err(|e| FormatError::IoError(e.to_string()))
492}
493
494fn write_string<W: Write>(writer: &mut W, s: &str) -> io::Result<()> {
495    write_u32(writer, s.len() as u32)?;
496    writer.write_all(s.as_bytes())
497}
498
499fn read_tensor<R: Read>(reader: &mut R) -> Result<Tensor, FormatError> {
500    let name = read_string(reader)?;
501    let dtype_u32 = read_u32(reader)?;
502    let dtype = DType::from_u32(dtype_u32).ok_or(FormatError::InvalidDType(dtype_u32))?;
503
504    let num_dims = read_u32(reader)? as usize;
505    let mut shape = Vec::with_capacity(num_dims);
506    for _ in 0..num_dims {
507        shape.push(read_u32(reader)?);
508    }
509
510    let numel: usize = shape.iter().map(|&d| d as usize).product();
511    let data_size = numel * dtype.size();
512    let mut data = vec![0u8; data_size];
513    reader.read_exact(&mut data)?;
514
515    Ok(Tensor {
516        name,
517        dtype,
518        shape,
519        data,
520    })
521}
522
523fn write_tensor<W: Write>(writer: &mut W, tensor: &Tensor) -> io::Result<()> {
524    write_string(writer, &tensor.name)?;
525    write_u32(writer, tensor.dtype as u32)?;
526    write_u32(writer, tensor.shape.len() as u32)?;
527    for &dim in &tensor.shape {
528        write_u32(writer, dim)?;
529    }
530    writer.write_all(&tensor.data)
531}
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536
537    // =========================================================================
538    // DType tests
539    // =========================================================================
540
541    #[test]
542    fn test_dtype_size() {
543        assert_eq!(DType::F32.size(), 4);
544        assert_eq!(DType::F64.size(), 8);
545        assert_eq!(DType::I32.size(), 4);
546        assert_eq!(DType::I64.size(), 8);
547        assert_eq!(DType::U8.size(), 1);
548    }
549
550    #[test]
551    fn test_dtype_from_u32() {
552        assert_eq!(DType::from_u32(0), Some(DType::F32));
553        assert_eq!(DType::from_u32(1), Some(DType::F64));
554        assert_eq!(DType::from_u32(2), Some(DType::I32));
555        assert_eq!(DType::from_u32(3), Some(DType::I64));
556        assert_eq!(DType::from_u32(4), Some(DType::U8));
557        assert_eq!(DType::from_u32(5), None);
558    }
559
560    // =========================================================================
561    // Tensor tests
562    // =========================================================================
563
564    #[test]
565    fn test_tensor_numel() {
566        let t = Tensor::new("test", DType::F32, vec![2, 3, 4], vec![0; 96]);
567        assert_eq!(t.numel(), 24);
568    }
569
570    #[test]
571    fn test_tensor_expected_size() {
572        let t = Tensor::new("test", DType::F32, vec![2, 3], vec![]);
573        assert_eq!(t.expected_size(), 24); // 6 elements * 4 bytes
574    }
575
576    #[test]
577    fn test_tensor_is_valid() {
578        let valid = Tensor::new("test", DType::F32, vec![2, 3], vec![0; 24]);
579        assert!(valid.is_valid());
580
581        let invalid = Tensor::new("test", DType::F32, vec![2, 3], vec![0; 10]);
582        assert!(!invalid.is_valid());
583    }
584
585    #[test]
586    fn test_tensor_from_f32() {
587        let data = [1.0f32, 2.0, 3.0, 4.0];
588        let t = Tensor::from_f32("weights", vec![2, 2], &data);
589
590        assert_eq!(t.name, "weights");
591        assert_eq!(t.dtype, DType::F32);
592        assert_eq!(t.shape, vec![2, 2]);
593        assert_eq!(t.data.len(), 16);
594
595        let vec = t.to_f32_vec().unwrap();
596        assert_eq!(vec, data.to_vec());
597    }
598
599    // =========================================================================
600    // AldDataset tests
601    // =========================================================================
602
603    #[test]
604    fn test_ald_new() {
605        let ds = AldDataset::new();
606        assert_eq!(ds.version, ALD_VERSION);
607        assert!(ds.tensors.is_empty());
608    }
609
610    #[test]
611    fn test_ald_add_get() {
612        let mut ds = AldDataset::new();
613        ds.add_tensor(Tensor::from_f32("x", vec![10], &[0.0; 10]));
614        ds.add_tensor(Tensor::from_f32("y", vec![5], &[0.0; 5]));
615
616        assert!(ds.get("x").is_some());
617        assert!(ds.get("y").is_some());
618        assert!(ds.get("z").is_none());
619    }
620
621    #[test]
622    fn test_ald_roundtrip() {
623        let mut ds = AldDataset::new();
624        ds.add_tensor(Tensor::from_f32("weights", vec![3, 3], &[1.0; 9]));
625        ds.add_tensor(Tensor::from_f32("bias", vec![3], &[0.5; 3]));
626
627        let bytes = ds.save();
628        let loaded = AldDataset::load(&bytes).unwrap();
629
630        assert_eq!(loaded.version, ds.version);
631        assert_eq!(loaded.tensors.len(), 2);
632        assert_eq!(loaded.get("weights").unwrap().shape, vec![3, 3]);
633        assert_eq!(loaded.get("bias").unwrap().shape, vec![3]);
634    }
635
636    #[test]
637    fn test_ald_invalid_magic() {
638        let result = AldDataset::load(b"BAAD");
639        assert!(matches!(result, Err(FormatError::InvalidMagic)));
640    }
641
642    #[test]
643    fn test_ald_truncated() {
644        let result = AldDataset::load(b"ALD\0");
645        assert!(matches!(result, Err(FormatError::TruncatedData)));
646    }
647
648    // =========================================================================
649    // AprModel tests
650    // =========================================================================
651
652    #[test]
653    fn test_apr_new() {
654        let model = AprModel::new("mlp");
655        assert_eq!(model.version, APR_VERSION);
656        assert_eq!(model.model_type, "mlp");
657        assert!(model.layers.is_empty());
658    }
659
660    #[test]
661    fn test_apr_param_count() {
662        let mut model = AprModel::new("test");
663        model.add_layer(ModelLayer {
664            layer_type: "dense".to_string(),
665            parameters: vec![
666                Tensor::from_f32("w", vec![10, 5], &[0.0; 50]),
667                Tensor::from_f32("b", vec![5], &[0.0; 5]),
668            ],
669        });
670
671        assert_eq!(model.param_count(), 55);
672    }
673
674    #[test]
675    fn test_apr_roundtrip() {
676        let mut model = AprModel::new("classifier");
677        model.add_layer(ModelLayer {
678            layer_type: "dense".to_string(),
679            parameters: vec![
680                Tensor::from_f32("weight", vec![4, 2], &[1.0; 8]),
681                Tensor::from_f32("bias", vec![2], &[0.1, 0.2]),
682            ],
683        });
684        model
685            .metadata
686            .insert("trained_epochs".to_string(), "100".to_string());
687
688        let bytes = model.save();
689        let loaded = AprModel::load(&bytes).unwrap();
690
691        assert_eq!(loaded.model_type, "classifier");
692        assert_eq!(loaded.layers.len(), 1);
693        assert_eq!(loaded.layers[0].layer_type, "dense");
694        assert_eq!(loaded.layers[0].parameters.len(), 2);
695    }
696
697    #[test]
698    fn test_apr_invalid_magic() {
699        let result = AprModel::load(b"NOPE");
700        assert!(matches!(result, Err(FormatError::InvalidMagic)));
701    }
702
703    // =========================================================================
704    // FormatError tests
705    // =========================================================================
706
707    #[test]
708    fn test_format_error_display() {
709        assert!(FormatError::InvalidMagic.to_string().contains("magic"));
710        assert!(FormatError::UnsupportedVersion(99)
711            .to_string()
712            .contains("99"));
713        assert!(FormatError::InvalidDType(255).to_string().contains("255"));
714        assert!(FormatError::TruncatedData.to_string().contains("Truncated"));
715    }
716}