1use crate::error::{AprenderError, Result};
12use crate::format::v2::{AprV2Metadata, AprV2Writer};
13use crate::format::validation::{AprValidator, TensorStats, ValidationReport};
14use crate::format::Compression;
15use crate::serialization::safetensors::{extract_tensor, load_safetensors, save_safetensors};
16use std::collections::BTreeMap;
17use std::fs;
18use std::io::Write;
19use std::path::{Path, PathBuf};
20
21#[derive(Debug, Clone, PartialEq)]
29pub enum Source {
30 HuggingFace {
32 org: String,
33 repo: String,
34 file: Option<String>,
35 },
36 Local(PathBuf),
38 Url(String),
40}
41
42impl Source {
43 pub fn parse(source: &str) -> Result<Self> {
45 if source.starts_with("hf://") {
46 Self::parse_hf(source)
47 } else if source.starts_with("http://") || source.starts_with("https://") {
48 Ok(Self::Url(source.to_string()))
49 } else {
50 Ok(Self::Local(PathBuf::from(source)))
51 }
52 }
53
54 fn parse_hf(source: &str) -> Result<Self> {
55 let path = source.strip_prefix("hf://").unwrap_or(source);
56 let parts: Vec<&str> = path.split('/').collect();
57
58 if parts.len() < 2 {
59 return Err(AprenderError::FormatError {
60 message: format!("Invalid HuggingFace source: {source}. Expected hf://org/repo"),
61 });
62 }
63
64 let org = parts[0].to_string();
65 let repo = parts[1].to_string();
66 let file = if parts.len() > 2 {
67 Some(parts[2..].join("/"))
68 } else {
69 None
70 };
71
72 Ok(Self::HuggingFace { org, repo, file })
73 }
74
75 pub fn default_file(&self) -> &str {
77 match self {
78 Self::HuggingFace { file: Some(f), .. } => f,
79 Self::HuggingFace { file: None, .. } => "model.safetensors",
80 Self::Local(p) => p.to_str().unwrap_or("model.safetensors"),
81 Self::Url(u) => u.rsplit('/').next().unwrap_or("model.safetensors"),
82 }
83 }
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
92pub enum Architecture {
93 #[default]
95 Auto,
96 Whisper,
98 Llama,
100 Bert,
102}
103
104impl Architecture {
105 pub fn map_name(&self, source_name: &str) -> String {
107 match self {
108 Self::Auto => Self::auto_map_name(source_name),
109 Self::Whisper => Self::whisper_map_name(source_name),
110 Self::Llama => Self::llama_map_name(source_name),
111 Self::Bert => Self::bert_map_name(source_name),
112 }
113 }
114
115 fn auto_map_name(name: &str) -> String {
116 let name = name.strip_prefix("model.").unwrap_or(name);
118 name.to_string()
119 }
120
121 fn whisper_map_name(name: &str) -> String {
122 let name = name.strip_prefix("model.").unwrap_or(name);
124 name.to_string()
125 }
126
127 fn llama_map_name(name: &str) -> String {
128 let name = name.strip_prefix("model.").unwrap_or(name);
130 name.to_string()
131 }
132
133 fn bert_map_name(name: &str) -> String {
134 let name = name.strip_prefix("bert.").unwrap_or(name);
136 name.to_string()
137 }
138}
139
140#[derive(Debug, Clone)]
146pub struct TensorExpectation {
147 pub mean_range: (f32, f32),
149 pub std_range: Option<(f32, f32)>,
151 pub description: &'static str,
153}
154
155impl TensorExpectation {
156 pub const LAYER_NORM_WEIGHT: Self = Self {
158 mean_range: (0.5, 3.0),
159 std_range: Some((0.0, 2.0)),
160 description: "LayerNorm weight (gamma)",
161 };
162
163 pub const LAYER_NORM_BIAS: Self = Self {
165 mean_range: (-0.5, 0.5),
166 std_range: Some((0.0, 1.0)),
167 description: "LayerNorm bias (beta)",
168 };
169
170 pub const LINEAR_WEIGHT: Self = Self {
172 mean_range: (-0.1, 0.1),
173 std_range: None,
174 description: "Linear/Attention weight",
175 };
176
177 pub const EMBEDDING: Self = Self {
179 mean_range: (-1.0, 1.0),
180 std_range: None,
181 description: "Embedding",
182 };
183
184 pub fn for_tensor(name: &str) -> Option<Self> {
186 if name.contains("layer_norm") || name.contains("ln_") {
187 if name.ends_with(".weight") || name.ends_with(".gamma") {
188 return Some(Self::LAYER_NORM_WEIGHT);
189 }
190 if name.ends_with(".bias") || name.ends_with(".beta") {
191 return Some(Self::LAYER_NORM_BIAS);
192 }
193 }
194
195 if name.contains("embed") {
196 return Some(Self::EMBEDDING);
197 }
198
199 if name.ends_with(".weight") {
200 return Some(Self::LINEAR_WEIGHT);
201 }
202
203 None
204 }
205
206 pub fn check(&self, stats: &TensorStats) -> Result<()> {
208 let (min_mean, max_mean) = self.mean_range;
209
210 if stats.mean < min_mean || stats.mean > max_mean {
211 return Err(AprenderError::FormatError {
212 message: format!(
213 "{}: mean={:.4} outside expected range [{:.1}, {:.1}]",
214 self.description, stats.mean, min_mean, max_mean
215 ),
216 });
217 }
218
219 if let Some((min_std, max_std)) = self.std_range {
220 if stats.std < min_std || stats.std > max_std {
221 return Err(AprenderError::FormatError {
222 message: format!(
223 "{}: std={:.4} outside expected range [{:.1}, {:.1}]",
224 self.description, stats.std, min_std, max_std
225 ),
226 });
227 }
228 }
229
230 Ok(())
231 }
232}
233
234#[derive(Debug, Clone, Copy, PartialEq, Eq)]
240pub enum ValidationConfig {
241 None,
243 Basic,
245 Strict,
247}
248
249impl Default for ValidationConfig {
250 fn default() -> Self {
251 Self::Strict
252 }
253}
254
255impl ValidationConfig {
256 pub fn strict() -> Self {
258 Self::Strict
259 }
260}
261
262#[derive(Debug, Clone, Copy, PartialEq, Eq)]
268pub enum QuantizationType {
269 Int8,
271 Int4,
273 Fp16,
275}
276
277#[derive(Debug, Clone)]
279pub struct ImportOptions {
280 pub architecture: Architecture,
282 pub validation: ValidationConfig,
284 pub quantize: Option<QuantizationType>,
286 pub compress: Option<Compression>,
288 pub force: bool,
290 pub cache: bool,
292}
293
294impl Default for ImportOptions {
295 fn default() -> Self {
296 Self {
297 architecture: Architecture::Auto,
298 validation: ValidationConfig::Strict,
299 quantize: None,
300 compress: None,
301 force: false,
302 cache: true,
303 }
304 }
305}
306
307#[derive(Debug, Clone)]
313pub enum ImportError {
314 DownloadFailed { source: String, reason: String },
316 UnsupportedFormat { extension: String },
318 ValidationFailed { name: String, reason: String },
320 UnknownTensor { source_name: String },
322 MissingTensor { name: String },
324 NotFound { resource: String, status: u16 },
326 RateLimited { retry_after: Option<u64> },
328 AuthRequired { resource: String },
330 ShardingRequired { model_size: u64, shard_count: usize },
332}
333
334impl std::fmt::Display for ImportError {
335 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
336 match self {
337 Self::DownloadFailed { source, reason } => {
338 write!(f, "Download failed: {source} - {reason}")
339 }
340 Self::UnsupportedFormat { extension } => {
341 write!(f, "Unsupported format: {extension}")
342 }
343 Self::ValidationFailed { name, reason } => {
344 write!(f, "Tensor validation failed: {name} - {reason}")
345 }
346 Self::UnknownTensor { source_name } => {
347 write!(f, "Unknown tensor: {source_name}")
348 }
349 Self::MissingTensor { name } => {
350 write!(f, "Missing required tensor: {name}")
351 }
352 Self::NotFound { resource, status } => {
354 write!(
355 f,
356 "Resource not found ({status}): {resource}. \
357 Fix: verify the model name exists on huggingface.co/models"
358 )
359 }
360 Self::RateLimited { retry_after } => {
361 if let Some(secs) = retry_after {
362 write!(
363 f,
364 "Rate limited by server. Retry after {secs} seconds. \
365 Fix: wait and retry, or use --cache to avoid re-downloads"
366 )
367 } else {
368 write!(
369 f,
370 "Rate limited by server. \
371 Fix: wait a few minutes and retry"
372 )
373 }
374 }
375 Self::AuthRequired { resource } => {
376 write!(
377 f,
378 "Authentication required for {resource}. \
379 Fix: set HF_TOKEN environment variable with your HuggingFace token"
380 )
381 }
382 Self::ShardingRequired {
383 model_size,
384 shard_count,
385 } => {
386 let size_gb = *model_size as f64 / 1_000_000_000.0;
387 write!(
388 f,
389 "Model too large ({size_gb:.1} GB, {shard_count} shards) for single-file loading. \
390 Fix: use streaming import with --sharded flag"
391 )
392 }
393 }
394 }
395}
396
397impl std::error::Error for ImportError {}
398
399impl From<ImportError> for AprenderError {
400 fn from(err: ImportError) -> Self {
401 AprenderError::FormatError {
402 message: err.to_string(),
403 }
404 }
405}
406
407#[cfg(feature = "hf-hub-integration")]
409fn parse_import_error(error_msg: &str, resource: &str) -> ImportError {
410 let msg_lower = error_msg.to_lowercase();
411
412 if msg_lower.contains("404")
414 || msg_lower.contains("not found")
415 || msg_lower.contains("does not exist")
416 || msg_lower.contains("no such")
417 {
418 return ImportError::NotFound {
419 resource: resource.to_string(),
420 status: 404,
421 };
422 }
423
424 if msg_lower.contains("401")
426 || msg_lower.contains("403")
427 || msg_lower.contains("unauthorized")
428 || msg_lower.contains("forbidden")
429 || msg_lower.contains("gated")
430 || msg_lower.contains("access denied")
431 {
432 return ImportError::AuthRequired {
433 resource: resource.to_string(),
434 };
435 }
436
437 if msg_lower.contains("429")
439 || msg_lower.contains("rate limit")
440 || msg_lower.contains("too many requests")
441 {
442 let retry_after = if let Some(pos) = msg_lower.find("retry") {
444 msg_lower[pos..]
445 .split_whitespace()
446 .find_map(|s| s.parse::<u64>().ok())
447 } else {
448 None
449 };
450 return ImportError::RateLimited { retry_after };
451 }
452
453 ImportError::DownloadFailed {
455 source: resource.to_string(),
456 reason: error_msg.to_string(),
457 }
458}
459
460#[derive(Debug, Clone)]
469pub struct ShardedIndex {
470 weight_map: std::collections::HashMap<String, String>,
472 total_size: Option<u64>,
474}
475
476impl ShardedIndex {
477 pub fn parse(json: &str) -> Result<Self> {
490 let json = json.trim();
494 if !json.starts_with('{') || !json.ends_with('}') {
495 return Err(AprenderError::FormatError {
496 message: "Invalid JSON: expected object".to_string(),
497 });
498 }
499
500 let weight_map_start =
502 json.find("\"weight_map\"")
503 .ok_or_else(|| AprenderError::FormatError {
504 message: "Missing 'weight_map' key in index.json".to_string(),
505 })?;
506
507 let after_key = &json[weight_map_start + 12..]; let obj_start = after_key
510 .find('{')
511 .ok_or_else(|| AprenderError::FormatError {
512 message: "Invalid weight_map: expected object".to_string(),
513 })?;
514
515 let obj_content = &after_key[obj_start..];
516 let mut weight_map = std::collections::HashMap::new();
517 let mut depth = 0;
518 let mut obj_end = 0;
519
520 for (i, c) in obj_content.char_indices() {
521 match c {
522 '{' => depth += 1,
523 '}' => {
524 depth -= 1;
525 if depth == 0 {
526 obj_end = i;
527 break;
528 }
529 }
530 _ => {}
531 }
532 }
533
534 let inner = &obj_content[1..obj_end];
535
536 for pair in inner.split(',') {
538 let pair = pair.trim();
539 if pair.is_empty() {
540 continue;
541 }
542
543 let parts: Vec<&str> = pair.splitn(2, ':').collect();
544 if parts.len() == 2 {
545 let key = parts[0].trim().trim_matches('"');
546 let val = parts[1].trim().trim_matches('"');
547 if !key.is_empty() && !val.is_empty() {
548 weight_map.insert(key.to_string(), val.to_string());
549 }
550 }
551 }
552
553 let total_size = json.find("\"total_size\"").and_then(|pos| {
555 let after = &json[pos + 12..];
556 let colon = after.find(':')?;
557 let after_colon = after[colon + 1..].trim_start();
558 let end = after_colon.find(|c: char| !c.is_ascii_digit())?;
559 after_colon[..end].parse::<u64>().ok()
560 });
561
562 Ok(Self {
563 weight_map,
564 total_size,
565 })
566 }
567
568 #[must_use]
570 pub fn shard_count(&self) -> usize {
571 let unique: std::collections::HashSet<_> = self.weight_map.values().collect();
572 unique.len()
573 }
574
575 #[must_use]
577 pub fn tensor_count(&self) -> usize {
578 self.weight_map.len()
579 }
580
581 #[must_use]
583 pub fn total_size(&self) -> Option<u64> {
584 self.total_size
585 }
586
587 #[must_use]
589 pub fn shard_for_tensor(&self, tensor_name: &str) -> Option<&str> {
590 self.weight_map.get(tensor_name).map(String::as_str)
591 }
592
593 #[must_use]
595 pub fn tensors_in_shard(&self, shard_file: &str) -> Vec<&str> {
596 self.weight_map
597 .iter()
598 .filter(|(_, v)| v.as_str() == shard_file)
599 .map(|(k, _)| k.as_str())
600 .collect()
601 }
602
603 #[must_use]
605 pub fn shard_files(&self) -> Vec<&str> {
606 let mut files: Vec<_> = self
607 .weight_map
608 .values()
609 .map(String::as_str)
610 .collect::<std::collections::HashSet<_>>()
611 .into_iter()
612 .collect();
613 files.sort_unstable();
614 files
615 }
616}
617
618#[must_use]
622pub fn detect_sharded_model(dir: &Path, base_name: &str) -> Option<PathBuf> {
623 let index_name = format!("{base_name}.index.json");
624 let index_path = dir.join(&index_name);
625
626 if index_path.exists() {
627 Some(index_path)
628 } else {
629 None
630 }
631}
632
633#[derive(Debug)]
639pub struct AprConverter {
640 source: Option<Source>,
641 architecture: Architecture,
642 validation: ValidationConfig,
643 quantize: Option<QuantizationType>,
644 compress: Option<Compression>,
645}
646
647impl AprConverter {
648 pub fn new() -> Self {
650 Self {
651 source: None,
652 architecture: Architecture::Auto,
653 validation: ValidationConfig::Strict,
654 quantize: None,
655 compress: None,
656 }
657 }
658
659 pub fn source(mut self, source: &str) -> Result<Self> {
661 self.source = Some(Source::parse(source)?);
662 Ok(self)
663 }
664
665 pub fn architecture(mut self, arch: Architecture) -> Self {
667 self.architecture = arch;
668 self
669 }
670
671 pub fn validate(mut self, config: ValidationConfig) -> Self {
673 self.validation = config;
674 self
675 }
676
677 pub fn quantize(mut self, quant: QuantizationType) -> Self {
679 self.quantize = Some(quant);
680 self
681 }
682
683 pub fn compress(mut self, comp: Compression) -> Self {
685 self.compress = Some(comp);
686 self
687 }
688
689 pub fn convert(self) -> Result<Vec<u8>> {
691 let source = self.source.ok_or_else(|| AprenderError::FormatError {
692 message: "No source specified".to_string(),
693 })?;
694
695 Err(AprenderError::FormatError {
698 message: format!(
699 "Conversion from {:?} not yet implemented - see GH-80",
700 source
701 ),
702 })
703 }
704}
705
706impl Default for AprConverter {
707 fn default() -> Self {
708 Self::new()
709 }
710}
711
712pub fn apr_import<P: AsRef<Path>>(
738 source: &str,
739 output: P,
740 options: ImportOptions,
741) -> Result<ValidationReport> {
742 let parsed_source = Source::parse(source)?;
743 let output_path = output.as_ref();
744
745 let local_path = resolve_source(&parsed_source, options.cache)?;
747
748 let tensors = load_source_tensors(&local_path, &options)?;
750
751 let mapped_tensors = map_tensor_names(&tensors, options.architecture);
753
754 let validation_result = validate_tensors(&mapped_tensors, &options)?;
756
757 write_apr_file(&mapped_tensors, output_path, &options)?;
759
760 Ok(validation_result)
761}
762
763fn resolve_source(source: &Source, cache: bool) -> Result<PathBuf> {
765 match source {
766 Source::Local(path) => {
767 if !path.exists() {
768 let err = ImportError::NotFound {
770 resource: path.display().to_string(),
771 status: 0, };
773 return Err(AprenderError::from(err));
774 }
775 Ok(path.clone())
776 }
777 Source::HuggingFace { org, repo, file } => {
778 let filename = file.as_deref().unwrap_or("model.safetensors");
779
780 if cache {
782 if let Some(path) = find_in_cache(org, repo, filename) {
783 return Ok(path);
784 }
785 }
786
787 #[cfg(feature = "hf-hub-integration")]
789 {
790 let repo_id = format!("{org}/{repo}");
791 match download_from_hf(&repo_id, filename) {
792 Ok(path) => return Ok(path),
793 Err(e) => {
794 eprintln!("HF download failed: {e}");
796 }
797 }
798 }
799
800 Err(AprenderError::FormatError {
801 message: format!(
802 "HuggingFace model not found in cache. Download manually:\n\
803 huggingface-cli download {org}/{repo} {filename}\n\
804 Or provide a local path to the SafeTensors file.",
805 ),
806 })
807 }
808 Source::Url(url) => Err(AprenderError::FormatError {
809 message: format!("URL download not yet implemented: {url}"),
810 }),
811 }
812}
813
814fn get_xdg_cache_dir() -> PathBuf {
816 std::env::var("XDG_CACHE_HOME")
817 .ok()
818 .map(PathBuf::from)
819 .unwrap_or_else(|| {
820 std::env::var("HOME")
821 .map(|h| PathBuf::from(h).join(".cache"))
822 .unwrap_or_else(|_| PathBuf::from(".cache"))
823 })
824}
825
826fn get_hf_cache_dir() -> PathBuf {
828 std::env::var("HF_HOME")
829 .ok()
830 .map(PathBuf::from)
831 .unwrap_or_else(|| {
832 std::env::var("HOME")
833 .map(|h| PathBuf::from(h).join(".cache").join("huggingface"))
834 .unwrap_or_else(|_| PathBuf::from(".cache").join("huggingface"))
835 })
836}
837
838fn find_in_aprender_cache(
840 cache_base: &Path,
841 org: &str,
842 repo: &str,
843 filename: &str,
844) -> Option<PathBuf> {
845 let apr_cache = cache_base
846 .join("aprender")
847 .join("hf")
848 .join(org)
849 .join(repo)
850 .join(filename);
851 apr_cache.exists().then_some(apr_cache)
852}
853
854fn find_in_hf_hub_cache(
856 cache_base: &Path,
857 org: &str,
858 repo: &str,
859 filename: &str,
860) -> Option<PathBuf> {
861 let hf_cache = cache_base
862 .join("hub")
863 .join(format!("models--{org}--{repo}"));
864
865 if !hf_cache.exists() {
866 return None;
867 }
868
869 let snapshot_dir = hf_cache.join("snapshots");
870 let entries = fs::read_dir(&snapshot_dir).ok()?;
871
872 for entry in entries.flatten() {
873 let file_path = entry.path().join(filename);
874 if file_path.exists() {
875 return Some(file_path);
876 }
877 }
878 None
879}
880
881fn find_in_cache(org: &str, repo: &str, filename: &str) -> Option<PathBuf> {
883 let cache_paths = [get_xdg_cache_dir(), get_hf_cache_dir()];
884
885 for cache_base in &cache_paths {
886 if let Some(path) = find_in_aprender_cache(cache_base, org, repo, filename) {
887 return Some(path);
888 }
889 if let Some(path) = find_in_hf_hub_cache(cache_base, org, repo, filename) {
890 return Some(path);
891 }
892 }
893
894 None
895}
896
897#[cfg(feature = "hf-hub-integration")]
899fn download_from_hf(repo_id: &str, filename: &str) -> Result<PathBuf> {
900 use hf_hub::api::sync::ApiBuilder;
901
902 let token = std::env::var("HF_TOKEN").ok();
904 let mut builder = ApiBuilder::new();
905 if let Some(t) = token {
906 builder = builder.with_token(Some(t));
907 }
908
909 let api = builder.build().map_err(|e| {
910 let resource = format!("{repo_id}/{filename}");
911 let err = parse_import_error(&e.to_string(), &resource);
912 AprenderError::from(err)
913 })?;
914
915 let repo = api.model(repo_id.to_string());
917
918 let path = repo.get(filename).map_err(|e| {
920 let resource = format!("{repo_id}/{filename}");
921 let err = parse_import_error(&e.to_string(), &resource);
922 AprenderError::from(err)
923 })?;
924
925 Ok(path)
926}
927
928fn load_source_tensors(
930 path: &Path,
931 _options: &ImportOptions,
932) -> Result<BTreeMap<String, (Vec<f32>, Vec<usize>)>> {
933 let extension = path.extension().and_then(|e| e.to_str()).unwrap_or("");
934
935 match extension {
936 "safetensors" => load_safetensors_tensors(path),
937 "apr" => {
938 Err(AprenderError::FormatError {
940 message: "Cannot import from APR format - use direct loading instead".to_string(),
941 })
942 }
943 "gguf" => Err(AprenderError::FormatError {
944 message: "GGUF import not yet implemented".to_string(),
945 }),
946 "bin" | "pt" | "pth" => Err(AprenderError::FormatError {
947 message: format!(
948 "PyTorch format ({extension}) not supported. Convert to SafeTensors first."
949 ),
950 }),
951 other => Err(AprenderError::FormatError {
952 message: format!("Unknown file format: .{other}. Supported: .safetensors"),
953 }),
954 }
955}
956
957fn load_safetensors_tensors(path: &Path) -> Result<BTreeMap<String, (Vec<f32>, Vec<usize>)>> {
959 let (metadata, raw_data) = load_safetensors(path).map_err(|e| AprenderError::FormatError {
960 message: format!("Failed to load SafeTensors: {e}"),
961 })?;
962
963 let mut tensors = BTreeMap::new();
964
965 for (name, tensor_meta) in metadata.iter() {
966 if name.starts_with("__") {
968 continue;
969 }
970
971 let data =
972 extract_tensor(&raw_data, tensor_meta).map_err(|e| AprenderError::FormatError {
973 message: format!("Failed to extract tensor '{name}': {e}"),
974 })?;
975
976 tensors.insert(name.clone(), (data, tensor_meta.shape.clone()));
977 }
978
979 Ok(tensors)
980}
981
982fn map_tensor_names(
984 tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
985 architecture: Architecture,
986) -> BTreeMap<String, (Vec<f32>, Vec<usize>)> {
987 tensors
988 .iter()
989 .map(|(name, data)| {
990 let mapped_name = architecture.map_name(name);
991 (mapped_name, data.clone())
992 })
993 .collect()
994}
995
996fn check_tensor_expectation(
998 name: &str,
999 stats: &TensorStats,
1000 options: &ImportOptions,
1001) -> Option<String> {
1002 if options.validation == ValidationConfig::None {
1003 return None;
1004 }
1005 let expectation = TensorExpectation::for_tensor(name)?;
1006 let err = expectation.check(stats).err()?;
1007 if options.validation == ValidationConfig::Strict && !options.force {
1008 Some(format!("{name}: {err}"))
1009 } else {
1010 None
1011 }
1012}
1013
1014fn check_special_values(name: &str, stats: &TensorStats, options: &ImportOptions) -> Vec<String> {
1016 if options.validation == ValidationConfig::None {
1017 return Vec::new();
1018 }
1019 let mut errors = Vec::new();
1020 if stats.nan_count > 0 {
1021 errors.push(format!("{name}: contains {} NaN values", stats.nan_count));
1022 }
1023 if stats.inf_count > 0 {
1024 errors.push(format!("{name}: contains {} Inf values", stats.inf_count));
1025 }
1026 errors
1027}
1028
1029fn validate_single_tensor(
1031 name: &str,
1032 data: &[f32],
1033 options: &ImportOptions,
1034 validator: &mut AprValidator,
1035 errors: &mut Vec<String>,
1036) {
1037 let stats = compute_tensor_stats(name, data);
1038
1039 if let Some(err) = check_tensor_expectation(name, &stats, options) {
1040 errors.push(err);
1041 }
1042 errors.extend(check_special_values(name, &stats, options));
1043
1044 validator.add_tensor_stats(stats);
1045}
1046
1047fn validate_tensors(
1049 tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
1050 options: &ImportOptions,
1051) -> Result<ValidationReport> {
1052 let mut validator = AprValidator::new();
1053 let mut validation_errors = Vec::new();
1054
1055 for (name, (data, _shape)) in tensors {
1056 validate_single_tensor(name, data, options, &mut validator, &mut validation_errors);
1057 }
1058
1059 let report = validator.validate();
1060
1061 if !validation_errors.is_empty() && !options.force {
1062 return Err(AprenderError::FormatError {
1063 message: format!(
1064 "Validation failed ({} errors):\n - {}",
1065 validation_errors.len(),
1066 validation_errors.join("\n - ")
1067 ),
1068 });
1069 }
1070
1071 Ok(report)
1072}
1073
1074struct TensorAccumulator {
1076 sum: f64,
1077 min: f32,
1078 max: f32,
1079 nan_count: usize,
1080 inf_count: usize,
1081 zero_count: usize,
1082 valid_count: usize,
1083}
1084
1085impl TensorAccumulator {
1086 fn new() -> Self {
1087 Self {
1088 sum: 0.0,
1089 min: f32::INFINITY,
1090 max: f32::NEG_INFINITY,
1091 nan_count: 0,
1092 inf_count: 0,
1093 zero_count: 0,
1094 valid_count: 0,
1095 }
1096 }
1097
1098 fn accumulate(&mut self, v: f32) {
1099 if v.is_nan() {
1100 self.nan_count += 1;
1101 } else if v.is_infinite() {
1102 self.inf_count += 1;
1103 } else {
1104 self.sum += v as f64;
1105 self.min = self.min.min(v);
1106 self.max = self.max.max(v);
1107 self.valid_count += 1;
1108 if v == 0.0 {
1109 self.zero_count += 1;
1110 }
1111 }
1112 }
1113
1114 fn mean(&self) -> f32 {
1115 if self.valid_count > 0 {
1116 (self.sum / self.valid_count as f64) as f32
1117 } else {
1118 0.0
1119 }
1120 }
1121
1122 fn safe_min(&self) -> f32 {
1123 if self.min == f32::INFINITY {
1124 0.0
1125 } else {
1126 self.min
1127 }
1128 }
1129
1130 fn safe_max(&self) -> f32 {
1131 if self.max == f32::NEG_INFINITY {
1132 0.0
1133 } else {
1134 self.max
1135 }
1136 }
1137}
1138
1139fn compute_std(data: &[f32], mean: f32, valid_count: usize) -> f32 {
1141 if valid_count <= 1 {
1142 return 0.0;
1143 }
1144 let variance_sum: f64 = data
1145 .iter()
1146 .filter(|v| !v.is_nan() && !v.is_infinite())
1147 .map(|&v| {
1148 let diff = v as f64 - mean as f64;
1149 diff * diff
1150 })
1151 .sum();
1152 ((variance_sum / (valid_count - 1) as f64).sqrt()) as f32
1153}
1154
1155fn compute_tensor_stats(name: &str, data: &[f32]) -> TensorStats {
1157 if data.is_empty() {
1158 return TensorStats {
1159 name: name.to_string(),
1160 count: 0,
1161 min: 0.0,
1162 max: 0.0,
1163 mean: 0.0,
1164 std: 0.0,
1165 nan_count: 0,
1166 inf_count: 0,
1167 zero_count: 0,
1168 };
1169 }
1170
1171 let mut acc = TensorAccumulator::new();
1172 for &v in data {
1173 acc.accumulate(v);
1174 }
1175
1176 let mean = acc.mean();
1177 let std = compute_std(data, mean, acc.valid_count);
1178
1179 TensorStats {
1180 name: name.to_string(),
1181 count: data.len(),
1182 min: acc.safe_min(),
1183 max: acc.safe_max(),
1184 mean,
1185 std,
1186 nan_count: acc.nan_count,
1187 inf_count: acc.inf_count,
1188 zero_count: acc.zero_count,
1189 }
1190}
1191
1192fn write_apr_file(
1194 tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
1195 output: &Path,
1196 options: &ImportOptions,
1197) -> Result<()> {
1198 let mut metadata = AprV2Metadata::default();
1200 metadata.model_type = format!("{:?}", options.architecture);
1201 metadata.name = Some(output.file_stem()
1202 .and_then(|s| s.to_str())
1203 .unwrap_or("model")
1204 .to_string());
1205
1206 let param_count: u64 = tensors.values()
1208 .map(|(data, _)| data.len() as u64)
1209 .sum();
1210 metadata.param_count = param_count;
1211
1212 let mut writer = AprV2Writer::new(metadata);
1214
1215 for (name, (data, shape)) in tensors {
1217 writer.add_f32_tensor(name, shape.clone(), data);
1218 }
1219
1220 let bytes = writer.write().map_err(|e| AprenderError::FormatError {
1222 message: format!("Failed to serialize APR format: {e}"),
1223 })?;
1224
1225 let mut file = fs::File::create(output).map_err(|e| AprenderError::FormatError {
1226 message: format!("Failed to create output file: {e}"),
1227 })?;
1228
1229 file.write_all(&bytes).map_err(|e| AprenderError::FormatError {
1230 message: format!("Failed to write APR file: {e}"),
1231 })?;
1232
1233 Ok(())
1234}
1235
1236#[derive(Debug, Clone)]
1242pub struct ConvertOptions {
1243 pub quantize: Option<QuantizationType>,
1245 pub compress: Option<Compression>,
1247 pub validate: bool,
1249}
1250
1251impl Default for ConvertOptions {
1252 fn default() -> Self {
1253 Self {
1254 quantize: None,
1255 compress: None,
1256 validate: true,
1257 }
1258 }
1259}
1260
1261pub fn apr_convert<P: AsRef<Path>>(
1283 input: P,
1284 output: P,
1285 options: ConvertOptions,
1286) -> Result<ConvertReport> {
1287 let input_path = input.as_ref();
1288 let output_path = output.as_ref();
1289
1290 let tensors = load_model_tensors(input_path)?;
1292 let original_size = calculate_tensor_size(&tensors);
1293 let original_count = tensors.len();
1294
1295 let tensors = if let Some(quant_type) = &options.quantize {
1297 quantize_tensors(&tensors, quant_type)?
1298 } else {
1299 tensors
1300 };
1301
1302 save_model_tensors(&tensors, output_path, options.compress)?;
1304
1305 let converted_size = fs::metadata(output_path)
1307 .map(|m| m.len() as usize)
1308 .unwrap_or(0);
1309
1310 Ok(ConvertReport {
1311 original_size,
1312 converted_size,
1313 tensor_count: original_count,
1314 quantization: options.quantize,
1315 compression: options.compress,
1316 reduction_ratio: if converted_size > 0 {
1317 original_size as f64 / converted_size as f64
1318 } else {
1319 0.0
1320 },
1321 })
1322}
1323
1324#[derive(Debug, Clone)]
1326pub struct ConvertReport {
1327 pub original_size: usize,
1329 pub converted_size: usize,
1331 pub tensor_count: usize,
1333 pub quantization: Option<QuantizationType>,
1335 pub compression: Option<Compression>,
1337 pub reduction_ratio: f64,
1339}
1340
1341impl ConvertReport {
1342 pub fn reduction_percent(&self) -> String {
1344 if self.original_size > 0 && self.converted_size > 0 {
1345 let reduction = 100.0 * (1.0 - self.converted_size as f64 / self.original_size as f64);
1346 format!("{:.1}%", reduction)
1347 } else {
1348 "N/A".to_string()
1349 }
1350 }
1351}
1352
1353fn load_model_tensors(path: &Path) -> Result<BTreeMap<String, (Vec<f32>, Vec<usize>)>> {
1355 let extension = path.extension().and_then(|e| e.to_str()).unwrap_or("");
1356
1357 match extension {
1358 "safetensors" | "apr" => load_safetensors_tensors(path),
1359 other => Err(AprenderError::FormatError {
1360 message: format!("Unsupported format for conversion: .{other}"),
1361 }),
1362 }
1363}
1364
1365fn calculate_tensor_size(tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>) -> usize {
1367 tensors.values().map(|(data, _)| data.len() * 4).sum()
1368}
1369
1370fn quantize_tensors(
1372 tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
1373 quant_type: &QuantizationType,
1374) -> Result<BTreeMap<String, (Vec<f32>, Vec<usize>)>> {
1375 let mut result = BTreeMap::new();
1376
1377 for (name, (data, shape)) in tensors {
1378 let quantized_data = match quant_type {
1379 QuantizationType::Fp16 => quantize_fp16(data),
1380 QuantizationType::Int8 => quantize_int8(data),
1381 QuantizationType::Int4 => quantize_int4(data),
1382 };
1383 result.insert(name.clone(), (quantized_data, shape.clone()));
1384 }
1385
1386 Ok(result)
1387}
1388
1389fn quantize_fp16(data: &[f32]) -> Vec<f32> {
1391 data.iter()
1392 .map(|&v| {
1393 let bits = v.to_bits();
1395 let sign = bits >> 31;
1396 let exp = (bits >> 23) & 0xFF;
1397 let mantissa = bits & 0x7FFFFF;
1398
1399 let mantissa_16 = mantissa >> 13;
1401
1402 let new_bits = (sign << 31) | (exp << 23) | (mantissa_16 << 13);
1404 f32::from_bits(new_bits)
1405 })
1406 .collect()
1407}
1408
1409fn quantize_int8(data: &[f32]) -> Vec<f32> {
1411 if data.is_empty() {
1412 return vec![];
1413 }
1414
1415 let max_abs = data.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
1417
1418 if max_abs == 0.0 {
1419 return vec![0.0; data.len()];
1420 }
1421
1422 let scale = max_abs / 127.0;
1423
1424 data.iter()
1426 .map(|&v| {
1427 let quantized = (v / scale).round().clamp(-127.0, 127.0) as i8;
1428 f32::from(quantized) * scale
1429 })
1430 .collect()
1431}
1432
1433fn quantize_int4(data: &[f32]) -> Vec<f32> {
1435 if data.is_empty() {
1436 return vec![];
1437 }
1438
1439 let max_abs = data.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
1441
1442 if max_abs == 0.0 {
1443 return vec![0.0; data.len()];
1444 }
1445
1446 let scale = max_abs / 7.0; data.iter()
1450 .map(|&v| {
1451 let quantized = (v / scale).round().clamp(-8.0, 7.0) as i8;
1452 f32::from(quantized) * scale
1453 })
1454 .collect()
1455}
1456
1457fn save_model_tensors(
1459 tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
1460 output: &Path,
1461 _compression: Option<Compression>,
1462) -> Result<()> {
1463 save_safetensors(output, tensors).map_err(|e| AprenderError::FormatError {
1466 message: format!("Failed to save converted model: {e}"),
1467 })
1468}
1469
1470#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1476pub enum ExportFormat {
1477 SafeTensors,
1479 Gguf,
1481 Onnx,
1483 TorchScript,
1485}
1486
1487impl std::str::FromStr for ExportFormat {
1488 type Err = String;
1489
1490 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
1491 match s.to_lowercase().as_str() {
1492 "safetensors" | "st" => Ok(Self::SafeTensors),
1493 "gguf" => Ok(Self::Gguf),
1494 "onnx" => Ok(Self::Onnx),
1495 "torchscript" | "pt" | "torch" => Ok(Self::TorchScript),
1496 _ => Err(format!("Unknown export format: {s}")),
1497 }
1498 }
1499}
1500
1501impl ExportFormat {
1502 #[must_use]
1504 pub fn extension(&self) -> &'static str {
1505 match self {
1506 Self::SafeTensors => "safetensors",
1507 Self::Gguf => "gguf",
1508 Self::Onnx => "onnx",
1509 Self::TorchScript => "pt",
1510 }
1511 }
1512
1513 #[must_use]
1515 pub fn is_supported(&self) -> bool {
1516 matches!(self, Self::SafeTensors | Self::Gguf)
1517 }
1518}
1519
1520#[derive(Debug, Clone)]
1522pub struct ExportOptions {
1523 pub format: ExportFormat,
1525 pub quantize: Option<QuantizationType>,
1527}
1528
1529impl Default for ExportOptions {
1530 fn default() -> Self {
1531 Self {
1532 format: ExportFormat::SafeTensors,
1533 quantize: None,
1534 }
1535 }
1536}
1537
1538#[derive(Debug, Clone)]
1540pub struct ExportReport {
1541 pub original_size: usize,
1543 pub exported_size: usize,
1545 pub tensor_count: usize,
1547 pub format: ExportFormat,
1549 pub quantization: Option<QuantizationType>,
1551}
1552
1553pub fn apr_export<P: AsRef<Path>>(
1584 input: P,
1585 output: P,
1586 options: ExportOptions,
1587) -> Result<ExportReport> {
1588 let input_path = input.as_ref();
1589 let output_path = output.as_ref();
1590
1591 if !input_path.exists() {
1593 return Err(AprenderError::FormatError {
1594 message: format!("Input file not found: {}", input_path.display()),
1595 });
1596 }
1597
1598 if !options.format.is_supported() {
1600 return Err(AprenderError::FormatError {
1601 message: format!(
1602 "Export format {:?} is not yet supported. Use 'safetensors' or 'gguf'.",
1603 options.format
1604 ),
1605 });
1606 }
1607
1608 let tensors = load_model_tensors(input_path)?;
1610 let original_size = calculate_tensor_size(&tensors);
1611 let original_count = tensors.len();
1612
1613 let tensors = if let Some(ref quant_type) = options.quantize {
1615 quantize_tensors(&tensors, quant_type)?
1616 } else {
1617 tensors
1618 };
1619
1620 match options.format {
1622 ExportFormat::SafeTensors => {
1623 save_safetensors(output_path, &tensors).map_err(|e| AprenderError::FormatError {
1624 message: format!("Failed to export to SafeTensors: {e}"),
1625 })?;
1626 }
1627 ExportFormat::Gguf => {
1628 export_to_gguf(&tensors, output_path)?;
1629 }
1630 ExportFormat::Onnx | ExportFormat::TorchScript => {
1631 return Err(AprenderError::FormatError {
1632 message: format!("Export format {:?} is not yet implemented", options.format),
1633 });
1634 }
1635 }
1636
1637 let exported_size = fs::metadata(output_path)
1639 .map(|m| m.len() as usize)
1640 .unwrap_or(0);
1641
1642 Ok(ExportReport {
1643 original_size,
1644 exported_size,
1645 tensor_count: original_count,
1646 format: options.format,
1647 quantization: options.quantize,
1648 })
1649}
1650
1651fn export_to_gguf(tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>, output: &Path) -> Result<()> {
1653 use crate::format::gguf::{export_tensors_to_gguf, GgmlType, GgufTensor, GgufValue};
1654 use std::fs::File;
1655 use std::io::BufWriter;
1656
1657 let gguf_tensors: Vec<GgufTensor> = tensors
1659 .iter()
1660 .map(|(name, (data, shape))| {
1661 let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
1663
1664 GgufTensor {
1665 name: name.clone(),
1666 shape: shape.iter().map(|&d| d as u64).collect(),
1667 dtype: GgmlType::F32,
1668 data: bytes,
1669 }
1670 })
1671 .collect();
1672
1673 let metadata = vec![
1675 (
1676 "general.name".to_string(),
1677 GgufValue::String("model".to_string()),
1678 ),
1679 (
1680 "general.quantization_version".to_string(),
1681 GgufValue::Uint32(1),
1682 ),
1683 ];
1684
1685 let file = File::create(output).map_err(|e| AprenderError::FormatError {
1687 message: format!("Failed to create output file: {e}"),
1688 })?;
1689 let mut writer = BufWriter::new(file);
1690
1691 export_tensors_to_gguf(&mut writer, &gguf_tensors, &metadata)
1692}
1693
1694#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1700pub enum MergeStrategy {
1701 Average,
1703 Weighted,
1705 Ties,
1707 Dare,
1709 Slerp,
1711}
1712
1713impl std::str::FromStr for MergeStrategy {
1714 type Err = String;
1715
1716 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
1717 match s.to_lowercase().as_str() {
1718 "average" | "avg" => Ok(Self::Average),
1719 "weighted" => Ok(Self::Weighted),
1720 "ties" => Ok(Self::Ties),
1721 "dare" => Ok(Self::Dare),
1722 "slerp" => Ok(Self::Slerp),
1723 _ => Err(format!("Unknown merge strategy: {s}")),
1724 }
1725 }
1726}
1727
1728impl MergeStrategy {
1729 #[must_use]
1731 pub fn is_supported(&self) -> bool {
1732 matches!(self, Self::Average | Self::Weighted)
1733 }
1734}
1735
1736#[derive(Debug, Clone)]
1738pub struct MergeOptions {
1739 pub strategy: MergeStrategy,
1741 pub weights: Option<Vec<f32>>,
1743}
1744
1745impl Default for MergeOptions {
1746 fn default() -> Self {
1747 Self {
1748 strategy: MergeStrategy::Average,
1749 weights: None,
1750 }
1751 }
1752}
1753
1754#[derive(Debug, Clone)]
1756pub struct MergeReport {
1757 pub model_count: usize,
1759 pub tensor_count: usize,
1761 pub output_size: usize,
1763 pub strategy: MergeStrategy,
1765 pub weights_used: Option<Vec<f32>>,
1767}
1768
1769fn validate_merge_options<P: AsRef<Path>>(inputs: &[P], options: &MergeOptions) -> Result<()> {
1775 if inputs.len() < 2 {
1776 return Err(AprenderError::FormatError {
1777 message: "Merge requires at least 2 input models".to_string(),
1778 });
1779 }
1780
1781 if !options.strategy.is_supported() {
1782 return Err(AprenderError::FormatError {
1783 message: format!(
1784 "Merge strategy {:?} is not yet supported. Use 'average' or 'weighted'.",
1785 options.strategy
1786 ),
1787 });
1788 }
1789
1790 if options.strategy == MergeStrategy::Weighted {
1791 match &options.weights {
1792 Some(weights) if weights.len() != inputs.len() => {
1793 return Err(AprenderError::FormatError {
1794 message: format!(
1795 "Weighted merge requires {} weights, got {}",
1796 inputs.len(),
1797 weights.len()
1798 ),
1799 });
1800 }
1801 None => {
1802 return Err(AprenderError::FormatError {
1803 message: "Weighted merge requires weights to be specified".to_string(),
1804 });
1805 }
1806 _ => {}
1807 }
1808 }
1809 Ok(())
1810}
1811
1812fn load_all_models<P: AsRef<Path>>(
1814 inputs: &[P],
1815) -> Result<Vec<BTreeMap<String, (Vec<f32>, Vec<usize>)>>> {
1816 let mut all_tensors = Vec::new();
1817 for input_path in inputs {
1818 let path = input_path.as_ref();
1819 if !path.exists() {
1820 return Err(AprenderError::FormatError {
1821 message: format!("Input file not found: {}", path.display()),
1822 });
1823 }
1824 all_tensors.push(load_model_tensors(path)?);
1825 }
1826 Ok(all_tensors)
1827}
1828
1829fn verify_tensor_compatibility(
1831 all_tensors: &[BTreeMap<String, (Vec<f32>, Vec<usize>)>],
1832) -> Result<()> {
1833 let reference = &all_tensors[0];
1834 for (i, tensors) in all_tensors.iter().enumerate().skip(1) {
1835 if tensors.len() != reference.len() {
1836 return Err(AprenderError::FormatError {
1837 message: format!(
1838 "Model {} has {} tensors, but model 0 has {}",
1839 i,
1840 tensors.len(),
1841 reference.len()
1842 ),
1843 });
1844 }
1845 verify_single_model_tensors(reference, tensors, i)?;
1846 }
1847 Ok(())
1848}
1849
1850fn verify_single_model_tensors(
1852 reference: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
1853 tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
1854 model_idx: usize,
1855) -> Result<()> {
1856 for (name, (_, shape)) in reference {
1857 match tensors.get(name) {
1858 None => {
1859 return Err(AprenderError::FormatError {
1860 message: format!("Model {} is missing tensor '{}'", model_idx, name),
1861 });
1862 }
1863 Some((_, other_shape)) if other_shape != shape => {
1864 return Err(AprenderError::FormatError {
1865 message: format!(
1866 "Tensor '{}' has shape {:?} in model 0 but {:?} in model {}",
1867 name, shape, other_shape, model_idx
1868 ),
1869 });
1870 }
1871 _ => {}
1872 }
1873 }
1874 Ok(())
1875}
1876
1877fn calculate_merge_weights(input_count: usize, options: &MergeOptions) -> Result<Vec<f32>> {
1879 match options.strategy {
1880 MergeStrategy::Average => {
1881 let w = 1.0 / input_count as f32;
1882 Ok(vec![w; input_count])
1883 }
1884 MergeStrategy::Weighted => {
1885 let raw_weights = options.weights.as_ref().expect("validated above");
1886 let sum: f32 = raw_weights.iter().sum();
1887 if sum <= 0.0 {
1888 return Err(AprenderError::FormatError {
1889 message: "Weights must sum to a positive value".to_string(),
1890 });
1891 }
1892 Ok(raw_weights.iter().map(|w| w / sum).collect())
1893 }
1894 _ => unreachable!("unsupported strategies filtered above"),
1895 }
1896}
1897
1898fn merge_tensors(
1900 all_tensors: &[BTreeMap<String, (Vec<f32>, Vec<usize>)>],
1901 weights: &[f32],
1902) -> BTreeMap<String, (Vec<f32>, Vec<usize>)> {
1903 let reference = &all_tensors[0];
1904 let mut merged = BTreeMap::new();
1905
1906 for (name, (_, shape)) in reference {
1907 let data_len = all_tensors[0].get(name).map(|(d, _)| d.len()).unwrap_or(0);
1908 let mut merged_data = vec![0.0f32; data_len];
1909
1910 for (model_idx, model_tensors) in all_tensors.iter().enumerate() {
1911 let (data, _) = model_tensors.get(name).expect("validated above");
1912 let weight = weights[model_idx];
1913 for (i, &val) in data.iter().enumerate() {
1914 merged_data[i] += val * weight;
1915 }
1916 }
1917
1918 merged.insert(name.clone(), (merged_data, shape.clone()));
1919 }
1920 merged
1921}
1922
1923pub fn apr_merge<P: AsRef<Path>>(
1955 inputs: &[P],
1956 output: P,
1957 options: MergeOptions,
1958) -> Result<MergeReport> {
1959 validate_merge_options(inputs, &options)?;
1961
1962 let all_tensors = load_all_models(inputs)?;
1964
1965 verify_tensor_compatibility(&all_tensors)?;
1967
1968 let weights = calculate_merge_weights(inputs.len(), &options)?;
1970
1971 let merged = merge_tensors(&all_tensors, &weights);
1973
1974 let output_path = output.as_ref();
1976 save_safetensors(output_path, &merged).map_err(|e| AprenderError::FormatError {
1977 message: format!("Failed to save merged model: {e}"),
1978 })?;
1979
1980 let output_size = fs::metadata(output_path)
1982 .map(|m| m.len() as usize)
1983 .unwrap_or(0);
1984
1985 Ok(MergeReport {
1986 model_count: inputs.len(),
1987 tensor_count: merged.len(),
1988 output_size,
1989 strategy: options.strategy,
1990 weights_used: Some(weights),
1991 })
1992}
1993
1994#[cfg(test)]
1999mod tests_source_parsing {
2000 use super::*;
2001
2002 #[test]
2003 fn test_parse_hf_org_repo() {
2004 let source = Source::parse("hf://openai/whisper-tiny").unwrap();
2005 assert_eq!(
2006 source,
2007 Source::HuggingFace {
2008 org: "openai".to_string(),
2009 repo: "whisper-tiny".to_string(),
2010 file: None,
2011 }
2012 );
2013 }
2014
2015 #[test]
2016 fn test_parse_hf_org_repo_file() {
2017 let source = Source::parse("hf://openai/whisper-tiny/model.safetensors").unwrap();
2018 assert_eq!(
2019 source,
2020 Source::HuggingFace {
2021 org: "openai".to_string(),
2022 repo: "whisper-tiny".to_string(),
2023 file: Some("model.safetensors".to_string()),
2024 }
2025 );
2026 }
2027
2028 #[test]
2029 fn test_parse_hf_nested_file() {
2030 let source =
2031 Source::parse("hf://meta-llama/Llama-2-7b/pytorch_model-00001-of-00002.bin").unwrap();
2032 assert_eq!(
2033 source,
2034 Source::HuggingFace {
2035 org: "meta-llama".to_string(),
2036 repo: "Llama-2-7b".to_string(),
2037 file: Some("pytorch_model-00001-of-00002.bin".to_string()),
2038 }
2039 );
2040 }
2041
2042 #[test]
2043 fn test_parse_local_path() {
2044 let source = Source::parse("./models/model.safetensors").unwrap();
2045 assert_eq!(
2046 source,
2047 Source::Local(PathBuf::from("./models/model.safetensors"))
2048 );
2049 }
2050
2051 #[test]
2052 fn test_parse_url() {
2053 let source = Source::parse("https://example.com/model.safetensors").unwrap();
2054 assert_eq!(
2055 source,
2056 Source::Url("https://example.com/model.safetensors".to_string())
2057 );
2058 }
2059
2060 #[test]
2061 fn test_parse_hf_invalid() {
2062 let result = Source::parse("hf://invalid");
2063 assert!(result.is_err());
2064 }
2065
2066 #[test]
2067 fn test_default_file() {
2068 let hf = Source::HuggingFace {
2069 org: "openai".to_string(),
2070 repo: "whisper".to_string(),
2071 file: None,
2072 };
2073 assert_eq!(hf.default_file(), "model.safetensors");
2074
2075 let hf_with_file = Source::HuggingFace {
2076 org: "openai".to_string(),
2077 repo: "whisper".to_string(),
2078 file: Some("custom.safetensors".to_string()),
2079 };
2080 assert_eq!(hf_with_file.default_file(), "custom.safetensors");
2081 }
2082}
2083
2084#[cfg(test)]
2085mod tests_name_mapping {
2086 use super::*;
2087
2088 #[test]
2089 fn test_whisper_strip_model_prefix() {
2090 let mapped = Architecture::Whisper.map_name("model.encoder.conv1.weight");
2091 assert_eq!(mapped, "encoder.conv1.weight");
2092 }
2093
2094 #[test]
2095 fn test_whisper_no_prefix() {
2096 let mapped = Architecture::Whisper.map_name("encoder.conv1.weight");
2097 assert_eq!(mapped, "encoder.conv1.weight");
2098 }
2099
2100 #[test]
2101 fn test_whisper_decoder_layer_norm() {
2102 let mapped = Architecture::Whisper.map_name("model.decoder.layer_norm.weight");
2103 assert_eq!(mapped, "decoder.layer_norm.weight");
2104 }
2105
2106 #[test]
2107 fn test_auto_strips_model_prefix() {
2108 let mapped = Architecture::Auto.map_name("model.encoder.layers.0.self_attn.q_proj.weight");
2109 assert_eq!(mapped, "encoder.layers.0.self_attn.q_proj.weight");
2110 }
2111
2112 #[test]
2113 fn test_llama_mapping() {
2114 let mapped = Architecture::Llama.map_name("model.layers.0.self_attn.q_proj.weight");
2115 assert_eq!(mapped, "layers.0.self_attn.q_proj.weight");
2116 }
2117
2118 #[test]
2119 fn test_bert_mapping() {
2120 let mapped =
2121 Architecture::Bert.map_name("bert.encoder.layer.0.attention.self.query.weight");
2122 assert_eq!(mapped, "encoder.layer.0.attention.self.query.weight");
2123 }
2124}
2125
2126#[cfg(test)]
2127mod tests_tensor_expectations {
2128 use super::*;
2129
2130 #[test]
2131 fn test_layer_norm_weight_expectation() {
2132 let exp = TensorExpectation::for_tensor("encoder.layer_norm.weight");
2133 assert!(exp.is_some());
2134 let exp = exp.unwrap();
2135 assert_eq!(exp.mean_range, (0.5, 3.0));
2136 }
2137
2138 #[test]
2139 fn test_layer_norm_bias_expectation() {
2140 let exp = TensorExpectation::for_tensor("decoder.layers.0.self_attn_layer_norm.bias");
2141 assert!(exp.is_some());
2142 let exp = exp.unwrap();
2143 assert_eq!(exp.mean_range, (-0.5, 0.5));
2144 }
2145
2146 #[test]
2147 fn test_linear_weight_expectation() {
2148 let exp = TensorExpectation::for_tensor("encoder.layers.0.fc1.weight");
2149 assert!(exp.is_some());
2150 let exp = exp.unwrap();
2151 assert_eq!(exp.mean_range, (-0.1, 0.1));
2152 }
2153
2154 #[test]
2155 fn test_embedding_expectation() {
2156 let exp = TensorExpectation::for_tensor("decoder.embed_tokens.weight");
2157 assert!(exp.is_some());
2158 }
2159
2160 #[test]
2161 fn test_check_layer_norm_valid() {
2162 let stats = TensorStats {
2163 name: "encoder.layer_norm.weight".to_string(),
2164 count: 384,
2165 min: 0.5,
2166 max: 2.0,
2167 mean: 1.0,
2168 std: 0.3,
2169 nan_count: 0,
2170 inf_count: 0,
2171 zero_count: 0,
2172 };
2173
2174 let exp = TensorExpectation::LAYER_NORM_WEIGHT;
2175 assert!(exp.check(&stats).is_ok());
2176 }
2177
2178 #[test]
2179 fn test_check_layer_norm_invalid_mean() {
2180 let stats = TensorStats {
2181 name: "decoder.layer_norm.weight".to_string(),
2182 count: 384,
2183 min: 5.0,
2184 max: 15.0,
2185 mean: 11.0, std: 2.0,
2187 nan_count: 0,
2188 inf_count: 0,
2189 zero_count: 0,
2190 };
2191
2192 let exp = TensorExpectation::LAYER_NORM_WEIGHT;
2193 let result = exp.check(&stats);
2194 assert!(result.is_err());
2195
2196 let err = result.unwrap_err().to_string();
2197 assert!(err.contains("mean=11"));
2198 assert!(err.contains("outside expected range"));
2199 }
2200}
2201
2202#[cfg(test)]
2203mod tests_converter_builder {
2204 use super::*;
2205
2206 #[test]
2207 fn test_converter_builder_chain() {
2208 let converter = AprConverter::new()
2209 .source("hf://openai/whisper-tiny")
2210 .unwrap()
2211 .architecture(Architecture::Whisper)
2212 .validate(ValidationConfig::Strict)
2213 .quantize(QuantizationType::Int8)
2214 .compress(Compression::Lz4);
2215
2216 assert_eq!(converter.architecture, Architecture::Whisper);
2217 assert_eq!(converter.validation, ValidationConfig::Strict);
2218 assert_eq!(converter.quantize, Some(QuantizationType::Int8));
2219 assert_eq!(converter.compress, Some(Compression::Lz4));
2220 }
2221
2222 #[test]
2223 fn test_converter_no_source_error() {
2224 let converter = AprConverter::new();
2225 let result = converter.convert();
2226 assert!(result.is_err());
2227 }
2228}
2229
2230#[cfg(test)]
2231mod tests_import_options {
2232 use super::*;
2233
2234 #[test]
2235 fn test_default_options() {
2236 let opts = ImportOptions::default();
2237 assert_eq!(opts.architecture, Architecture::Auto);
2238 assert_eq!(opts.validation, ValidationConfig::Strict);
2239 assert_eq!(opts.quantize, None);
2240 assert_eq!(opts.compress, None);
2241 assert!(!opts.force);
2242 assert!(opts.cache);
2243 }
2244}
2245
2246#[cfg(test)]
2247mod tests_conversion {
2248 use super::*;
2249
2250 fn create_test_safetensors(path: &Path, tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>) {
2251 save_safetensors(path, tensors).expect("Failed to create test SafeTensors file");
2252 }
2253
2254 #[test]
2255 fn test_convert_valid_safetensors() {
2256 let input = "/tmp/test_valid_input.safetensors";
2257 let output = "/tmp/test_valid_output.apr";
2258
2259 let mut tensors = BTreeMap::new();
2261 tensors.insert(
2262 "encoder.layer_norm.weight".to_string(),
2263 (vec![1.0f32; 384], vec![384]),
2264 );
2265 tensors.insert(
2266 "encoder.layer_norm.bias".to_string(),
2267 (vec![0.0f32; 384], vec![384]),
2268 );
2269 tensors.insert(
2270 "encoder.conv1.weight".to_string(),
2271 (vec![0.01f32; 1000], vec![80, 1, 3]),
2272 );
2273
2274 create_test_safetensors(Path::new(input), &tensors);
2275
2276 let options = ImportOptions::default();
2278 let result = apr_import(input, output, options);
2279
2280 assert!(
2281 result.is_ok(),
2282 "Valid tensors should convert successfully: {:?}",
2283 result.err()
2284 );
2285 let report = result.unwrap();
2286 assert!(report.total_score > 0, "Score should be > 0");
2287
2288 fs::remove_file(input).ok();
2290 fs::remove_file(output).ok();
2291 }
2292
2293 #[test]
2294 fn test_convert_invalid_layernorm_fails_strict() {
2295 let input = "/tmp/test_invalid_ln_input.safetensors";
2296 let output = "/tmp/test_invalid_ln_output.apr";
2297
2298 let mut tensors = BTreeMap::new();
2300 tensors.insert(
2301 "decoder.layer_norm.weight".to_string(),
2302 (vec![11.0f32; 384], vec![384]),
2303 );
2304
2305 create_test_safetensors(Path::new(input), &tensors);
2306
2307 let options = ImportOptions {
2309 validation: ValidationConfig::Strict,
2310 force: false,
2311 ..Default::default()
2312 };
2313 let result = apr_import(input, output, options);
2314
2315 assert!(
2316 result.is_err(),
2317 "Invalid LayerNorm should fail strict validation"
2318 );
2319 let err = result.unwrap_err().to_string();
2320 assert!(
2321 err.contains("mean=11") || err.contains("LayerNorm"),
2322 "Error should mention LayerNorm issue: {err}"
2323 );
2324
2325 fs::remove_file(input).ok();
2327 fs::remove_file(output).ok();
2328 }
2329
2330 #[test]
2331 fn test_convert_invalid_layernorm_force_succeeds() {
2332 let input = "/tmp/test_force_ln_input.safetensors";
2333 let output = "/tmp/test_force_ln_output.apr";
2334
2335 let mut tensors = BTreeMap::new();
2337 tensors.insert(
2338 "decoder.layer_norm.weight".to_string(),
2339 (vec![11.0f32; 384], vec![384]),
2340 );
2341
2342 create_test_safetensors(Path::new(input), &tensors);
2343
2344 let options = ImportOptions {
2346 validation: ValidationConfig::Strict,
2347 force: true,
2348 ..Default::default()
2349 };
2350 let result = apr_import(input, output, options);
2351
2352 assert!(
2353 result.is_ok(),
2354 "Force should bypass validation: {:?}",
2355 result.err()
2356 );
2357
2358 fs::remove_file(input).ok();
2360 fs::remove_file(output).ok();
2361 }
2362
2363 #[test]
2364 fn test_convert_nan_fails() {
2365 let input = "/tmp/test_nan_input.safetensors";
2366 let output = "/tmp/test_nan_output.apr";
2367
2368 let mut tensors = BTreeMap::new();
2370 tensors.insert(
2371 "test.weight".to_string(),
2372 (vec![1.0, f32::NAN, 3.0], vec![3]),
2373 );
2374
2375 create_test_safetensors(Path::new(input), &tensors);
2376
2377 let options = ImportOptions::default();
2378 let result = apr_import(input, output, options);
2379
2380 assert!(result.is_err(), "NaN should fail validation");
2381 let err = result.unwrap_err().to_string();
2382 assert!(err.contains("NaN"), "Error should mention NaN: {err}");
2383
2384 fs::remove_file(input).ok();
2386 fs::remove_file(output).ok();
2387 }
2388
2389 #[test]
2390 fn test_convert_nonexistent_file() {
2391 let result = apr_import(
2392 "/tmp/nonexistent_model.safetensors",
2393 "/tmp/out.apr",
2394 ImportOptions::default(),
2395 );
2396 assert!(result.is_err(), "Nonexistent file should fail");
2397 let err = result.unwrap_err().to_string();
2398 assert!(
2399 err.contains("not found") || err.contains("No such file"),
2400 "Error should mention file not found: {err}"
2401 );
2402 }
2403
2404 #[test]
2405 fn test_convert_unsupported_format() {
2406 let input = "/tmp/test_bad_format.gguf";
2407 fs::write(input, b"test").expect("Failed to create test file");
2408
2409 let result = apr_import(input, "/tmp/out.apr", ImportOptions::default());
2410 assert!(result.is_err(), "Unsupported format should fail");
2411 let err = result.unwrap_err().to_string();
2412 assert!(
2413 err.contains("GGUF") || err.contains("not yet"),
2414 "Error should mention unsupported: {err}"
2415 );
2416
2417 fs::remove_file(input).ok();
2418 }
2419
2420 #[test]
2421 fn test_name_mapping_whisper() {
2422 use crate::format::v2::AprV2Reader;
2423
2424 let input = "/tmp/test_whisper_input.safetensors";
2425 let output = "/tmp/test_whisper_output.apr";
2426
2427 let mut tensors = BTreeMap::new();
2429 tensors.insert(
2430 "model.encoder.conv1.weight".to_string(),
2431 (vec![0.01f32; 100], vec![10, 10]),
2432 );
2433 tensors.insert(
2434 "model.decoder.layer_norm.weight".to_string(),
2435 (vec![1.0f32; 384], vec![384]),
2436 );
2437
2438 create_test_safetensors(Path::new(input), &tensors);
2439
2440 let options = ImportOptions {
2441 architecture: Architecture::Whisper,
2442 ..Default::default()
2443 };
2444 let result = apr_import(input, output, options);
2445 assert!(
2446 result.is_ok(),
2447 "Whisper mapping should work: {:?}",
2448 result.err()
2449 );
2450
2451 let data = fs::read(output).expect("Failed to read output");
2453 let reader = AprV2Reader::from_bytes(&data).expect("Failed to parse APR v2");
2454 let tensor_names = reader.tensor_names();
2455
2456 assert!(
2457 tensor_names.contains(&"encoder.conv1.weight"),
2458 "Should strip 'model.' prefix, got: {:?}",
2459 tensor_names
2460 );
2461 assert!(
2462 tensor_names.contains(&"decoder.layer_norm.weight"),
2463 "Should strip 'model.' prefix, got: {:?}",
2464 tensor_names
2465 );
2466
2467 fs::remove_file(input).ok();
2469 fs::remove_file(output).ok();
2470 }
2471}
2472
2473#[cfg(test)]
2474mod tests_tensor_stats {
2475 use super::*;
2476
2477 #[test]
2478 fn test_compute_stats_basic() {
2479 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
2480 let stats = compute_tensor_stats("test", &data);
2481
2482 assert_eq!(stats.count, 5);
2483 assert!((stats.mean - 3.0).abs() < 0.001, "Mean should be 3.0");
2484 assert_eq!(stats.min, 1.0);
2485 assert_eq!(stats.max, 5.0);
2486 assert_eq!(stats.nan_count, 0);
2487 assert_eq!(stats.inf_count, 0);
2488 }
2489
2490 #[test]
2491 fn test_compute_stats_with_nan() {
2492 let data = vec![1.0f32, f32::NAN, 3.0];
2493 let stats = compute_tensor_stats("test", &data);
2494
2495 assert_eq!(stats.nan_count, 1);
2496 assert_eq!(stats.count, 3);
2497 assert!(
2499 (stats.mean - 2.0).abs() < 0.001,
2500 "Mean should be 2.0 (from valid values)"
2501 );
2502 }
2503
2504 #[test]
2505 fn test_compute_stats_with_inf() {
2506 let data = vec![1.0f32, f32::INFINITY, f32::NEG_INFINITY, 3.0];
2507 let stats = compute_tensor_stats("test", &data);
2508
2509 assert_eq!(stats.inf_count, 2);
2510 assert!(
2511 (stats.mean - 2.0).abs() < 0.001,
2512 "Mean should be 2.0 (from valid values)"
2513 );
2514 }
2515
2516 #[test]
2517 fn test_compute_stats_empty() {
2518 let data: Vec<f32> = vec![];
2519 let stats = compute_tensor_stats("test", &data);
2520
2521 assert_eq!(stats.count, 0);
2522 assert_eq!(stats.mean, 0.0);
2523 assert_eq!(stats.std, 0.0);
2524 }
2525
2526 #[test]
2527 fn test_compute_stats_all_zeros() {
2528 let data = vec![0.0f32; 100];
2529 let stats = compute_tensor_stats("test", &data);
2530
2531 assert_eq!(stats.zero_count, 100);
2532 assert_eq!(stats.mean, 0.0);
2533 }
2534}
2535
2536#[cfg(test)]
2537mod tests_quantization {
2538 use super::*;
2539
2540 #[test]
2541 fn test_quantize_int8_basic() {
2542 let data = vec![1.0f32, -1.0, 0.5, -0.5, 0.0];
2543 let quantized = quantize_int8(&data);
2544
2545 assert_eq!(quantized.len(), data.len());
2546 for (orig, quant) in data.iter().zip(quantized.iter()) {
2548 assert!((orig - quant).abs() < 0.02, "Quantization error too large");
2549 }
2550 }
2551
2552 #[test]
2553 fn test_quantize_int8_preserves_zeros() {
2554 let data = vec![0.0f32; 10];
2555 let quantized = quantize_int8(&data);
2556 assert!(
2557 quantized.iter().all(|&v| v == 0.0),
2558 "Zeros should remain zeros"
2559 );
2560 }
2561
2562 #[test]
2563 fn test_quantize_int8_empty() {
2564 let data: Vec<f32> = vec![];
2565 let quantized = quantize_int8(&data);
2566 assert!(quantized.is_empty());
2567 }
2568
2569 #[test]
2570 fn test_quantize_int4_basic() {
2571 let data = vec![1.0f32, -1.0, 0.5, -0.5, 0.0];
2572 let quantized = quantize_int4(&data);
2573
2574 assert_eq!(quantized.len(), data.len());
2575 for (orig, quant) in data.iter().zip(quantized.iter()) {
2577 assert!(
2578 (orig - quant).abs() < 0.2,
2579 "Int4 quantization error too large"
2580 );
2581 }
2582 }
2583
2584 #[test]
2585 fn test_quantize_fp16_basic() {
2586 let data = vec![1.0f32, -1.0, 0.5, -0.5, 0.0, 0.123456789];
2587 let quantized = quantize_fp16(&data);
2588
2589 assert_eq!(quantized.len(), data.len());
2590 assert_eq!(quantized[0], 1.0);
2592 assert_eq!(quantized[1], -1.0);
2593 assert_eq!(quantized[4], 0.0);
2594 }
2595
2596 #[test]
2597 fn test_quantize_tensors_int8() {
2598 let mut tensors = BTreeMap::new();
2599 tensors.insert("test".to_string(), (vec![1.0f32, -1.0, 0.5], vec![3]));
2600
2601 let result = quantize_tensors(&tensors, &QuantizationType::Int8).unwrap();
2602
2603 assert_eq!(result.len(), 1);
2604 assert!(result.contains_key("test"));
2605 let (data, shape) = result.get("test").unwrap();
2606 assert_eq!(shape, &vec![3]);
2607 assert_eq!(data.len(), 3);
2608 }
2609}
2610
2611#[cfg(test)]
2612mod tests_convert {
2613 use super::*;
2614
2615 fn create_test_model(path: &Path) {
2616 let mut tensors = BTreeMap::new();
2617 tensors.insert(
2618 "encoder.weight".to_string(),
2619 (vec![0.01f32; 1000], vec![100, 10]),
2620 );
2621 tensors.insert("encoder.bias".to_string(), (vec![0.0f32; 100], vec![100]));
2622 tensors.insert(
2623 "decoder.weight".to_string(),
2624 (vec![0.02f32; 500], vec![50, 10]),
2625 );
2626 save_safetensors(path, &tensors).expect("Failed to create test model");
2627 }
2628
2629 #[test]
2630 fn test_convert_no_quantization() {
2631 let input = Path::new("/tmp/test_convert_input.safetensors");
2632 let output = Path::new("/tmp/test_convert_output.apr");
2633
2634 create_test_model(input);
2635
2636 let options = ConvertOptions::default();
2637 let result = apr_convert(input, output, options);
2638
2639 assert!(
2640 result.is_ok(),
2641 "Convert without quantization should work: {:?}",
2642 result.err()
2643 );
2644 let report = result.unwrap();
2645 assert_eq!(report.tensor_count, 3);
2646 assert!(report.quantization.is_none());
2647
2648 fs::remove_file(input).ok();
2649 fs::remove_file(output).ok();
2650 }
2651
2652 #[test]
2653 fn test_convert_with_int8_quantization() {
2654 let input = Path::new("/tmp/test_convert_int8_input.safetensors");
2655 let output = Path::new("/tmp/test_convert_int8_output.apr");
2656
2657 create_test_model(input);
2658
2659 let options = ConvertOptions {
2660 quantize: Some(QuantizationType::Int8),
2661 ..Default::default()
2662 };
2663 let result = apr_convert(input, output, options);
2664
2665 assert!(
2666 result.is_ok(),
2667 "Int8 quantization should work: {:?}",
2668 result.err()
2669 );
2670 let report = result.unwrap();
2671 assert_eq!(report.quantization, Some(QuantizationType::Int8));
2672 assert_eq!(report.tensor_count, 3);
2673
2674 fs::remove_file(input).ok();
2675 fs::remove_file(output).ok();
2676 }
2677
2678 #[test]
2679 fn test_convert_with_fp16_quantization() {
2680 let input = Path::new("/tmp/test_convert_fp16_input.safetensors");
2681 let output = Path::new("/tmp/test_convert_fp16_output.apr");
2682
2683 create_test_model(input);
2684
2685 let options = ConvertOptions {
2686 quantize: Some(QuantizationType::Fp16),
2687 ..Default::default()
2688 };
2689 let result = apr_convert(input, output, options);
2690
2691 assert!(
2692 result.is_ok(),
2693 "FP16 quantization should work: {:?}",
2694 result.err()
2695 );
2696
2697 fs::remove_file(input).ok();
2698 fs::remove_file(output).ok();
2699 }
2700
2701 #[test]
2702 fn test_convert_nonexistent_file() {
2703 let options = ConvertOptions::default();
2704 let result = apr_convert("/tmp/nonexistent.safetensors", "/tmp/out.apr", options);
2705
2706 assert!(result.is_err(), "Nonexistent file should fail");
2707 }
2708
2709 #[test]
2710 fn test_convert_report_reduction_percent() {
2711 let report = ConvertReport {
2712 original_size: 1000,
2713 converted_size: 250,
2714 tensor_count: 5,
2715 quantization: Some(QuantizationType::Int8),
2716 compression: None,
2717 reduction_ratio: 4.0,
2718 };
2719
2720 assert_eq!(report.reduction_percent(), "75.0%");
2721 }
2722
2723 #[test]
2724 fn test_convert_options_default() {
2725 let options = ConvertOptions::default();
2726 assert!(options.quantize.is_none());
2727 assert!(options.compress.is_none());
2728 assert!(options.validate);
2729 }
2730}
2731
2732#[cfg(test)]
2737mod tests_sharded_import {
2738 use super::*;
2739
2740 #[test]
2741 fn test_sharded_index_parse_valid() {
2742 let json = r#"{
2743 "metadata": {"total_size": 1000000},
2744 "weight_map": {
2745 "encoder.conv1.weight": "model-00001-of-00002.safetensors",
2746 "encoder.conv2.weight": "model-00001-of-00002.safetensors",
2747 "decoder.fc.weight": "model-00002-of-00002.safetensors"
2748 }
2749 }"#;
2750
2751 let index = ShardedIndex::parse(json).expect("Valid index should parse");
2752 assert_eq!(index.shard_count(), 2);
2753 assert_eq!(index.tensor_count(), 3);
2754 assert!(index.total_size().is_some());
2755 }
2756
2757 #[test]
2758 fn test_sharded_index_shard_for_tensor() {
2759 let json = r#"{
2760 "weight_map": {
2761 "encoder.weight": "shard-00001.safetensors",
2762 "decoder.weight": "shard-00002.safetensors"
2763 }
2764 }"#;
2765
2766 let index = ShardedIndex::parse(json).unwrap();
2767 assert_eq!(
2768 index.shard_for_tensor("encoder.weight"),
2769 Some("shard-00001.safetensors")
2770 );
2771 assert_eq!(
2772 index.shard_for_tensor("decoder.weight"),
2773 Some("shard-00002.safetensors")
2774 );
2775 assert_eq!(index.shard_for_tensor("unknown"), None);
2776 }
2777
2778 #[test]
2779 fn test_sharded_index_tensors_in_shard() {
2780 let json = r#"{
2781 "weight_map": {
2782 "a": "shard1.safetensors",
2783 "b": "shard1.safetensors",
2784 "c": "shard2.safetensors"
2785 }
2786 }"#;
2787
2788 let index = ShardedIndex::parse(json).unwrap();
2789 let shard1_tensors = index.tensors_in_shard("shard1.safetensors");
2790 assert_eq!(shard1_tensors.len(), 2);
2791 assert!(shard1_tensors.contains(&"a"));
2792 assert!(shard1_tensors.contains(&"b"));
2793 }
2794
2795 #[test]
2796 fn test_sharded_index_parse_invalid_json() {
2797 let result = ShardedIndex::parse("not valid json");
2798 assert!(result.is_err());
2799 }
2800
2801 #[test]
2802 fn test_sharded_index_parse_missing_weight_map() {
2803 let result = ShardedIndex::parse(r#"{"metadata": {}}"#);
2804 assert!(result.is_err());
2805 }
2806
2807 #[test]
2808 fn test_detect_sharded_model_index_exists() {
2809 let dir = tempfile::tempdir().unwrap();
2811 let index_path = dir.path().join("model.safetensors.index.json");
2812 fs::write(&index_path, r#"{"weight_map": {"a": "shard.safetensors"}}"#).unwrap();
2813
2814 let result = detect_sharded_model(dir.path(), "model.safetensors");
2815 assert!(result.is_some());
2816 }
2817
2818 #[test]
2819 fn test_detect_sharded_model_single_file() {
2820 let dir = tempfile::tempdir().unwrap();
2821 let model_path = dir.path().join("model.safetensors");
2822 fs::write(&model_path, &[0u8; 8]).unwrap(); let result = detect_sharded_model(dir.path(), "model.safetensors");
2825 assert!(result.is_none(), "Single file should not be sharded");
2826 }
2827
2828 #[test]
2829 fn test_sharded_index_shard_files_sorted() {
2830 let json = r#"{
2831 "weight_map": {
2832 "a": "model-00002-of-00003.safetensors",
2833 "b": "model-00001-of-00003.safetensors",
2834 "c": "model-00003-of-00003.safetensors"
2835 }
2836 }"#;
2837
2838 let index = ShardedIndex::parse(json).unwrap();
2839 let shards = index.shard_files();
2840 assert_eq!(shards[0], "model-00001-of-00003.safetensors");
2841 assert_eq!(shards[1], "model-00002-of-00003.safetensors");
2842 assert_eq!(shards[2], "model-00003-of-00003.safetensors");
2843 }
2844}
2845
2846#[cfg(test)]
2851mod tests_import_errors {
2852 use super::*;
2853
2854 #[test]
2855 fn test_import_error_not_found_message() {
2856 let err = ImportError::NotFound {
2857 resource: "openai/whisper-tiny".to_string(),
2858 status: 404,
2859 };
2860 let msg = err.to_string();
2861 assert!(msg.contains("404"), "Should include status code");
2862 assert!(msg.contains("whisper-tiny"), "Should include resource");
2863 }
2864
2865 #[test]
2866 fn test_import_error_rate_limited_message() {
2867 let err = ImportError::RateLimited {
2868 retry_after: Some(60),
2869 };
2870 let msg = err.to_string();
2871 assert!(
2872 msg.to_lowercase().contains("rate"),
2873 "Should mention rate limit"
2874 );
2875 assert!(msg.contains("60"), "Should include retry time");
2876 }
2877
2878 #[test]
2879 fn test_import_error_auth_required_message() {
2880 let err = ImportError::AuthRequired {
2881 resource: "meta-llama/Llama-2-7b".to_string(),
2882 };
2883 let msg = err.to_string();
2884 assert!(msg.contains("HF_TOKEN"), "Should suggest HF_TOKEN");
2885 assert!(msg.contains("Llama-2-7b"), "Should include resource");
2886 }
2887
2888 #[test]
2889 fn test_import_error_actionable_suggestions() {
2890 let err = ImportError::NotFound {
2891 resource: "openai/whisper-tiny".to_string(),
2892 status: 404,
2893 };
2894
2895 let msg = err.to_string();
2897 assert!(
2898 msg.contains("Fix:") || msg.contains("check") || msg.contains("verify"),
2899 "Error should be actionable"
2900 );
2901 }
2902
2903 #[test]
2904 fn test_import_error_sharding_oom() {
2905 let err = ImportError::ShardingRequired {
2906 model_size: 14_000_000_000, shard_count: 7,
2908 };
2909 let msg = err.to_string();
2910 assert!(msg.contains("14"), "Should include size");
2911 assert!(msg.contains("7"), "Should include shard count");
2912 }
2913
2914 #[cfg(feature = "hf-hub-integration")]
2916 #[test]
2917 fn test_parse_import_error_404() {
2918 let err = parse_import_error(
2919 "HTTP 404: Repository not found",
2920 "openai/whisper-tiny",
2921 );
2922 match err {
2923 ImportError::NotFound { resource, status } => {
2924 assert_eq!(resource, "openai/whisper-tiny");
2925 assert_eq!(status, 404);
2926 }
2927 _ => panic!("Expected NotFound error, got {:?}", err),
2928 }
2929 }
2930
2931 #[cfg(feature = "hf-hub-integration")]
2932 #[test]
2933 fn test_parse_import_error_not_found_text() {
2934 let err = parse_import_error(
2935 "The requested resource does not exist",
2936 "test/model",
2937 );
2938 match err {
2939 ImportError::NotFound { .. } => {}
2940 _ => panic!("Expected NotFound error, got {:?}", err),
2941 }
2942 }
2943
2944 #[cfg(feature = "hf-hub-integration")]
2945 #[test]
2946 fn test_parse_import_error_401() {
2947 let err = parse_import_error(
2948 "HTTP 401: Unauthorized access",
2949 "meta-llama/Llama-2-7b",
2950 );
2951 match err {
2952 ImportError::AuthRequired { resource } => {
2953 assert_eq!(resource, "meta-llama/Llama-2-7b");
2954 }
2955 _ => panic!("Expected AuthRequired error, got {:?}", err),
2956 }
2957 }
2958
2959 #[cfg(feature = "hf-hub-integration")]
2960 #[test]
2961 fn test_parse_import_error_gated_model() {
2962 let err = parse_import_error(
2963 "This model is gated. Access requires acceptance.",
2964 "meta-llama/Llama-2-7b",
2965 );
2966 match err {
2967 ImportError::AuthRequired { .. } => {}
2968 _ => panic!("Expected AuthRequired error, got {:?}", err),
2969 }
2970 }
2971
2972 #[cfg(feature = "hf-hub-integration")]
2973 #[test]
2974 fn test_parse_import_error_429() {
2975 let err = parse_import_error(
2976 "HTTP 429: Too many requests. Retry after 60 seconds.",
2977 "test/model",
2978 );
2979 match err {
2980 ImportError::RateLimited { retry_after } => {
2981 assert_eq!(retry_after, Some(60));
2982 }
2983 _ => panic!("Expected RateLimited error, got {:?}", err),
2984 }
2985 }
2986
2987 #[cfg(feature = "hf-hub-integration")]
2988 #[test]
2989 fn test_parse_import_error_rate_limit_no_retry() {
2990 let err = parse_import_error(
2991 "Rate limit exceeded",
2992 "test/model",
2993 );
2994 match err {
2995 ImportError::RateLimited { retry_after } => {
2996 assert_eq!(retry_after, None);
2997 }
2998 _ => panic!("Expected RateLimited error, got {:?}", err),
2999 }
3000 }
3001
3002 #[cfg(feature = "hf-hub-integration")]
3003 #[test]
3004 fn test_parse_import_error_generic() {
3005 let err = parse_import_error(
3006 "Connection timeout",
3007 "test/model",
3008 );
3009 match err {
3010 ImportError::DownloadFailed { source, reason } => {
3011 assert_eq!(source, "test/model");
3012 assert_eq!(reason, "Connection timeout");
3013 }
3014 _ => panic!("Expected DownloadFailed error, got {:?}", err),
3015 }
3016 }
3017
3018 #[test]
3019 fn test_import_error_from_conversion() {
3020 let import_err = ImportError::NotFound {
3021 resource: "test".to_string(),
3022 status: 404,
3023 };
3024 let aprender_err: AprenderError = import_err.into();
3025 let msg = aprender_err.to_string();
3026 assert!(msg.contains("404"));
3027 assert!(msg.contains("test"));
3028 }
3029}