1use super::{Compression, Header, Metadata, ModelInfo, ModelType, SaveOptions, HEADER_SIZE};
4use crate::error::{AprenderError, Result};
5use serde::{de::DeserializeOwned, Serialize};
6use std::fs::File;
7#[cfg(feature = "format-compression")]
8use std::io::Cursor;
9use std::io::{BufReader, BufWriter, Read, Write};
10use std::path::Path;
11
12#[cfg(feature = "format-encryption")]
13use super::{KEY_SIZE, NONCE_SIZE, SALT_SIZE};
14
15#[allow(clippy::unnecessary_wraps)] pub(crate) fn compress_payload(
18 data: &[u8],
19 compression: Compression,
20) -> Result<(Vec<u8>, Compression)> {
21 match compression {
22 Compression::None => Ok((data.to_vec(), Compression::None)),
23 #[cfg(feature = "format-compression")]
24 Compression::ZstdDefault => {
25 let compressed = zstd::encode_all(Cursor::new(data), 3).map_err(|e| {
27 AprenderError::Serialization(format!("Zstd compression failed: {e}"))
28 })?;
29 Ok((compressed, Compression::ZstdDefault))
30 }
31 #[cfg(feature = "format-compression")]
32 Compression::ZstdMax => {
33 let compressed = zstd::encode_all(Cursor::new(data), 19).map_err(|e| {
35 AprenderError::Serialization(format!("Zstd compression failed: {e}"))
36 })?;
37 Ok((compressed, Compression::ZstdMax))
38 }
39 #[cfg(not(feature = "format-compression"))]
40 Compression::ZstdDefault | Compression::ZstdMax => {
41 Ok((data.to_vec(), Compression::None))
43 }
44 #[cfg(feature = "format-compression")]
45 Compression::Lz4 => {
46 let compressed = lz4_flex::compress_prepend_size(data);
48 Ok((compressed, Compression::Lz4))
49 }
50 #[cfg(not(feature = "format-compression"))]
51 Compression::Lz4 => {
52 Ok((data.to_vec(), Compression::None))
54 }
55 }
56}
57
58pub(crate) fn decompress_payload(data: &[u8], compression: Compression) -> Result<Vec<u8>> {
60 match compression {
61 Compression::None => Ok(data.to_vec()),
62 #[cfg(feature = "format-compression")]
63 Compression::ZstdDefault | Compression::ZstdMax => zstd::decode_all(Cursor::new(data))
64 .map_err(|e| AprenderError::Serialization(format!("Zstd decompression failed: {e}"))),
65 #[cfg(not(feature = "format-compression"))]
66 Compression::ZstdDefault | Compression::ZstdMax => Err(AprenderError::FormatError {
67 message: "Zstd compression not supported (enable format-compression feature)"
68 .to_string(),
69 }),
70 #[cfg(feature = "format-compression")]
71 Compression::Lz4 => lz4_flex::decompress_size_prepended(data)
72 .map_err(|e| AprenderError::Serialization(format!("LZ4 decompression failed: {e}"))),
73 #[cfg(not(feature = "format-compression"))]
74 Compression::Lz4 => Err(AprenderError::FormatError {
75 message: "LZ4 compression not supported (enable format-compression feature)"
76 .to_string(),
77 }),
78 }
79}
80
81pub(crate) fn crc32(data: &[u8]) -> u32 {
83 const TABLE: [u32; 256] = {
85 let mut table = [0u32; 256];
86 let mut i = 0;
87 while i < 256 {
88 let mut crc = i as u32;
89 let mut j = 0;
90 while j < 8 {
91 if crc & 1 != 0 {
92 crc = (crc >> 1) ^ 0xEDB8_8320;
93 } else {
94 crc >>= 1;
95 }
96 j += 1;
97 }
98 table[i] = crc;
99 i += 1;
100 }
101 table
102 };
103
104 let mut crc = 0xFFFF_FFFF_u32;
105 for &byte in data {
106 let idx = ((crc ^ u32::from(byte)) & 0xFF) as usize;
107 crc = (crc >> 8) ^ TABLE[idx];
108 }
109 !crc
110}
111
112#[cfg(any(feature = "format-signing", feature = "format-encryption"))]
118pub(crate) fn read_file_content(path: &Path) -> Result<Vec<u8>> {
119 let file = File::open(path)?;
120 let mut reader = BufReader::new(file);
121 let mut content = Vec::new();
122 reader.read_to_end(&mut content)?;
123 Ok(content)
124}
125
126#[cfg(any(feature = "format-signing", feature = "format-encryption"))]
128pub(crate) fn verify_file_checksum(content: &[u8]) -> Result<()> {
129 if content.len() < 4 {
130 return Err(AprenderError::FormatError {
131 message: "File too small for checksum".to_string(),
132 });
133 }
134 let stored_checksum = u32::from_le_bytes([
135 content[content.len() - 4],
136 content[content.len() - 3],
137 content[content.len() - 2],
138 content[content.len() - 1],
139 ]);
140 let computed_checksum = crc32(&content[..content.len() - 4]);
141 if stored_checksum != computed_checksum {
142 return Err(AprenderError::ChecksumMismatch {
143 expected: stored_checksum,
144 actual: computed_checksum,
145 });
146 }
147 Ok(())
148}
149
150#[cfg(any(feature = "format-signing", feature = "format-encryption"))]
152pub(crate) fn parse_and_validate_header(
153 content: &[u8],
154 expected_type: ModelType,
155) -> Result<Header> {
156 let header = Header::from_bytes(&content[..HEADER_SIZE])?;
157 if header.model_type != expected_type {
158 return Err(AprenderError::FormatError {
159 message: format!(
160 "Model type mismatch: file contains {:?}, expected {:?}",
161 header.model_type, expected_type
162 ),
163 });
164 }
165 Ok(header)
166}
167
168#[cfg(feature = "format-signing")]
170pub(crate) fn verify_signed_flag(header: &Header) -> Result<()> {
171 if !header.flags.is_signed() {
172 return Err(AprenderError::FormatError {
173 message: "File is not signed (SIGNED flag not set)".to_string(),
174 });
175 }
176 Ok(())
177}
178
179#[cfg(feature = "format-encryption")]
181pub(crate) fn verify_encrypted_flag(header: &Header) -> Result<()> {
182 if !header.flags.is_encrypted() {
183 return Err(AprenderError::FormatError {
184 message: "File is not encrypted (ENCRYPTED flag not set)".to_string(),
185 });
186 }
187 Ok(())
188}
189
190#[cfg(any(feature = "format-signing", feature = "format-encryption"))]
192pub(crate) fn verify_payload_boundary(payload_end: usize, content_len: usize) -> Result<()> {
193 if payload_end > content_len - 4 {
194 return Err(AprenderError::FormatError {
195 message: "Payload extends beyond file boundary".to_string(),
196 });
197 }
198 Ok(())
199}
200
201#[cfg(feature = "format-signing")]
203pub(crate) fn decompress_and_deserialize<M: DeserializeOwned>(
204 payload_compressed: &[u8],
205 compression: Compression,
206) -> Result<M> {
207 let payload_uncompressed = decompress_payload(payload_compressed, compression)?;
208 bincode::deserialize(&payload_uncompressed)
209 .map_err(|e| AprenderError::Serialization(format!("Failed to deserialize model: {e}")))
210}
211
212#[allow(clippy::needless_pass_by_value)] pub fn save<M: Serialize>(
224 model: &M,
225 model_type: ModelType,
226 path: impl AsRef<Path>,
227 options: SaveOptions,
228) -> Result<()> {
229 let path = path.as_ref();
230
231 if options.quality_score == Some(0) {
234 return Err(AprenderError::ValidationError {
235 message: "Jidoka: Refusing to save model with quality_score=0. \
236 Fix validation errors or use score=None to skip validation."
237 .to_string(),
238 });
239 }
240
241 let payload_uncompressed = bincode::serialize(model)
243 .map_err(|e| AprenderError::Serialization(format!("Failed to serialize model: {e}")))?;
244
245 let (payload_compressed, compression) =
247 compress_payload(&payload_uncompressed, options.compression)?;
248
249 let metadata_bytes = rmp_serde::to_vec_named(&options.metadata)
252 .map_err(|e| AprenderError::Serialization(format!("Failed to serialize metadata: {e}")))?;
253
254 let mut header = Header::new(model_type);
256 header.compression = compression;
257 header.metadata_size = metadata_bytes.len() as u32;
258 header.payload_size = payload_compressed.len() as u32;
259 header.uncompressed_size = payload_uncompressed.len() as u32;
260
261 if options.metadata.license.is_some() {
263 header.flags = header.flags.with_licensed();
264 }
265
266 header.quality_score = options.quality_score.unwrap_or(0);
268
269 let mut content = Vec::new();
271 content.extend_from_slice(&header.to_bytes());
272 content.extend_from_slice(&metadata_bytes);
273 content.extend_from_slice(&payload_compressed);
274
275 let checksum = crc32(&content);
277 content.extend_from_slice(&checksum.to_le_bytes());
278
279 let file = File::create(path)?;
281 let mut writer = BufWriter::new(file);
282 writer.write_all(&content)?;
283 writer.flush()?;
284
285 Ok(())
286}
287
288pub fn load<M: DeserializeOwned>(path: impl AsRef<Path>, expected_type: ModelType) -> Result<M> {
297 let path = path.as_ref();
298
299 let file = File::open(path)?;
301 let mut reader = BufReader::new(file);
302 let mut content = Vec::new();
303 reader.read_to_end(&mut content)?;
304
305 if content.len() < HEADER_SIZE + 4 {
307 return Err(AprenderError::FormatError {
308 message: format!("File too small: {} bytes", content.len()),
309 });
310 }
311
312 let stored_checksum = u32::from_le_bytes([
314 content[content.len() - 4],
315 content[content.len() - 3],
316 content[content.len() - 2],
317 content[content.len() - 1],
318 ]);
319 let computed_checksum = crc32(&content[..content.len() - 4]);
320 if stored_checksum != computed_checksum {
321 return Err(AprenderError::ChecksumMismatch {
322 expected: stored_checksum,
323 actual: computed_checksum,
324 });
325 }
326
327 let header = Header::from_bytes(&content[..HEADER_SIZE])?;
329
330 if header.model_type != expected_type {
332 return Err(AprenderError::FormatError {
333 message: format!(
334 "Model type mismatch: file contains {:?}, expected {:?}",
335 header.model_type, expected_type
336 ),
337 });
338 }
339
340 let metadata_end = HEADER_SIZE + header.metadata_size as usize;
342 let payload_end = metadata_end + header.payload_size as usize;
343
344 if payload_end > content.len() - 4 {
345 return Err(AprenderError::FormatError {
346 message: "Payload extends beyond file boundary".to_string(),
347 });
348 }
349
350 let payload_compressed = &content[metadata_end..payload_end];
351
352 let payload_uncompressed = decompress_payload(payload_compressed, header.compression)?;
354
355 bincode::deserialize(&payload_uncompressed)
357 .map_err(|e| AprenderError::Serialization(format!("Failed to deserialize model: {e}")))
358}
359
360pub fn load_from_bytes<M: DeserializeOwned>(data: &[u8], expected_type: ModelType) -> Result<M> {
387 if data.len() < HEADER_SIZE + 4 {
389 return Err(AprenderError::FormatError {
390 message: format!("Data too small: {} bytes", data.len()),
391 });
392 }
393
394 let stored_checksum = u32::from_le_bytes([
396 data[data.len() - 4],
397 data[data.len() - 3],
398 data[data.len() - 2],
399 data[data.len() - 1],
400 ]);
401 let computed_checksum = crc32(&data[..data.len() - 4]);
402 if stored_checksum != computed_checksum {
403 return Err(AprenderError::ChecksumMismatch {
404 expected: stored_checksum,
405 actual: computed_checksum,
406 });
407 }
408
409 let header = Header::from_bytes(&data[..HEADER_SIZE])?;
411
412 if header.model_type != expected_type {
414 return Err(AprenderError::FormatError {
415 message: format!(
416 "Model type mismatch: data contains {:?}, expected {:?}",
417 header.model_type, expected_type
418 ),
419 });
420 }
421
422 let metadata_end = HEADER_SIZE + header.metadata_size as usize;
424 let payload_end = metadata_end + header.payload_size as usize;
425
426 if payload_end > data.len() - 4 {
427 return Err(AprenderError::FormatError {
428 message: "Payload extends beyond data boundary".to_string(),
429 });
430 }
431
432 let payload_compressed = &data[metadata_end..payload_end];
433
434 let payload_uncompressed = decompress_payload(payload_compressed, header.compression)?;
436
437 bincode::deserialize(&payload_uncompressed)
439 .map_err(|e| AprenderError::Serialization(format!("Failed to deserialize model: {e}")))
440}
441
442pub const MMAP_THRESHOLD: u64 = 1024 * 1024;
447
448pub fn load_mmap<M: DeserializeOwned>(
482 path: impl AsRef<Path>,
483 expected_type: ModelType,
484) -> Result<M> {
485 use crate::bundle::MappedFile;
486
487 let mapped = MappedFile::open(path.as_ref())?;
488
489 load_from_bytes(mapped.as_slice(), expected_type)
490}
491
492pub fn load_auto<M: DeserializeOwned>(
515 path: impl AsRef<Path>,
516 expected_type: ModelType,
517) -> Result<M> {
518 let metadata = std::fs::metadata(path.as_ref())?;
519
520 if metadata.len() > MMAP_THRESHOLD {
521 load_mmap(path, expected_type)
522 } else {
523 load(path, expected_type)
524 }
525}
526
527#[cfg(feature = "format-encryption")]
533pub(crate) fn verify_encrypted_data_size(data: &[u8]) -> Result<()> {
534 if data.len() < HEADER_SIZE + SALT_SIZE + NONCE_SIZE + 4 {
535 return Err(AprenderError::FormatError {
536 message: format!("Data too small for encrypted model: {} bytes", data.len()),
537 });
538 }
539 Ok(())
540}
541
542#[cfg(feature = "format-encryption")]
544pub(crate) fn verify_encrypted_checksum(data: &[u8]) -> Result<()> {
545 let stored_checksum = u32::from_le_bytes([
546 data[data.len() - 4],
547 data[data.len() - 3],
548 data[data.len() - 2],
549 data[data.len() - 1],
550 ]);
551 let computed_checksum = crc32(&data[..data.len() - 4]);
552 if stored_checksum != computed_checksum {
553 return Err(AprenderError::ChecksumMismatch {
554 expected: stored_checksum,
555 actual: computed_checksum,
556 });
557 }
558 Ok(())
559}
560
561#[cfg(feature = "format-encryption")]
563pub(crate) fn verify_encrypted_header(header: &Header, expected_type: ModelType) -> Result<()> {
564 if !header.flags.is_encrypted() {
565 return Err(AprenderError::FormatError {
566 message: "Data is not encrypted (ENCRYPTED flag not set)".to_string(),
567 });
568 }
569 if header.model_type != expected_type {
570 return Err(AprenderError::FormatError {
571 message: format!(
572 "Model type mismatch: data contains {:?}, expected {:?}",
573 header.model_type, expected_type
574 ),
575 });
576 }
577 Ok(())
578}
579
580#[cfg(feature = "format-encryption")]
582pub(crate) fn extract_encrypted_components<'a>(
583 data: &'a [u8],
584 header: &Header,
585) -> Result<([u8; SALT_SIZE], [u8; NONCE_SIZE], &'a [u8])> {
586 let metadata_end = HEADER_SIZE + header.metadata_size as usize;
587 let salt_end = metadata_end + SALT_SIZE;
588 let nonce_end = salt_end + NONCE_SIZE;
589 let payload_end = metadata_end + header.payload_size as usize;
590
591 if payload_end > data.len() - 4 {
592 return Err(AprenderError::FormatError {
593 message: "Encrypted payload extends beyond data boundary".to_string(),
594 });
595 }
596
597 let salt: [u8; SALT_SIZE] =
598 data[metadata_end..salt_end]
599 .try_into()
600 .map_err(|_| AprenderError::FormatError {
601 message: "Invalid salt size".to_string(),
602 })?;
603 let nonce: [u8; NONCE_SIZE] =
604 data[salt_end..nonce_end]
605 .try_into()
606 .map_err(|_| AprenderError::FormatError {
607 message: "Invalid nonce size".to_string(),
608 })?;
609 let ciphertext = &data[nonce_end..payload_end];
610
611 Ok((salt, nonce, ciphertext))
612}
613
614#[cfg(feature = "format-encryption")]
616pub(crate) fn decrypt_encrypted_payload(
617 password: &str,
618 salt: &[u8; SALT_SIZE],
619 nonce_bytes: &[u8; NONCE_SIZE],
620 ciphertext: &[u8],
621) -> Result<Vec<u8>> {
622 use aes_gcm::{
623 aead::{Aead, KeyInit},
624 Aes256Gcm, Nonce,
625 };
626 use argon2::Argon2;
627
628 let mut key = [0u8; KEY_SIZE];
629 Argon2::default()
630 .hash_password_into(password.as_bytes(), salt, &mut key)
631 .map_err(|e| AprenderError::Other(format!("Key derivation failed: {e}")))?;
632
633 let cipher = Aes256Gcm::new_from_slice(&key)
634 .map_err(|e| AprenderError::Other(format!("Failed to create cipher: {e}")))?;
635 let nonce = Nonce::from_slice(nonce_bytes);
636
637 cipher
638 .decrypt(nonce, ciphertext)
639 .map_err(|_| AprenderError::DecryptionFailed {
640 message: "Decryption failed (wrong password or corrupted data)".to_string(),
641 })
642}
643
644#[cfg(feature = "format-encryption")]
675pub fn load_from_bytes_encrypted<M: DeserializeOwned>(
676 data: &[u8],
677 expected_type: ModelType,
678 password: &str,
679) -> Result<M> {
680 verify_encrypted_data_size(data)?;
682 verify_encrypted_checksum(data)?;
683
684 let header = Header::from_bytes(&data[..HEADER_SIZE])?;
686 verify_encrypted_header(&header, expected_type)?;
687
688 let (salt, nonce, ciphertext) = extract_encrypted_components(data, &header)?;
690 let payload_compressed = decrypt_encrypted_payload(password, &salt, &nonce, ciphertext)?;
691
692 let payload_uncompressed = decompress_payload(&payload_compressed, header.compression)?;
694 bincode::deserialize(&payload_uncompressed)
695 .map_err(|e| AprenderError::Serialization(format!("Failed to deserialize model: {e}")))
696}
697
698pub fn inspect_bytes(data: &[u8]) -> Result<ModelInfo> {
709 if data.len() < HEADER_SIZE {
711 return Err(AprenderError::FormatError {
712 message: format!("Data too small: {} bytes", data.len()),
713 });
714 }
715
716 let header = Header::from_bytes(&data[..HEADER_SIZE])?;
718
719 let metadata_end = HEADER_SIZE + header.metadata_size as usize;
721 if metadata_end > data.len() {
722 return Err(AprenderError::FormatError {
723 message: "Metadata extends beyond data boundary".to_string(),
724 });
725 }
726
727 let metadata_bytes = &data[HEADER_SIZE..metadata_end];
728 let metadata: Metadata = rmp_serde::from_slice(metadata_bytes)
729 .map_err(|e| AprenderError::Serialization(format!("Failed to parse metadata: {e}")))?;
730
731 Ok(ModelInfo {
732 model_type: header.model_type,
733 format_version: header.version,
734 metadata,
735 payload_size: header.payload_size as usize,
736 uncompressed_size: header.uncompressed_size as usize,
737 encrypted: header.flags.is_encrypted(),
738 signed: header.flags.is_signed(),
739 streaming: header.flags.is_streaming(),
740 licensed: header.flags.is_licensed(),
741 trueno_native: header.flags.is_trueno_native(),
742 quantized: header.flags.is_quantized(),
743 has_model_card: header.flags.has_model_card(),
744 })
745}
746
747pub fn inspect(path: impl AsRef<Path>) -> Result<ModelInfo> {
755 let path = path.as_ref();
756
757 let file = File::open(path)?;
759 let mut reader = BufReader::new(file);
760
761 let mut header_bytes = [0u8; HEADER_SIZE];
763 reader.read_exact(&mut header_bytes)?;
764 let header = Header::from_bytes(&header_bytes)?;
765
766 let mut metadata_bytes = vec![0u8; header.metadata_size as usize];
768 reader.read_exact(&mut metadata_bytes)?;
769 let metadata: Metadata = rmp_serde::from_slice(&metadata_bytes)
770 .map_err(|e| AprenderError::Serialization(format!("Failed to parse metadata: {e}")))?;
771
772 Ok(ModelInfo {
773 model_type: header.model_type,
774 format_version: header.version,
775 metadata,
776 payload_size: header.payload_size as usize,
777 uncompressed_size: header.uncompressed_size as usize,
778 encrypted: header.flags.is_encrypted(),
779 signed: header.flags.is_signed(),
780 streaming: header.flags.is_streaming(),
781 licensed: header.flags.is_licensed(),
782 trueno_native: header.flags.is_trueno_native(),
783 quantized: header.flags.is_quantized(),
784 has_model_card: header.flags.has_model_card(),
785 })
786}
787
788#[cfg(test)]
789mod tests {
790 use super::*;
791 use serde::{Deserialize, Serialize};
792
793 #[test]
798 fn test_crc32_empty() {
799 assert_eq!(crc32(&[]), 0x0000_0000);
801 }
802
803 #[test]
804 fn test_crc32_known_values() {
805 let data = b"123456789";
807 assert_eq!(crc32(data), 0xCBF4_3926);
808 }
809
810 #[test]
811 fn test_crc32_single_byte() {
812 assert_eq!(crc32(&[0x00]), 0xD202_EF8D);
814 assert_eq!(crc32(&[0xFF]), 0xFF00_0000);
815 }
816
817 #[test]
818 fn test_crc32_multiple_bytes() {
819 let data = b"Hello, World!";
820 let crc = crc32(data);
821 assert_eq!(crc, crc32(data));
823 assert_ne!(crc, crc32(b"Hello, World"));
825 }
826
827 #[test]
832 fn test_compress_payload_none() {
833 let data = b"test data for compression";
834 let (compressed, compression) =
835 compress_payload(data, Compression::None).expect("compress");
836 assert_eq!(compression, Compression::None);
837 assert_eq!(compressed, data);
838 }
839
840 #[test]
841 fn test_decompress_payload_none() {
842 let data = b"test data for decompression";
843 let decompressed = decompress_payload(data, Compression::None).expect("decompress");
844 assert_eq!(decompressed, data);
845 }
846
847 #[cfg(feature = "format-compression")]
848 #[test]
849 fn test_compress_decompress_zstd_default() {
850 let data = b"test data that should compress well with zstd compression";
851 let (compressed, compression) =
852 compress_payload(data, Compression::ZstdDefault).expect("compress");
853 assert_eq!(compression, Compression::ZstdDefault);
854 let decompressed = decompress_payload(&compressed, compression).expect("decompress");
855 assert_eq!(decompressed, data);
856 }
857
858 #[cfg(feature = "format-compression")]
859 #[test]
860 fn test_compress_decompress_lz4() {
861 let data = b"test data for lz4 compression";
862 let (compressed, compression) = compress_payload(data, Compression::Lz4).expect("compress");
863 assert_eq!(compression, Compression::Lz4);
864 let decompressed = decompress_payload(&compressed, compression).expect("decompress");
865 assert_eq!(decompressed, data);
866 }
867
868 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
873 struct TestModel {
874 name: String,
875 values: Vec<f32>,
876 }
877
878 #[test]
879 fn test_save_load_roundtrip() {
880 let model = TestModel {
881 name: "test_model".to_string(),
882 values: vec![1.0, 2.0, 3.0, 4.0],
883 };
884
885 let dir = tempfile::tempdir().expect("create temp dir");
886 let path = dir.path().join("test.apr");
887
888 let options = SaveOptions::default();
889 save(&model, ModelType::LinearRegression, &path, options).expect("save");
890
891 let loaded: TestModel = load(&path, ModelType::LinearRegression).expect("load");
892 assert_eq!(model, loaded);
893 }
894
895 #[test]
896 fn test_save_with_metadata() {
897 let model = TestModel {
898 name: "metadata_test".to_string(),
899 values: vec![1.0],
900 };
901
902 let dir = tempfile::tempdir().expect("create temp dir");
903 let path = dir.path().join("test_metadata.apr");
904
905 let mut metadata = Metadata::default();
906 metadata.description = Some("A test model".to_string());
907
908 let options = SaveOptions {
909 metadata,
910 compression: Compression::None,
911 quality_score: Some(85),
912 };
913 save(&model, ModelType::LinearRegression, &path, options).expect("save");
914
915 let info = inspect(&path).expect("inspect");
916 assert_eq!(info.metadata.description, Some("A test model".to_string()));
917 }
918
919 #[test]
920 fn test_save_rejects_quality_score_zero() {
921 let model = TestModel {
922 name: "bad_model".to_string(),
923 values: vec![],
924 };
925
926 let dir = tempfile::tempdir().expect("create temp dir");
927 let path = dir.path().join("should_not_exist.apr");
928
929 let options = SaveOptions {
930 quality_score: Some(0),
931 ..Default::default()
932 };
933
934 let result = save(&model, ModelType::LinearRegression, &path, options);
935 assert!(result.is_err());
936 }
937
938 #[test]
939 fn test_load_wrong_model_type() {
940 let model = TestModel {
941 name: "type_test".to_string(),
942 values: vec![1.0],
943 };
944
945 let dir = tempfile::tempdir().expect("create temp dir");
946 let path = dir.path().join("type_test.apr");
947
948 save(
949 &model,
950 ModelType::LinearRegression,
951 &path,
952 SaveOptions::default(),
953 )
954 .expect("save");
955
956 let result: Result<TestModel> = load(&path, ModelType::KMeans);
957 assert!(result.is_err());
958 }
959
960 #[test]
961 fn test_load_nonexistent_file() {
962 let result: Result<TestModel> =
963 load("/nonexistent/path/model.apr", ModelType::LinearRegression);
964 assert!(result.is_err());
965 }
966
967 #[test]
968 fn test_load_file_too_small() {
969 let dir = tempfile::tempdir().expect("create temp dir");
970 let path = dir.path().join("tiny.apr");
971
972 std::fs::write(&path, &[0u8; 10]).expect("write tiny file");
973
974 let result: Result<TestModel> = load(&path, ModelType::LinearRegression);
975 assert!(result.is_err());
976 }
977
978 #[test]
983 fn test_inspect_model() {
984 let model = TestModel {
985 name: "inspect_test".to_string(),
986 values: vec![1.0, 2.0, 3.0],
987 };
988
989 let dir = tempfile::tempdir().expect("create temp dir");
990 let path = dir.path().join("inspect_test.apr");
991
992 let mut metadata = Metadata::default();
993 metadata.model_name = Some("Test Model".to_string());
994
995 let options = SaveOptions {
996 metadata,
997 compression: Compression::None,
998 quality_score: Some(90),
999 };
1000 save(&model, ModelType::LinearRegression, &path, options).expect("save");
1001
1002 let info = inspect(&path).expect("inspect");
1003 assert_eq!(info.model_type, ModelType::LinearRegression);
1004 assert_eq!(info.metadata.model_name, Some("Test Model".to_string()));
1005 }
1006
1007 #[test]
1008 fn test_inspect_with_license_flag() {
1009 use super::super::{LicenseInfo, LicenseTier};
1010
1011 let model = TestModel {
1012 name: "licensed".to_string(),
1013 values: vec![1.0],
1014 };
1015
1016 let dir = tempfile::tempdir().expect("create temp dir");
1017 let path = dir.path().join("licensed.apr");
1018
1019 let mut metadata = Metadata::default();
1020 metadata.license = Some(LicenseInfo {
1021 uuid: "test-uuid".to_string(),
1022 hash: "test-hash".to_string(),
1023 expiry: None,
1024 seats: None,
1025 licensee: Some("Test User".to_string()),
1026 tier: LicenseTier::Enterprise,
1027 });
1028
1029 let options = SaveOptions {
1030 metadata,
1031 compression: Compression::None,
1032 quality_score: None,
1033 };
1034 save(&model, ModelType::LinearRegression, &path, options).expect("save");
1035
1036 let info = inspect(&path).expect("inspect");
1037 assert!(info.licensed);
1038 }
1039
1040 #[test]
1041 fn test_inspect_bytes_valid() {
1042 let model = TestModel {
1043 name: "bytes_test".to_string(),
1044 values: vec![1.0],
1045 };
1046
1047 let dir = tempfile::tempdir().expect("create temp dir");
1048 let path = dir.path().join("bytes_test.apr");
1049
1050 save(
1051 &model,
1052 ModelType::LinearRegression,
1053 &path,
1054 SaveOptions::default(),
1055 )
1056 .expect("save");
1057
1058 let data = std::fs::read(&path).expect("read file");
1059 let info = inspect_bytes(&data).expect("inspect bytes");
1060 assert_eq!(info.model_type, ModelType::LinearRegression);
1061 }
1062
1063 #[test]
1064 fn test_inspect_bytes_too_small() {
1065 let data = vec![0u8; 10];
1066 let result = inspect_bytes(&data);
1067 assert!(result.is_err());
1068 }
1069}