Skip to main content

entrenar/hf_pipeline/fetcher/
types.rs

1//! Type definitions for HuggingFace model fetching.
2//!
3//! Contains core enums and structs for model weights, architectures, and artifacts.
4
5use serde::{Deserialize, Serialize};
6use std::path::PathBuf;
7
8/// Model weight format
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum WeightFormat {
11    /// SafeTensors format (recommended, secure)
12    SafeTensors,
13    /// GGUF quantized format
14    GGUF { quant_type: String },
15    /// PyTorch pickle format (SECURITY RISK)
16    PyTorchBin,
17    /// ONNX format
18    ONNX,
19}
20
21impl WeightFormat {
22    /// Detect format from filename
23    #[must_use]
24    pub fn from_filename(filename: &str) -> Option<Self> {
25        if filename.ends_with(".safetensors") {
26            Some(Self::SafeTensors)
27        } else if filename.ends_with(".gguf") {
28            Some(Self::GGUF { quant_type: "unknown".into() })
29        } else if filename.ends_with(".bin") {
30            Some(Self::PyTorchBin)
31        } else if filename.ends_with(".onnx") {
32            Some(Self::ONNX)
33        } else {
34            None
35        }
36    }
37
38    /// Check if format is safe (no arbitrary code execution)
39    #[must_use]
40    pub fn is_safe(&self) -> bool {
41        matches!(self, Self::SafeTensors | Self::GGUF { .. } | Self::ONNX)
42    }
43}
44
45/// Model architecture information
46// CB-519: Serialize + Deserialize derive is intentional for config round-trip.
47// PartialEq enables exact structural validation (not just param_count) after deserialization.
48#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
49pub enum Architecture {
50    /// BERT-style encoder
51    BERT { num_layers: usize, hidden_size: usize, num_attention_heads: usize },
52    /// GPT-style decoder
53    GPT2 { num_layers: usize, hidden_size: usize, num_attention_heads: usize },
54    /// Llama architecture
55    Llama {
56        num_layers: usize,
57        hidden_size: usize,
58        num_attention_heads: usize,
59        intermediate_size: usize,
60    },
61    /// T5 encoder-decoder
62    T5 { encoder_layers: usize, decoder_layers: usize, hidden_size: usize },
63    /// Custom/unknown architecture
64    Custom { config: serde_json::Value },
65}
66
67impl Architecture {
68    /// Estimate parameter count
69    #[must_use]
70    pub fn param_count(&self) -> u64 {
71        match self {
72            Self::BERT { num_layers, hidden_size, num_attention_heads: _ } => {
73                // Rough estimate: 4 * hidden^2 per layer (Q, K, V, O projections + FFN)
74                let per_layer = 4 * (*hidden_size as u64).pow(2) + 4 * (*hidden_size as u64).pow(2);
75                per_layer * (*num_layers as u64)
76            }
77            Self::GPT2 { num_layers, hidden_size, .. } => {
78                let per_layer = 4 * (*hidden_size as u64).pow(2) + 4 * (*hidden_size as u64).pow(2);
79                per_layer * (*num_layers as u64)
80            }
81            Self::Llama { num_layers, hidden_size, intermediate_size, .. } => {
82                let attn = 4 * (*hidden_size as u64).pow(2);
83                let ffn = 2 * (*hidden_size as u64) * (*intermediate_size as u64);
84                (attn + ffn) * (*num_layers as u64)
85            }
86            Self::T5 { encoder_layers, decoder_layers, hidden_size } => {
87                let per_layer = 8 * (*hidden_size as u64).pow(2);
88                per_layer * ((*encoder_layers + *decoder_layers) as u64)
89            }
90            Self::Custom { .. } => 0, // Unknown
91        }
92    }
93}
94
95/// Downloaded model artifact
96#[derive(Debug)]
97pub struct ModelArtifact {
98    /// Local path to downloaded files
99    pub path: PathBuf,
100    /// Detected weight format
101    pub format: WeightFormat,
102    /// Model architecture (parsed from config.json)
103    pub architecture: Option<Architecture>,
104    /// SHA256 hash of model file
105    pub sha256: Option<String>,
106}