Skip to main content

alimentar/format/
mod.rs

1//! Alimentar Dataset Format (.ald)
2//!
3//! A binary format for secure, verifiable dataset distribution.
4//! See `docs/specifications/dataset-format-spec.md` for full specification.
5//!
6//! # Format Structure
7//!
8//! ```text
9//! ┌─────────────────────────────────────────┐
10//! │ Header (32 bytes, fixed)                │
11//! ├─────────────────────────────────────────┤
12//! │ Metadata (variable, MessagePack)        │
13//! ├─────────────────────────────────────────┤
14//! │ Schema (variable, Arrow IPC)            │
15//! ├─────────────────────────────────────────┤
16//! │ Payload (variable, Arrow IPC + zstd)    │
17//! ├─────────────────────────────────────────┤
18//! │ Checksum (4 bytes, CRC32)               │
19//! └─────────────────────────────────────────┘
20//! ```
21//!
22//! # Example
23//!
24//! ```ignore
25//! use alimentar::format::{save, load, SaveOptions, DatasetType};
26//!
27//! // Save dataset
28//! save(&dataset, DatasetType::Tabular, "data.ald", SaveOptions::default())?;
29//!
30//! // Load dataset
31//! let dataset = load("data.ald")?;
32//! ```
33
34mod crc;
35#[cfg(feature = "format-encryption")]
36pub mod encryption;
37pub mod license;
38pub mod piracy;
39#[cfg(feature = "format-signing")]
40pub mod signing;
41#[cfg(feature = "format-streaming")]
42pub mod streaming;
43
44pub use crc::crc32;
45use serde::{Deserialize, Serialize};
46
47use crate::error::{Error, Result};
48
49/// Magic bytes: "ALDF" (0x414C4446)
50pub const MAGIC: [u8; 4] = [0x41, 0x4C, 0x44, 0x46];
51
52/// Current format version major number
53pub const FORMAT_VERSION_MAJOR: u8 = 1;
54/// Current format version minor number
55pub const FORMAT_VERSION_MINOR: u8 = 2;
56
57/// Header size in bytes (fixed)
58pub const HEADER_SIZE: usize = 32;
59
60/// Header flags (bit positions)
61pub mod flags {
62    /// Payload encrypted (AES-256-GCM)
63    pub const ENCRYPTED: u8 = 0b0000_0001;
64    /// Has digital signature (Ed25519)
65    pub const SIGNED: u8 = 0b0000_0010;
66    /// Supports chunked/mmap loading (native only)
67    pub const STREAMING: u8 = 0b0000_0100;
68    /// Has commercial license block
69    pub const LICENSED: u8 = 0b0000_1000;
70    /// 64-byte aligned arrays for zero-copy SIMD (native only)
71    pub const TRUENO_NATIVE: u8 = 0b0001_0000;
72}
73
74/// Dataset type identifiers (§3.1)
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
76#[repr(u16)]
77pub enum DatasetType {
78    // Structured Data (0x0001-0x000F)
79    /// Generic columnar data
80    Tabular = 0x0001,
81    /// Temporal sequences
82    TimeSeries = 0x0002,
83    /// Node/edge structures
84    Graph = 0x0003,
85    /// Geospatial coordinates
86    Spatial = 0x0004,
87
88    // Text & NLP (0x0010-0x001F)
89    /// Raw text documents
90    TextCorpus = 0x0010,
91    /// Labeled text
92    TextClassification = 0x0011,
93    /// Sentence pairs (NLI, STS)
94    TextPairs = 0x0012,
95    /// Token-level labels (NER)
96    SequenceLabeling = 0x0013,
97    /// QA datasets (SQuAD-style)
98    QuestionAnswering = 0x0014,
99    /// Document + summary pairs
100    Summarization = 0x0015,
101    /// Parallel corpora
102    Translation = 0x0016,
103
104    // Computer Vision (0x0020-0x002F)
105    /// Images + class labels
106    ImageClassification = 0x0020,
107    /// Images + bounding boxes
108    ObjectDetection = 0x0021,
109    /// Images + pixel masks
110    Segmentation = 0x0022,
111    /// Image matching/similarity
112    ImagePairs = 0x0023,
113    /// Sequential frames
114    Video = 0x0024,
115
116    // Audio (0x0030-0x003F)
117    /// Audio + class labels
118    AudioClassification = 0x0030,
119    /// Audio + transcripts (ASR)
120    SpeechRecognition = 0x0031,
121    /// Audio + speaker labels
122    SpeakerIdentification = 0x0032,
123
124    // Recommender Systems (0x0040-0x004F)
125    /// Collaborative filtering
126    UserItemRatings = 0x0040,
127    /// Click/view interactions
128    ImplicitFeedback = 0x0041,
129    /// Session-based sequences
130    SequentialRecs = 0x0042,
131
132    // Multimodal (0x0050-0x005F)
133    /// Image captioning/VQA
134    ImageText = 0x0050,
135    /// Speech-to-text pairs
136    AudioText = 0x0051,
137    /// Video descriptions
138    VideoText = 0x0052,
139
140    // Special
141    /// User extensions
142    Custom = 0x00FF,
143}
144
145impl DatasetType {
146    /// Convert from u16 value
147    #[must_use]
148    pub fn from_u16(value: u16) -> Option<Self> {
149        match value {
150            0x0001 => Some(Self::Tabular),
151            0x0002 => Some(Self::TimeSeries),
152            0x0003 => Some(Self::Graph),
153            0x0004 => Some(Self::Spatial),
154            0x0010 => Some(Self::TextCorpus),
155            0x0011 => Some(Self::TextClassification),
156            0x0012 => Some(Self::TextPairs),
157            0x0013 => Some(Self::SequenceLabeling),
158            0x0014 => Some(Self::QuestionAnswering),
159            0x0015 => Some(Self::Summarization),
160            0x0016 => Some(Self::Translation),
161            0x0020 => Some(Self::ImageClassification),
162            0x0021 => Some(Self::ObjectDetection),
163            0x0022 => Some(Self::Segmentation),
164            0x0023 => Some(Self::ImagePairs),
165            0x0024 => Some(Self::Video),
166            0x0030 => Some(Self::AudioClassification),
167            0x0031 => Some(Self::SpeechRecognition),
168            0x0032 => Some(Self::SpeakerIdentification),
169            0x0040 => Some(Self::UserItemRatings),
170            0x0041 => Some(Self::ImplicitFeedback),
171            0x0042 => Some(Self::SequentialRecs),
172            0x0050 => Some(Self::ImageText),
173            0x0051 => Some(Self::AudioText),
174            0x0052 => Some(Self::VideoText),
175            0x00FF => Some(Self::Custom),
176            _ => None,
177        }
178    }
179
180    /// Convert to u16 value
181    #[must_use]
182    pub const fn as_u16(self) -> u16 {
183        self as u16
184    }
185}
186
187/// Compression algorithm identifiers (§3.3)
188#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
189#[repr(u8)]
190pub enum Compression {
191    /// No compression (debugging)
192    None = 0x00,
193    /// Zstd level 3 (standard distribution)
194    #[default]
195    ZstdL3 = 0x01,
196    /// Zstd level 19 (archival, max compression)
197    ZstdL19 = 0x02,
198    /// LZ4 (high-throughput streaming)
199    Lz4 = 0x03,
200}
201
202impl Compression {
203    /// Convert from u8 value
204    #[must_use]
205    pub fn from_u8(value: u8) -> Option<Self> {
206        match value {
207            0x00 => Some(Self::None),
208            0x01 => Some(Self::ZstdL3),
209            0x02 => Some(Self::ZstdL19),
210            0x03 => Some(Self::Lz4),
211            _ => None,
212        }
213    }
214
215    /// Convert to u8 value
216    #[must_use]
217    pub const fn as_u8(self) -> u8 {
218        self as u8
219    }
220}
221
222/// File header (32 bytes, fixed)
223///
224/// | Offset | Size | Field |
225/// |--------|------|-------|
226/// | 0 | 4 | magic |
227/// | 4 | 2 | format_version (major.minor) |
228/// | 6 | 2 | dataset_type |
229/// | 8 | 4 | metadata_size |
230/// | 12 | 4 | payload_size (compressed) |
231/// | 16 | 4 | uncompressed_size |
232/// | 20 | 1 | compression |
233/// | 21 | 1 | flags |
234/// | 22 | 2 | schema_size |
235/// | 24 | 8 | num_rows |
236#[derive(Debug, Clone, PartialEq, Eq)]
237pub struct Header {
238    /// Format version (major, minor)
239    pub version: (u8, u8),
240    /// Dataset type identifier
241    pub dataset_type: DatasetType,
242    /// Metadata block size in bytes
243    pub metadata_size: u32,
244    /// Compressed payload size in bytes
245    pub payload_size: u32,
246    /// Uncompressed payload size in bytes
247    pub uncompressed_size: u32,
248    /// Compression algorithm
249    pub compression: Compression,
250    /// Feature flags
251    pub flags: u8,
252    /// Schema block size in bytes
253    pub schema_size: u16,
254    /// Total row count
255    pub num_rows: u64,
256}
257
258impl Header {
259    /// Create a new header with default values
260    #[must_use]
261    pub fn new(dataset_type: DatasetType) -> Self {
262        // Contract: configuration-v1.yaml precondition (pv codegen)
263        contract_pre_configuration!();
264        Self {
265            version: (FORMAT_VERSION_MAJOR, FORMAT_VERSION_MINOR),
266            dataset_type,
267            metadata_size: 0,
268            payload_size: 0,
269            uncompressed_size: 0,
270            compression: Compression::default(),
271            flags: 0,
272            schema_size: 0,
273            num_rows: 0,
274        }
275    }
276
277    /// Serialize header to 32 bytes
278    #[must_use]
279    pub fn to_bytes(&self) -> [u8; HEADER_SIZE] {
280        // Contract: serialization-v1.yaml precondition (pv codegen)
281        contract_pre_configuration!();
282        let mut buf = [0u8; HEADER_SIZE];
283
284        // Magic (0-3)
285        buf[0..4].copy_from_slice(&MAGIC);
286
287        // Version (4-5)
288        buf[4] = self.version.0;
289        buf[5] = self.version.1;
290
291        // Dataset type (6-7)
292        let dt = self.dataset_type.as_u16().to_le_bytes();
293        buf[6..8].copy_from_slice(&dt);
294
295        // Metadata size (8-11)
296        buf[8..12].copy_from_slice(&self.metadata_size.to_le_bytes());
297
298        // Payload size (12-15)
299        buf[12..16].copy_from_slice(&self.payload_size.to_le_bytes());
300
301        // Uncompressed size (16-19)
302        buf[16..20].copy_from_slice(&self.uncompressed_size.to_le_bytes());
303
304        // Compression (20)
305        buf[20] = self.compression.as_u8();
306
307        // Flags (21)
308        buf[21] = self.flags;
309
310        // Schema size (22-23)
311        buf[22..24].copy_from_slice(&self.schema_size.to_le_bytes());
312
313        // Num rows (24-31)
314        buf[24..32].copy_from_slice(&self.num_rows.to_le_bytes());
315
316        buf
317    }
318
319    /// Deserialize header from bytes
320    ///
321    /// # Errors
322    ///
323    /// Returns error if magic is invalid, version is unsupported, or types are
324    /// unknown.
325    pub fn from_bytes(buf: &[u8]) -> Result<Self> {
326        if buf.len() < HEADER_SIZE {
327            return Err(Error::Format(format!(
328                "Header too short: {} bytes, expected {}",
329                buf.len(),
330                HEADER_SIZE
331            )));
332        }
333
334        // Validate magic
335        if buf[0..4] != MAGIC {
336            return Err(Error::Format(format!(
337                "Invalid magic: expected {:?}, got {:?}",
338                MAGIC,
339                &buf[0..4]
340            )));
341        }
342
343        // Contract: serialization-v1.yaml precondition (pv codegen)
344        contract_pre_configuration!(buf);
345
346        // Version
347        let version = (buf[4], buf[5]);
348        if version.0 > FORMAT_VERSION_MAJOR {
349            return Err(Error::Format(format!(
350                "Unsupported version: {}.{}, max supported: {}.{}",
351                version.0, version.1, FORMAT_VERSION_MAJOR, FORMAT_VERSION_MINOR
352            )));
353        }
354
355        // Dataset type
356        let dt_value = u16::from_le_bytes([buf[6], buf[7]]);
357        let dataset_type = DatasetType::from_u16(dt_value)
358            .ok_or_else(|| Error::Format(format!("Unknown dataset type: 0x{:04X}", dt_value)))?;
359
360        // Metadata size
361        let metadata_size = u32::from_le_bytes([buf[8], buf[9], buf[10], buf[11]]);
362
363        // Payload size
364        let payload_size = u32::from_le_bytes([buf[12], buf[13], buf[14], buf[15]]);
365
366        // Uncompressed size
367        let uncompressed_size = u32::from_le_bytes([buf[16], buf[17], buf[18], buf[19]]);
368
369        // Compression
370        let compression = Compression::from_u8(buf[20])
371            .ok_or_else(|| Error::Format(format!("Unknown compression: 0x{:02X}", buf[20])))?;
372
373        // Flags
374        let flags = buf[21];
375
376        // Schema size
377        let schema_size = u16::from_le_bytes([buf[22], buf[23]]);
378
379        // Num rows
380        let num_rows = u64::from_le_bytes([
381            buf[24], buf[25], buf[26], buf[27], buf[28], buf[29], buf[30], buf[31],
382        ]);
383
384        Ok(Self {
385            version,
386            dataset_type,
387            metadata_size,
388            payload_size,
389            uncompressed_size,
390            compression,
391            flags,
392            schema_size,
393            num_rows,
394        })
395    }
396
397    /// Check if encrypted flag is set
398    #[must_use]
399    pub const fn is_encrypted(&self) -> bool {
400        self.flags & flags::ENCRYPTED != 0
401    }
402
403    /// Check if signed flag is set
404    #[must_use]
405    pub const fn is_signed(&self) -> bool {
406        self.flags & flags::SIGNED != 0
407    }
408
409    /// Check if streaming flag is set
410    #[must_use]
411    pub const fn is_streaming(&self) -> bool {
412        self.flags & flags::STREAMING != 0
413    }
414
415    /// Check if licensed flag is set
416    #[must_use]
417    pub const fn is_licensed(&self) -> bool {
418        self.flags & flags::LICENSED != 0
419    }
420
421    /// Check if trueno-native flag is set
422    #[must_use]
423    pub const fn is_trueno_native(&self) -> bool {
424        self.flags & flags::TRUENO_NATIVE != 0
425    }
426}
427
428/// Dataset metadata (MessagePack-encoded)
429#[derive(Debug, Clone, Default, Serialize, Deserialize)]
430pub struct Metadata {
431    /// Human-readable name
432    pub name: Option<String>,
433    /// Version (semver)
434    pub version: Option<String>,
435    /// SPDX license identifier
436    pub license: Option<String>,
437    /// Searchable tags
438    #[serde(default)]
439    pub tags: Vec<String>,
440    /// Markdown description
441    pub description: Option<String>,
442    /// Citation (BibTeX)
443    pub citation: Option<String>,
444    /// Creation timestamp (RFC 3339)
445    pub created_at: Option<String>,
446    /// SHA-256 hash of the payload data (hex string, 64 chars)
447    #[serde(default, skip_serializing_if = "Option::is_none")]
448    pub sha256: Option<String>,
449}
450
451/// Computes SHA-256 hash of data and returns it as a hex string.
452///
453/// # Example
454///
455/// ```
456/// use alimentar::format::sha256_hex;
457///
458/// let hash = sha256_hex(b"Hello, World!");
459/// assert_eq!(hash.len(), 64); // 256 bits = 32 bytes = 64 hex chars
460/// ```
461#[cfg(feature = "provenance")]
462#[must_use]
463pub fn sha256_hex(data: &[u8]) -> String {
464    use sha2::{Digest, Sha256};
465
466    let mut hasher = Sha256::new();
467    hasher.update(data);
468    let result = hasher.finalize();
469
470    // Convert to hex string
471    result.iter().fold(String::with_capacity(64), |mut s, b| {
472        use std::fmt::Write;
473        let _ = write!(s, "{b:02x}");
474        s
475    })
476}
477
478/// Options for saving datasets
479#[derive(Debug, Clone)]
480pub struct SaveOptions {
481    /// Compression algorithm to use
482    pub compression: Compression,
483    /// Optional metadata to include
484    pub metadata: Option<Metadata>,
485    /// Encryption parameters (requires `format-encryption` feature)
486    #[cfg(feature = "format-encryption")]
487    pub encryption: Option<encryption::EncryptionParams>,
488    /// Signing key pair (requires `format-signing` feature)
489    #[cfg(feature = "format-signing")]
490    pub signing_key: Option<signing::SigningKeyPair>,
491    /// License block for commercial distribution
492    pub license: Option<license::LicenseBlock>,
493}
494
495impl Default for SaveOptions {
496    fn default() -> Self {
497        Self {
498            compression: Compression::ZstdL3,
499            metadata: None,
500            #[cfg(feature = "format-encryption")]
501            encryption: None,
502            #[cfg(feature = "format-signing")]
503            signing_key: None,
504            license: None,
505        }
506    }
507}
508
509impl SaveOptions {
510    /// Set compression algorithm
511    #[must_use]
512    pub fn with_compression(mut self, compression: Compression) -> Self {
513        self.compression = compression;
514        self
515    }
516
517    /// Set metadata
518    #[must_use]
519    pub fn with_metadata(mut self, metadata: Metadata) -> Self {
520        self.metadata = Some(metadata);
521        self
522    }
523
524    /// Set password-based encryption (requires `format-encryption` feature)
525    #[cfg(feature = "format-encryption")]
526    #[must_use]
527    pub fn with_password(mut self, password: impl Into<String>) -> Self {
528        self.encryption = Some(encryption::EncryptionParams::password(password));
529        self
530    }
531
532    /// Set recipient-based encryption (requires `format-encryption` feature)
533    #[cfg(feature = "format-encryption")]
534    #[must_use]
535    pub fn with_recipient(mut self, public_key: [u8; 32]) -> Self {
536        self.encryption = Some(encryption::EncryptionParams::recipient(public_key));
537        self
538    }
539
540    /// Set signing key (requires `format-signing` feature)
541    #[cfg(feature = "format-signing")]
542    #[must_use]
543    pub fn with_signing_key(mut self, key: signing::SigningKeyPair) -> Self {
544        self.signing_key = Some(key);
545        self
546    }
547
548    /// Set license block
549    #[must_use]
550    pub fn with_license(mut self, license: license::LicenseBlock) -> Self {
551        self.license = Some(license);
552        self
553    }
554}
555
556/// Options for loading datasets
557#[derive(Debug, Clone, Default)]
558pub struct LoadOptions {
559    /// Decryption parameters (required if dataset is encrypted)
560    #[cfg(feature = "format-encryption")]
561    pub decryption: Option<encryption::DecryptionParams>,
562    /// Trusted public keys for signature verification (if empty, skip
563    /// verification)
564    #[cfg(feature = "format-signing")]
565    pub trusted_keys: Vec<[u8; 32]>,
566    /// Whether to verify license expiration
567    pub verify_license: bool,
568}
569
570impl LoadOptions {
571    /// Set password for decryption
572    #[cfg(feature = "format-encryption")]
573    #[must_use]
574    pub fn with_password(mut self, password: impl Into<String>) -> Self {
575        self.decryption = Some(encryption::DecryptionParams::password(password));
576        self
577    }
578
579    /// Set private key for decryption
580    #[cfg(feature = "format-encryption")]
581    #[must_use]
582    pub fn with_private_key(mut self, key: [u8; 32]) -> Self {
583        self.decryption = Some(encryption::DecryptionParams::private_key(key));
584        self
585    }
586
587    /// Add trusted public key for signature verification
588    #[cfg(feature = "format-signing")]
589    #[must_use]
590    pub fn with_trusted_key(mut self, key: [u8; 32]) -> Self {
591        self.trusted_keys.push(key);
592        self
593    }
594
595    /// Enable license verification
596    #[must_use]
597    pub fn verify_license(mut self) -> Self {
598        self.verify_license = true;
599        self
600    }
601}
602
603/// Save an Arrow dataset to the .ald format
604///
605/// # Errors
606///
607/// Returns error if serialization or I/O fails.
608#[allow(clippy::cast_possible_truncation, clippy::too_many_lines)]
609pub fn save<W: std::io::Write>(
610    writer: &mut W,
611    batches: &[arrow::array::RecordBatch],
612    dataset_type: DatasetType,
613    options: &SaveOptions,
614) -> Result<()> {
615    use arrow::ipc::writer::StreamWriter;
616
617    if batches.is_empty() {
618        return Err(Error::EmptyDataset);
619    }
620
621    let schema = batches[0].schema();
622
623    // Serialize schema via Arrow IPC
624    let mut schema_buf = Vec::new();
625    {
626        let mut schema_writer =
627            StreamWriter::try_new(&mut schema_buf, &schema).map_err(Error::Arrow)?;
628        schema_writer.finish().map_err(Error::Arrow)?;
629    }
630
631    // Serialize payload via Arrow IPC
632    let mut payload_buf = Vec::new();
633    {
634        let mut payload_writer =
635            StreamWriter::try_new(&mut payload_buf, &schema).map_err(Error::Arrow)?;
636        for batch in batches {
637            payload_writer.write(batch).map_err(Error::Arrow)?;
638        }
639        payload_writer.finish().map_err(Error::Arrow)?;
640    }
641
642    let uncompressed_size = payload_buf.len() as u32;
643
644    // Compress payload if needed
645    let compressed_payload = compress_payload(payload_buf, options.compression)?;
646
647    // Build flags
648    let mut header_flags: u8 = 0;
649
650    // Encryption: build block, split into header and ciphertext payload
651    #[cfg(feature = "format-encryption")]
652    let (final_payload, encryption_header) = if let Some(ref enc_params) = options.encryption {
653        header_flags |= flags::ENCRYPTED;
654        let block = build_encryption_block(&compressed_payload, enc_params)?;
655        let hdr_size = encryption_block_header_size(block[0]);
656        (block[hdr_size..].to_vec(), block[..hdr_size].to_vec())
657    } else {
658        (compressed_payload, Vec::new())
659    };
660    #[cfg(not(feature = "format-encryption"))]
661    let (final_payload, encryption_header): (Vec<u8>, Vec<u8>) = (compressed_payload, Vec::new());
662
663    // Signing setup
664    #[cfg(feature = "format-signing")]
665    if options.signing_key.is_some() {
666        header_flags |= flags::SIGNED;
667    }
668
669    // License setup
670    if options.license.is_some() {
671        header_flags |= flags::LICENSED;
672    }
673
674    // Serialize metadata
675    let metadata_buf = if let Some(ref meta) = options.metadata {
676        rmp_serde::to_vec(meta).map_err(|e| Error::Format(e.to_string()))?
677    } else {
678        rmp_serde::to_vec(&Metadata::default()).map_err(|e| Error::Format(e.to_string()))?
679    };
680
681    // Count total rows
682    let num_rows: u64 = batches.iter().map(|b| b.num_rows() as u64).sum();
683
684    // Build header
685    let header = Header {
686        version: (FORMAT_VERSION_MAJOR, FORMAT_VERSION_MINOR),
687        dataset_type,
688        metadata_size: metadata_buf.len() as u32,
689        payload_size: final_payload.len() as u32,
690        uncompressed_size,
691        compression: options.compression,
692        flags: header_flags,
693        schema_size: schema_buf.len() as u16,
694        num_rows,
695    };
696
697    // Build all data for checksum and signature
698    let mut all_data = Vec::new();
699    let header_bytes = header.to_bytes();
700    all_data.extend_from_slice(&header_bytes);
701    all_data.extend_from_slice(&metadata_buf);
702    all_data.extend_from_slice(&schema_buf);
703    all_data.extend_from_slice(&encryption_header);
704    all_data.extend_from_slice(&final_payload);
705
706    // Add signature block if signing
707    #[cfg(feature = "format-signing")]
708    let signature_block: Option<[u8; signing::SignatureBlock::SIZE]> =
709        if let Some(ref key) = options.signing_key {
710            let sig_block = signing::SignatureBlock::sign(&all_data, key);
711            let sig_bytes = sig_block.to_bytes();
712            all_data.extend_from_slice(&sig_bytes);
713            Some(sig_bytes)
714        } else {
715            None
716        };
717    #[cfg(not(feature = "format-signing"))]
718    let signature_block: Option<[u8; 96]> = None;
719
720    // Add license block if present
721    let license_bytes: Option<Vec<u8>> = if let Some(ref lic) = options.license {
722        let lic_bytes = lic.to_bytes();
723        all_data.extend_from_slice(&lic_bytes);
724        Some(lic_bytes)
725    } else {
726        None
727    };
728
729    // Calculate checksum over all preceding data
730    let checksum = crc32(&all_data);
731
732    // Write everything
733    writer.write_all(&header_bytes).map_err(Error::io_no_path)?;
734    writer.write_all(&metadata_buf).map_err(Error::io_no_path)?;
735    writer.write_all(&schema_buf).map_err(Error::io_no_path)?;
736    writer
737        .write_all(&encryption_header)
738        .map_err(Error::io_no_path)?;
739    writer
740        .write_all(&final_payload)
741        .map_err(Error::io_no_path)?;
742
743    if let Some(ref sig) = signature_block {
744        writer.write_all(sig).map_err(Error::io_no_path)?;
745    }
746
747    if let Some(ref lic) = license_bytes {
748        writer.write_all(lic).map_err(Error::io_no_path)?;
749    }
750
751    writer
752        .write_all(&checksum.to_le_bytes())
753        .map_err(Error::io_no_path)?;
754
755    Ok(())
756}
757
758/// Compress a payload buffer using the specified compression method.
759fn compress_payload(payload: Vec<u8>, compression: Compression) -> Result<Vec<u8>> {
760    match compression {
761        Compression::None => Ok(payload),
762        Compression::ZstdL3 => zstd::encode_all(payload.as_slice(), 3).map_err(Error::io_no_path),
763        Compression::ZstdL19 => zstd::encode_all(payload.as_slice(), 19).map_err(Error::io_no_path),
764        Compression::Lz4 => {
765            let mut encoder = lz4_flex::frame::FrameEncoder::new(Vec::new());
766            std::io::Write::write_all(&mut encoder, &payload).map_err(Error::io_no_path)?;
767            encoder
768                .finish()
769                .map_err(|e| Error::Format(format!("LZ4 compression error: {e}")))
770        }
771    }
772}
773
774/// Return the header size for an encryption block based on its mode byte.
775#[cfg(feature = "format-encryption")]
776fn encryption_block_header_size(mode: u8) -> usize {
777    if mode == encryption::mode::PASSWORD {
778        1 + 16 + 12 // mode + salt + nonce
779    } else {
780        1 + 32 + 12 // mode + ephemeral_pub + nonce
781    }
782}
783
784/// Build encryption block: mode + key_material + nonce + ciphertext
785#[cfg(feature = "format-encryption")]
786fn build_encryption_block(
787    plaintext: &[u8],
788    params: &encryption::EncryptionParams,
789) -> Result<Vec<u8>> {
790    match &params.mode {
791        encryption::EncryptionMode::Password(password) => {
792            let (mode, salt, nonce, ciphertext) =
793                encryption::encrypt_password(plaintext, password)?;
794            let mut block = Vec::with_capacity(1 + 16 + 12 + ciphertext.len());
795            block.push(mode);
796            block.extend_from_slice(&salt);
797            block.extend_from_slice(&nonce);
798            block.extend_from_slice(&ciphertext);
799            Ok(block)
800        }
801        encryption::EncryptionMode::Recipient {
802            recipient_public_key,
803        } => {
804            let (mode, ephemeral_pub, nonce, ciphertext) =
805                encryption::encrypt_recipient(plaintext, recipient_public_key)?;
806            let mut block = Vec::with_capacity(1 + 32 + 12 + ciphertext.len());
807            block.push(mode);
808            block.extend_from_slice(&ephemeral_pub);
809            block.extend_from_slice(&nonce);
810            block.extend_from_slice(&ciphertext);
811            Ok(block)
812        }
813    }
814}
815
816/// Loaded dataset from .ald format
817#[derive(Debug)]
818pub struct LoadedDataset {
819    /// Parsed header
820    pub header: Header,
821    /// Dataset metadata
822    pub metadata: Metadata,
823    /// Arrow record batches
824    pub batches: Vec<arrow::array::RecordBatch>,
825    /// License block (if present)
826    pub license: Option<license::LicenseBlock>,
827    /// Signer public key (if signed and verified)
828    pub signer_public_key: Option<[u8; 32]>,
829}
830
831/// Load an Arrow dataset from the .ald format (unencrypted only)
832///
833/// For encrypted or signed datasets, use `load_with_options`.
834///
835/// # Errors
836///
837/// Returns error if dataset is encrypted, or if deserialization,
838/// decompression, or checksum validation fails.
839pub fn load<R: std::io::Read>(reader: &mut R) -> Result<LoadedDataset> {
840    load_with_options(reader, &LoadOptions::default())
841}
842
843/// Load an Arrow dataset with decryption and verification options
844///
845/// # Errors
846///
847/// Returns error if deserialization, decompression, decryption,
848/// signature verification, license validation, or checksum validation fails.
849#[allow(clippy::too_many_lines)]
850pub fn load_with_options<R: std::io::Read>(
851    reader: &mut R,
852    options: &LoadOptions,
853) -> Result<LoadedDataset> {
854    use arrow::ipc::reader::StreamReader;
855
856    // Read all data
857    let mut all_data = Vec::new();
858    reader
859        .read_to_end(&mut all_data)
860        .map_err(Error::io_no_path)?;
861
862    if all_data.len() < HEADER_SIZE + 4 {
863        return Err(Error::Format("File too small".to_string()));
864    }
865
866    // Split off checksum (last 4 bytes)
867    let checksum_offset = all_data.len() - 4;
868    let stored_checksum = u32::from_le_bytes([
869        all_data[checksum_offset],
870        all_data[checksum_offset + 1],
871        all_data[checksum_offset + 2],
872        all_data[checksum_offset + 3],
873    ]);
874
875    // Verify checksum
876    let computed_checksum = crc32(&all_data[..checksum_offset]);
877    if stored_checksum != computed_checksum {
878        return Err(Error::ChecksumMismatch {
879            expected: stored_checksum,
880            actual: computed_checksum,
881        });
882    }
883
884    // Parse header
885    let header = Header::from_bytes(&all_data[..HEADER_SIZE])?;
886
887    // Parse metadata
888    let metadata_start = HEADER_SIZE;
889    let metadata_end = metadata_start + header.metadata_size as usize;
890    let metadata: Metadata = rmp_serde::from_slice(&all_data[metadata_start..metadata_end])
891        .map_err(|e| Error::Format(format!("Metadata parse error: {e}")))?;
892
893    // Skip schema (embedded in payload IPC stream)
894    let schema_end = metadata_end + header.schema_size as usize;
895
896    // Determine encryption header size
897    let encryption_header_size = determine_encryption_header_size(&header, &all_data, schema_end)?;
898
899    let payload_start = schema_end + encryption_header_size;
900    let payload_end = payload_start + header.payload_size as usize;
901
902    if payload_end > checksum_offset {
903        return Err(Error::Format("Payload extends beyond data".to_string()));
904    }
905
906    // Extract and decrypt payload if encrypted
907    let compressed_payload: Vec<u8> = if header.is_encrypted() {
908        #[cfg(feature = "format-encryption")]
909        {
910            let enc_header = &all_data[schema_end..payload_start];
911            let ciphertext = &all_data[payload_start..payload_end];
912
913            let decryption_params = options.decryption.as_ref().ok_or_else(|| {
914                Error::Format("Dataset is encrypted but no decryption params provided".to_string())
915            })?;
916
917            decrypt_payload(enc_header, ciphertext, decryption_params)?
918        }
919        #[cfg(not(feature = "format-encryption"))]
920        {
921            return Err(Error::Format(
922                "Dataset is encrypted but format-encryption feature is not enabled".to_string(),
923            ));
924        }
925    } else {
926        all_data[payload_start..payload_end].to_vec()
927    };
928
929    // Parse trailing blocks (signature, license)
930    let (signer_public_key, license_block) =
931        parse_trailing_blocks(&header, &all_data, payload_end, checksum_offset, options)?;
932
933    // Decompress payload
934    let decompressed_payload = decompress_payload(compressed_payload, header.compression)?;
935
936    // Parse Arrow IPC stream
937    let cursor = std::io::Cursor::new(decompressed_payload);
938    let stream_reader = StreamReader::try_new(cursor, None).map_err(Error::Arrow)?;
939
940    let batches: Vec<_> = stream_reader
941        .into_iter()
942        .collect::<std::result::Result<Vec<_>, _>>()
943        .map_err(Error::Arrow)?;
944
945    Ok(LoadedDataset {
946        header,
947        metadata,
948        batches,
949        license: license_block,
950        signer_public_key,
951    })
952}
953
954/// Determine the encryption header size from the mode byte.
955fn determine_encryption_header_size(
956    header: &Header,
957    all_data: &[u8],
958    schema_end: usize,
959) -> Result<usize> {
960    if !header.is_encrypted() {
961        return Ok(0);
962    }
963
964    if all_data.len() <= schema_end {
965        return Err(Error::Format("Missing encryption header".to_string()));
966    }
967
968    #[cfg(feature = "format-encryption")]
969    {
970        Ok(encryption_block_header_size(all_data[schema_end]))
971    }
972    #[cfg(not(feature = "format-encryption"))]
973    {
974        Err(Error::Format(
975            "Dataset is encrypted but format-encryption feature is not enabled".to_string(),
976        ))
977    }
978}
979
980/// Parse trailing blocks (signature, license) after the payload.
981fn parse_trailing_blocks(
982    header: &Header,
983    all_data: &[u8],
984    payload_end: usize,
985    checksum_offset: usize,
986    options: &LoadOptions,
987) -> Result<(Option<[u8; 32]>, Option<license::LicenseBlock>)> {
988    #[allow(unused_mut)]
989    let mut trailing_offset = payload_end;
990    #[allow(unused_mut)]
991    let mut signer_public_key: Option<[u8; 32]> = None;
992    let mut license_block: Option<license::LicenseBlock> = None;
993
994    if header.is_signed() {
995        #[cfg(feature = "format-signing")]
996        {
997            let sig_end = trailing_offset + signing::SignatureBlock::SIZE;
998            if sig_end > checksum_offset {
999                return Err(Error::Format(
1000                    "Signature block extends beyond data".to_string(),
1001                ));
1002            }
1003
1004            let sig_block =
1005                signing::SignatureBlock::from_bytes(&all_data[trailing_offset..sig_end])?;
1006
1007            if !options.trusted_keys.is_empty() {
1008                let signed_data = &all_data[..trailing_offset];
1009                if !options.trusted_keys.contains(&sig_block.public_key) {
1010                    return Err(Error::Format("Signer not in trusted keys list".to_string()));
1011                }
1012                sig_block.verify(signed_data)?;
1013            }
1014
1015            signer_public_key = Some(sig_block.public_key);
1016            trailing_offset = sig_end;
1017        }
1018        #[cfg(not(feature = "format-signing"))]
1019        {
1020            return Err(Error::Format(
1021                "Dataset is signed but format-signing feature is not enabled".to_string(),
1022            ));
1023        }
1024    }
1025
1026    if header.is_licensed() {
1027        if trailing_offset >= checksum_offset {
1028            return Err(Error::Format("Missing license block".to_string()));
1029        }
1030        let lic = license::LicenseBlock::from_bytes(&all_data[trailing_offset..checksum_offset])?;
1031        if options.verify_license {
1032            lic.verify()?;
1033        }
1034        license_block = Some(lic);
1035    }
1036
1037    Ok((signer_public_key, license_block))
1038}
1039
1040/// Decompress a payload buffer using the specified compression method.
1041fn decompress_payload(payload: Vec<u8>, compression: Compression) -> Result<Vec<u8>> {
1042    match compression {
1043        Compression::None => Ok(payload),
1044        Compression::ZstdL3 | Compression::ZstdL19 => zstd::decode_all(payload.as_slice())
1045            .map_err(|e| Error::Format(format!("Zstd decompression error: {e}"))),
1046        Compression::Lz4 => {
1047            let mut decoder = lz4_flex::frame::FrameDecoder::new(payload.as_slice());
1048            let mut decompressed = Vec::new();
1049            std::io::Read::read_to_end(&mut decoder, &mut decompressed)
1050                .map_err(|e| Error::Format(format!("LZ4 decompression error: {e}")))?;
1051            Ok(decompressed)
1052        }
1053    }
1054}
1055
1056/// Decrypt payload using the provided parameters
1057#[cfg(feature = "format-encryption")]
1058fn decrypt_payload(
1059    enc_header: &[u8],
1060    ciphertext: &[u8],
1061    params: &encryption::DecryptionParams,
1062) -> Result<Vec<u8>> {
1063    if enc_header.is_empty() {
1064        return Err(Error::Format("Empty encryption header".to_string()));
1065    }
1066
1067    let mode = enc_header[0];
1068
1069    match (mode, params) {
1070        (encryption::mode::PASSWORD, encryption::DecryptionParams::Password(password)) => {
1071            if enc_header.len() < 1 + 16 + 12 {
1072                return Err(Error::Format(
1073                    "Invalid password encryption header".to_string(),
1074                ));
1075            }
1076            let mut salt = [0u8; 16];
1077            let mut nonce = [0u8; 12];
1078            salt.copy_from_slice(&enc_header[1..17]);
1079            nonce.copy_from_slice(&enc_header[17..29]);
1080
1081            encryption::decrypt_password(ciphertext, password, &salt, &nonce)
1082        }
1083        (encryption::mode::RECIPIENT, encryption::DecryptionParams::PrivateKey(private_key)) => {
1084            if enc_header.len() < 1 + 32 + 12 {
1085                return Err(Error::Format(
1086                    "Invalid recipient encryption header".to_string(),
1087                ));
1088            }
1089            let mut ephemeral_pub = [0u8; 32];
1090            let mut nonce = [0u8; 12];
1091            ephemeral_pub.copy_from_slice(&enc_header[1..33]);
1092            nonce.copy_from_slice(&enc_header[33..45]);
1093
1094            encryption::decrypt_recipient(ciphertext, private_key, &ephemeral_pub, &nonce)
1095        }
1096        (encryption::mode::PASSWORD, encryption::DecryptionParams::PrivateKey(_)) => Err(
1097            Error::Format("Dataset encrypted with password but private key provided".to_string()),
1098        ),
1099        (encryption::mode::RECIPIENT, encryption::DecryptionParams::Password(_)) => Err(
1100            Error::Format("Dataset encrypted for recipient but password provided".to_string()),
1101        ),
1102        _ => Err(Error::Format(format!("Unknown encryption mode: {mode}"))),
1103    }
1104}
1105
1106/// Save dataset to a file path
1107///
1108/// # Errors
1109///
1110/// Returns error if file creation or serialization fails.
1111pub fn save_to_file<P: AsRef<std::path::Path>>(
1112    path: P,
1113    batches: &[arrow::array::RecordBatch],
1114    dataset_type: DatasetType,
1115    options: &SaveOptions,
1116) -> Result<()> {
1117    let file = std::fs::File::create(path.as_ref())
1118        .map_err(|e| Error::io(e, path.as_ref().to_path_buf()))?;
1119    let mut writer = std::io::BufWriter::new(file);
1120    save(&mut writer, batches, dataset_type, options)
1121}
1122
1123/// Load dataset from a file path
1124///
1125/// # Errors
1126///
1127/// Returns error if file reading or deserialization fails.
1128pub fn load_from_file<P: AsRef<std::path::Path>>(path: P) -> Result<LoadedDataset> {
1129    load_from_file_with_options(path, &LoadOptions::default())
1130}
1131
1132/// Load dataset from a file path with decryption and verification options
1133///
1134/// # Errors
1135///
1136/// Returns error if file reading, decryption, verification, or deserialization
1137/// fails.
1138pub fn load_from_file_with_options<P: AsRef<std::path::Path>>(
1139    path: P,
1140    options: &LoadOptions,
1141) -> Result<LoadedDataset> {
1142    let file = std::fs::File::open(path.as_ref())
1143        .map_err(|e| Error::io(e, path.as_ref().to_path_buf()))?;
1144    let mut reader = std::io::BufReader::new(file);
1145    load_with_options(&mut reader, options)
1146}
1147
1148#[cfg(test)]
1149mod tests;