aprender/format/
converter.rs

1//! APR Converter Module - Import Pipeline
2//!
3//! Implements Section 13 of APR-SPEC.md: Import/Convert Pipeline
4//!
5//! Supports:
6//! - HuggingFace Hub downloads (hf://org/repo)
7//! - SafeTensors conversion
8//! - Inline validation during conversion
9//! - Quantization and compression
10
11use crate::error::{AprenderError, Result};
12use crate::format::v2::{AprV2Metadata, AprV2Writer};
13use crate::format::validation::{AprValidator, TensorStats, ValidationReport};
14use crate::format::Compression;
15use crate::serialization::safetensors::{extract_tensor, load_safetensors, save_safetensors};
16use std::collections::BTreeMap;
17use std::fs;
18use std::io::Write;
19use std::path::{Path, PathBuf};
20
21// HF Hub integration is used via hf_hub::api::sync::ApiBuilder in download_from_hf()
22
23// ============================================================================
24// Source Parsing
25// ============================================================================
26
27/// Parsed source location
28#[derive(Debug, Clone, PartialEq)]
29pub enum Source {
30    /// HuggingFace Hub: hf://org/repo or hf://org/repo/file.safetensors
31    HuggingFace {
32        org: String,
33        repo: String,
34        file: Option<String>,
35    },
36    /// Local file path
37    Local(PathBuf),
38    /// HTTP/HTTPS URL
39    Url(String),
40}
41
42impl Source {
43    /// Parse a source string into a Source enum
44    pub fn parse(source: &str) -> Result<Self> {
45        if source.starts_with("hf://") {
46            Self::parse_hf(source)
47        } else if source.starts_with("http://") || source.starts_with("https://") {
48            Ok(Self::Url(source.to_string()))
49        } else {
50            Ok(Self::Local(PathBuf::from(source)))
51        }
52    }
53
54    fn parse_hf(source: &str) -> Result<Self> {
55        let path = source.strip_prefix("hf://").unwrap_or(source);
56        let parts: Vec<&str> = path.split('/').collect();
57
58        if parts.len() < 2 {
59            return Err(AprenderError::FormatError {
60                message: format!("Invalid HuggingFace source: {source}. Expected hf://org/repo"),
61            });
62        }
63
64        let org = parts[0].to_string();
65        let repo = parts[1].to_string();
66        let file = if parts.len() > 2 {
67            Some(parts[2..].join("/"))
68        } else {
69            None
70        };
71
72        Ok(Self::HuggingFace { org, repo, file })
73    }
74
75    /// Get the default model file for this source
76    pub fn default_file(&self) -> &str {
77        match self {
78            Self::HuggingFace { file: Some(f), .. } => f,
79            Self::HuggingFace { file: None, .. } => "model.safetensors",
80            Self::Local(p) => p.to_str().unwrap_or("model.safetensors"),
81            Self::Url(u) => u.rsplit('/').next().unwrap_or("model.safetensors"),
82        }
83    }
84}
85
86// ============================================================================
87// Architecture / Name Mapping
88// ============================================================================
89
90/// Model architecture for tensor name mapping
91#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
92pub enum Architecture {
93    /// Auto-detect from tensor names
94    #[default]
95    Auto,
96    /// OpenAI Whisper
97    Whisper,
98    /// Meta LLaMA
99    Llama,
100    /// Google BERT
101    Bert,
102}
103
104impl Architecture {
105    /// Map a source tensor name to APR canonical name
106    pub fn map_name(&self, source_name: &str) -> String {
107        match self {
108            Self::Auto => Self::auto_map_name(source_name),
109            Self::Whisper => Self::whisper_map_name(source_name),
110            Self::Llama => Self::llama_map_name(source_name),
111            Self::Bert => Self::bert_map_name(source_name),
112        }
113    }
114
115    fn auto_map_name(name: &str) -> String {
116        // Strip common prefixes
117        let name = name.strip_prefix("model.").unwrap_or(name);
118        name.to_string()
119    }
120
121    fn whisper_map_name(name: &str) -> String {
122        // HuggingFace Whisper uses "model." prefix
123        let name = name.strip_prefix("model.").unwrap_or(name);
124        name.to_string()
125    }
126
127    fn llama_map_name(name: &str) -> String {
128        // LLaMA models use "model.layers." prefix
129        let name = name.strip_prefix("model.").unwrap_or(name);
130        name.to_string()
131    }
132
133    fn bert_map_name(name: &str) -> String {
134        // BERT uses "bert." prefix
135        let name = name.strip_prefix("bert.").unwrap_or(name);
136        name.to_string()
137    }
138}
139
140// ============================================================================
141// Tensor Expectations
142// ============================================================================
143
144/// Expected statistics for a tensor type
145#[derive(Debug, Clone)]
146pub struct TensorExpectation {
147    /// Expected mean range (min, max)
148    pub mean_range: (f32, f32),
149    /// Expected std range (min, max)
150    pub std_range: Option<(f32, f32)>,
151    /// Description for error messages
152    pub description: &'static str,
153}
154
155impl TensorExpectation {
156    /// LayerNorm weight: gamma initialized to ~1.0
157    pub const LAYER_NORM_WEIGHT: Self = Self {
158        mean_range: (0.5, 3.0),
159        std_range: Some((0.0, 2.0)),
160        description: "LayerNorm weight (gamma)",
161    };
162
163    /// LayerNorm bias: beta initialized to ~0.0
164    pub const LAYER_NORM_BIAS: Self = Self {
165        mean_range: (-0.5, 0.5),
166        std_range: Some((0.0, 1.0)),
167        description: "LayerNorm bias (beta)",
168    };
169
170    /// Linear/Attention weight: Xavier/He initialized, mean ~0
171    pub const LINEAR_WEIGHT: Self = Self {
172        mean_range: (-0.1, 0.1),
173        std_range: None,
174        description: "Linear/Attention weight",
175    };
176
177    /// Embedding weight: varies by initialization
178    pub const EMBEDDING: Self = Self {
179        mean_range: (-1.0, 1.0),
180        std_range: None,
181        description: "Embedding",
182    };
183
184    /// Get expectation for a tensor name
185    pub fn for_tensor(name: &str) -> Option<Self> {
186        if name.contains("layer_norm") || name.contains("ln_") {
187            if name.ends_with(".weight") || name.ends_with(".gamma") {
188                return Some(Self::LAYER_NORM_WEIGHT);
189            }
190            if name.ends_with(".bias") || name.ends_with(".beta") {
191                return Some(Self::LAYER_NORM_BIAS);
192            }
193        }
194
195        if name.contains("embed") {
196            return Some(Self::EMBEDDING);
197        }
198
199        if name.ends_with(".weight") {
200            return Some(Self::LINEAR_WEIGHT);
201        }
202
203        None
204    }
205
206    /// Check if stats match expectation
207    pub fn check(&self, stats: &TensorStats) -> Result<()> {
208        let (min_mean, max_mean) = self.mean_range;
209
210        if stats.mean < min_mean || stats.mean > max_mean {
211            return Err(AprenderError::FormatError {
212                message: format!(
213                    "{}: mean={:.4} outside expected range [{:.1}, {:.1}]",
214                    self.description, stats.mean, min_mean, max_mean
215                ),
216            });
217        }
218
219        if let Some((min_std, max_std)) = self.std_range {
220            if stats.std < min_std || stats.std > max_std {
221                return Err(AprenderError::FormatError {
222                    message: format!(
223                        "{}: std={:.4} outside expected range [{:.1}, {:.1}]",
224                        self.description, stats.std, min_std, max_std
225                    ),
226                });
227            }
228        }
229
230        Ok(())
231    }
232}
233
234// ============================================================================
235// Validation Config
236// ============================================================================
237
238/// Validation strictness configuration
239#[derive(Debug, Clone, Copy, PartialEq, Eq)]
240pub enum ValidationConfig {
241    /// No validation
242    None,
243    /// Basic checks (NaN, Inf only)
244    Basic,
245    /// Full statistical validation
246    Strict,
247}
248
249impl Default for ValidationConfig {
250    fn default() -> Self {
251        Self::Strict
252    }
253}
254
255impl ValidationConfig {
256    /// Create strict validation config
257    pub fn strict() -> Self {
258        Self::Strict
259    }
260}
261
262// ============================================================================
263// Import Options
264// ============================================================================
265
266/// Quantization type for import pipeline
267#[derive(Debug, Clone, Copy, PartialEq, Eq)]
268pub enum QuantizationType {
269    /// 8-bit integer quantization
270    Int8,
271    /// 4-bit integer quantization
272    Int4,
273    /// 16-bit float
274    Fp16,
275}
276
277/// Options for the import pipeline
278#[derive(Debug, Clone)]
279pub struct ImportOptions {
280    /// Target architecture for name mapping
281    pub architecture: Architecture,
282    /// Validation configuration
283    pub validation: ValidationConfig,
284    /// Quantization (None = keep original precision)
285    pub quantize: Option<QuantizationType>,
286    /// Compression algorithm
287    pub compress: Option<Compression>,
288    /// Force import even if validation fails
289    pub force: bool,
290    /// Cache downloaded files
291    pub cache: bool,
292}
293
294impl Default for ImportOptions {
295    fn default() -> Self {
296        Self {
297            architecture: Architecture::Auto,
298            validation: ValidationConfig::Strict,
299            quantize: None,
300            compress: None,
301            force: false,
302            cache: true,
303        }
304    }
305}
306
307// ============================================================================
308// Import Error
309// ============================================================================
310
311/// Import-specific errors (GH-129: actionable error messages)
312#[derive(Debug, Clone)]
313pub enum ImportError {
314    /// Download failed
315    DownloadFailed { source: String, reason: String },
316    /// Unsupported format
317    UnsupportedFormat { extension: String },
318    /// Tensor validation failed
319    ValidationFailed { name: String, reason: String },
320    /// Unknown tensor name
321    UnknownTensor { source_name: String },
322    /// Missing required tensor
323    MissingTensor { name: String },
324    /// Resource not found (404)
325    NotFound { resource: String, status: u16 },
326    /// Rate limited by server
327    RateLimited { retry_after: Option<u64> },
328    /// Authentication required (gated model)
329    AuthRequired { resource: String },
330    /// Model requires sharded loading (GH-127)
331    ShardingRequired { model_size: u64, shard_count: usize },
332}
333
334impl std::fmt::Display for ImportError {
335    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
336        match self {
337            Self::DownloadFailed { source, reason } => {
338                write!(f, "Download failed: {source} - {reason}")
339            }
340            Self::UnsupportedFormat { extension } => {
341                write!(f, "Unsupported format: {extension}")
342            }
343            Self::ValidationFailed { name, reason } => {
344                write!(f, "Tensor validation failed: {name} - {reason}")
345            }
346            Self::UnknownTensor { source_name } => {
347                write!(f, "Unknown tensor: {source_name}")
348            }
349            Self::MissingTensor { name } => {
350                write!(f, "Missing required tensor: {name}")
351            }
352            // GH-129: Actionable error messages
353            Self::NotFound { resource, status } => {
354                write!(
355                    f,
356                    "Resource not found ({status}): {resource}. \
357                     Fix: verify the model name exists on huggingface.co/models"
358                )
359            }
360            Self::RateLimited { retry_after } => {
361                if let Some(secs) = retry_after {
362                    write!(
363                        f,
364                        "Rate limited by server. Retry after {secs} seconds. \
365                         Fix: wait and retry, or use --cache to avoid re-downloads"
366                    )
367                } else {
368                    write!(
369                        f,
370                        "Rate limited by server. \
371                         Fix: wait a few minutes and retry"
372                    )
373                }
374            }
375            Self::AuthRequired { resource } => {
376                write!(
377                    f,
378                    "Authentication required for {resource}. \
379                     Fix: set HF_TOKEN environment variable with your HuggingFace token"
380                )
381            }
382            Self::ShardingRequired {
383                model_size,
384                shard_count,
385            } => {
386                let size_gb = *model_size as f64 / 1_000_000_000.0;
387                write!(
388                    f,
389                    "Model too large ({size_gb:.1} GB, {shard_count} shards) for single-file loading. \
390                     Fix: use streaming import with --sharded flag"
391                )
392            }
393        }
394    }
395}
396
397impl std::error::Error for ImportError {}
398
399impl From<ImportError> for AprenderError {
400    fn from(err: ImportError) -> Self {
401        AprenderError::FormatError {
402            message: err.to_string(),
403        }
404    }
405}
406
407/// Parse error message to detect specific error types (GH-129)
408#[cfg(feature = "hf-hub-integration")]
409fn parse_import_error(error_msg: &str, resource: &str) -> ImportError {
410    let msg_lower = error_msg.to_lowercase();
411
412    // Check for 404 / not found
413    if msg_lower.contains("404")
414        || msg_lower.contains("not found")
415        || msg_lower.contains("does not exist")
416        || msg_lower.contains("no such")
417    {
418        return ImportError::NotFound {
419            resource: resource.to_string(),
420            status: 404,
421        };
422    }
423
424    // Check for authentication / 401 / 403
425    if msg_lower.contains("401")
426        || msg_lower.contains("403")
427        || msg_lower.contains("unauthorized")
428        || msg_lower.contains("forbidden")
429        || msg_lower.contains("gated")
430        || msg_lower.contains("access denied")
431    {
432        return ImportError::AuthRequired {
433            resource: resource.to_string(),
434        };
435    }
436
437    // Check for rate limiting / 429
438    if msg_lower.contains("429")
439        || msg_lower.contains("rate limit")
440        || msg_lower.contains("too many requests")
441    {
442        // Try to extract retry-after
443        let retry_after = if let Some(pos) = msg_lower.find("retry") {
444            msg_lower[pos..]
445                .split_whitespace()
446                .find_map(|s| s.parse::<u64>().ok())
447        } else {
448            None
449        };
450        return ImportError::RateLimited { retry_after };
451    }
452
453    // Default to download failed
454    ImportError::DownloadFailed {
455        source: resource.to_string(),
456        reason: error_msg.to_string(),
457    }
458}
459
460// ============================================================================
461// GH-127: Sharded Model Support
462// ============================================================================
463
464/// Parsed sharded model index (model.safetensors.index.json)
465///
466/// HuggingFace uses this format for large models split across multiple shards.
467/// Example: Llama-2-7b has 2 shards, Llama-2-70b has 15 shards.
468#[derive(Debug, Clone)]
469pub struct ShardedIndex {
470    /// Map of tensor name → shard filename
471    weight_map: std::collections::HashMap<String, String>,
472    /// Optional total size in bytes
473    total_size: Option<u64>,
474}
475
476impl ShardedIndex {
477    /// Parse a sharded index from JSON string
478    ///
479    /// # Example JSON format
480    /// ```json
481    /// {
482    ///   "metadata": {"total_size": 14000000000},
483    ///   "weight_map": {
484    ///     "model.encoder.weight": "model-00001-of-00002.safetensors",
485    ///     "model.decoder.weight": "model-00002-of-00002.safetensors"
486    ///   }
487    /// }
488    /// ```
489    pub fn parse(json: &str) -> Result<Self> {
490        // Minimal JSON parsing without serde dependency
491        // Look for "weight_map" key and parse the object
492
493        let json = json.trim();
494        if !json.starts_with('{') || !json.ends_with('}') {
495            return Err(AprenderError::FormatError {
496                message: "Invalid JSON: expected object".to_string(),
497            });
498        }
499
500        // Find weight_map section
501        let weight_map_start =
502            json.find("\"weight_map\"")
503                .ok_or_else(|| AprenderError::FormatError {
504                    message: "Missing 'weight_map' key in index.json".to_string(),
505                })?;
506
507        // Parse weight_map object
508        let after_key = &json[weight_map_start + 12..]; // Skip "weight_map"
509        let obj_start = after_key
510            .find('{')
511            .ok_or_else(|| AprenderError::FormatError {
512                message: "Invalid weight_map: expected object".to_string(),
513            })?;
514
515        let obj_content = &after_key[obj_start..];
516        let mut weight_map = std::collections::HashMap::new();
517        let mut depth = 0;
518        let mut obj_end = 0;
519
520        for (i, c) in obj_content.char_indices() {
521            match c {
522                '{' => depth += 1,
523                '}' => {
524                    depth -= 1;
525                    if depth == 0 {
526                        obj_end = i;
527                        break;
528                    }
529                }
530                _ => {}
531            }
532        }
533
534        let inner = &obj_content[1..obj_end];
535
536        // Parse key-value pairs: "tensor_name": "shard_file"
537        for pair in inner.split(',') {
538            let pair = pair.trim();
539            if pair.is_empty() {
540                continue;
541            }
542
543            let parts: Vec<&str> = pair.splitn(2, ':').collect();
544            if parts.len() == 2 {
545                let key = parts[0].trim().trim_matches('"');
546                let val = parts[1].trim().trim_matches('"');
547                if !key.is_empty() && !val.is_empty() {
548                    weight_map.insert(key.to_string(), val.to_string());
549                }
550            }
551        }
552
553        // Parse optional total_size from metadata
554        let total_size = json.find("\"total_size\"").and_then(|pos| {
555            let after = &json[pos + 12..];
556            let colon = after.find(':')?;
557            let after_colon = after[colon + 1..].trim_start();
558            let end = after_colon.find(|c: char| !c.is_ascii_digit())?;
559            after_colon[..end].parse::<u64>().ok()
560        });
561
562        Ok(Self {
563            weight_map,
564            total_size,
565        })
566    }
567
568    /// Number of unique shard files
569    #[must_use]
570    pub fn shard_count(&self) -> usize {
571        let unique: std::collections::HashSet<_> = self.weight_map.values().collect();
572        unique.len()
573    }
574
575    /// Number of tensors in the index
576    #[must_use]
577    pub fn tensor_count(&self) -> usize {
578        self.weight_map.len()
579    }
580
581    /// Total model size in bytes (if available)
582    #[must_use]
583    pub fn total_size(&self) -> Option<u64> {
584        self.total_size
585    }
586
587    /// Get the shard file containing a specific tensor
588    #[must_use]
589    pub fn shard_for_tensor(&self, tensor_name: &str) -> Option<&str> {
590        self.weight_map.get(tensor_name).map(String::as_str)
591    }
592
593    /// Get all tensor names in a specific shard
594    #[must_use]
595    pub fn tensors_in_shard(&self, shard_file: &str) -> Vec<&str> {
596        self.weight_map
597            .iter()
598            .filter(|(_, v)| v.as_str() == shard_file)
599            .map(|(k, _)| k.as_str())
600            .collect()
601    }
602
603    /// Get sorted list of shard files
604    #[must_use]
605    pub fn shard_files(&self) -> Vec<&str> {
606        let mut files: Vec<_> = self
607            .weight_map
608            .values()
609            .map(String::as_str)
610            .collect::<std::collections::HashSet<_>>()
611            .into_iter()
612            .collect();
613        files.sort_unstable();
614        files
615    }
616}
617
618/// Detect if a model directory contains a sharded model
619///
620/// Checks for `model.safetensors.index.json` which indicates sharding.
621#[must_use]
622pub fn detect_sharded_model(dir: &Path, base_name: &str) -> Option<PathBuf> {
623    let index_name = format!("{base_name}.index.json");
624    let index_path = dir.join(&index_name);
625
626    if index_path.exists() {
627        Some(index_path)
628    } else {
629        None
630    }
631}
632
633// ============================================================================
634// Converter
635// ============================================================================
636
637/// APR Converter with builder pattern
638#[derive(Debug)]
639pub struct AprConverter {
640    source: Option<Source>,
641    architecture: Architecture,
642    validation: ValidationConfig,
643    quantize: Option<QuantizationType>,
644    compress: Option<Compression>,
645}
646
647impl AprConverter {
648    /// Create a new converter
649    pub fn new() -> Self {
650        Self {
651            source: None,
652            architecture: Architecture::Auto,
653            validation: ValidationConfig::Strict,
654            quantize: None,
655            compress: None,
656        }
657    }
658
659    /// Set the source
660    pub fn source(mut self, source: &str) -> Result<Self> {
661        self.source = Some(Source::parse(source)?);
662        Ok(self)
663    }
664
665    /// Set the architecture
666    pub fn architecture(mut self, arch: Architecture) -> Self {
667        self.architecture = arch;
668        self
669    }
670
671    /// Set validation config
672    pub fn validate(mut self, config: ValidationConfig) -> Self {
673        self.validation = config;
674        self
675    }
676
677    /// Set quantization
678    pub fn quantize(mut self, quant: QuantizationType) -> Self {
679        self.quantize = Some(quant);
680        self
681    }
682
683    /// Set compression
684    pub fn compress(mut self, comp: Compression) -> Self {
685        self.compress = Some(comp);
686        self
687    }
688
689    /// Run the conversion
690    pub fn convert(self) -> Result<Vec<u8>> {
691        let source = self.source.ok_or_else(|| AprenderError::FormatError {
692            message: "No source specified".to_string(),
693        })?;
694
695        // NOTE: Full conversion pipeline is tracked in GH-80 (metaheuristics milestone)
696        // Current limitation: Returns error for unsupported sources
697        Err(AprenderError::FormatError {
698            message: format!(
699                "Conversion from {:?} not yet implemented - see GH-80",
700                source
701            ),
702        })
703    }
704}
705
706impl Default for AprConverter {
707    fn default() -> Self {
708        Self::new()
709    }
710}
711
712// ============================================================================
713// High-level API
714// ============================================================================
715
716/// Import a model from source to APR format
717///
718/// # Arguments
719/// * `source` - Source path: local file, hf://org/repo, or URL
720/// * `output` - Output APR file path
721/// * `options` - Import configuration
722///
723/// # Returns
724/// * `ValidationReport` with 100-point checklist results
725///
726/// # Example
727/// ```rust,ignore
728/// use aprender::format::{apr_import, ImportOptions, Architecture};
729///
730/// let options = ImportOptions {
731///     architecture: Architecture::Whisper,
732///     ..Default::default()
733/// };
734/// let report = apr_import("model.safetensors", "model.apr", options)?;
735/// println!("Score: {}/100", report.total_score);
736/// ```
737pub fn apr_import<P: AsRef<Path>>(
738    source: &str,
739    output: P,
740    options: ImportOptions,
741) -> Result<ValidationReport> {
742    let parsed_source = Source::parse(source)?;
743    let output_path = output.as_ref();
744
745    // Step 1: Resolve source to local path
746    let local_path = resolve_source(&parsed_source, options.cache)?;
747
748    // Step 2: Detect format and load tensors
749    let tensors = load_source_tensors(&local_path, &options)?;
750
751    // Step 3: Map tensor names to canonical APR names
752    let mapped_tensors = map_tensor_names(&tensors, options.architecture);
753
754    // Step 4: Validate tensors (inline validation)
755    let validation_result = validate_tensors(&mapped_tensors, &options)?;
756
757    // Step 5: Write APR format
758    write_apr_file(&mapped_tensors, output_path, &options)?;
759
760    Ok(validation_result)
761}
762
763/// Resolve a source to a local file path
764fn resolve_source(source: &Source, cache: bool) -> Result<PathBuf> {
765    match source {
766        Source::Local(path) => {
767            if !path.exists() {
768                // GH-129: Use ImportError for actionable message
769                let err = ImportError::NotFound {
770                    resource: path.display().to_string(),
771                    status: 0, // Local file, not HTTP
772                };
773                return Err(AprenderError::from(err));
774            }
775            Ok(path.clone())
776        }
777        Source::HuggingFace { org, repo, file } => {
778            let filename = file.as_deref().unwrap_or("model.safetensors");
779
780            // Check standard cache locations first
781            if cache {
782                if let Some(path) = find_in_cache(org, repo, filename) {
783                    return Ok(path);
784                }
785            }
786
787            // Try to download using hf-hub if feature is enabled
788            #[cfg(feature = "hf-hub-integration")]
789            {
790                let repo_id = format!("{org}/{repo}");
791                match download_from_hf(&repo_id, filename) {
792                    Ok(path) => return Ok(path),
793                    Err(e) => {
794                        // Fall through to manual download instructions
795                        eprintln!("HF download failed: {e}");
796                    }
797                }
798            }
799
800            Err(AprenderError::FormatError {
801                message: format!(
802                    "HuggingFace model not found in cache. Download manually:\n\
803                     huggingface-cli download {org}/{repo} {filename}\n\
804                     Or provide a local path to the SafeTensors file.",
805                ),
806            })
807        }
808        Source::Url(url) => Err(AprenderError::FormatError {
809            message: format!("URL download not yet implemented: {url}"),
810        }),
811    }
812}
813
814/// Get XDG cache directory or fallback.
815fn get_xdg_cache_dir() -> PathBuf {
816    std::env::var("XDG_CACHE_HOME")
817        .ok()
818        .map(PathBuf::from)
819        .unwrap_or_else(|| {
820            std::env::var("HOME")
821                .map(|h| PathBuf::from(h).join(".cache"))
822                .unwrap_or_else(|_| PathBuf::from(".cache"))
823        })
824}
825
826/// Get HuggingFace cache directory.
827fn get_hf_cache_dir() -> PathBuf {
828    std::env::var("HF_HOME")
829        .ok()
830        .map(PathBuf::from)
831        .unwrap_or_else(|| {
832            std::env::var("HOME")
833                .map(|h| PathBuf::from(h).join(".cache").join("huggingface"))
834                .unwrap_or_else(|_| PathBuf::from(".cache").join("huggingface"))
835        })
836}
837
838/// Check aprender cache for a file.
839fn find_in_aprender_cache(
840    cache_base: &Path,
841    org: &str,
842    repo: &str,
843    filename: &str,
844) -> Option<PathBuf> {
845    let apr_cache = cache_base
846        .join("aprender")
847        .join("hf")
848        .join(org)
849        .join(repo)
850        .join(filename);
851    apr_cache.exists().then_some(apr_cache)
852}
853
854/// Check HuggingFace hub cache for a file.
855fn find_in_hf_hub_cache(
856    cache_base: &Path,
857    org: &str,
858    repo: &str,
859    filename: &str,
860) -> Option<PathBuf> {
861    let hf_cache = cache_base
862        .join("hub")
863        .join(format!("models--{org}--{repo}"));
864
865    if !hf_cache.exists() {
866        return None;
867    }
868
869    let snapshot_dir = hf_cache.join("snapshots");
870    let entries = fs::read_dir(&snapshot_dir).ok()?;
871
872    for entry in entries.flatten() {
873        let file_path = entry.path().join(filename);
874        if file_path.exists() {
875            return Some(file_path);
876        }
877    }
878    None
879}
880
881/// Find a model file in standard cache locations
882fn find_in_cache(org: &str, repo: &str, filename: &str) -> Option<PathBuf> {
883    let cache_paths = [get_xdg_cache_dir(), get_hf_cache_dir()];
884
885    for cache_base in &cache_paths {
886        if let Some(path) = find_in_aprender_cache(cache_base, org, repo, filename) {
887            return Some(path);
888        }
889        if let Some(path) = find_in_hf_hub_cache(cache_base, org, repo, filename) {
890            return Some(path);
891        }
892    }
893
894    None
895}
896
897/// Download a file from HuggingFace Hub
898#[cfg(feature = "hf-hub-integration")]
899fn download_from_hf(repo_id: &str, filename: &str) -> Result<PathBuf> {
900    use hf_hub::api::sync::ApiBuilder;
901
902    // Build API client (uses HF_TOKEN if available)
903    let token = std::env::var("HF_TOKEN").ok();
904    let mut builder = ApiBuilder::new();
905    if let Some(t) = token {
906        builder = builder.with_token(Some(t));
907    }
908
909    let api = builder.build().map_err(|e| {
910        let resource = format!("{repo_id}/{filename}");
911        let err = parse_import_error(&e.to_string(), &resource);
912        AprenderError::from(err)
913    })?;
914
915    // Get repo handle
916    let repo = api.model(repo_id.to_string());
917
918    // Download the file (GH-129: parse error for actionable messages)
919    let path = repo.get(filename).map_err(|e| {
920        let resource = format!("{repo_id}/{filename}");
921        let err = parse_import_error(&e.to_string(), &resource);
922        AprenderError::from(err)
923    })?;
924
925    Ok(path)
926}
927
928/// Load tensors from source file (SafeTensors format)
929fn load_source_tensors(
930    path: &Path,
931    _options: &ImportOptions,
932) -> Result<BTreeMap<String, (Vec<f32>, Vec<usize>)>> {
933    let extension = path.extension().and_then(|e| e.to_str()).unwrap_or("");
934
935    match extension {
936        "safetensors" => load_safetensors_tensors(path),
937        "apr" => {
938            // Already APR format - extract tensors
939            Err(AprenderError::FormatError {
940                message: "Cannot import from APR format - use direct loading instead".to_string(),
941            })
942        }
943        "gguf" => Err(AprenderError::FormatError {
944            message: "GGUF import not yet implemented".to_string(),
945        }),
946        "bin" | "pt" | "pth" => Err(AprenderError::FormatError {
947            message: format!(
948                "PyTorch format ({extension}) not supported. Convert to SafeTensors first."
949            ),
950        }),
951        other => Err(AprenderError::FormatError {
952            message: format!("Unknown file format: .{other}. Supported: .safetensors"),
953        }),
954    }
955}
956
957/// Load tensors from SafeTensors file
958fn load_safetensors_tensors(path: &Path) -> Result<BTreeMap<String, (Vec<f32>, Vec<usize>)>> {
959    let (metadata, raw_data) = load_safetensors(path).map_err(|e| AprenderError::FormatError {
960        message: format!("Failed to load SafeTensors: {e}"),
961    })?;
962
963    let mut tensors = BTreeMap::new();
964
965    for (name, tensor_meta) in metadata.iter() {
966        // Skip __metadata__ key if present
967        if name.starts_with("__") {
968            continue;
969        }
970
971        let data =
972            extract_tensor(&raw_data, tensor_meta).map_err(|e| AprenderError::FormatError {
973                message: format!("Failed to extract tensor '{name}': {e}"),
974            })?;
975
976        tensors.insert(name.clone(), (data, tensor_meta.shape.clone()));
977    }
978
979    Ok(tensors)
980}
981
982/// Map tensor names to APR canonical format
983fn map_tensor_names(
984    tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
985    architecture: Architecture,
986) -> BTreeMap<String, (Vec<f32>, Vec<usize>)> {
987    tensors
988        .iter()
989        .map(|(name, data)| {
990            let mapped_name = architecture.map_name(name);
991            (mapped_name, data.clone())
992        })
993        .collect()
994}
995
996/// Check tensor expectations and return error message if failed.
997fn check_tensor_expectation(
998    name: &str,
999    stats: &TensorStats,
1000    options: &ImportOptions,
1001) -> Option<String> {
1002    if options.validation == ValidationConfig::None {
1003        return None;
1004    }
1005    let expectation = TensorExpectation::for_tensor(name)?;
1006    let err = expectation.check(stats).err()?;
1007    if options.validation == ValidationConfig::Strict && !options.force {
1008        Some(format!("{name}: {err}"))
1009    } else {
1010        None
1011    }
1012}
1013
1014/// Check for special values (NaN/Inf) and return error messages.
1015fn check_special_values(name: &str, stats: &TensorStats, options: &ImportOptions) -> Vec<String> {
1016    if options.validation == ValidationConfig::None {
1017        return Vec::new();
1018    }
1019    let mut errors = Vec::new();
1020    if stats.nan_count > 0 {
1021        errors.push(format!("{name}: contains {} NaN values", stats.nan_count));
1022    }
1023    if stats.inf_count > 0 {
1024        errors.push(format!("{name}: contains {} Inf values", stats.inf_count));
1025    }
1026    errors
1027}
1028
1029/// Validate a single tensor and collect errors.
1030fn validate_single_tensor(
1031    name: &str,
1032    data: &[f32],
1033    options: &ImportOptions,
1034    validator: &mut AprValidator,
1035    errors: &mut Vec<String>,
1036) {
1037    let stats = compute_tensor_stats(name, data);
1038
1039    if let Some(err) = check_tensor_expectation(name, &stats, options) {
1040        errors.push(err);
1041    }
1042    errors.extend(check_special_values(name, &stats, options));
1043
1044    validator.add_tensor_stats(stats);
1045}
1046
1047/// Validate tensors according to architecture expectations
1048fn validate_tensors(
1049    tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
1050    options: &ImportOptions,
1051) -> Result<ValidationReport> {
1052    let mut validator = AprValidator::new();
1053    let mut validation_errors = Vec::new();
1054
1055    for (name, (data, _shape)) in tensors {
1056        validate_single_tensor(name, data, options, &mut validator, &mut validation_errors);
1057    }
1058
1059    let report = validator.validate();
1060
1061    if !validation_errors.is_empty() && !options.force {
1062        return Err(AprenderError::FormatError {
1063            message: format!(
1064                "Validation failed ({} errors):\n  - {}",
1065                validation_errors.len(),
1066                validation_errors.join("\n  - ")
1067            ),
1068        });
1069    }
1070
1071    Ok(report)
1072}
1073
1074/// Accumulator for tensor statistics during first pass.
1075struct TensorAccumulator {
1076    sum: f64,
1077    min: f32,
1078    max: f32,
1079    nan_count: usize,
1080    inf_count: usize,
1081    zero_count: usize,
1082    valid_count: usize,
1083}
1084
1085impl TensorAccumulator {
1086    fn new() -> Self {
1087        Self {
1088            sum: 0.0,
1089            min: f32::INFINITY,
1090            max: f32::NEG_INFINITY,
1091            nan_count: 0,
1092            inf_count: 0,
1093            zero_count: 0,
1094            valid_count: 0,
1095        }
1096    }
1097
1098    fn accumulate(&mut self, v: f32) {
1099        if v.is_nan() {
1100            self.nan_count += 1;
1101        } else if v.is_infinite() {
1102            self.inf_count += 1;
1103        } else {
1104            self.sum += v as f64;
1105            self.min = self.min.min(v);
1106            self.max = self.max.max(v);
1107            self.valid_count += 1;
1108            if v == 0.0 {
1109                self.zero_count += 1;
1110            }
1111        }
1112    }
1113
1114    fn mean(&self) -> f32 {
1115        if self.valid_count > 0 {
1116            (self.sum / self.valid_count as f64) as f32
1117        } else {
1118            0.0
1119        }
1120    }
1121
1122    fn safe_min(&self) -> f32 {
1123        if self.min == f32::INFINITY {
1124            0.0
1125        } else {
1126            self.min
1127        }
1128    }
1129
1130    fn safe_max(&self) -> f32 {
1131        if self.max == f32::NEG_INFINITY {
1132            0.0
1133        } else {
1134            self.max
1135        }
1136    }
1137}
1138
1139/// Compute standard deviation from data.
1140fn compute_std(data: &[f32], mean: f32, valid_count: usize) -> f32 {
1141    if valid_count <= 1 {
1142        return 0.0;
1143    }
1144    let variance_sum: f64 = data
1145        .iter()
1146        .filter(|v| !v.is_nan() && !v.is_infinite())
1147        .map(|&v| {
1148            let diff = v as f64 - mean as f64;
1149            diff * diff
1150        })
1151        .sum();
1152    ((variance_sum / (valid_count - 1) as f64).sqrt()) as f32
1153}
1154
1155/// Compute statistics for a tensor
1156fn compute_tensor_stats(name: &str, data: &[f32]) -> TensorStats {
1157    if data.is_empty() {
1158        return TensorStats {
1159            name: name.to_string(),
1160            count: 0,
1161            min: 0.0,
1162            max: 0.0,
1163            mean: 0.0,
1164            std: 0.0,
1165            nan_count: 0,
1166            inf_count: 0,
1167            zero_count: 0,
1168        };
1169    }
1170
1171    let mut acc = TensorAccumulator::new();
1172    for &v in data {
1173        acc.accumulate(v);
1174    }
1175
1176    let mean = acc.mean();
1177    let std = compute_std(data, mean, acc.valid_count);
1178
1179    TensorStats {
1180        name: name.to_string(),
1181        count: data.len(),
1182        min: acc.safe_min(),
1183        max: acc.safe_max(),
1184        mean,
1185        std,
1186        nan_count: acc.nan_count,
1187        inf_count: acc.inf_count,
1188        zero_count: acc.zero_count,
1189    }
1190}
1191
1192/// Write tensors to native APR v2 format
1193fn write_apr_file(
1194    tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
1195    output: &Path,
1196    options: &ImportOptions,
1197) -> Result<()> {
1198    // Create metadata with architecture info
1199    let mut metadata = AprV2Metadata::default();
1200    metadata.model_type = format!("{:?}", options.architecture);
1201    metadata.name = Some(output.file_stem()
1202        .and_then(|s| s.to_str())
1203        .unwrap_or("model")
1204        .to_string());
1205
1206    // Calculate total parameter count
1207    let param_count: u64 = tensors.values()
1208        .map(|(data, _)| data.len() as u64)
1209        .sum();
1210    metadata.param_count = param_count;
1211
1212    // Create APR v2 writer
1213    let mut writer = AprV2Writer::new(metadata);
1214
1215    // Add all tensors
1216    for (name, (data, shape)) in tensors {
1217        writer.add_f32_tensor(name, shape.clone(), data);
1218    }
1219
1220    // Write to file
1221    let bytes = writer.write().map_err(|e| AprenderError::FormatError {
1222        message: format!("Failed to serialize APR format: {e}"),
1223    })?;
1224
1225    let mut file = fs::File::create(output).map_err(|e| AprenderError::FormatError {
1226        message: format!("Failed to create output file: {e}"),
1227    })?;
1228
1229    file.write_all(&bytes).map_err(|e| AprenderError::FormatError {
1230        message: format!("Failed to write APR file: {e}"),
1231    })?;
1232
1233    Ok(())
1234}
1235
1236// ============================================================================
1237// Model Conversion (apr convert)
1238// ============================================================================
1239
1240/// Options for model conversion
1241#[derive(Debug, Clone)]
1242pub struct ConvertOptions {
1243    /// Quantization method (int8, int4, fp16)
1244    pub quantize: Option<QuantizationType>,
1245    /// Compression method
1246    pub compress: Option<Compression>,
1247    /// Validate after conversion
1248    pub validate: bool,
1249}
1250
1251impl Default for ConvertOptions {
1252    fn default() -> Self {
1253        Self {
1254            quantize: None,
1255            compress: None,
1256            validate: true,
1257        }
1258    }
1259}
1260
1261/// Convert a model with quantization and/or compression
1262///
1263/// # Arguments
1264/// * `input` - Input model path (.safetensors or .apr)
1265/// * `output` - Output model path
1266/// * `options` - Conversion options
1267///
1268/// # Returns
1269/// * `ConvertReport` with size reduction stats
1270///
1271/// # Example
1272/// ```rust,ignore
1273/// use aprender::format::{apr_convert, ConvertOptions, QuantizationType};
1274///
1275/// let options = ConvertOptions {
1276///     quantize: Some(QuantizationType::Int8),
1277///     ..Default::default()
1278/// };
1279/// let report = apr_convert("model.safetensors", "model-int8.apr", options)?;
1280/// println!("Reduced from {} to {} bytes", report.original_size, report.converted_size);
1281/// ```
1282pub fn apr_convert<P: AsRef<Path>>(
1283    input: P,
1284    output: P,
1285    options: ConvertOptions,
1286) -> Result<ConvertReport> {
1287    let input_path = input.as_ref();
1288    let output_path = output.as_ref();
1289
1290    // Step 1: Load tensors
1291    let tensors = load_model_tensors(input_path)?;
1292    let original_size = calculate_tensor_size(&tensors);
1293    let original_count = tensors.len();
1294
1295    // Step 2: Apply quantization if requested
1296    let tensors = if let Some(quant_type) = &options.quantize {
1297        quantize_tensors(&tensors, quant_type)?
1298    } else {
1299        tensors
1300    };
1301
1302    // Step 3: Save output (compression applied during save)
1303    save_model_tensors(&tensors, output_path, options.compress)?;
1304
1305    // Step 4: Calculate stats
1306    let converted_size = fs::metadata(output_path)
1307        .map(|m| m.len() as usize)
1308        .unwrap_or(0);
1309
1310    Ok(ConvertReport {
1311        original_size,
1312        converted_size,
1313        tensor_count: original_count,
1314        quantization: options.quantize,
1315        compression: options.compress,
1316        reduction_ratio: if converted_size > 0 {
1317            original_size as f64 / converted_size as f64
1318        } else {
1319            0.0
1320        },
1321    })
1322}
1323
1324/// Report from model conversion
1325#[derive(Debug, Clone)]
1326pub struct ConvertReport {
1327    /// Original model size in bytes
1328    pub original_size: usize,
1329    /// Converted model size in bytes
1330    pub converted_size: usize,
1331    /// Number of tensors
1332    pub tensor_count: usize,
1333    /// Quantization applied
1334    pub quantization: Option<QuantizationType>,
1335    /// Compression applied
1336    pub compression: Option<Compression>,
1337    /// Size reduction ratio (original/converted)
1338    pub reduction_ratio: f64,
1339}
1340
1341impl ConvertReport {
1342    /// Format reduction as percentage string
1343    pub fn reduction_percent(&self) -> String {
1344        if self.original_size > 0 && self.converted_size > 0 {
1345            let reduction = 100.0 * (1.0 - self.converted_size as f64 / self.original_size as f64);
1346            format!("{:.1}%", reduction)
1347        } else {
1348            "N/A".to_string()
1349        }
1350    }
1351}
1352
1353/// Load tensors from model file
1354fn load_model_tensors(path: &Path) -> Result<BTreeMap<String, (Vec<f32>, Vec<usize>)>> {
1355    let extension = path.extension().and_then(|e| e.to_str()).unwrap_or("");
1356
1357    match extension {
1358        "safetensors" | "apr" => load_safetensors_tensors(path),
1359        other => Err(AprenderError::FormatError {
1360            message: format!("Unsupported format for conversion: .{other}"),
1361        }),
1362    }
1363}
1364
1365/// Calculate total tensor size in bytes (f32)
1366fn calculate_tensor_size(tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>) -> usize {
1367    tensors.values().map(|(data, _)| data.len() * 4).sum()
1368}
1369
1370/// Apply quantization to tensors
1371fn quantize_tensors(
1372    tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
1373    quant_type: &QuantizationType,
1374) -> Result<BTreeMap<String, (Vec<f32>, Vec<usize>)>> {
1375    let mut result = BTreeMap::new();
1376
1377    for (name, (data, shape)) in tensors {
1378        let quantized_data = match quant_type {
1379            QuantizationType::Fp16 => quantize_fp16(data),
1380            QuantizationType::Int8 => quantize_int8(data),
1381            QuantizationType::Int4 => quantize_int4(data),
1382        };
1383        result.insert(name.clone(), (quantized_data, shape.clone()));
1384    }
1385
1386    Ok(result)
1387}
1388
1389/// Quantize to fp16 (simulate by reducing precision)
1390fn quantize_fp16(data: &[f32]) -> Vec<f32> {
1391    data.iter()
1392        .map(|&v| {
1393            // Convert to f16 precision by truncating mantissa
1394            let bits = v.to_bits();
1395            let sign = bits >> 31;
1396            let exp = (bits >> 23) & 0xFF;
1397            let mantissa = bits & 0x7FFFFF;
1398
1399            // Truncate mantissa to 10 bits (f16 precision)
1400            let mantissa_16 = mantissa >> 13;
1401
1402            // Reconstruct as f32 with reduced precision
1403            let new_bits = (sign << 31) | (exp << 23) | (mantissa_16 << 13);
1404            f32::from_bits(new_bits)
1405        })
1406        .collect()
1407}
1408
1409/// Quantize to int8 (symmetric quantization)
1410fn quantize_int8(data: &[f32]) -> Vec<f32> {
1411    if data.is_empty() {
1412        return vec![];
1413    }
1414
1415    // Find scale factor (max absolute value)
1416    let max_abs = data.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
1417
1418    if max_abs == 0.0 {
1419        return vec![0.0; data.len()];
1420    }
1421
1422    let scale = max_abs / 127.0;
1423
1424    // Quantize and dequantize
1425    data.iter()
1426        .map(|&v| {
1427            let quantized = (v / scale).round().clamp(-127.0, 127.0) as i8;
1428            f32::from(quantized) * scale
1429        })
1430        .collect()
1431}
1432
1433/// Quantize to int4 (symmetric quantization)
1434fn quantize_int4(data: &[f32]) -> Vec<f32> {
1435    if data.is_empty() {
1436        return vec![];
1437    }
1438
1439    // Find scale factor
1440    let max_abs = data.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
1441
1442    if max_abs == 0.0 {
1443        return vec![0.0; data.len()];
1444    }
1445
1446    let scale = max_abs / 7.0; // 4-bit signed range: -8 to 7
1447
1448    // Quantize and dequantize
1449    data.iter()
1450        .map(|&v| {
1451            let quantized = (v / scale).round().clamp(-8.0, 7.0) as i8;
1452            f32::from(quantized) * scale
1453        })
1454        .collect()
1455}
1456
1457/// Save model tensors with optional compression
1458fn save_model_tensors(
1459    tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
1460    output: &Path,
1461    _compression: Option<Compression>,
1462) -> Result<()> {
1463    // NOTE: Compression support deferred to APR-FORMAT-003 milestone
1464    // Currently saves as uncompressed SafeTensors (sufficient for most models <2GB)
1465    save_safetensors(output, tensors).map_err(|e| AprenderError::FormatError {
1466        message: format!("Failed to save converted model: {e}"),
1467    })
1468}
1469
1470// ============================================================================
1471// EXPORT FUNCTIONALITY (APR-SPEC §4.6)
1472// ============================================================================
1473
1474/// Export format options
1475#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1476pub enum ExportFormat {
1477    /// SafeTensors format (.safetensors) - HuggingFace ecosystem
1478    SafeTensors,
1479    /// GGUF format (.gguf) - llama.cpp / local inference
1480    Gguf,
1481    /// ONNX format (.onnx) - Cross-framework inference (not yet implemented)
1482    Onnx,
1483    /// TorchScript format (.pt) - PyTorch deployment (not yet implemented)
1484    TorchScript,
1485}
1486
1487impl std::str::FromStr for ExportFormat {
1488    type Err = String;
1489
1490    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
1491        match s.to_lowercase().as_str() {
1492            "safetensors" | "st" => Ok(Self::SafeTensors),
1493            "gguf" => Ok(Self::Gguf),
1494            "onnx" => Ok(Self::Onnx),
1495            "torchscript" | "pt" | "torch" => Ok(Self::TorchScript),
1496            _ => Err(format!("Unknown export format: {s}")),
1497        }
1498    }
1499}
1500
1501impl ExportFormat {
1502    /// Get default file extension
1503    #[must_use]
1504    pub fn extension(&self) -> &'static str {
1505        match self {
1506            Self::SafeTensors => "safetensors",
1507            Self::Gguf => "gguf",
1508            Self::Onnx => "onnx",
1509            Self::TorchScript => "pt",
1510        }
1511    }
1512
1513    /// Check if format is supported
1514    #[must_use]
1515    pub fn is_supported(&self) -> bool {
1516        matches!(self, Self::SafeTensors | Self::Gguf)
1517    }
1518}
1519
1520/// Options for model export
1521#[derive(Debug, Clone)]
1522pub struct ExportOptions {
1523    /// Target format
1524    pub format: ExportFormat,
1525    /// Optional quantization
1526    pub quantize: Option<QuantizationType>,
1527}
1528
1529impl Default for ExportOptions {
1530    fn default() -> Self {
1531        Self {
1532            format: ExportFormat::SafeTensors,
1533            quantize: None,
1534        }
1535    }
1536}
1537
1538/// Report from export operation
1539#[derive(Debug, Clone)]
1540pub struct ExportReport {
1541    /// Original size in bytes
1542    pub original_size: usize,
1543    /// Exported size in bytes
1544    pub exported_size: usize,
1545    /// Number of tensors exported
1546    pub tensor_count: usize,
1547    /// Export format used
1548    pub format: ExportFormat,
1549    /// Quantization applied
1550    pub quantization: Option<QuantizationType>,
1551}
1552
1553/// Export APR/SafeTensors model to another format
1554///
1555/// # Arguments
1556///
1557/// * `input` - Input model path (.apr or .safetensors)
1558/// * `output` - Output file path
1559/// * `options` - Export options
1560///
1561/// # Returns
1562///
1563/// Export report with size and format information
1564///
1565/// # Errors
1566///
1567/// Returns error if:
1568/// - Input file doesn't exist
1569/// - Format not supported
1570/// - Export fails
1571///
1572/// # Example
1573///
1574/// ```rust,ignore
1575/// use aprender::format::{apr_export, ExportOptions, ExportFormat};
1576///
1577/// let options = ExportOptions {
1578///     format: ExportFormat::Gguf,
1579///     quantize: None,
1580/// };
1581/// let report = apr_export("model.apr", "model.gguf", options)?;
1582/// ```
1583pub fn apr_export<P: AsRef<Path>>(
1584    input: P,
1585    output: P,
1586    options: ExportOptions,
1587) -> Result<ExportReport> {
1588    let input_path = input.as_ref();
1589    let output_path = output.as_ref();
1590
1591    // Validate input exists
1592    if !input_path.exists() {
1593        return Err(AprenderError::FormatError {
1594            message: format!("Input file not found: {}", input_path.display()),
1595        });
1596    }
1597
1598    // Check if format is supported
1599    if !options.format.is_supported() {
1600        return Err(AprenderError::FormatError {
1601            message: format!(
1602                "Export format {:?} is not yet supported. Use 'safetensors' or 'gguf'.",
1603                options.format
1604            ),
1605        });
1606    }
1607
1608    // Load tensors
1609    let tensors = load_model_tensors(input_path)?;
1610    let original_size = calculate_tensor_size(&tensors);
1611    let original_count = tensors.len();
1612
1613    // Apply quantization if requested
1614    let tensors = if let Some(ref quant_type) = options.quantize {
1615        quantize_tensors(&tensors, quant_type)?
1616    } else {
1617        tensors
1618    };
1619
1620    // Export to target format
1621    match options.format {
1622        ExportFormat::SafeTensors => {
1623            save_safetensors(output_path, &tensors).map_err(|e| AprenderError::FormatError {
1624                message: format!("Failed to export to SafeTensors: {e}"),
1625            })?;
1626        }
1627        ExportFormat::Gguf => {
1628            export_to_gguf(&tensors, output_path)?;
1629        }
1630        ExportFormat::Onnx | ExportFormat::TorchScript => {
1631            return Err(AprenderError::FormatError {
1632                message: format!("Export format {:?} is not yet implemented", options.format),
1633            });
1634        }
1635    }
1636
1637    // Get exported file size
1638    let exported_size = fs::metadata(output_path)
1639        .map(|m| m.len() as usize)
1640        .unwrap_or(0);
1641
1642    Ok(ExportReport {
1643        original_size,
1644        exported_size,
1645        tensor_count: original_count,
1646        format: options.format,
1647        quantization: options.quantize,
1648    })
1649}
1650
1651/// Export tensors to GGUF format
1652fn export_to_gguf(tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>, output: &Path) -> Result<()> {
1653    use crate::format::gguf::{export_tensors_to_gguf, GgmlType, GgufTensor, GgufValue};
1654    use std::fs::File;
1655    use std::io::BufWriter;
1656
1657    // Convert tensors to GGUF format
1658    let gguf_tensors: Vec<GgufTensor> = tensors
1659        .iter()
1660        .map(|(name, (data, shape))| {
1661            // Convert f32 data to bytes
1662            let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
1663
1664            GgufTensor {
1665                name: name.clone(),
1666                shape: shape.iter().map(|&d| d as u64).collect(),
1667                dtype: GgmlType::F32,
1668                data: bytes,
1669            }
1670        })
1671        .collect();
1672
1673    // Basic metadata
1674    let metadata = vec![
1675        (
1676            "general.name".to_string(),
1677            GgufValue::String("model".to_string()),
1678        ),
1679        (
1680            "general.quantization_version".to_string(),
1681            GgufValue::Uint32(1),
1682        ),
1683    ];
1684
1685    // Write to file
1686    let file = File::create(output).map_err(|e| AprenderError::FormatError {
1687        message: format!("Failed to create output file: {e}"),
1688    })?;
1689    let mut writer = BufWriter::new(file);
1690
1691    export_tensors_to_gguf(&mut writer, &gguf_tensors, &metadata)
1692}
1693
1694// ============================================================================
1695// MERGE FUNCTIONALITY (APR-SPEC §4.9)
1696// ============================================================================
1697
1698/// Merge strategy options
1699#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1700pub enum MergeStrategy {
1701    /// Average weights (simple ensemble)
1702    Average,
1703    /// Weighted average by performance
1704    Weighted,
1705    /// TIES merging (trim, elect, sign) - advanced
1706    Ties,
1707    /// DARE merging (drop and rescale) - advanced
1708    Dare,
1709    /// Spherical linear interpolation - advanced
1710    Slerp,
1711}
1712
1713impl std::str::FromStr for MergeStrategy {
1714    type Err = String;
1715
1716    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
1717        match s.to_lowercase().as_str() {
1718            "average" | "avg" => Ok(Self::Average),
1719            "weighted" => Ok(Self::Weighted),
1720            "ties" => Ok(Self::Ties),
1721            "dare" => Ok(Self::Dare),
1722            "slerp" => Ok(Self::Slerp),
1723            _ => Err(format!("Unknown merge strategy: {s}")),
1724        }
1725    }
1726}
1727
1728impl MergeStrategy {
1729    /// Check if strategy is currently supported
1730    #[must_use]
1731    pub fn is_supported(&self) -> bool {
1732        matches!(self, Self::Average | Self::Weighted)
1733    }
1734}
1735
1736/// Options for model merging
1737#[derive(Debug, Clone)]
1738pub struct MergeOptions {
1739    /// Merge strategy to use
1740    pub strategy: MergeStrategy,
1741    /// Weights for weighted merging (must match number of models)
1742    pub weights: Option<Vec<f32>>,
1743}
1744
1745impl Default for MergeOptions {
1746    fn default() -> Self {
1747        Self {
1748            strategy: MergeStrategy::Average,
1749            weights: None,
1750        }
1751    }
1752}
1753
1754/// Report from merge operation
1755#[derive(Debug, Clone)]
1756pub struct MergeReport {
1757    /// Number of models merged
1758    pub model_count: usize,
1759    /// Number of tensors in merged model
1760    pub tensor_count: usize,
1761    /// Output file size in bytes
1762    pub output_size: usize,
1763    /// Strategy used
1764    pub strategy: MergeStrategy,
1765    /// Weights used (if weighted merge)
1766    pub weights_used: Option<Vec<f32>>,
1767}
1768
1769// ============================================================================
1770// MERGE HELPER FUNCTIONS (Refactored for reduced complexity)
1771// ============================================================================
1772
1773/// Validate merge options and input count.
1774fn validate_merge_options<P: AsRef<Path>>(inputs: &[P], options: &MergeOptions) -> Result<()> {
1775    if inputs.len() < 2 {
1776        return Err(AprenderError::FormatError {
1777            message: "Merge requires at least 2 input models".to_string(),
1778        });
1779    }
1780
1781    if !options.strategy.is_supported() {
1782        return Err(AprenderError::FormatError {
1783            message: format!(
1784                "Merge strategy {:?} is not yet supported. Use 'average' or 'weighted'.",
1785                options.strategy
1786            ),
1787        });
1788    }
1789
1790    if options.strategy == MergeStrategy::Weighted {
1791        match &options.weights {
1792            Some(weights) if weights.len() != inputs.len() => {
1793                return Err(AprenderError::FormatError {
1794                    message: format!(
1795                        "Weighted merge requires {} weights, got {}",
1796                        inputs.len(),
1797                        weights.len()
1798                    ),
1799                });
1800            }
1801            None => {
1802                return Err(AprenderError::FormatError {
1803                    message: "Weighted merge requires weights to be specified".to_string(),
1804                });
1805            }
1806            _ => {}
1807        }
1808    }
1809    Ok(())
1810}
1811
1812/// Load all model tensors from input files.
1813fn load_all_models<P: AsRef<Path>>(
1814    inputs: &[P],
1815) -> Result<Vec<BTreeMap<String, (Vec<f32>, Vec<usize>)>>> {
1816    let mut all_tensors = Vec::new();
1817    for input_path in inputs {
1818        let path = input_path.as_ref();
1819        if !path.exists() {
1820            return Err(AprenderError::FormatError {
1821                message: format!("Input file not found: {}", path.display()),
1822            });
1823        }
1824        all_tensors.push(load_model_tensors(path)?);
1825    }
1826    Ok(all_tensors)
1827}
1828
1829/// Verify all models have compatible tensor structures.
1830fn verify_tensor_compatibility(
1831    all_tensors: &[BTreeMap<String, (Vec<f32>, Vec<usize>)>],
1832) -> Result<()> {
1833    let reference = &all_tensors[0];
1834    for (i, tensors) in all_tensors.iter().enumerate().skip(1) {
1835        if tensors.len() != reference.len() {
1836            return Err(AprenderError::FormatError {
1837                message: format!(
1838                    "Model {} has {} tensors, but model 0 has {}",
1839                    i,
1840                    tensors.len(),
1841                    reference.len()
1842                ),
1843            });
1844        }
1845        verify_single_model_tensors(reference, tensors, i)?;
1846    }
1847    Ok(())
1848}
1849
1850/// Verify tensor compatibility for a single model against reference.
1851fn verify_single_model_tensors(
1852    reference: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
1853    tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>,
1854    model_idx: usize,
1855) -> Result<()> {
1856    for (name, (_, shape)) in reference {
1857        match tensors.get(name) {
1858            None => {
1859                return Err(AprenderError::FormatError {
1860                    message: format!("Model {} is missing tensor '{}'", model_idx, name),
1861                });
1862            }
1863            Some((_, other_shape)) if other_shape != shape => {
1864                return Err(AprenderError::FormatError {
1865                    message: format!(
1866                        "Tensor '{}' has shape {:?} in model 0 but {:?} in model {}",
1867                        name, shape, other_shape, model_idx
1868                    ),
1869                });
1870            }
1871            _ => {}
1872        }
1873    }
1874    Ok(())
1875}
1876
1877/// Calculate normalized merge weights based on strategy.
1878fn calculate_merge_weights(input_count: usize, options: &MergeOptions) -> Result<Vec<f32>> {
1879    match options.strategy {
1880        MergeStrategy::Average => {
1881            let w = 1.0 / input_count as f32;
1882            Ok(vec![w; input_count])
1883        }
1884        MergeStrategy::Weighted => {
1885            let raw_weights = options.weights.as_ref().expect("validated above");
1886            let sum: f32 = raw_weights.iter().sum();
1887            if sum <= 0.0 {
1888                return Err(AprenderError::FormatError {
1889                    message: "Weights must sum to a positive value".to_string(),
1890                });
1891            }
1892            Ok(raw_weights.iter().map(|w| w / sum).collect())
1893        }
1894        _ => unreachable!("unsupported strategies filtered above"),
1895    }
1896}
1897
1898/// Merge tensors from multiple models using given weights.
1899fn merge_tensors(
1900    all_tensors: &[BTreeMap<String, (Vec<f32>, Vec<usize>)>],
1901    weights: &[f32],
1902) -> BTreeMap<String, (Vec<f32>, Vec<usize>)> {
1903    let reference = &all_tensors[0];
1904    let mut merged = BTreeMap::new();
1905
1906    for (name, (_, shape)) in reference {
1907        let data_len = all_tensors[0].get(name).map(|(d, _)| d.len()).unwrap_or(0);
1908        let mut merged_data = vec![0.0f32; data_len];
1909
1910        for (model_idx, model_tensors) in all_tensors.iter().enumerate() {
1911            let (data, _) = model_tensors.get(name).expect("validated above");
1912            let weight = weights[model_idx];
1913            for (i, &val) in data.iter().enumerate() {
1914                merged_data[i] += val * weight;
1915            }
1916        }
1917
1918        merged.insert(name.clone(), (merged_data, shape.clone()));
1919    }
1920    merged
1921}
1922
1923/// Merge multiple models into one
1924///
1925/// # Arguments
1926///
1927/// * `inputs` - Input model paths (.apr or .safetensors)
1928/// * `output` - Output file path
1929/// * `options` - Merge options
1930///
1931/// # Returns
1932///
1933/// Merge report with statistics
1934///
1935/// # Errors
1936///
1937/// Returns error if:
1938/// - Less than 2 input files
1939/// - Input files don't exist
1940/// - Models have incompatible tensor shapes
1941/// - Strategy not supported
1942///
1943/// # Example
1944///
1945/// ```rust,ignore
1946/// use aprender::format::{apr_merge, MergeOptions, MergeStrategy};
1947///
1948/// let options = MergeOptions {
1949///     strategy: MergeStrategy::Average,
1950///     weights: None,
1951/// };
1952/// let report = apr_merge(&["model1.apr", "model2.apr"], "merged.apr", options)?;
1953/// ```
1954pub fn apr_merge<P: AsRef<Path>>(
1955    inputs: &[P],
1956    output: P,
1957    options: MergeOptions,
1958) -> Result<MergeReport> {
1959    // Validate inputs and options
1960    validate_merge_options(inputs, &options)?;
1961
1962    // Load all models
1963    let all_tensors = load_all_models(inputs)?;
1964
1965    // Verify tensor compatibility
1966    verify_tensor_compatibility(&all_tensors)?;
1967
1968    // Calculate weights
1969    let weights = calculate_merge_weights(inputs.len(), &options)?;
1970
1971    // Merge tensors
1972    let merged = merge_tensors(&all_tensors, &weights);
1973
1974    // Save merged model
1975    let output_path = output.as_ref();
1976    save_safetensors(output_path, &merged).map_err(|e| AprenderError::FormatError {
1977        message: format!("Failed to save merged model: {e}"),
1978    })?;
1979
1980    // Get output file size
1981    let output_size = fs::metadata(output_path)
1982        .map(|m| m.len() as usize)
1983        .unwrap_or(0);
1984
1985    Ok(MergeReport {
1986        model_count: inputs.len(),
1987        tensor_count: merged.len(),
1988        output_size,
1989        strategy: options.strategy,
1990        weights_used: Some(weights),
1991    })
1992}
1993
1994// ============================================================================
1995// TESTS - EXTREME TDD
1996// ============================================================================
1997
1998#[cfg(test)]
1999mod tests_source_parsing {
2000    use super::*;
2001
2002    #[test]
2003    fn test_parse_hf_org_repo() {
2004        let source = Source::parse("hf://openai/whisper-tiny").unwrap();
2005        assert_eq!(
2006            source,
2007            Source::HuggingFace {
2008                org: "openai".to_string(),
2009                repo: "whisper-tiny".to_string(),
2010                file: None,
2011            }
2012        );
2013    }
2014
2015    #[test]
2016    fn test_parse_hf_org_repo_file() {
2017        let source = Source::parse("hf://openai/whisper-tiny/model.safetensors").unwrap();
2018        assert_eq!(
2019            source,
2020            Source::HuggingFace {
2021                org: "openai".to_string(),
2022                repo: "whisper-tiny".to_string(),
2023                file: Some("model.safetensors".to_string()),
2024            }
2025        );
2026    }
2027
2028    #[test]
2029    fn test_parse_hf_nested_file() {
2030        let source =
2031            Source::parse("hf://meta-llama/Llama-2-7b/pytorch_model-00001-of-00002.bin").unwrap();
2032        assert_eq!(
2033            source,
2034            Source::HuggingFace {
2035                org: "meta-llama".to_string(),
2036                repo: "Llama-2-7b".to_string(),
2037                file: Some("pytorch_model-00001-of-00002.bin".to_string()),
2038            }
2039        );
2040    }
2041
2042    #[test]
2043    fn test_parse_local_path() {
2044        let source = Source::parse("./models/model.safetensors").unwrap();
2045        assert_eq!(
2046            source,
2047            Source::Local(PathBuf::from("./models/model.safetensors"))
2048        );
2049    }
2050
2051    #[test]
2052    fn test_parse_url() {
2053        let source = Source::parse("https://example.com/model.safetensors").unwrap();
2054        assert_eq!(
2055            source,
2056            Source::Url("https://example.com/model.safetensors".to_string())
2057        );
2058    }
2059
2060    #[test]
2061    fn test_parse_hf_invalid() {
2062        let result = Source::parse("hf://invalid");
2063        assert!(result.is_err());
2064    }
2065
2066    #[test]
2067    fn test_default_file() {
2068        let hf = Source::HuggingFace {
2069            org: "openai".to_string(),
2070            repo: "whisper".to_string(),
2071            file: None,
2072        };
2073        assert_eq!(hf.default_file(), "model.safetensors");
2074
2075        let hf_with_file = Source::HuggingFace {
2076            org: "openai".to_string(),
2077            repo: "whisper".to_string(),
2078            file: Some("custom.safetensors".to_string()),
2079        };
2080        assert_eq!(hf_with_file.default_file(), "custom.safetensors");
2081    }
2082}
2083
2084#[cfg(test)]
2085mod tests_name_mapping {
2086    use super::*;
2087
2088    #[test]
2089    fn test_whisper_strip_model_prefix() {
2090        let mapped = Architecture::Whisper.map_name("model.encoder.conv1.weight");
2091        assert_eq!(mapped, "encoder.conv1.weight");
2092    }
2093
2094    #[test]
2095    fn test_whisper_no_prefix() {
2096        let mapped = Architecture::Whisper.map_name("encoder.conv1.weight");
2097        assert_eq!(mapped, "encoder.conv1.weight");
2098    }
2099
2100    #[test]
2101    fn test_whisper_decoder_layer_norm() {
2102        let mapped = Architecture::Whisper.map_name("model.decoder.layer_norm.weight");
2103        assert_eq!(mapped, "decoder.layer_norm.weight");
2104    }
2105
2106    #[test]
2107    fn test_auto_strips_model_prefix() {
2108        let mapped = Architecture::Auto.map_name("model.encoder.layers.0.self_attn.q_proj.weight");
2109        assert_eq!(mapped, "encoder.layers.0.self_attn.q_proj.weight");
2110    }
2111
2112    #[test]
2113    fn test_llama_mapping() {
2114        let mapped = Architecture::Llama.map_name("model.layers.0.self_attn.q_proj.weight");
2115        assert_eq!(mapped, "layers.0.self_attn.q_proj.weight");
2116    }
2117
2118    #[test]
2119    fn test_bert_mapping() {
2120        let mapped =
2121            Architecture::Bert.map_name("bert.encoder.layer.0.attention.self.query.weight");
2122        assert_eq!(mapped, "encoder.layer.0.attention.self.query.weight");
2123    }
2124}
2125
2126#[cfg(test)]
2127mod tests_tensor_expectations {
2128    use super::*;
2129
2130    #[test]
2131    fn test_layer_norm_weight_expectation() {
2132        let exp = TensorExpectation::for_tensor("encoder.layer_norm.weight");
2133        assert!(exp.is_some());
2134        let exp = exp.unwrap();
2135        assert_eq!(exp.mean_range, (0.5, 3.0));
2136    }
2137
2138    #[test]
2139    fn test_layer_norm_bias_expectation() {
2140        let exp = TensorExpectation::for_tensor("decoder.layers.0.self_attn_layer_norm.bias");
2141        assert!(exp.is_some());
2142        let exp = exp.unwrap();
2143        assert_eq!(exp.mean_range, (-0.5, 0.5));
2144    }
2145
2146    #[test]
2147    fn test_linear_weight_expectation() {
2148        let exp = TensorExpectation::for_tensor("encoder.layers.0.fc1.weight");
2149        assert!(exp.is_some());
2150        let exp = exp.unwrap();
2151        assert_eq!(exp.mean_range, (-0.1, 0.1));
2152    }
2153
2154    #[test]
2155    fn test_embedding_expectation() {
2156        let exp = TensorExpectation::for_tensor("decoder.embed_tokens.weight");
2157        assert!(exp.is_some());
2158    }
2159
2160    #[test]
2161    fn test_check_layer_norm_valid() {
2162        let stats = TensorStats {
2163            name: "encoder.layer_norm.weight".to_string(),
2164            count: 384,
2165            min: 0.5,
2166            max: 2.0,
2167            mean: 1.0,
2168            std: 0.3,
2169            nan_count: 0,
2170            inf_count: 0,
2171            zero_count: 0,
2172        };
2173
2174        let exp = TensorExpectation::LAYER_NORM_WEIGHT;
2175        assert!(exp.check(&stats).is_ok());
2176    }
2177
2178    #[test]
2179    fn test_check_layer_norm_invalid_mean() {
2180        let stats = TensorStats {
2181            name: "decoder.layer_norm.weight".to_string(),
2182            count: 384,
2183            min: 5.0,
2184            max: 15.0,
2185            mean: 11.0, // BUG: should be ~1.0
2186            std: 2.0,
2187            nan_count: 0,
2188            inf_count: 0,
2189            zero_count: 0,
2190        };
2191
2192        let exp = TensorExpectation::LAYER_NORM_WEIGHT;
2193        let result = exp.check(&stats);
2194        assert!(result.is_err());
2195
2196        let err = result.unwrap_err().to_string();
2197        assert!(err.contains("mean=11"));
2198        assert!(err.contains("outside expected range"));
2199    }
2200}
2201
2202#[cfg(test)]
2203mod tests_converter_builder {
2204    use super::*;
2205
2206    #[test]
2207    fn test_converter_builder_chain() {
2208        let converter = AprConverter::new()
2209            .source("hf://openai/whisper-tiny")
2210            .unwrap()
2211            .architecture(Architecture::Whisper)
2212            .validate(ValidationConfig::Strict)
2213            .quantize(QuantizationType::Int8)
2214            .compress(Compression::Lz4);
2215
2216        assert_eq!(converter.architecture, Architecture::Whisper);
2217        assert_eq!(converter.validation, ValidationConfig::Strict);
2218        assert_eq!(converter.quantize, Some(QuantizationType::Int8));
2219        assert_eq!(converter.compress, Some(Compression::Lz4));
2220    }
2221
2222    #[test]
2223    fn test_converter_no_source_error() {
2224        let converter = AprConverter::new();
2225        let result = converter.convert();
2226        assert!(result.is_err());
2227    }
2228}
2229
2230#[cfg(test)]
2231mod tests_import_options {
2232    use super::*;
2233
2234    #[test]
2235    fn test_default_options() {
2236        let opts = ImportOptions::default();
2237        assert_eq!(opts.architecture, Architecture::Auto);
2238        assert_eq!(opts.validation, ValidationConfig::Strict);
2239        assert_eq!(opts.quantize, None);
2240        assert_eq!(opts.compress, None);
2241        assert!(!opts.force);
2242        assert!(opts.cache);
2243    }
2244}
2245
2246#[cfg(test)]
2247mod tests_conversion {
2248    use super::*;
2249
2250    fn create_test_safetensors(path: &Path, tensors: &BTreeMap<String, (Vec<f32>, Vec<usize>)>) {
2251        save_safetensors(path, tensors).expect("Failed to create test SafeTensors file");
2252    }
2253
2254    #[test]
2255    fn test_convert_valid_safetensors() {
2256        let input = "/tmp/test_valid_input.safetensors";
2257        let output = "/tmp/test_valid_output.apr";
2258
2259        // Create valid test tensors
2260        let mut tensors = BTreeMap::new();
2261        tensors.insert(
2262            "encoder.layer_norm.weight".to_string(),
2263            (vec![1.0f32; 384], vec![384]),
2264        );
2265        tensors.insert(
2266            "encoder.layer_norm.bias".to_string(),
2267            (vec![0.0f32; 384], vec![384]),
2268        );
2269        tensors.insert(
2270            "encoder.conv1.weight".to_string(),
2271            (vec![0.01f32; 1000], vec![80, 1, 3]),
2272        );
2273
2274        create_test_safetensors(Path::new(input), &tensors);
2275
2276        // Run conversion
2277        let options = ImportOptions::default();
2278        let result = apr_import(input, output, options);
2279
2280        assert!(
2281            result.is_ok(),
2282            "Valid tensors should convert successfully: {:?}",
2283            result.err()
2284        );
2285        let report = result.unwrap();
2286        assert!(report.total_score > 0, "Score should be > 0");
2287
2288        // Cleanup
2289        fs::remove_file(input).ok();
2290        fs::remove_file(output).ok();
2291    }
2292
2293    #[test]
2294    fn test_convert_invalid_layernorm_fails_strict() {
2295        let input = "/tmp/test_invalid_ln_input.safetensors";
2296        let output = "/tmp/test_invalid_ln_output.apr";
2297
2298        // Create tensors with INVALID LayerNorm (mean=11, should be ~1)
2299        let mut tensors = BTreeMap::new();
2300        tensors.insert(
2301            "decoder.layer_norm.weight".to_string(),
2302            (vec![11.0f32; 384], vec![384]),
2303        );
2304
2305        create_test_safetensors(Path::new(input), &tensors);
2306
2307        // Run conversion with strict validation
2308        let options = ImportOptions {
2309            validation: ValidationConfig::Strict,
2310            force: false,
2311            ..Default::default()
2312        };
2313        let result = apr_import(input, output, options);
2314
2315        assert!(
2316            result.is_err(),
2317            "Invalid LayerNorm should fail strict validation"
2318        );
2319        let err = result.unwrap_err().to_string();
2320        assert!(
2321            err.contains("mean=11") || err.contains("LayerNorm"),
2322            "Error should mention LayerNorm issue: {err}"
2323        );
2324
2325        // Cleanup
2326        fs::remove_file(input).ok();
2327        fs::remove_file(output).ok();
2328    }
2329
2330    #[test]
2331    fn test_convert_invalid_layernorm_force_succeeds() {
2332        let input = "/tmp/test_force_ln_input.safetensors";
2333        let output = "/tmp/test_force_ln_output.apr";
2334
2335        // Create tensors with invalid LayerNorm
2336        let mut tensors = BTreeMap::new();
2337        tensors.insert(
2338            "decoder.layer_norm.weight".to_string(),
2339            (vec![11.0f32; 384], vec![384]),
2340        );
2341
2342        create_test_safetensors(Path::new(input), &tensors);
2343
2344        // Run conversion with force=true
2345        let options = ImportOptions {
2346            validation: ValidationConfig::Strict,
2347            force: true,
2348            ..Default::default()
2349        };
2350        let result = apr_import(input, output, options);
2351
2352        assert!(
2353            result.is_ok(),
2354            "Force should bypass validation: {:?}",
2355            result.err()
2356        );
2357
2358        // Cleanup
2359        fs::remove_file(input).ok();
2360        fs::remove_file(output).ok();
2361    }
2362
2363    #[test]
2364    fn test_convert_nan_fails() {
2365        let input = "/tmp/test_nan_input.safetensors";
2366        let output = "/tmp/test_nan_output.apr";
2367
2368        // Create tensors with NaN
2369        let mut tensors = BTreeMap::new();
2370        tensors.insert(
2371            "test.weight".to_string(),
2372            (vec![1.0, f32::NAN, 3.0], vec![3]),
2373        );
2374
2375        create_test_safetensors(Path::new(input), &tensors);
2376
2377        let options = ImportOptions::default();
2378        let result = apr_import(input, output, options);
2379
2380        assert!(result.is_err(), "NaN should fail validation");
2381        let err = result.unwrap_err().to_string();
2382        assert!(err.contains("NaN"), "Error should mention NaN: {err}");
2383
2384        // Cleanup
2385        fs::remove_file(input).ok();
2386        fs::remove_file(output).ok();
2387    }
2388
2389    #[test]
2390    fn test_convert_nonexistent_file() {
2391        let result = apr_import(
2392            "/tmp/nonexistent_model.safetensors",
2393            "/tmp/out.apr",
2394            ImportOptions::default(),
2395        );
2396        assert!(result.is_err(), "Nonexistent file should fail");
2397        let err = result.unwrap_err().to_string();
2398        assert!(
2399            err.contains("not found") || err.contains("No such file"),
2400            "Error should mention file not found: {err}"
2401        );
2402    }
2403
2404    #[test]
2405    fn test_convert_unsupported_format() {
2406        let input = "/tmp/test_bad_format.gguf";
2407        fs::write(input, b"test").expect("Failed to create test file");
2408
2409        let result = apr_import(input, "/tmp/out.apr", ImportOptions::default());
2410        assert!(result.is_err(), "Unsupported format should fail");
2411        let err = result.unwrap_err().to_string();
2412        assert!(
2413            err.contains("GGUF") || err.contains("not yet"),
2414            "Error should mention unsupported: {err}"
2415        );
2416
2417        fs::remove_file(input).ok();
2418    }
2419
2420    #[test]
2421    fn test_name_mapping_whisper() {
2422        use crate::format::v2::AprV2Reader;
2423
2424        let input = "/tmp/test_whisper_input.safetensors";
2425        let output = "/tmp/test_whisper_output.apr";
2426
2427        // Create tensors with HuggingFace-style names
2428        let mut tensors = BTreeMap::new();
2429        tensors.insert(
2430            "model.encoder.conv1.weight".to_string(),
2431            (vec![0.01f32; 100], vec![10, 10]),
2432        );
2433        tensors.insert(
2434            "model.decoder.layer_norm.weight".to_string(),
2435            (vec![1.0f32; 384], vec![384]),
2436        );
2437
2438        create_test_safetensors(Path::new(input), &tensors);
2439
2440        let options = ImportOptions {
2441            architecture: Architecture::Whisper,
2442            ..Default::default()
2443        };
2444        let result = apr_import(input, output, options);
2445        assert!(
2446            result.is_ok(),
2447            "Whisper mapping should work: {:?}",
2448            result.err()
2449        );
2450
2451        // Load output as APR v2 and verify names are mapped
2452        let data = fs::read(output).expect("Failed to read output");
2453        let reader = AprV2Reader::from_bytes(&data).expect("Failed to parse APR v2");
2454        let tensor_names = reader.tensor_names();
2455
2456        assert!(
2457            tensor_names.contains(&"encoder.conv1.weight"),
2458            "Should strip 'model.' prefix, got: {:?}",
2459            tensor_names
2460        );
2461        assert!(
2462            tensor_names.contains(&"decoder.layer_norm.weight"),
2463            "Should strip 'model.' prefix, got: {:?}",
2464            tensor_names
2465        );
2466
2467        // Cleanup
2468        fs::remove_file(input).ok();
2469        fs::remove_file(output).ok();
2470    }
2471}
2472
2473#[cfg(test)]
2474mod tests_tensor_stats {
2475    use super::*;
2476
2477    #[test]
2478    fn test_compute_stats_basic() {
2479        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0];
2480        let stats = compute_tensor_stats("test", &data);
2481
2482        assert_eq!(stats.count, 5);
2483        assert!((stats.mean - 3.0).abs() < 0.001, "Mean should be 3.0");
2484        assert_eq!(stats.min, 1.0);
2485        assert_eq!(stats.max, 5.0);
2486        assert_eq!(stats.nan_count, 0);
2487        assert_eq!(stats.inf_count, 0);
2488    }
2489
2490    #[test]
2491    fn test_compute_stats_with_nan() {
2492        let data = vec![1.0f32, f32::NAN, 3.0];
2493        let stats = compute_tensor_stats("test", &data);
2494
2495        assert_eq!(stats.nan_count, 1);
2496        assert_eq!(stats.count, 3);
2497        // Mean computed from valid values only
2498        assert!(
2499            (stats.mean - 2.0).abs() < 0.001,
2500            "Mean should be 2.0 (from valid values)"
2501        );
2502    }
2503
2504    #[test]
2505    fn test_compute_stats_with_inf() {
2506        let data = vec![1.0f32, f32::INFINITY, f32::NEG_INFINITY, 3.0];
2507        let stats = compute_tensor_stats("test", &data);
2508
2509        assert_eq!(stats.inf_count, 2);
2510        assert!(
2511            (stats.mean - 2.0).abs() < 0.001,
2512            "Mean should be 2.0 (from valid values)"
2513        );
2514    }
2515
2516    #[test]
2517    fn test_compute_stats_empty() {
2518        let data: Vec<f32> = vec![];
2519        let stats = compute_tensor_stats("test", &data);
2520
2521        assert_eq!(stats.count, 0);
2522        assert_eq!(stats.mean, 0.0);
2523        assert_eq!(stats.std, 0.0);
2524    }
2525
2526    #[test]
2527    fn test_compute_stats_all_zeros() {
2528        let data = vec![0.0f32; 100];
2529        let stats = compute_tensor_stats("test", &data);
2530
2531        assert_eq!(stats.zero_count, 100);
2532        assert_eq!(stats.mean, 0.0);
2533    }
2534}
2535
2536#[cfg(test)]
2537mod tests_quantization {
2538    use super::*;
2539
2540    #[test]
2541    fn test_quantize_int8_basic() {
2542        let data = vec![1.0f32, -1.0, 0.5, -0.5, 0.0];
2543        let quantized = quantize_int8(&data);
2544
2545        assert_eq!(quantized.len(), data.len());
2546        // Values should be close but not exact due to quantization
2547        for (orig, quant) in data.iter().zip(quantized.iter()) {
2548            assert!((orig - quant).abs() < 0.02, "Quantization error too large");
2549        }
2550    }
2551
2552    #[test]
2553    fn test_quantize_int8_preserves_zeros() {
2554        let data = vec![0.0f32; 10];
2555        let quantized = quantize_int8(&data);
2556        assert!(
2557            quantized.iter().all(|&v| v == 0.0),
2558            "Zeros should remain zeros"
2559        );
2560    }
2561
2562    #[test]
2563    fn test_quantize_int8_empty() {
2564        let data: Vec<f32> = vec![];
2565        let quantized = quantize_int8(&data);
2566        assert!(quantized.is_empty());
2567    }
2568
2569    #[test]
2570    fn test_quantize_int4_basic() {
2571        let data = vec![1.0f32, -1.0, 0.5, -0.5, 0.0];
2572        let quantized = quantize_int4(&data);
2573
2574        assert_eq!(quantized.len(), data.len());
2575        // Int4 has more error than int8
2576        for (orig, quant) in data.iter().zip(quantized.iter()) {
2577            assert!(
2578                (orig - quant).abs() < 0.2,
2579                "Int4 quantization error too large"
2580            );
2581        }
2582    }
2583
2584    #[test]
2585    fn test_quantize_fp16_basic() {
2586        let data = vec![1.0f32, -1.0, 0.5, -0.5, 0.0, 0.123456789];
2587        let quantized = quantize_fp16(&data);
2588
2589        assert_eq!(quantized.len(), data.len());
2590        // FP16 should have minimal error for simple values
2591        assert_eq!(quantized[0], 1.0);
2592        assert_eq!(quantized[1], -1.0);
2593        assert_eq!(quantized[4], 0.0);
2594    }
2595
2596    #[test]
2597    fn test_quantize_tensors_int8() {
2598        let mut tensors = BTreeMap::new();
2599        tensors.insert("test".to_string(), (vec![1.0f32, -1.0, 0.5], vec![3]));
2600
2601        let result = quantize_tensors(&tensors, &QuantizationType::Int8).unwrap();
2602
2603        assert_eq!(result.len(), 1);
2604        assert!(result.contains_key("test"));
2605        let (data, shape) = result.get("test").unwrap();
2606        assert_eq!(shape, &vec![3]);
2607        assert_eq!(data.len(), 3);
2608    }
2609}
2610
2611#[cfg(test)]
2612mod tests_convert {
2613    use super::*;
2614
2615    fn create_test_model(path: &Path) {
2616        let mut tensors = BTreeMap::new();
2617        tensors.insert(
2618            "encoder.weight".to_string(),
2619            (vec![0.01f32; 1000], vec![100, 10]),
2620        );
2621        tensors.insert("encoder.bias".to_string(), (vec![0.0f32; 100], vec![100]));
2622        tensors.insert(
2623            "decoder.weight".to_string(),
2624            (vec![0.02f32; 500], vec![50, 10]),
2625        );
2626        save_safetensors(path, &tensors).expect("Failed to create test model");
2627    }
2628
2629    #[test]
2630    fn test_convert_no_quantization() {
2631        let input = Path::new("/tmp/test_convert_input.safetensors");
2632        let output = Path::new("/tmp/test_convert_output.apr");
2633
2634        create_test_model(input);
2635
2636        let options = ConvertOptions::default();
2637        let result = apr_convert(input, output, options);
2638
2639        assert!(
2640            result.is_ok(),
2641            "Convert without quantization should work: {:?}",
2642            result.err()
2643        );
2644        let report = result.unwrap();
2645        assert_eq!(report.tensor_count, 3);
2646        assert!(report.quantization.is_none());
2647
2648        fs::remove_file(input).ok();
2649        fs::remove_file(output).ok();
2650    }
2651
2652    #[test]
2653    fn test_convert_with_int8_quantization() {
2654        let input = Path::new("/tmp/test_convert_int8_input.safetensors");
2655        let output = Path::new("/tmp/test_convert_int8_output.apr");
2656
2657        create_test_model(input);
2658
2659        let options = ConvertOptions {
2660            quantize: Some(QuantizationType::Int8),
2661            ..Default::default()
2662        };
2663        let result = apr_convert(input, output, options);
2664
2665        assert!(
2666            result.is_ok(),
2667            "Int8 quantization should work: {:?}",
2668            result.err()
2669        );
2670        let report = result.unwrap();
2671        assert_eq!(report.quantization, Some(QuantizationType::Int8));
2672        assert_eq!(report.tensor_count, 3);
2673
2674        fs::remove_file(input).ok();
2675        fs::remove_file(output).ok();
2676    }
2677
2678    #[test]
2679    fn test_convert_with_fp16_quantization() {
2680        let input = Path::new("/tmp/test_convert_fp16_input.safetensors");
2681        let output = Path::new("/tmp/test_convert_fp16_output.apr");
2682
2683        create_test_model(input);
2684
2685        let options = ConvertOptions {
2686            quantize: Some(QuantizationType::Fp16),
2687            ..Default::default()
2688        };
2689        let result = apr_convert(input, output, options);
2690
2691        assert!(
2692            result.is_ok(),
2693            "FP16 quantization should work: {:?}",
2694            result.err()
2695        );
2696
2697        fs::remove_file(input).ok();
2698        fs::remove_file(output).ok();
2699    }
2700
2701    #[test]
2702    fn test_convert_nonexistent_file() {
2703        let options = ConvertOptions::default();
2704        let result = apr_convert("/tmp/nonexistent.safetensors", "/tmp/out.apr", options);
2705
2706        assert!(result.is_err(), "Nonexistent file should fail");
2707    }
2708
2709    #[test]
2710    fn test_convert_report_reduction_percent() {
2711        let report = ConvertReport {
2712            original_size: 1000,
2713            converted_size: 250,
2714            tensor_count: 5,
2715            quantization: Some(QuantizationType::Int8),
2716            compression: None,
2717            reduction_ratio: 4.0,
2718        };
2719
2720        assert_eq!(report.reduction_percent(), "75.0%");
2721    }
2722
2723    #[test]
2724    fn test_convert_options_default() {
2725        let options = ConvertOptions::default();
2726        assert!(options.quantize.is_none());
2727        assert!(options.compress.is_none());
2728        assert!(options.validate);
2729    }
2730}
2731
2732// ============================================================================
2733// GH-127: Multi-tensor (sharded) model import tests
2734// ============================================================================
2735
2736#[cfg(test)]
2737mod tests_sharded_import {
2738    use super::*;
2739
2740    #[test]
2741    fn test_sharded_index_parse_valid() {
2742        let json = r#"{
2743            "metadata": {"total_size": 1000000},
2744            "weight_map": {
2745                "encoder.conv1.weight": "model-00001-of-00002.safetensors",
2746                "encoder.conv2.weight": "model-00001-of-00002.safetensors",
2747                "decoder.fc.weight": "model-00002-of-00002.safetensors"
2748            }
2749        }"#;
2750
2751        let index = ShardedIndex::parse(json).expect("Valid index should parse");
2752        assert_eq!(index.shard_count(), 2);
2753        assert_eq!(index.tensor_count(), 3);
2754        assert!(index.total_size().is_some());
2755    }
2756
2757    #[test]
2758    fn test_sharded_index_shard_for_tensor() {
2759        let json = r#"{
2760            "weight_map": {
2761                "encoder.weight": "shard-00001.safetensors",
2762                "decoder.weight": "shard-00002.safetensors"
2763            }
2764        }"#;
2765
2766        let index = ShardedIndex::parse(json).unwrap();
2767        assert_eq!(
2768            index.shard_for_tensor("encoder.weight"),
2769            Some("shard-00001.safetensors")
2770        );
2771        assert_eq!(
2772            index.shard_for_tensor("decoder.weight"),
2773            Some("shard-00002.safetensors")
2774        );
2775        assert_eq!(index.shard_for_tensor("unknown"), None);
2776    }
2777
2778    #[test]
2779    fn test_sharded_index_tensors_in_shard() {
2780        let json = r#"{
2781            "weight_map": {
2782                "a": "shard1.safetensors",
2783                "b": "shard1.safetensors",
2784                "c": "shard2.safetensors"
2785            }
2786        }"#;
2787
2788        let index = ShardedIndex::parse(json).unwrap();
2789        let shard1_tensors = index.tensors_in_shard("shard1.safetensors");
2790        assert_eq!(shard1_tensors.len(), 2);
2791        assert!(shard1_tensors.contains(&"a"));
2792        assert!(shard1_tensors.contains(&"b"));
2793    }
2794
2795    #[test]
2796    fn test_sharded_index_parse_invalid_json() {
2797        let result = ShardedIndex::parse("not valid json");
2798        assert!(result.is_err());
2799    }
2800
2801    #[test]
2802    fn test_sharded_index_parse_missing_weight_map() {
2803        let result = ShardedIndex::parse(r#"{"metadata": {}}"#);
2804        assert!(result.is_err());
2805    }
2806
2807    #[test]
2808    fn test_detect_sharded_model_index_exists() {
2809        // Create a temp dir with index.json
2810        let dir = tempfile::tempdir().unwrap();
2811        let index_path = dir.path().join("model.safetensors.index.json");
2812        fs::write(&index_path, r#"{"weight_map": {"a": "shard.safetensors"}}"#).unwrap();
2813
2814        let result = detect_sharded_model(dir.path(), "model.safetensors");
2815        assert!(result.is_some());
2816    }
2817
2818    #[test]
2819    fn test_detect_sharded_model_single_file() {
2820        let dir = tempfile::tempdir().unwrap();
2821        let model_path = dir.path().join("model.safetensors");
2822        fs::write(&model_path, &[0u8; 8]).unwrap(); // Minimal file
2823
2824        let result = detect_sharded_model(dir.path(), "model.safetensors");
2825        assert!(result.is_none(), "Single file should not be sharded");
2826    }
2827
2828    #[test]
2829    fn test_sharded_index_shard_files_sorted() {
2830        let json = r#"{
2831            "weight_map": {
2832                "a": "model-00002-of-00003.safetensors",
2833                "b": "model-00001-of-00003.safetensors",
2834                "c": "model-00003-of-00003.safetensors"
2835            }
2836        }"#;
2837
2838        let index = ShardedIndex::parse(json).unwrap();
2839        let shards = index.shard_files();
2840        assert_eq!(shards[0], "model-00001-of-00003.safetensors");
2841        assert_eq!(shards[1], "model-00002-of-00003.safetensors");
2842        assert_eq!(shards[2], "model-00003-of-00003.safetensors");
2843    }
2844}
2845
2846// ============================================================================
2847// GH-129: Import error message tests
2848// ============================================================================
2849
2850#[cfg(test)]
2851mod tests_import_errors {
2852    use super::*;
2853
2854    #[test]
2855    fn test_import_error_not_found_message() {
2856        let err = ImportError::NotFound {
2857            resource: "openai/whisper-tiny".to_string(),
2858            status: 404,
2859        };
2860        let msg = err.to_string();
2861        assert!(msg.contains("404"), "Should include status code");
2862        assert!(msg.contains("whisper-tiny"), "Should include resource");
2863    }
2864
2865    #[test]
2866    fn test_import_error_rate_limited_message() {
2867        let err = ImportError::RateLimited {
2868            retry_after: Some(60),
2869        };
2870        let msg = err.to_string();
2871        assert!(
2872            msg.to_lowercase().contains("rate"),
2873            "Should mention rate limit"
2874        );
2875        assert!(msg.contains("60"), "Should include retry time");
2876    }
2877
2878    #[test]
2879    fn test_import_error_auth_required_message() {
2880        let err = ImportError::AuthRequired {
2881            resource: "meta-llama/Llama-2-7b".to_string(),
2882        };
2883        let msg = err.to_string();
2884        assert!(msg.contains("HF_TOKEN"), "Should suggest HF_TOKEN");
2885        assert!(msg.contains("Llama-2-7b"), "Should include resource");
2886    }
2887
2888    #[test]
2889    fn test_import_error_actionable_suggestions() {
2890        let err = ImportError::NotFound {
2891            resource: "openai/whisper-tiny".to_string(),
2892            status: 404,
2893        };
2894
2895        // Error should provide actionable fix
2896        let msg = err.to_string();
2897        assert!(
2898            msg.contains("Fix:") || msg.contains("check") || msg.contains("verify"),
2899            "Error should be actionable"
2900        );
2901    }
2902
2903    #[test]
2904    fn test_import_error_sharding_oom() {
2905        let err = ImportError::ShardingRequired {
2906            model_size: 14_000_000_000, // 14GB
2907            shard_count: 7,
2908        };
2909        let msg = err.to_string();
2910        assert!(msg.contains("14"), "Should include size");
2911        assert!(msg.contains("7"), "Should include shard count");
2912    }
2913
2914    // GH-129: Tests for parse_import_error (only when hf-hub-integration enabled)
2915    #[cfg(feature = "hf-hub-integration")]
2916    #[test]
2917    fn test_parse_import_error_404() {
2918        let err = parse_import_error(
2919            "HTTP 404: Repository not found",
2920            "openai/whisper-tiny",
2921        );
2922        match err {
2923            ImportError::NotFound { resource, status } => {
2924                assert_eq!(resource, "openai/whisper-tiny");
2925                assert_eq!(status, 404);
2926            }
2927            _ => panic!("Expected NotFound error, got {:?}", err),
2928        }
2929    }
2930
2931    #[cfg(feature = "hf-hub-integration")]
2932    #[test]
2933    fn test_parse_import_error_not_found_text() {
2934        let err = parse_import_error(
2935            "The requested resource does not exist",
2936            "test/model",
2937        );
2938        match err {
2939            ImportError::NotFound { .. } => {}
2940            _ => panic!("Expected NotFound error, got {:?}", err),
2941        }
2942    }
2943
2944    #[cfg(feature = "hf-hub-integration")]
2945    #[test]
2946    fn test_parse_import_error_401() {
2947        let err = parse_import_error(
2948            "HTTP 401: Unauthorized access",
2949            "meta-llama/Llama-2-7b",
2950        );
2951        match err {
2952            ImportError::AuthRequired { resource } => {
2953                assert_eq!(resource, "meta-llama/Llama-2-7b");
2954            }
2955            _ => panic!("Expected AuthRequired error, got {:?}", err),
2956        }
2957    }
2958
2959    #[cfg(feature = "hf-hub-integration")]
2960    #[test]
2961    fn test_parse_import_error_gated_model() {
2962        let err = parse_import_error(
2963            "This model is gated. Access requires acceptance.",
2964            "meta-llama/Llama-2-7b",
2965        );
2966        match err {
2967            ImportError::AuthRequired { .. } => {}
2968            _ => panic!("Expected AuthRequired error, got {:?}", err),
2969        }
2970    }
2971
2972    #[cfg(feature = "hf-hub-integration")]
2973    #[test]
2974    fn test_parse_import_error_429() {
2975        let err = parse_import_error(
2976            "HTTP 429: Too many requests. Retry after 60 seconds.",
2977            "test/model",
2978        );
2979        match err {
2980            ImportError::RateLimited { retry_after } => {
2981                assert_eq!(retry_after, Some(60));
2982            }
2983            _ => panic!("Expected RateLimited error, got {:?}", err),
2984        }
2985    }
2986
2987    #[cfg(feature = "hf-hub-integration")]
2988    #[test]
2989    fn test_parse_import_error_rate_limit_no_retry() {
2990        let err = parse_import_error(
2991            "Rate limit exceeded",
2992            "test/model",
2993        );
2994        match err {
2995            ImportError::RateLimited { retry_after } => {
2996                assert_eq!(retry_after, None);
2997            }
2998            _ => panic!("Expected RateLimited error, got {:?}", err),
2999        }
3000    }
3001
3002    #[cfg(feature = "hf-hub-integration")]
3003    #[test]
3004    fn test_parse_import_error_generic() {
3005        let err = parse_import_error(
3006            "Connection timeout",
3007            "test/model",
3008        );
3009        match err {
3010            ImportError::DownloadFailed { source, reason } => {
3011                assert_eq!(source, "test/model");
3012                assert_eq!(reason, "Connection timeout");
3013            }
3014            _ => panic!("Expected DownloadFailed error, got {:?}", err),
3015        }
3016    }
3017
3018    #[test]
3019    fn test_import_error_from_conversion() {
3020        let import_err = ImportError::NotFound {
3021            resource: "test".to_string(),
3022            status: 404,
3023        };
3024        let aprender_err: AprenderError = import_err.into();
3025        let msg = aprender_err.to_string();
3026        assert!(msg.contains("404"));
3027        assert!(msg.contains("test"));
3028    }
3029}