Skip to main content

pacha/
format.rs

1//! Model Format Detection and Metadata
2//!
3//! Detects model file formats and extracts metadata from model files.
4//!
5//! ## Supported Formats
6//!
7//! - **GGUF**: GGML Universal Format (llama.cpp, ollama)
8//! - **SafeTensors**: HuggingFace safe tensor format
9//! - **APR**: Aprender native format
10//! - **ONNX**: Open Neural Network Exchange
11//! - **PyTorch**: `.pt`/`.pth` files (detection only)
12//!
13//! ## Example
14//!
15//! ```rust,ignore
16//! use pacha::format::{detect_format, ModelFormat};
17//!
18//! let format = detect_format(&data)?;
19//! match format {
20//!     ModelFormat::Gguf(info) => println!("GGUF: {} params", info.parameters),
21//!     ModelFormat::SafeTensors(info) => println!("SafeTensors: {} tensors", info.tensor_count),
22//!     _ => println!("Other format"),
23//! }
24//! ```
25
26use serde::{Deserialize, Serialize};
27use std::collections::HashMap;
28
29// ============================================================================
30// FMT-001: Model Format Enum
31// ============================================================================
32
33/// Detected model format
34#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
35pub enum ModelFormat {
36    /// GGUF format (llama.cpp)
37    Gguf(GgufInfo),
38    /// SafeTensors format (HuggingFace)
39    SafeTensors(SafeTensorsInfo),
40    /// Aprender native format
41    Apr(AprInfo),
42    /// ONNX format
43    Onnx(OnnxInfo),
44    /// PyTorch format (limited detection)
45    PyTorch,
46    /// Unknown format
47    Unknown,
48}
49
50impl ModelFormat {
51    /// Get format name
52    #[must_use]
53    pub fn name(&self) -> &'static str {
54        match self {
55            Self::Gguf(_) => "GGUF",
56            Self::SafeTensors(_) => "SafeTensors",
57            Self::Apr(_) => "APR",
58            Self::Onnx(_) => "ONNX",
59            Self::PyTorch => "PyTorch",
60            Self::Unknown => "Unknown",
61        }
62    }
63
64    /// Get file extension
65    #[must_use]
66    pub fn extension(&self) -> &'static str {
67        match self {
68            Self::Gguf(_) => ".gguf",
69            Self::SafeTensors(_) => ".safetensors",
70            Self::Apr(_) => ".apr",
71            Self::Onnx(_) => ".onnx",
72            Self::PyTorch => ".pt",
73            Self::Unknown => "",
74        }
75    }
76
77    /// Check if format is quantized
78    #[must_use]
79    pub fn is_quantized(&self) -> bool {
80        match self {
81            Self::Gguf(info) => info.quantization.is_some(),
82            Self::Apr(info) => info.quantization.is_some(),
83            _ => false,
84        }
85    }
86
87    /// Get quantization type if available
88    #[must_use]
89    pub fn quantization(&self) -> Option<&str> {
90        match self {
91            Self::Gguf(info) => info.quantization.as_deref(),
92            Self::Apr(info) => info.quantization.as_deref(),
93            _ => None,
94        }
95    }
96
97    /// Get parameter count if available
98    #[must_use]
99    pub fn parameters(&self) -> Option<u64> {
100        match self {
101            Self::Gguf(info) => info.parameters,
102            Self::SafeTensors(info) => info.parameters,
103            Self::Apr(info) => info.parameters,
104            Self::Onnx(info) => info.parameters,
105            Self::PyTorch | Self::Unknown => None,
106        }
107    }
108}
109
110// ============================================================================
111// FMT-002: Format-Specific Info
112// ============================================================================
113
114/// GGUF file information
115#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
116pub struct GgufInfo {
117    /// GGUF version
118    pub version: u32,
119    /// Number of tensors
120    pub tensor_count: u64,
121    /// Number of metadata key-value pairs
122    pub metadata_count: u64,
123    /// Model architecture (e.g., "llama", "mistral")
124    pub architecture: Option<String>,
125    /// Quantization type (e.g., "Q4_K_M", "Q8_0")
126    pub quantization: Option<String>,
127    /// Context length
128    pub context_length: Option<u32>,
129    /// Embedding dimension
130    pub embedding_dim: Option<u32>,
131    /// Number of layers
132    pub num_layers: Option<u32>,
133    /// Number of attention heads
134    pub num_heads: Option<u32>,
135    /// Vocabulary size
136    pub vocab_size: Option<u32>,
137    /// Estimated parameter count
138    pub parameters: Option<u64>,
139    /// Model name from metadata
140    pub name: Option<String>,
141    /// Author from metadata
142    pub author: Option<String>,
143    /// License from metadata
144    pub license: Option<String>,
145}
146
147impl Default for GgufInfo {
148    fn default() -> Self {
149        Self {
150            version: 0,
151            tensor_count: 0,
152            metadata_count: 0,
153            architecture: None,
154            quantization: None,
155            context_length: None,
156            embedding_dim: None,
157            num_layers: None,
158            num_heads: None,
159            vocab_size: None,
160            parameters: None,
161            name: None,
162            author: None,
163            license: None,
164        }
165    }
166}
167
168/// SafeTensors file information
169#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
170pub struct SafeTensorsInfo {
171    /// Number of tensors
172    pub tensor_count: usize,
173    /// Tensor names and shapes
174    pub tensors: HashMap<String, TensorInfo>,
175    /// Metadata from header
176    pub metadata: HashMap<String, String>,
177    /// Estimated parameter count
178    pub parameters: Option<u64>,
179    /// Data type
180    pub dtype: Option<String>,
181}
182
183impl Default for SafeTensorsInfo {
184    fn default() -> Self {
185        Self {
186            tensor_count: 0,
187            tensors: HashMap::new(),
188            metadata: HashMap::new(),
189            parameters: None,
190            dtype: None,
191        }
192    }
193}
194
195/// Tensor information
196#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
197pub struct TensorInfo {
198    /// Tensor shape
199    pub shape: Vec<usize>,
200    /// Data type
201    pub dtype: String,
202    /// Offset in file
203    pub offset: usize,
204}
205
206/// APR (Aprender) file information
207#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
208pub struct AprInfo {
209    /// APR version
210    pub version: u32,
211    /// Model type (e.g., "LogisticRegression")
212    pub model_type: String,
213    /// Quantization type
214    pub quantization: Option<String>,
215    /// Compressed
216    pub compressed: bool,
217    /// Encrypted
218    pub encrypted: bool,
219    /// Signed
220    pub signed: bool,
221    /// Parameter count
222    pub parameters: Option<u64>,
223    /// CRC32 checksum
224    pub checksum: Option<u32>,
225}
226
227impl Default for AprInfo {
228    fn default() -> Self {
229        Self {
230            version: 0,
231            model_type: String::new(),
232            quantization: None,
233            compressed: false,
234            encrypted: false,
235            signed: false,
236            parameters: None,
237            checksum: None,
238        }
239    }
240}
241
242/// ONNX file information
243#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
244pub struct OnnxInfo {
245    /// ONNX IR version
246    pub ir_version: u64,
247    /// Producer name
248    pub producer_name: Option<String>,
249    /// Producer version
250    pub producer_version: Option<String>,
251    /// Model description
252    pub description: Option<String>,
253    /// Number of nodes
254    pub node_count: usize,
255    /// Estimated parameters
256    pub parameters: Option<u64>,
257}
258
259impl Default for OnnxInfo {
260    fn default() -> Self {
261        Self {
262            ir_version: 0,
263            producer_name: None,
264            producer_version: None,
265            description: None,
266            node_count: 0,
267            parameters: None,
268        }
269    }
270}
271
272// ============================================================================
273// FMT-003: Format Detection
274// ============================================================================
275
276/// Magic bytes for format detection
277mod magic {
278    /// GGUF magic bytes ("GGUF")
279    pub(super) const GGUF: [u8; 4] = [0x47, 0x47, 0x55, 0x46];
280    /// SafeTensors starts with JSON header size (little-endian u64)
281    pub(super) const _SAFETENSORS_MIN_HEADER: u64 = 8;
282    /// APR magic bytes ("APR\0")
283    pub(super) const APR: [u8; 4] = [0x41, 0x50, 0x52, 0x00];
284    /// ONNX (protobuf) magic
285    pub(super) const ONNX: [u8; 2] = [0x08, 0x00]; // Protobuf field 1, varint
286    /// PyTorch magic (PK zip for newer, 0x80 for older pickle)
287    pub(super) const PYTORCH_ZIP: [u8; 2] = [0x50, 0x4B];
288    pub(super) const PYTORCH_PICKLE: u8 = 0x80;
289}
290
291/// Detect model format from bytes
292///
293/// # Arguments
294///
295/// * `data` - At least first 1KB of the file
296///
297/// # Returns
298///
299/// Detected `ModelFormat` with extracted metadata
300#[must_use]
301pub fn detect_format(data: &[u8]) -> ModelFormat {
302    if data.len() < 8 {
303        return ModelFormat::Unknown;
304    }
305
306    // Check GGUF magic
307    if data[..4] == magic::GGUF {
308        return parse_gguf_header(data);
309    }
310
311    // Check APR magic
312    if data[..4] == magic::APR {
313        return parse_apr_header(data);
314    }
315
316    // Try SafeTensors BEFORE PyTorch: SafeTensors has a more specific signature
317    // (u64 header size + valid JSON), while PyTorch pickle is just data[0]==0x80.
318    // SafeTensors files whose header_size has low byte 0x80 would otherwise be
319    // misidentified as PyTorch pickle. (Fixes pacha#4)
320    if let Some(info) = try_parse_safetensors(data) {
321        return ModelFormat::SafeTensors(info);
322    }
323
324    // Check PyTorch (zip or pickle) — AFTER SafeTensors to avoid false positives
325    if data[..2] == magic::PYTORCH_ZIP || data[0] == magic::PYTORCH_PICKLE {
326        return ModelFormat::PyTorch;
327    }
328
329    // Try ONNX (protobuf)
330    if data[0] == magic::ONNX[0] {
331        if let Some(info) = try_parse_onnx(data) {
332            return ModelFormat::Onnx(info);
333        }
334    }
335
336    ModelFormat::Unknown
337}
338
339/// Known format extensions and their names
340const FORMAT_EXTENSIONS: &[(&str, &str)] = &[
341    (".gguf", "GGUF"),
342    (".safetensors", "SafeTensors"),
343    (".apr", "APR"),
344    (".onnx", "ONNX"),
345    (".pt", "PyTorch"),
346    (".pth", "PyTorch"),
347    (".bin", "Binary"),
348];
349
350/// Detect format from file path extension
351#[must_use]
352pub fn detect_format_from_path(path: &str) -> Option<&'static str> {
353    let path_lower = path.to_lowercase();
354    FORMAT_EXTENSIONS.iter().find(|(ext, _)| path_lower.ends_with(ext)).map(|(_, name)| *name)
355}
356
357/// Parse GGUF header
358fn parse_gguf_header(data: &[u8]) -> ModelFormat {
359    if data.len() < 24 {
360        return ModelFormat::Gguf(GgufInfo::default());
361    }
362
363    // GGUF header format:
364    // 0-3: magic "GGUF"
365    // 4-7: version (u32 LE)
366    // 8-15: tensor_count (u64 LE)
367    // 16-23: metadata_kv_count (u64 LE)
368
369    let version = u32::from_le_bytes([data[4], data[5], data[6], data[7]]);
370    let tensor_count = u64::from_le_bytes([
371        data[8], data[9], data[10], data[11], data[12], data[13], data[14], data[15],
372    ]);
373    let metadata_count = u64::from_le_bytes([
374        data[16], data[17], data[18], data[19], data[20], data[21], data[22], data[23],
375    ]);
376
377    // For full metadata parsing, we'd need to parse the key-value pairs
378    // This is a simplified version that just extracts the header info
379
380    ModelFormat::Gguf(GgufInfo { version, tensor_count, metadata_count, ..Default::default() })
381}
382
383/// Parse APR header
384fn parse_apr_header(data: &[u8]) -> ModelFormat {
385    if data.len() < 16 {
386        return ModelFormat::Apr(AprInfo::default());
387    }
388
389    // APR header format:
390    // 0-3: magic "APR\0"
391    // 4-7: version (u32 LE)
392    // 8-11: flags (u32 LE)
393    // 12-15: model_type_len (u32 LE)
394
395    let version = u32::from_le_bytes([data[4], data[5], data[6], data[7]]);
396    let flags = u32::from_le_bytes([data[8], data[9], data[10], data[11]]);
397
398    let compressed = (flags & 0x01) != 0;
399    let encrypted = (flags & 0x02) != 0;
400    let signed = (flags & 0x04) != 0;
401
402    // Extract model type if we have enough data
403    let model_type_len = u32::from_le_bytes([data[12], data[13], data[14], data[15]]) as usize;
404    let model_type = if data.len() >= 16 + model_type_len {
405        String::from_utf8_lossy(&data[16..16 + model_type_len]).to_string()
406    } else {
407        String::new()
408    };
409
410    ModelFormat::Apr(AprInfo {
411        version,
412        model_type,
413        compressed,
414        encrypted,
415        signed,
416        ..Default::default()
417    })
418}
419
420/// Read SafeTensors header JSON from raw bytes
421fn read_safetensors_header(data: &[u8]) -> Option<&[u8]> {
422    if data.len() < 8 {
423        return None;
424    }
425
426    let header_size = u64::from_le_bytes([
427        data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
428    ]) as usize;
429
430    if header_size == 0 || header_size > 100_000_000 {
431        return None;
432    }
433
434    if data.len() < 8 + header_size {
435        return None;
436    }
437
438    let header_json = &data[8..8 + header_size];
439    if header_json.first() != Some(&b'{') {
440        return None;
441    }
442    Some(header_json)
443}
444
445/// Extract __metadata__ string values from a parsed SafeTensors header
446fn extract_metadata(header: &HashMap<String, serde_json::Value>, info: &mut SafeTensorsInfo) {
447    let Some(meta) = header.get("__metadata__") else { return };
448    let Some(obj) = meta.as_object() else { return };
449    for (k, v) in obj {
450        if let Some(s) = v.as_str() {
451            info.metadata.insert(k.clone(), s.to_string());
452        }
453    }
454}
455
456/// Try to parse SafeTensors header
457fn try_parse_safetensors(data: &[u8]) -> Option<SafeTensorsInfo> {
458    // Check for partial header (enough to identify format, not enough to parse)
459    if data.len() >= 8 {
460        let header_size = u64::from_le_bytes([
461            data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
462        ]) as usize;
463        if header_size > 0
464            && header_size <= 100_000_000
465            && data.len() < 8 + header_size
466            && data.get(8) == Some(&b'{')
467        {
468            return Some(SafeTensorsInfo { tensor_count: 0, ..Default::default() });
469        }
470    }
471
472    let header_json = read_safetensors_header(data)?;
473    let header: HashMap<String, serde_json::Value> = serde_json::from_slice(header_json).ok()?;
474
475    let mut info = SafeTensorsInfo::default();
476    info.tensor_count = header.keys().filter(|k| *k != "__metadata__").count();
477    extract_metadata(&header, &mut info);
478    info.parameters = Some(extract_tensor_info(&header, &mut info));
479    Some(info)
480}
481
482/// Extract tensor info from a SafeTensors JSON header
483fn extract_tensor_info(
484    header: &HashMap<String, serde_json::Value>,
485    info: &mut SafeTensorsInfo,
486) -> u64 {
487    let mut total_params: u64 = 0;
488    for (name, value) in header {
489        if name == "__metadata__" {
490            continue;
491        }
492        let Some(obj) = value.as_object() else {
493            continue;
494        };
495        let (Some(dtype), Some(shape)) = (obj.get("dtype"), obj.get("shape")) else {
496            continue;
497        };
498        let dtype_str = dtype.as_str().unwrap_or("F32").to_string();
499        let shape_vec: Vec<usize> = shape
500            .as_array()
501            .map(|arr| arr.iter().filter_map(|v| v.as_u64().map(|n| n as usize)).collect())
502            .unwrap_or_default();
503
504        let elements: u64 = shape_vec.iter().map(|&s| s as u64).product();
505        total_params += elements;
506
507        info.tensors.insert(
508            name.clone(),
509            TensorInfo { shape: shape_vec, dtype: dtype_str.clone(), offset: 0 },
510        );
511
512        if info.dtype.is_none() {
513            info.dtype = Some(dtype_str);
514        }
515    }
516    total_params
517}
518
519/// Try to parse ONNX header (simplified)
520fn try_parse_onnx(data: &[u8]) -> Option<OnnxInfo> {
521    // ONNX uses protobuf format
522    // This is a simplified detection that just checks for valid protobuf structure
523
524    if data.len() < 16 {
525        return None;
526    }
527
528    // Very basic protobuf field detection
529    // Field 1 (ir_version) should be present at the start
530    if data[0] != 0x08 {
531        return None;
532    }
533
534    // Read varint for ir_version
535    let (ir_version, _) = read_varint(&data[1..])?;
536
537    Some(OnnxInfo { ir_version, ..Default::default() })
538}
539
540/// Read a protobuf varint
541fn read_varint(data: &[u8]) -> Option<(u64, usize)> {
542    let mut result: u64 = 0;
543    let mut shift = 0;
544
545    for (i, &byte) in data.iter().enumerate() {
546        if i >= 10 {
547            return None; // Varint too long
548        }
549
550        result |= ((byte & 0x7F) as u64) << shift;
551        shift += 7;
552
553        if byte & 0x80 == 0 {
554            return Some((result, i + 1));
555        }
556    }
557
558    None
559}
560
561// ============================================================================
562// FMT-004: Quantization Types
563// ============================================================================
564
565/// Common quantization types (GGUF spec naming convention)
566#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
567#[allow(non_camel_case_types)]
568pub enum QuantType {
569    /// Full precision (FP32)
570    F32,
571    /// Half precision (FP16)
572    F16,
573    /// Brain floating point (BF16)
574    BF16,
575    /// 8-bit integer
576    Q8_0,
577    /// 8-bit with K-quants
578    Q8_K,
579    /// 6-bit K-quants
580    Q6_K,
581    /// 5-bit K-quants (small)
582    Q5_K_S,
583    /// 5-bit K-quants (medium)
584    Q5_K_M,
585    /// 5-bit (legacy)
586    Q5_0,
587    /// 5-bit with 1 (legacy)
588    Q5_1,
589    /// 4-bit K-quants (small)
590    Q4_K_S,
591    /// 4-bit K-quants (medium)
592    Q4_K_M,
593    /// 4-bit (legacy)
594    Q4_0,
595    /// 4-bit with 1 (legacy)
596    Q4_1,
597    /// 3-bit K-quants (small)
598    Q3_K_S,
599    /// 3-bit K-quants (medium)
600    Q3_K_M,
601    /// 3-bit K-quants (large)
602    Q3_K_L,
603    /// 2-bit K-quants (small)
604    Q2_K_S,
605    /// 2-bit K-quants
606    Q2_K,
607    /// Importance-weighted 4-bit (non-linear)
608    IQ4_NL,
609    /// Importance-weighted 4-bit (extra small)
610    IQ4_XS,
611    /// Importance-weighted 3-bit (small)
612    IQ3_S,
613    /// Importance-weighted 3-bit (medium)
614    IQ3_M,
615    /// Importance-weighted 3-bit (extra small)
616    IQ3_XS,
617    /// Importance-weighted 3-bit (extra extra small)
618    IQ3_XXS,
619    /// Importance-weighted 2-bit (small)
620    IQ2_S,
621    /// Importance-weighted 2-bit (extra small)
622    IQ2_XS,
623    /// Importance-weighted 2-bit (extra extra small)
624    IQ2_XXS,
625    /// Importance-weighted 1-bit (small)
626    IQ1_S,
627    /// Importance-weighted 1-bit (medium)
628    IQ1_M,
629}
630
631impl QuantType {
632    /// Parse quantization type from string
633    #[must_use]
634    pub fn from_str(s: &str) -> Option<Self> {
635        let s = s.to_uppercase();
636        match s.as_str() {
637            "F32" | "FP32" => Some(Self::F32),
638            "F16" | "FP16" => Some(Self::F16),
639            "BF16" => Some(Self::BF16),
640            "Q8_0" => Some(Self::Q8_0),
641            "Q8_K" => Some(Self::Q8_K),
642            "Q6_K" => Some(Self::Q6_K),
643            "Q5_K_S" => Some(Self::Q5_K_S),
644            "Q5_K_M" => Some(Self::Q5_K_M),
645            "Q5_0" => Some(Self::Q5_0),
646            "Q5_1" => Some(Self::Q5_1),
647            "Q4_K_S" => Some(Self::Q4_K_S),
648            "Q4_K_M" => Some(Self::Q4_K_M),
649            "Q4_0" => Some(Self::Q4_0),
650            "Q4_1" => Some(Self::Q4_1),
651            "Q3_K_S" => Some(Self::Q3_K_S),
652            "Q3_K_M" => Some(Self::Q3_K_M),
653            "Q3_K_L" => Some(Self::Q3_K_L),
654            "Q2_K_S" => Some(Self::Q2_K_S),
655            "Q2_K" => Some(Self::Q2_K),
656            "IQ4_NL" => Some(Self::IQ4_NL),
657            "IQ4_XS" => Some(Self::IQ4_XS),
658            "IQ3_S" => Some(Self::IQ3_S),
659            "IQ3_M" => Some(Self::IQ3_M),
660            "IQ3_XS" => Some(Self::IQ3_XS),
661            "IQ3_XXS" => Some(Self::IQ3_XXS),
662            "IQ2_S" => Some(Self::IQ2_S),
663            "IQ2_XS" => Some(Self::IQ2_XS),
664            "IQ2_XXS" => Some(Self::IQ2_XXS),
665            "IQ1_S" => Some(Self::IQ1_S),
666            "IQ1_M" => Some(Self::IQ1_M),
667            _ => None,
668        }
669    }
670
671    /// Get bits per weight
672    #[must_use]
673    pub const fn bits_per_weight(&self) -> f32 {
674        match self {
675            Self::F32 => 32.0,
676            Self::F16 | Self::BF16 => 16.0,
677            Self::Q8_0 | Self::Q8_K => 8.0,
678            Self::Q6_K => 6.5,
679            Self::Q5_K_S | Self::Q5_K_M | Self::Q5_0 | Self::Q5_1 => 5.5,
680            Self::Q4_K_S | Self::Q4_K_M | Self::Q4_0 | Self::Q4_1 => 4.5,
681            Self::Q3_K_S | Self::Q3_K_M | Self::Q3_K_L => 3.5,
682            Self::Q2_K_S | Self::Q2_K => 2.5,
683            Self::IQ4_NL | Self::IQ4_XS => 4.25,
684            Self::IQ3_S | Self::IQ3_M | Self::IQ3_XS | Self::IQ3_XXS => 3.0,
685            Self::IQ2_S | Self::IQ2_XS | Self::IQ2_XXS => 2.0,
686            Self::IQ1_S | Self::IQ1_M => 1.5,
687        }
688    }
689
690    /// Estimate file size for given parameter count
691    #[must_use]
692    pub fn estimate_size(&self, parameters: u64) -> u64 {
693        let bits = self.bits_per_weight() as f64;
694        let bytes = (parameters as f64 * bits) / 8.0;
695        // Add ~10% overhead for metadata
696        (bytes * 1.1) as u64
697    }
698
699    /// Get quality tier (1-5, higher is better quality)
700    #[must_use]
701    pub const fn quality_tier(&self) -> u8 {
702        match self {
703            Self::F32 | Self::F16 | Self::BF16 => 5,
704            Self::Q8_0 | Self::Q8_K => 5,
705            Self::Q6_K => 4,
706            Self::Q5_K_S | Self::Q5_K_M | Self::Q5_0 | Self::Q5_1 => 4,
707            Self::Q4_K_S | Self::Q4_K_M | Self::Q4_0 | Self::Q4_1 => 3,
708            Self::IQ4_NL | Self::IQ4_XS => 3,
709            Self::Q3_K_S | Self::Q3_K_M | Self::Q3_K_L => 2,
710            Self::IQ3_S | Self::IQ3_M | Self::IQ3_XS | Self::IQ3_XXS => 2,
711            Self::Q2_K_S | Self::Q2_K => 1,
712            Self::IQ2_S | Self::IQ2_XS | Self::IQ2_XXS => 1,
713            Self::IQ1_S | Self::IQ1_M => 1,
714        }
715    }
716
717    /// Get recommended VRAM in GB for given parameter count
718    #[must_use]
719    pub fn vram_requirement(&self, parameters: u64) -> f64 {
720        // Base model size
721        let model_size = self.estimate_size(parameters) as f64;
722        // Context cache (rough estimate: 2GB per 4K context)
723        let context_overhead = 2.0 * 1024.0 * 1024.0 * 1024.0;
724        // Total in GB
725        (model_size + context_overhead) / (1024.0 * 1024.0 * 1024.0)
726    }
727}
728
729impl std::fmt::Display for QuantType {
730    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
731        let s = match self {
732            Self::F32 => "F32",
733            Self::F16 => "F16",
734            Self::BF16 => "BF16",
735            Self::Q8_0 => "Q8_0",
736            Self::Q8_K => "Q8_K",
737            Self::Q6_K => "Q6_K",
738            Self::Q5_K_S => "Q5_K_S",
739            Self::Q5_K_M => "Q5_K_M",
740            Self::Q5_0 => "Q5_0",
741            Self::Q5_1 => "Q5_1",
742            Self::Q4_K_S => "Q4_K_S",
743            Self::Q4_K_M => "Q4_K_M",
744            Self::Q4_0 => "Q4_0",
745            Self::Q4_1 => "Q4_1",
746            Self::Q3_K_S => "Q3_K_S",
747            Self::Q3_K_M => "Q3_K_M",
748            Self::Q3_K_L => "Q3_K_L",
749            Self::Q2_K_S => "Q2_K_S",
750            Self::Q2_K => "Q2_K",
751            Self::IQ4_NL => "IQ4_NL",
752            Self::IQ4_XS => "IQ4_XS",
753            Self::IQ3_S => "IQ3_S",
754            Self::IQ3_M => "IQ3_M",
755            Self::IQ3_XS => "IQ3_XS",
756            Self::IQ3_XXS => "IQ3_XXS",
757            Self::IQ2_S => "IQ2_S",
758            Self::IQ2_XS => "IQ2_XS",
759            Self::IQ2_XXS => "IQ2_XXS",
760            Self::IQ1_S => "IQ1_S",
761            Self::IQ1_M => "IQ1_M",
762        };
763        write!(f, "{s}")
764    }
765}
766
767// ============================================================================
768// Tests
769// ============================================================================
770
771#[cfg(test)]
772mod tests {
773    use super::*;
774
775    // ========================================================================
776    // FMT-001: Format Detection Tests
777    // ========================================================================
778
779    #[test]
780    fn test_detect_gguf_format() {
781        // GGUF magic + version 3 + 100 tensors + 50 metadata
782        let mut data = vec![0u8; 100];
783        data[0..4].copy_from_slice(&magic::GGUF);
784        data[4..8].copy_from_slice(&3u32.to_le_bytes());
785        data[8..16].copy_from_slice(&100u64.to_le_bytes());
786        data[16..24].copy_from_slice(&50u64.to_le_bytes());
787
788        let format = detect_format(&data);
789        assert!(matches!(format, ModelFormat::Gguf(_)));
790
791        if let ModelFormat::Gguf(info) = format {
792            assert_eq!(info.version, 3);
793            assert_eq!(info.tensor_count, 100);
794            assert_eq!(info.metadata_count, 50);
795        }
796    }
797
798    #[test]
799    fn test_detect_apr_format() {
800        // APR magic + version 1 + flags (compressed + signed)
801        let mut data = vec![0u8; 100];
802        data[0..4].copy_from_slice(&magic::APR);
803        data[4..8].copy_from_slice(&1u32.to_le_bytes());
804        data[8..12].copy_from_slice(&0x05u32.to_le_bytes()); // compressed + signed
805        data[12..16].copy_from_slice(&4u32.to_le_bytes()); // model type len
806        data[16..20].copy_from_slice(b"Test");
807
808        let format = detect_format(&data);
809        assert!(matches!(format, ModelFormat::Apr(_)));
810
811        if let ModelFormat::Apr(info) = format {
812            assert_eq!(info.version, 1);
813            assert!(info.compressed);
814            assert!(!info.encrypted);
815            assert!(info.signed);
816            assert_eq!(info.model_type, "Test");
817        }
818    }
819
820    #[test]
821    fn test_detect_pytorch_zip_format() {
822        let mut data = vec![0u8; 100];
823        data[0..2].copy_from_slice(&magic::PYTORCH_ZIP);
824
825        let format = detect_format(&data);
826        assert!(matches!(format, ModelFormat::PyTorch));
827    }
828
829    #[test]
830    fn test_detect_pytorch_pickle_format() {
831        let mut data = vec![0u8; 100];
832        data[0] = magic::PYTORCH_PICKLE;
833
834        let format = detect_format(&data);
835        assert!(matches!(format, ModelFormat::PyTorch));
836    }
837
838    #[test]
839    fn test_detect_safetensors_format() {
840        // SafeTensors: header_size (u64 LE) + JSON header
841        let header = r#"{"tensor1":{"dtype":"F32","shape":[100,100]}}"#;
842        let header_bytes = header.as_bytes();
843        let header_size = header_bytes.len() as u64;
844
845        let mut data = Vec::new();
846        data.extend_from_slice(&header_size.to_le_bytes());
847        data.extend_from_slice(header_bytes);
848
849        let format = detect_format(&data);
850        assert!(matches!(format, ModelFormat::SafeTensors(_)));
851
852        if let ModelFormat::SafeTensors(info) = format {
853            assert_eq!(info.tensor_count, 1);
854            assert!(info.tensors.contains_key("tensor1"));
855            assert_eq!(info.parameters, Some(10000)); // 100 * 100
856        }
857    }
858
859    #[test]
860    fn test_detect_unknown_format() {
861        let data = vec![0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07];
862        let format = detect_format(&data);
863        assert!(matches!(format, ModelFormat::Unknown));
864    }
865
866    #[test]
867    fn test_detect_empty_data() {
868        let data = vec![];
869        let format = detect_format(&data);
870        assert!(matches!(format, ModelFormat::Unknown));
871    }
872
873    #[test]
874    fn test_detect_short_data() {
875        let data = vec![0x47, 0x47, 0x55]; // Incomplete GGUF magic
876        let format = detect_format(&data);
877        assert!(matches!(format, ModelFormat::Unknown));
878    }
879
880    // ========================================================================
881    // FMT-002: Format Info Tests
882    // ========================================================================
883
884    #[test]
885    fn test_model_format_name() {
886        assert_eq!(ModelFormat::Gguf(GgufInfo::default()).name(), "GGUF");
887        assert_eq!(ModelFormat::SafeTensors(SafeTensorsInfo::default()).name(), "SafeTensors");
888        assert_eq!(ModelFormat::Apr(AprInfo::default()).name(), "APR");
889        assert_eq!(ModelFormat::Onnx(OnnxInfo::default()).name(), "ONNX");
890        assert_eq!(ModelFormat::PyTorch.name(), "PyTorch");
891        assert_eq!(ModelFormat::Unknown.name(), "Unknown");
892    }
893
894    #[test]
895    fn test_model_format_extension() {
896        assert_eq!(ModelFormat::Gguf(GgufInfo::default()).extension(), ".gguf");
897        assert_eq!(
898            ModelFormat::SafeTensors(SafeTensorsInfo::default()).extension(),
899            ".safetensors"
900        );
901        assert_eq!(ModelFormat::Apr(AprInfo::default()).extension(), ".apr");
902    }
903
904    #[test]
905    fn test_model_format_is_quantized() {
906        let gguf_quant = ModelFormat::Gguf(GgufInfo {
907            quantization: Some("Q4_K_M".to_string()),
908            ..Default::default()
909        });
910        assert!(gguf_quant.is_quantized());
911
912        let gguf_no_quant = ModelFormat::Gguf(GgufInfo::default());
913        assert!(!gguf_no_quant.is_quantized());
914    }
915
916    #[test]
917    fn test_model_format_quantization() {
918        let format = ModelFormat::Gguf(GgufInfo {
919            quantization: Some("Q8_0".to_string()),
920            ..Default::default()
921        });
922        assert_eq!(format.quantization(), Some("Q8_0"));
923    }
924
925    #[test]
926    fn test_model_format_parameters() {
927        let format =
928            ModelFormat::Gguf(GgufInfo { parameters: Some(7_000_000_000), ..Default::default() });
929        assert_eq!(format.parameters(), Some(7_000_000_000));
930
931        assert_eq!(ModelFormat::PyTorch.parameters(), None);
932    }
933
934    // ========================================================================
935    // FMT-003: Path Detection Tests
936    // ========================================================================
937
938    #[test]
939    fn test_detect_format_from_path() {
940        assert_eq!(detect_format_from_path("model.gguf"), Some("GGUF"));
941        assert_eq!(detect_format_from_path("model.GGUF"), Some("GGUF"));
942        assert_eq!(detect_format_from_path("model.safetensors"), Some("SafeTensors"));
943        assert_eq!(detect_format_from_path("model.apr"), Some("APR"));
944        assert_eq!(detect_format_from_path("model.onnx"), Some("ONNX"));
945        assert_eq!(detect_format_from_path("model.pt"), Some("PyTorch"));
946        assert_eq!(detect_format_from_path("model.pth"), Some("PyTorch"));
947        assert_eq!(detect_format_from_path("model.bin"), Some("Binary"));
948        assert_eq!(detect_format_from_path("model.txt"), None);
949    }
950
951    // ========================================================================
952    // FMT-004: Quantization Tests
953    // ========================================================================
954
955    #[test]
956    fn test_quant_type_from_str() {
957        assert_eq!(QuantType::from_str("Q4_K_M"), Some(QuantType::Q4_K_M));
958        assert_eq!(QuantType::from_str("q4_k_m"), Some(QuantType::Q4_K_M));
959        assert_eq!(QuantType::from_str("F16"), Some(QuantType::F16));
960        assert_eq!(QuantType::from_str("fp16"), Some(QuantType::F16));
961        assert_eq!(QuantType::from_str("invalid"), None);
962    }
963
964    #[test]
965    fn test_quant_type_bits_per_weight() {
966        assert!((QuantType::F32.bits_per_weight() - 32.0).abs() < f32::EPSILON);
967        assert!((QuantType::F16.bits_per_weight() - 16.0).abs() < f32::EPSILON);
968        assert!((QuantType::Q8_0.bits_per_weight() - 8.0).abs() < f32::EPSILON);
969        assert!((QuantType::Q4_K_M.bits_per_weight() - 4.5).abs() < f32::EPSILON);
970    }
971
972    #[test]
973    fn test_quant_type_estimate_size() {
974        let params = 7_000_000_000u64; // 7B parameters
975
976        // F32: 7B * 32 bits / 8 * 1.1 = ~30.8 GB
977        let f32_size = QuantType::F32.estimate_size(params);
978        assert!(f32_size > 28_000_000_000 && f32_size < 32_000_000_000);
979
980        // Q4_K_M: 7B * 4.5 bits / 8 * 1.1 = ~4.3 GB
981        let q4_size = QuantType::Q4_K_M.estimate_size(params);
982        assert!(q4_size > 4_000_000_000 && q4_size < 5_000_000_000);
983    }
984
985    #[test]
986    fn test_quant_type_quality_tier() {
987        assert_eq!(QuantType::F32.quality_tier(), 5);
988        assert_eq!(QuantType::Q8_0.quality_tier(), 5);
989        assert_eq!(QuantType::Q4_K_M.quality_tier(), 3);
990        assert_eq!(QuantType::Q2_K.quality_tier(), 1);
991    }
992
993    #[test]
994    fn test_quant_type_vram_requirement() {
995        let params = 7_000_000_000u64;
996        let vram_f32 = QuantType::F32.vram_requirement(params);
997        let vram_q4 = QuantType::Q4_K_M.vram_requirement(params);
998
999        // F32 should require more VRAM than Q4
1000        assert!(vram_f32 > vram_q4);
1001        // Both should be positive
1002        assert!(vram_f32 > 0.0);
1003        assert!(vram_q4 > 0.0);
1004    }
1005
1006    #[test]
1007    fn test_quant_type_display() {
1008        assert_eq!(format!("{}", QuantType::Q4_K_M), "Q4_K_M");
1009        assert_eq!(format!("{}", QuantType::F16), "F16");
1010        assert_eq!(format!("{}", QuantType::IQ3_XXS), "IQ3_XXS");
1011    }
1012
1013    // ========================================================================
1014    // Serialization Tests
1015    // ========================================================================
1016
1017    #[test]
1018    fn test_gguf_info_serialization() {
1019        let info = GgufInfo {
1020            version: 3,
1021            tensor_count: 100,
1022            metadata_count: 50,
1023            architecture: Some("llama".to_string()),
1024            quantization: Some("Q4_K_M".to_string()),
1025            ..Default::default()
1026        };
1027
1028        let json = serde_json::to_string(&info).unwrap();
1029        assert!(json.contains("llama"));
1030        assert!(json.contains("Q4_K_M"));
1031
1032        let parsed: GgufInfo = serde_json::from_str(&json).unwrap();
1033        assert_eq!(parsed.version, 3);
1034        assert_eq!(parsed.architecture, Some("llama".to_string()));
1035    }
1036
1037    #[test]
1038    fn test_safetensors_info_serialization() {
1039        let mut tensors = HashMap::new();
1040        tensors.insert(
1041            "weight".to_string(),
1042            TensorInfo { shape: vec![100, 100], dtype: "F32".to_string(), offset: 0 },
1043        );
1044
1045        let info = SafeTensorsInfo {
1046            tensor_count: 1,
1047            tensors,
1048            parameters: Some(10000),
1049            ..Default::default()
1050        };
1051
1052        let json = serde_json::to_string(&info).unwrap();
1053        let parsed: SafeTensorsInfo = serde_json::from_str(&json).unwrap();
1054
1055        assert_eq!(parsed.tensor_count, 1);
1056        assert_eq!(parsed.parameters, Some(10000));
1057    }
1058
1059    #[test]
1060    fn test_model_format_serialization() {
1061        let format = ModelFormat::Gguf(GgufInfo { version: 3, ..Default::default() });
1062
1063        let json = serde_json::to_string(&format).unwrap();
1064        let parsed: ModelFormat = serde_json::from_str(&json).unwrap();
1065
1066        assert!(matches!(parsed, ModelFormat::Gguf(_)));
1067    }
1068
1069    #[test]
1070    fn test_quant_type_serialization() {
1071        let qt = QuantType::Q4_K_M;
1072        let json = serde_json::to_string(&qt).unwrap();
1073        let parsed: QuantType = serde_json::from_str(&json).unwrap();
1074        assert_eq!(parsed, QuantType::Q4_K_M);
1075    }
1076
1077    // ========================================================================
1078    // Edge Cases
1079    // ========================================================================
1080
1081    #[test]
1082    fn test_safetensors_with_metadata() {
1083        let header = r#"{"__metadata__":{"format":"pt"},"tensor1":{"dtype":"F16","shape":[512]}}"#;
1084        let header_bytes = header.as_bytes();
1085        let header_size = header_bytes.len() as u64;
1086
1087        let mut data = Vec::new();
1088        data.extend_from_slice(&header_size.to_le_bytes());
1089        data.extend_from_slice(header_bytes);
1090
1091        let format = detect_format(&data);
1092        if let ModelFormat::SafeTensors(info) = format {
1093            assert_eq!(info.tensor_count, 1); // Excludes __metadata__
1094            assert_eq!(info.metadata.get("format"), Some(&"pt".to_string()));
1095            assert_eq!(info.dtype, Some("F16".to_string()));
1096        } else {
1097            panic!("Expected SafeTensors format");
1098        }
1099    }
1100
1101    #[test]
1102    fn test_pacha4_safetensors_header_size_0x80_not_pytorch() {
1103        // Regression test for pacha#4: SafeTensors file whose header_size
1104        // has low byte 0x80 was misidentified as PyTorch pickle because
1105        // detect_format checked data[0]==0x80 before trying SafeTensors.
1106        //
1107        // Real-world case: Qwen2.5-Coder-1.5B-Instruct model.safetensors
1108        // has header_size=38528 → first byte is 0x80.
1109        let header = r#"{"__metadata__":{"format":"pt"},"tensor1":{"dtype":"F32","shape":[32],"data_offsets":[0,128]}}"#;
1110        let header_bytes = header.as_bytes();
1111        // Force header_size to have low byte 0x80 (= 128)
1112        // We need header_size = header_bytes.len(), and pad to make it end in 0x80
1113        let target_size = 128usize; // 0x80
1114        assert!(header_bytes.len() <= target_size, "header too large for test");
1115        let padding = target_size - header_bytes.len();
1116
1117        let mut data = Vec::new();
1118        data.extend_from_slice(&(target_size as u64).to_le_bytes());
1119        data.extend_from_slice(header_bytes);
1120        // Pad JSON with spaces before closing brace
1121        // Actually, we need to pad the header itself. Let's build a padded header.
1122        let padded_header = format!(
1123            r#"{{"__metadata__":{{"format":"pt"}},"tensor1":{{"dtype":"F32","shape":[32],"data_offsets":[0,128]}}{}}}"#,
1124            " ".repeat(padding)
1125        );
1126        let padded_bytes = padded_header.as_bytes();
1127
1128        let mut data2 = Vec::new();
1129        data2.extend_from_slice(&(padded_bytes.len() as u64).to_le_bytes());
1130        data2.extend_from_slice(padded_bytes);
1131        // Add some fake tensor data
1132        data2.extend_from_slice(&[0u8; 128]);
1133
1134        // First byte should be 0x80 (the pickle magic)
1135        assert_eq!(data2[0], 0x80, "Test setup: first byte must be 0x80");
1136
1137        let format = detect_format(&data2);
1138        match format {
1139            ModelFormat::SafeTensors(info) => {
1140                assert_eq!(info.tensor_count, 1);
1141                assert_eq!(info.metadata.get("format"), Some(&"pt".to_string()));
1142            }
1143            other => panic!("Expected SafeTensors but got {:?} — pacha#4 regression", other),
1144        }
1145    }
1146
1147    #[test]
1148    fn test_safetensors_invalid_header_size() {
1149        // Header size larger than file
1150        let header_size = 1_000_000u64;
1151        let mut data = Vec::new();
1152        data.extend_from_slice(&header_size.to_le_bytes());
1153        data.extend_from_slice(b"{}");
1154
1155        let format = detect_format(&data);
1156        // Should still identify as SafeTensors due to JSON structure
1157        assert!(matches!(format, ModelFormat::SafeTensors(_)));
1158    }
1159
1160    #[test]
1161    fn test_gguf_info_default() {
1162        let info = GgufInfo::default();
1163        assert_eq!(info.version, 0);
1164        assert!(info.architecture.is_none());
1165    }
1166
1167    #[test]
1168    fn test_apr_info_default() {
1169        let info = AprInfo::default();
1170        assert_eq!(info.version, 0);
1171        assert!(!info.compressed);
1172        assert!(!info.encrypted);
1173    }
1174}