Skip to main content

peat_protocol/distribution/
types.rs

1//! Core types for AI model distribution
2
3use serde::{Deserialize, Serialize};
4
5/// Type of AI model
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
7#[serde(rename_all = "snake_case")]
8pub enum ModelType {
9    /// Object detection model (YOLO, etc.)
10    Detector,
11    /// Large language model (Ministral, Llama, Phi, etc.)
12    Llm,
13    /// Object tracking / re-identification model
14    Tracker,
15    /// Feature embedding model
16    Embedder,
17    /// Vision-language model
18    Vlm,
19    /// Audio transcription model (Whisper, etc.)
20    Whisper,
21    /// Custom/other model type
22    Custom,
23}
24
25impl ModelType {
26    /// Get a human-readable name
27    pub fn display_name(&self) -> &'static str {
28        match self {
29            Self::Detector => "Object Detector",
30            Self::Llm => "Language Model",
31            Self::Tracker => "Tracker",
32            Self::Embedder => "Embedder",
33            Self::Vlm => "Vision-Language Model",
34            Self::Whisper => "Audio Transcription",
35            Self::Custom => "Custom Model",
36        }
37    }
38
39    /// Get capability type string (for CapabilityInfo compatibility)
40    pub fn capability_type(&self) -> &'static str {
41        match self {
42            Self::Detector => "OBJECT_DETECTION",
43            Self::Llm => "LLM_INFERENCE",
44            Self::Tracker => "OBJECT_TRACKING",
45            Self::Embedder => "FEATURE_EMBEDDING",
46            Self::Vlm => "VISION_LANGUAGE",
47            Self::Whisper => "AUDIO_TRANSCRIPTION",
48            Self::Custom => "CUSTOM",
49        }
50    }
51}
52
53/// Model file format
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
55#[serde(rename_all = "lowercase")]
56pub enum ModelFormat {
57    /// ONNX format (cross-platform)
58    Onnx,
59    /// GGUF format (llama.cpp quantized models)
60    Gguf,
61    /// TensorRT engine (NVIDIA optimized)
62    TensorRT,
63    /// PyTorch format
64    PyTorch,
65    /// SafeTensors format
66    SafeTensors,
67}
68
69impl ModelFormat {
70    /// File extension for this format
71    pub fn extension(&self) -> &'static str {
72        match self {
73            Self::Onnx => "onnx",
74            Self::Gguf => "gguf",
75            Self::TensorRT => "engine",
76            Self::PyTorch => "pt",
77            Self::SafeTensors => "safetensors",
78        }
79    }
80
81    /// MIME type for this format
82    pub fn mime_type(&self) -> &'static str {
83        match self {
84            Self::Onnx => "application/onnx",
85            Self::Gguf => "application/octet-stream",
86            Self::TensorRT => "application/octet-stream",
87            Self::PyTorch => "application/octet-stream",
88            Self::SafeTensors => "application/octet-stream",
89        }
90    }
91}
92
93/// Quantization level for model weights
94#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
95#[allow(non_camel_case_types)] // Keep standard quantization naming
96pub enum Quantization {
97    /// Full precision (FP32)
98    F32,
99    /// Half precision (FP16)
100    F16,
101    /// Brain float 16
102    BF16,
103    /// 8-bit integer
104    INT8,
105    /// 4-bit (Q4_0 - legacy)
106    Q4_0,
107    /// 4-bit K-quant small
108    Q4_K_S,
109    /// 4-bit K-quant medium (good balance of size/quality)
110    Q4_K_M,
111    /// 5-bit K-quant small
112    Q5_K_S,
113    /// 5-bit K-quant medium
114    Q5_K_M,
115    /// 6-bit K-quant
116    Q6_K,
117    /// 8-bit (Q8_0)
118    Q8_0,
119}
120
121impl Quantization {
122    /// Get display string
123    pub fn as_str(&self) -> &'static str {
124        match self {
125            Self::F32 => "F32",
126            Self::F16 => "F16",
127            Self::BF16 => "BF16",
128            Self::INT8 => "INT8",
129            Self::Q4_0 => "Q4_0",
130            Self::Q4_K_S => "Q4_K_S",
131            Self::Q4_K_M => "Q4_K_M",
132            Self::Q5_K_S => "Q5_K_S",
133            Self::Q5_K_M => "Q5_K_M",
134            Self::Q6_K => "Q6_K",
135            Self::Q8_0 => "Q8_0",
136        }
137    }
138
139    /// Approximate memory multiplier vs FP16 (lower = smaller)
140    pub fn memory_factor(&self) -> f32 {
141        match self {
142            Self::F32 => 2.0,
143            Self::F16 | Self::BF16 => 1.0,
144            Self::INT8 | Self::Q8_0 => 0.5,
145            Self::Q6_K => 0.41,
146            Self::Q5_K_S | Self::Q5_K_M => 0.35,
147            Self::Q4_0 | Self::Q4_K_S | Self::Q4_K_M => 0.28,
148        }
149    }
150
151    /// Parse from filename component (e.g., "q4_k_m" -> Q4_K_M)
152    pub fn from_filename(s: &str) -> Option<Self> {
153        let lower = s.to_lowercase();
154        match lower.as_str() {
155            "f32" | "fp32" => Some(Self::F32),
156            "f16" | "fp16" => Some(Self::F16),
157            "bf16" => Some(Self::BF16),
158            "int8" => Some(Self::INT8),
159            "q4_0" => Some(Self::Q4_0),
160            "q4_k_s" => Some(Self::Q4_K_S),
161            "q4_k_m" => Some(Self::Q4_K_M),
162            "q5_k_s" => Some(Self::Q5_K_S),
163            "q5_k_m" => Some(Self::Q5_K_M),
164            "q6_k" => Some(Self::Q6_K),
165            "q8_0" => Some(Self::Q8_0),
166            _ => None,
167        }
168    }
169}
170
171impl std::fmt::Display for Quantization {
172    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173        write!(f, "{}", self.as_str())
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    #[test]
182    fn test_model_type_capability() {
183        assert_eq!(ModelType::Detector.capability_type(), "OBJECT_DETECTION");
184        assert_eq!(ModelType::Llm.capability_type(), "LLM_INFERENCE");
185    }
186
187    #[test]
188    fn test_model_format_extension() {
189        assert_eq!(ModelFormat::Onnx.extension(), "onnx");
190        assert_eq!(ModelFormat::Gguf.extension(), "gguf");
191        assert_eq!(ModelFormat::TensorRT.extension(), "engine");
192    }
193
194    #[test]
195    fn test_quantization_memory_factor() {
196        assert!(Quantization::Q4_K_M.memory_factor() < Quantization::F16.memory_factor());
197        assert!(Quantization::Q8_0.memory_factor() < Quantization::F16.memory_factor());
198        assert!(Quantization::Q4_K_M.memory_factor() < Quantization::Q8_0.memory_factor());
199    }
200
201    #[test]
202    fn test_quantization_from_filename() {
203        assert_eq!(
204            Quantization::from_filename("q4_k_m"),
205            Some(Quantization::Q4_K_M)
206        );
207        assert_eq!(
208            Quantization::from_filename("Q4_K_M"),
209            Some(Quantization::Q4_K_M)
210        );
211        assert_eq!(Quantization::from_filename("fp16"), Some(Quantization::F16));
212        assert_eq!(Quantization::from_filename("unknown"), None);
213    }
214}