Skip to main content

aprender/format/
core_io.rs

1//! APR format core I/O operations (save, load, inspect)
2
3use super::{Compression, Header, Metadata, ModelInfo, ModelType, SaveOptions, HEADER_SIZE};
4use crate::error::{AprenderError, Result};
5use serde::{de::DeserializeOwned, Serialize};
6use std::fs::File;
7#[cfg(feature = "format-compression")]
8use std::io::Cursor;
9use std::io::{BufReader, BufWriter, Read, Write};
10use std::path::Path;
11
12#[cfg(feature = "format-encryption")]
13use super::{KEY_SIZE, NONCE_SIZE, SALT_SIZE};
14
15/// Compress payload based on algorithm (spec 3.3)
16#[allow(clippy::unnecessary_wraps)] // Returns Result to handle compression errors when feature enabled
17pub(crate) fn compress_payload(
18    data: &[u8],
19    compression: Compression,
20) -> Result<(Vec<u8>, Compression)> {
21    match compression {
22        Compression::None => Ok((data.to_vec(), Compression::None)),
23        #[cfg(feature = "format-compression")]
24        Compression::ZstdDefault => {
25            // Zstd level 3 (good balance of speed and ratio)
26            let compressed = zstd::encode_all(Cursor::new(data), 3).map_err(|e| {
27                AprenderError::Serialization(format!("Zstd compression failed: {e}"))
28            })?;
29            Ok((compressed, Compression::ZstdDefault))
30        }
31        #[cfg(feature = "format-compression")]
32        Compression::ZstdMax => {
33            // Zstd level 19 (maximum compression for archival)
34            let compressed = zstd::encode_all(Cursor::new(data), 19).map_err(|e| {
35                AprenderError::Serialization(format!("Zstd compression failed: {e}"))
36            })?;
37            Ok((compressed, Compression::ZstdMax))
38        }
39        #[cfg(not(feature = "format-compression"))]
40        Compression::ZstdDefault | Compression::ZstdMax => {
41            // Feature not enabled, fall back to no compression
42            Ok((data.to_vec(), Compression::None))
43        }
44        #[cfg(feature = "format-compression")]
45        Compression::Lz4 => {
46            // LZ4 compression using lz4_flex with prepended size (GH-146)
47            let compressed = lz4_flex::compress_prepend_size(data);
48            Ok((compressed, Compression::Lz4))
49        }
50        #[cfg(not(feature = "format-compression"))]
51        Compression::Lz4 => {
52            // Feature not enabled, fall back to no compression
53            Ok((data.to_vec(), Compression::None))
54        }
55    }
56}
57
58/// Decompress payload based on algorithm (spec 3.3)
59pub(crate) fn decompress_payload(data: &[u8], compression: Compression) -> Result<Vec<u8>> {
60    match compression {
61        Compression::None => Ok(data.to_vec()),
62        #[cfg(feature = "format-compression")]
63        Compression::ZstdDefault | Compression::ZstdMax => zstd::decode_all(Cursor::new(data))
64            .map_err(|e| AprenderError::Serialization(format!("Zstd decompression failed: {e}"))),
65        #[cfg(not(feature = "format-compression"))]
66        Compression::ZstdDefault | Compression::ZstdMax => Err(AprenderError::FormatError {
67            message: "Zstd compression not supported (enable format-compression feature)"
68                .to_string(),
69        }),
70        #[cfg(feature = "format-compression")]
71        Compression::Lz4 => lz4_flex::decompress_size_prepended(data)
72            .map_err(|e| AprenderError::Serialization(format!("LZ4 decompression failed: {e}"))),
73        #[cfg(not(feature = "format-compression"))]
74        Compression::Lz4 => Err(AprenderError::FormatError {
75            message: "LZ4 compression not supported (enable format-compression feature)"
76                .to_string(),
77        }),
78    }
79}
80
81/// CRC32 checksum (IEEE polynomial)
82pub(crate) fn crc32(data: &[u8]) -> u32 {
83    // CRC32 lookup table (IEEE polynomial 0xEDB88320)
84    const TABLE: [u32; 256] = {
85        let mut table = [0u32; 256];
86        let mut i = 0;
87        while i < 256 {
88            let mut crc = i as u32;
89            let mut j = 0;
90            while j < 8 {
91                if crc & 1 != 0 {
92                    crc = (crc >> 1) ^ 0xEDB8_8320;
93                } else {
94                    crc >>= 1;
95                }
96                j += 1;
97            }
98            table[i] = crc;
99            i += 1;
100        }
101        table
102    };
103
104    let mut crc = 0xFFFF_FFFF_u32;
105    for &byte in data {
106        let idx = ((crc ^ u32::from(byte)) & 0xFF) as usize;
107        crc = (crc >> 8) ^ TABLE[idx];
108    }
109    !crc
110}
111
112// ============================================================================
113// FILE LOADING HELPER FUNCTIONS (Refactored for reduced complexity)
114// ============================================================================
115
116/// Read entire file content into a buffer.
117#[cfg(any(feature = "format-signing", feature = "format-encryption"))]
118pub(crate) fn read_file_content(path: &Path) -> Result<Vec<u8>> {
119    let file = File::open(path)?;
120    let mut reader = BufReader::new(file);
121    let mut content = Vec::new();
122    reader.read_to_end(&mut content)?;
123    Ok(content)
124}
125
126/// Verify CRC32 checksum at end of file content.
127#[cfg(any(feature = "format-signing", feature = "format-encryption"))]
128pub(crate) fn verify_file_checksum(content: &[u8]) -> Result<()> {
129    if content.len() < 4 {
130        return Err(AprenderError::FormatError {
131            message: "File too small for checksum".to_string(),
132        });
133    }
134    let stored_checksum = u32::from_le_bytes([
135        content[content.len() - 4],
136        content[content.len() - 3],
137        content[content.len() - 2],
138        content[content.len() - 1],
139    ]);
140    let computed_checksum = crc32(&content[..content.len() - 4]);
141    if stored_checksum != computed_checksum {
142        return Err(AprenderError::ChecksumMismatch {
143            expected: stored_checksum,
144            actual: computed_checksum,
145        });
146    }
147    Ok(())
148}
149
150/// Parse header and validate model type.
151#[cfg(any(feature = "format-signing", feature = "format-encryption"))]
152pub(crate) fn parse_and_validate_header(
153    content: &[u8],
154    expected_type: ModelType,
155) -> Result<Header> {
156    let header = Header::from_bytes(&content[..HEADER_SIZE])?;
157    if header.model_type != expected_type {
158        return Err(AprenderError::FormatError {
159            message: format!(
160                "Model type mismatch: file contains {:?}, expected {:?}",
161                header.model_type, expected_type
162            ),
163        });
164    }
165    Ok(header)
166}
167
168/// Verify header flag is set for signed files.
169#[cfg(feature = "format-signing")]
170pub(crate) fn verify_signed_flag(header: &Header) -> Result<()> {
171    if !header.flags.is_signed() {
172        return Err(AprenderError::FormatError {
173            message: "File is not signed (SIGNED flag not set)".to_string(),
174        });
175    }
176    Ok(())
177}
178
179/// Verify header flag is set for encrypted files.
180#[cfg(feature = "format-encryption")]
181pub(crate) fn verify_encrypted_flag(header: &Header) -> Result<()> {
182    if !header.flags.is_encrypted() {
183        return Err(AprenderError::FormatError {
184            message: "File is not encrypted (ENCRYPTED flag not set)".to_string(),
185        });
186    }
187    Ok(())
188}
189
190/// Verify payload boundary is within file content.
191#[cfg(any(feature = "format-signing", feature = "format-encryption"))]
192pub(crate) fn verify_payload_boundary(payload_end: usize, content_len: usize) -> Result<()> {
193    if payload_end > content_len - 4 {
194        return Err(AprenderError::FormatError {
195            message: "Payload extends beyond file boundary".to_string(),
196        });
197    }
198    Ok(())
199}
200
201/// Decompress and deserialize payload.
202#[cfg(feature = "format-signing")]
203pub(crate) fn decompress_and_deserialize<M: DeserializeOwned>(
204    payload_compressed: &[u8],
205    compression: Compression,
206) -> Result<M> {
207    let payload_uncompressed = decompress_payload(payload_compressed, compression)?;
208    bincode::deserialize(&payload_uncompressed)
209        .map_err(|e| AprenderError::Serialization(format!("Failed to deserialize model: {e}")))
210}
211
212/// Save a model to .apr format
213///
214/// # Arguments
215/// * `model` - The model to save (must implement Serialize)
216/// * `model_type` - Model type identifier
217/// * `path` - Output file path
218/// * `options` - Save options (compression, metadata)
219///
220/// # Errors
221/// Returns error on I/O failure or serialization error
222#[allow(clippy::needless_pass_by_value)] // SaveOptions is small and passed by value for ergonomics
223pub fn save<M: Serialize>(
224    model: &M,
225    model_type: ModelType,
226    path: impl AsRef<Path>,
227    options: SaveOptions,
228) -> Result<()> {
229    let path = path.as_ref();
230
231    // APR-POKA-001: Jidoka gate - refuse to write if validation explicitly failed
232    // Score 0 means "validation rules exist but model failed all of them"
233    if options.quality_score == Some(0) {
234        return Err(AprenderError::ValidationError {
235            message: "Jidoka: Refusing to save model with quality_score=0. \
236                      Fix validation errors or use score=None to skip validation."
237                .to_string(),
238        });
239    }
240
241    // Serialize payload with bincode
242    let payload_uncompressed = bincode::serialize(model)
243        .map_err(|e| AprenderError::Serialization(format!("Failed to serialize model: {e}")))?;
244
245    // Compress payload
246    let (payload_compressed, compression) =
247        compress_payload(&payload_uncompressed, options.compression)?;
248
249    // Serialize metadata as MessagePack with named fields (spec 2)
250    // Must use to_vec_named() for map mode to preserve field names with skip_serializing_if
251    let metadata_bytes = rmp_serde::to_vec_named(&options.metadata)
252        .map_err(|e| AprenderError::Serialization(format!("Failed to serialize metadata: {e}")))?;
253
254    // Build header
255    let mut header = Header::new(model_type);
256    header.compression = compression;
257    header.metadata_size = metadata_bytes.len() as u32;
258    header.payload_size = payload_compressed.len() as u32;
259    header.uncompressed_size = payload_uncompressed.len() as u32;
260
261    // Set LICENSED flag if license info present (spec 9.1)
262    if options.metadata.license.is_some() {
263        header.flags = header.flags.with_licensed();
264    }
265
266    // APR-POKA-001: Set quality score in header (0 = no validation performed)
267    header.quality_score = options.quality_score.unwrap_or(0);
268
269    // Assemble file content (without checksum)
270    let mut content = Vec::new();
271    content.extend_from_slice(&header.to_bytes());
272    content.extend_from_slice(&metadata_bytes);
273    content.extend_from_slice(&payload_compressed);
274
275    // Calculate and append checksum
276    let checksum = crc32(&content);
277    content.extend_from_slice(&checksum.to_le_bytes());
278
279    // Write to file
280    let file = File::create(path)?;
281    let mut writer = BufWriter::new(file);
282    writer.write_all(&content)?;
283    writer.flush()?;
284
285    Ok(())
286}
287
288/// Load a model from .apr format
289///
290/// # Arguments
291/// * `path` - Input file path
292/// * `expected_type` - Expected model type (for type safety)
293///
294/// # Errors
295/// Returns error on I/O failure, format error, or type mismatch
296pub fn load<M: DeserializeOwned>(path: impl AsRef<Path>, expected_type: ModelType) -> Result<M> {
297    let path = path.as_ref();
298
299    // Read entire file
300    let file = File::open(path)?;
301    let mut reader = BufReader::new(file);
302    let mut content = Vec::new();
303    reader.read_to_end(&mut content)?;
304
305    // Verify minimum size
306    if content.len() < HEADER_SIZE + 4 {
307        return Err(AprenderError::FormatError {
308            message: format!("File too small: {} bytes", content.len()),
309        });
310    }
311
312    // Verify checksum (Jidoka: stop the line on corruption)
313    let stored_checksum = u32::from_le_bytes([
314        content[content.len() - 4],
315        content[content.len() - 3],
316        content[content.len() - 2],
317        content[content.len() - 1],
318    ]);
319    let computed_checksum = crc32(&content[..content.len() - 4]);
320    if stored_checksum != computed_checksum {
321        return Err(AprenderError::ChecksumMismatch {
322            expected: stored_checksum,
323            actual: computed_checksum,
324        });
325    }
326
327    // Parse header
328    let header = Header::from_bytes(&content[..HEADER_SIZE])?;
329
330    // Verify model type
331    if header.model_type != expected_type {
332        return Err(AprenderError::FormatError {
333            message: format!(
334                "Model type mismatch: file contains {:?}, expected {:?}",
335                header.model_type, expected_type
336            ),
337        });
338    }
339
340    // Extract payload
341    let metadata_end = HEADER_SIZE + header.metadata_size as usize;
342    let payload_end = metadata_end + header.payload_size as usize;
343
344    if payload_end > content.len() - 4 {
345        return Err(AprenderError::FormatError {
346            message: "Payload extends beyond file boundary".to_string(),
347        });
348    }
349
350    let payload_compressed = &content[metadata_end..payload_end];
351
352    // Decompress payload
353    let payload_uncompressed = decompress_payload(payload_compressed, header.compression)?;
354
355    // Deserialize model
356    bincode::deserialize(&payload_uncompressed)
357        .map_err(|e| AprenderError::Serialization(format!("Failed to deserialize model: {e}")))
358}
359
360/// Load a model from a byte slice (spec 1.1 - Single Binary Deployment)
361///
362/// Enables the `include_bytes!()` pattern for embedding models directly
363/// in executables. This is the key function for zero-dependency ML deployment.
364///
365/// # Arguments
366/// * `data` - Raw .apr file bytes (e.g., from `include_bytes!()`)
367/// * `expected_type` - Expected model type (for type safety)
368///
369/// # Example
370///
371/// ```rust,ignore
372/// use aprender::format::{load_from_bytes, ModelType};
373///
374/// // Embed model at compile time
375/// const MODEL: &[u8] = include_bytes!("sentiment.apr");
376///
377/// fn main() -> Result<()> {
378///     let model: LogisticRegression = load_from_bytes(MODEL, ModelType::LogisticRegression)?;
379///     let prediction = model.predict(&input)?;
380///     Ok(())
381/// }
382/// ```
383///
384/// # Errors
385/// Returns error on format error, type mismatch, or checksum failure
386pub fn load_from_bytes<M: DeserializeOwned>(data: &[u8], expected_type: ModelType) -> Result<M> {
387    // Verify minimum size
388    if data.len() < HEADER_SIZE + 4 {
389        return Err(AprenderError::FormatError {
390            message: format!("Data too small: {} bytes", data.len()),
391        });
392    }
393
394    // Verify checksum (Jidoka: stop the line on corruption)
395    let stored_checksum = u32::from_le_bytes([
396        data[data.len() - 4],
397        data[data.len() - 3],
398        data[data.len() - 2],
399        data[data.len() - 1],
400    ]);
401    let computed_checksum = crc32(&data[..data.len() - 4]);
402    if stored_checksum != computed_checksum {
403        return Err(AprenderError::ChecksumMismatch {
404            expected: stored_checksum,
405            actual: computed_checksum,
406        });
407    }
408
409    // Parse header
410    let header = Header::from_bytes(&data[..HEADER_SIZE])?;
411
412    // Verify model type
413    if header.model_type != expected_type {
414        return Err(AprenderError::FormatError {
415            message: format!(
416                "Model type mismatch: data contains {:?}, expected {:?}",
417                header.model_type, expected_type
418            ),
419        });
420    }
421
422    // Extract payload
423    let metadata_end = HEADER_SIZE + header.metadata_size as usize;
424    let payload_end = metadata_end + header.payload_size as usize;
425
426    if payload_end > data.len() - 4 {
427        return Err(AprenderError::FormatError {
428            message: "Payload extends beyond data boundary".to_string(),
429        });
430    }
431
432    let payload_compressed = &data[metadata_end..payload_end];
433
434    // Decompress payload
435    let payload_uncompressed = decompress_payload(payload_compressed, header.compression)?;
436
437    // Deserialize model
438    bincode::deserialize(&payload_uncompressed)
439        .map_err(|e| AprenderError::Serialization(format!("Failed to deserialize model: {e}")))
440}
441
442/// Threshold for switching to mmap loading (1MB)
443///
444/// Files larger than this will use memory-mapped I/O for better performance.
445/// Smaller files use standard read-to-heap which has lower overhead for small data.
446pub const MMAP_THRESHOLD: u64 = 1024 * 1024;
447
448/// Load a model using memory-mapped I/O (zero-copy where possible)
449///
450/// Toyota Way Principle: *Muda* (Waste Elimination) - Eliminates redundant
451/// data copies by mapping the file directly into the process address space.
452///
453/// # Performance
454///
455/// - Cold load: ~4x faster than standard `load()` for large models
456/// - Memory: Uses ~1x file size vs ~2x for standard load
457/// - Syscalls: Reduces `brk` calls from ~970 to ~50
458///
459/// # Safety
460///
461/// Uses OS-level memory mapping. The file must not be modified while loaded.
462/// See `bundle-mmap-spec.md` Section 4 for safety considerations.
463///
464/// # Example
465///
466/// ```rust,ignore
467/// use aprender::format::{load_mmap, ModelType};
468///
469/// // Load large model efficiently
470/// let model: RandomForest = load_mmap("large_model.apr", ModelType::RandomForest)?;
471/// ```
472///
473/// # Feature Flag
474///
475/// When `format-mmap` is enabled, uses real OS mmap via `memmap2`.
476/// Otherwise, falls back to standard file I/O (same API, heap-allocated).
477///
478/// # Errors
479///
480/// Returns error on file not found, format error, type mismatch, or checksum failure
481pub fn load_mmap<M: DeserializeOwned>(
482    path: impl AsRef<Path>,
483    expected_type: ModelType,
484) -> Result<M> {
485    use crate::bundle::MappedFile;
486
487    let mapped = MappedFile::open(path.as_ref())?;
488
489    load_from_bytes(mapped.as_slice(), expected_type)
490}
491
492/// Load a model with automatic strategy selection based on file size
493///
494/// Toyota Way Principle: *Heijunka* (Level Loading) - Chooses the optimal
495/// loading strategy based on file size to balance memory and performance.
496///
497/// # Strategy
498///
499/// - Files <= 1MB: Standard `load()` (lower overhead for small files)
500/// - Files > 1MB: Memory-mapped `load_mmap()` (better for large files)
501///
502/// # Example
503///
504/// ```rust,ignore
505/// use aprender::format::{load_auto, ModelType};
506///
507/// // Automatically chooses best loading strategy
508/// let model: KMeans = load_auto("model.apr", ModelType::KMeans)?;
509/// ```
510///
511/// # Errors
512///
513/// Returns error on file not found, format error, type mismatch, or checksum failure
514pub fn load_auto<M: DeserializeOwned>(
515    path: impl AsRef<Path>,
516    expected_type: ModelType,
517) -> Result<M> {
518    let metadata = std::fs::metadata(path.as_ref())?;
519
520    if metadata.len() > MMAP_THRESHOLD {
521        load_mmap(path, expected_type)
522    } else {
523        load(path, expected_type)
524    }
525}
526
527// ============================================================================
528// ENCRYPTION HELPER FUNCTIONS
529// ============================================================================
530
531/// Verify encrypted data has minimum required size.
532#[cfg(feature = "format-encryption")]
533pub(crate) fn verify_encrypted_data_size(data: &[u8]) -> Result<()> {
534    if data.len() < HEADER_SIZE + SALT_SIZE + NONCE_SIZE + 4 {
535        return Err(AprenderError::FormatError {
536            message: format!("Data too small for encrypted model: {} bytes", data.len()),
537        });
538    }
539    Ok(())
540}
541
542/// Verify encrypted data checksum.
543#[cfg(feature = "format-encryption")]
544pub(crate) fn verify_encrypted_checksum(data: &[u8]) -> Result<()> {
545    let stored_checksum = u32::from_le_bytes([
546        data[data.len() - 4],
547        data[data.len() - 3],
548        data[data.len() - 2],
549        data[data.len() - 1],
550    ]);
551    let computed_checksum = crc32(&data[..data.len() - 4]);
552    if stored_checksum != computed_checksum {
553        return Err(AprenderError::ChecksumMismatch {
554            expected: stored_checksum,
555            actual: computed_checksum,
556        });
557    }
558    Ok(())
559}
560
561/// Verify header has ENCRYPTED flag and correct model type.
562#[cfg(feature = "format-encryption")]
563pub(crate) fn verify_encrypted_header(header: &Header, expected_type: ModelType) -> Result<()> {
564    if !header.flags.is_encrypted() {
565        return Err(AprenderError::FormatError {
566            message: "Data is not encrypted (ENCRYPTED flag not set)".to_string(),
567        });
568    }
569    if header.model_type != expected_type {
570        return Err(AprenderError::FormatError {
571            message: format!(
572                "Model type mismatch: data contains {:?}, expected {:?}",
573                header.model_type, expected_type
574            ),
575        });
576    }
577    Ok(())
578}
579
580/// Extract salt, nonce, and ciphertext from encrypted data.
581#[cfg(feature = "format-encryption")]
582pub(crate) fn extract_encrypted_components<'a>(
583    data: &'a [u8],
584    header: &Header,
585) -> Result<([u8; SALT_SIZE], [u8; NONCE_SIZE], &'a [u8])> {
586    let metadata_end = HEADER_SIZE + header.metadata_size as usize;
587    let salt_end = metadata_end + SALT_SIZE;
588    let nonce_end = salt_end + NONCE_SIZE;
589    let payload_end = metadata_end + header.payload_size as usize;
590
591    if payload_end > data.len() - 4 {
592        return Err(AprenderError::FormatError {
593            message: "Encrypted payload extends beyond data boundary".to_string(),
594        });
595    }
596
597    let salt: [u8; SALT_SIZE] =
598        data[metadata_end..salt_end]
599            .try_into()
600            .map_err(|_| AprenderError::FormatError {
601                message: "Invalid salt size".to_string(),
602            })?;
603    let nonce: [u8; NONCE_SIZE] =
604        data[salt_end..nonce_end]
605            .try_into()
606            .map_err(|_| AprenderError::FormatError {
607                message: "Invalid nonce size".to_string(),
608            })?;
609    let ciphertext = &data[nonce_end..payload_end];
610
611    Ok((salt, nonce, ciphertext))
612}
613
614/// Decrypt payload using password and extracted components.
615#[cfg(feature = "format-encryption")]
616pub(crate) fn decrypt_encrypted_payload(
617    password: &str,
618    salt: &[u8; SALT_SIZE],
619    nonce_bytes: &[u8; NONCE_SIZE],
620    ciphertext: &[u8],
621) -> Result<Vec<u8>> {
622    use aes_gcm::{
623        aead::{Aead, KeyInit},
624        Aes256Gcm, Nonce,
625    };
626    use argon2::Argon2;
627
628    let mut key = [0u8; KEY_SIZE];
629    Argon2::default()
630        .hash_password_into(password.as_bytes(), salt, &mut key)
631        .map_err(|e| AprenderError::Other(format!("Key derivation failed: {e}")))?;
632
633    let cipher = Aes256Gcm::new_from_slice(&key)
634        .map_err(|e| AprenderError::Other(format!("Failed to create cipher: {e}")))?;
635    let nonce = Nonce::from_slice(nonce_bytes);
636
637    cipher
638        .decrypt(nonce, ciphertext)
639        .map_err(|_| AprenderError::DecryptionFailed {
640            message: "Decryption failed (wrong password or corrupted data)".to_string(),
641        })
642}
643
644/// Load an encrypted model from a byte slice (spec 1.1 + 4.1.2)
645///
646/// Enables the `include_bytes!()` pattern for embedding encrypted models.
647/// Combines single binary deployment with password-based encryption.
648///
649/// # Arguments
650/// * `data` - Raw encrypted .apr file bytes
651/// * `expected_type` - Expected model type
652/// * `password` - Password for decryption
653///
654/// # Example
655///
656/// ```rust,ignore
657/// use aprender::format::{load_from_bytes_encrypted, ModelType};
658///
659/// // Embed encrypted model at compile time
660/// const MODEL: &[u8] = include_bytes!("model.apr.enc");
661///
662/// fn main() -> Result<()> {
663///     let model: NaiveBayes = load_from_bytes_encrypted(
664///         MODEL,
665///         ModelType::NaiveBayes,
666///         &get_password_from_env(),
667///     )?;
668///     Ok(())
669/// }
670/// ```
671///
672/// # Errors
673/// Returns error on format error, type mismatch, or decryption failure
674#[cfg(feature = "format-encryption")]
675pub fn load_from_bytes_encrypted<M: DeserializeOwned>(
676    data: &[u8],
677    expected_type: ModelType,
678    password: &str,
679) -> Result<M> {
680    // Validate data integrity (Jidoka: stop the line on corruption)
681    verify_encrypted_data_size(data)?;
682    verify_encrypted_checksum(data)?;
683
684    // Parse and verify header
685    let header = Header::from_bytes(&data[..HEADER_SIZE])?;
686    verify_encrypted_header(&header, expected_type)?;
687
688    // Extract encryption components and decrypt
689    let (salt, nonce, ciphertext) = extract_encrypted_components(data, &header)?;
690    let payload_compressed = decrypt_encrypted_payload(password, &salt, &nonce, ciphertext)?;
691
692    // Decompress and deserialize
693    let payload_uncompressed = decompress_payload(&payload_compressed, header.compression)?;
694    bincode::deserialize(&payload_uncompressed)
695        .map_err(|e| AprenderError::Serialization(format!("Failed to deserialize model: {e}")))
696}
697
698/// Inspect model data without loading the payload (spec 1.1)
699///
700/// Useful for validating embedded models or checking metadata
701/// without deserializing the full model.
702///
703/// # Arguments
704/// * `data` - Raw .apr file bytes
705///
706/// # Errors
707/// Returns error on format error
708pub fn inspect_bytes(data: &[u8]) -> Result<ModelInfo> {
709    // Verify minimum size
710    if data.len() < HEADER_SIZE {
711        return Err(AprenderError::FormatError {
712            message: format!("Data too small: {} bytes", data.len()),
713        });
714    }
715
716    // Parse header
717    let header = Header::from_bytes(&data[..HEADER_SIZE])?;
718
719    // Extract metadata
720    let metadata_end = HEADER_SIZE + header.metadata_size as usize;
721    if metadata_end > data.len() {
722        return Err(AprenderError::FormatError {
723            message: "Metadata extends beyond data boundary".to_string(),
724        });
725    }
726
727    let metadata_bytes = &data[HEADER_SIZE..metadata_end];
728    let metadata: Metadata = rmp_serde::from_slice(metadata_bytes)
729        .map_err(|e| AprenderError::Serialization(format!("Failed to parse metadata: {e}")))?;
730
731    Ok(ModelInfo {
732        model_type: header.model_type,
733        format_version: header.version,
734        metadata,
735        payload_size: header.payload_size as usize,
736        uncompressed_size: header.uncompressed_size as usize,
737        encrypted: header.flags.is_encrypted(),
738        signed: header.flags.is_signed(),
739        streaming: header.flags.is_streaming(),
740        licensed: header.flags.is_licensed(),
741        trueno_native: header.flags.is_trueno_native(),
742        quantized: header.flags.is_quantized(),
743        has_model_card: header.flags.has_model_card(),
744    })
745}
746
747/// Inspect a model file without loading the payload
748///
749/// # Arguments
750/// * `path` - Input file path
751///
752/// # Errors
753/// Returns error on I/O failure or format error
754pub fn inspect(path: impl AsRef<Path>) -> Result<ModelInfo> {
755    let path = path.as_ref();
756
757    // Read header + metadata only
758    let file = File::open(path)?;
759    let mut reader = BufReader::new(file);
760
761    // Read header
762    let mut header_bytes = [0u8; HEADER_SIZE];
763    reader.read_exact(&mut header_bytes)?;
764    let header = Header::from_bytes(&header_bytes)?;
765
766    // Read metadata (MessagePack per spec 2)
767    let mut metadata_bytes = vec![0u8; header.metadata_size as usize];
768    reader.read_exact(&mut metadata_bytes)?;
769    let metadata: Metadata = rmp_serde::from_slice(&metadata_bytes)
770        .map_err(|e| AprenderError::Serialization(format!("Failed to parse metadata: {e}")))?;
771
772    Ok(ModelInfo {
773        model_type: header.model_type,
774        format_version: header.version,
775        metadata,
776        payload_size: header.payload_size as usize,
777        uncompressed_size: header.uncompressed_size as usize,
778        encrypted: header.flags.is_encrypted(),
779        signed: header.flags.is_signed(),
780        streaming: header.flags.is_streaming(),
781        licensed: header.flags.is_licensed(),
782        trueno_native: header.flags.is_trueno_native(),
783        quantized: header.flags.is_quantized(),
784        has_model_card: header.flags.has_model_card(),
785    })
786}
787
788#[cfg(test)]
789mod tests {
790    use super::*;
791    use serde::{Deserialize, Serialize};
792
793    // ============================================================================
794    // CRC32 Tests
795    // ============================================================================
796
797    #[test]
798    fn test_crc32_empty() {
799        // CRC32 of empty data (IEEE polynomial)
800        assert_eq!(crc32(&[]), 0x0000_0000);
801    }
802
803    #[test]
804    fn test_crc32_known_values() {
805        // "123456789" should give CRC32 = 0xCBF43926
806        let data = b"123456789";
807        assert_eq!(crc32(data), 0xCBF4_3926);
808    }
809
810    #[test]
811    fn test_crc32_single_byte() {
812        // Single byte values
813        assert_eq!(crc32(&[0x00]), 0xD202_EF8D);
814        assert_eq!(crc32(&[0xFF]), 0xFF00_0000);
815    }
816
817    #[test]
818    fn test_crc32_multiple_bytes() {
819        let data = b"Hello, World!";
820        let crc = crc32(data);
821        // Verify it's deterministic
822        assert_eq!(crc, crc32(data));
823        // Verify different data gives different CRC
824        assert_ne!(crc, crc32(b"Hello, World"));
825    }
826
827    // ============================================================================
828    // Compression Tests
829    // ============================================================================
830
831    #[test]
832    fn test_compress_payload_none() {
833        let data = b"test data for compression";
834        let (compressed, compression) =
835            compress_payload(data, Compression::None).expect("compress");
836        assert_eq!(compression, Compression::None);
837        assert_eq!(compressed, data);
838    }
839
840    #[test]
841    fn test_decompress_payload_none() {
842        let data = b"test data for decompression";
843        let decompressed = decompress_payload(data, Compression::None).expect("decompress");
844        assert_eq!(decompressed, data);
845    }
846
847    #[cfg(feature = "format-compression")]
848    #[test]
849    fn test_compress_decompress_zstd_default() {
850        let data = b"test data that should compress well with zstd compression";
851        let (compressed, compression) =
852            compress_payload(data, Compression::ZstdDefault).expect("compress");
853        assert_eq!(compression, Compression::ZstdDefault);
854        let decompressed = decompress_payload(&compressed, compression).expect("decompress");
855        assert_eq!(decompressed, data);
856    }
857
858    #[cfg(feature = "format-compression")]
859    #[test]
860    fn test_compress_decompress_lz4() {
861        let data = b"test data for lz4 compression";
862        let (compressed, compression) = compress_payload(data, Compression::Lz4).expect("compress");
863        assert_eq!(compression, Compression::Lz4);
864        let decompressed = decompress_payload(&compressed, compression).expect("decompress");
865        assert_eq!(decompressed, data);
866    }
867
868    // ============================================================================
869    // Save/Load Round-Trip Tests
870    // ============================================================================
871
872    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
873    struct TestModel {
874        name: String,
875        values: Vec<f32>,
876    }
877
878    #[test]
879    fn test_save_load_roundtrip() {
880        let model = TestModel {
881            name: "test_model".to_string(),
882            values: vec![1.0, 2.0, 3.0, 4.0],
883        };
884
885        let dir = tempfile::tempdir().expect("create temp dir");
886        let path = dir.path().join("test.apr");
887
888        let options = SaveOptions::default();
889        save(&model, ModelType::LinearRegression, &path, options).expect("save");
890
891        let loaded: TestModel = load(&path, ModelType::LinearRegression).expect("load");
892        assert_eq!(model, loaded);
893    }
894
895    #[test]
896    fn test_save_with_metadata() {
897        let model = TestModel {
898            name: "metadata_test".to_string(),
899            values: vec![1.0],
900        };
901
902        let dir = tempfile::tempdir().expect("create temp dir");
903        let path = dir.path().join("test_metadata.apr");
904
905        let mut metadata = Metadata::default();
906        metadata.description = Some("A test model".to_string());
907
908        let options = SaveOptions {
909            metadata,
910            compression: Compression::None,
911            quality_score: Some(85),
912        };
913        save(&model, ModelType::LinearRegression, &path, options).expect("save");
914
915        let info = inspect(&path).expect("inspect");
916        assert_eq!(info.metadata.description, Some("A test model".to_string()));
917    }
918
919    #[test]
920    fn test_save_rejects_quality_score_zero() {
921        let model = TestModel {
922            name: "bad_model".to_string(),
923            values: vec![],
924        };
925
926        let dir = tempfile::tempdir().expect("create temp dir");
927        let path = dir.path().join("should_not_exist.apr");
928
929        let options = SaveOptions {
930            quality_score: Some(0),
931            ..Default::default()
932        };
933
934        let result = save(&model, ModelType::LinearRegression, &path, options);
935        assert!(result.is_err());
936    }
937
938    #[test]
939    fn test_load_wrong_model_type() {
940        let model = TestModel {
941            name: "type_test".to_string(),
942            values: vec![1.0],
943        };
944
945        let dir = tempfile::tempdir().expect("create temp dir");
946        let path = dir.path().join("type_test.apr");
947
948        save(
949            &model,
950            ModelType::LinearRegression,
951            &path,
952            SaveOptions::default(),
953        )
954        .expect("save");
955
956        let result: Result<TestModel> = load(&path, ModelType::KMeans);
957        assert!(result.is_err());
958    }
959
960    #[test]
961    fn test_load_nonexistent_file() {
962        let result: Result<TestModel> =
963            load("/nonexistent/path/model.apr", ModelType::LinearRegression);
964        assert!(result.is_err());
965    }
966
967    #[test]
968    fn test_load_file_too_small() {
969        let dir = tempfile::tempdir().expect("create temp dir");
970        let path = dir.path().join("tiny.apr");
971
972        std::fs::write(&path, &[0u8; 10]).expect("write tiny file");
973
974        let result: Result<TestModel> = load(&path, ModelType::LinearRegression);
975        assert!(result.is_err());
976    }
977
978    // ============================================================================
979    // Inspect Tests
980    // ============================================================================
981
982    #[test]
983    fn test_inspect_model() {
984        let model = TestModel {
985            name: "inspect_test".to_string(),
986            values: vec![1.0, 2.0, 3.0],
987        };
988
989        let dir = tempfile::tempdir().expect("create temp dir");
990        let path = dir.path().join("inspect_test.apr");
991
992        let mut metadata = Metadata::default();
993        metadata.model_name = Some("Test Model".to_string());
994
995        let options = SaveOptions {
996            metadata,
997            compression: Compression::None,
998            quality_score: Some(90),
999        };
1000        save(&model, ModelType::LinearRegression, &path, options).expect("save");
1001
1002        let info = inspect(&path).expect("inspect");
1003        assert_eq!(info.model_type, ModelType::LinearRegression);
1004        assert_eq!(info.metadata.model_name, Some("Test Model".to_string()));
1005    }
1006
1007    #[test]
1008    fn test_inspect_with_license_flag() {
1009        use super::super::{LicenseInfo, LicenseTier};
1010
1011        let model = TestModel {
1012            name: "licensed".to_string(),
1013            values: vec![1.0],
1014        };
1015
1016        let dir = tempfile::tempdir().expect("create temp dir");
1017        let path = dir.path().join("licensed.apr");
1018
1019        let mut metadata = Metadata::default();
1020        metadata.license = Some(LicenseInfo {
1021            uuid: "test-uuid".to_string(),
1022            hash: "test-hash".to_string(),
1023            expiry: None,
1024            seats: None,
1025            licensee: Some("Test User".to_string()),
1026            tier: LicenseTier::Enterprise,
1027        });
1028
1029        let options = SaveOptions {
1030            metadata,
1031            compression: Compression::None,
1032            quality_score: None,
1033        };
1034        save(&model, ModelType::LinearRegression, &path, options).expect("save");
1035
1036        let info = inspect(&path).expect("inspect");
1037        assert!(info.licensed);
1038    }
1039
1040    #[test]
1041    fn test_inspect_bytes_valid() {
1042        let model = TestModel {
1043            name: "bytes_test".to_string(),
1044            values: vec![1.0],
1045        };
1046
1047        let dir = tempfile::tempdir().expect("create temp dir");
1048        let path = dir.path().join("bytes_test.apr");
1049
1050        save(
1051            &model,
1052            ModelType::LinearRegression,
1053            &path,
1054            SaveOptions::default(),
1055        )
1056        .expect("save");
1057
1058        let data = std::fs::read(&path).expect("read file");
1059        let info = inspect_bytes(&data).expect("inspect bytes");
1060        assert_eq!(info.model_type, ModelType::LinearRegression);
1061    }
1062
1063    #[test]
1064    fn test_inspect_bytes_too_small() {
1065        let data = vec![0u8; 10];
1066        let result = inspect_bytes(&data);
1067        assert!(result.is_err());
1068    }
1069}