peat_protocol/distribution/
types.rs1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
7#[serde(rename_all = "snake_case")]
8pub enum ModelType {
9 Detector,
11 Llm,
13 Tracker,
15 Embedder,
17 Vlm,
19 Whisper,
21 Custom,
23}
24
25impl ModelType {
26 pub fn display_name(&self) -> &'static str {
28 match self {
29 Self::Detector => "Object Detector",
30 Self::Llm => "Language Model",
31 Self::Tracker => "Tracker",
32 Self::Embedder => "Embedder",
33 Self::Vlm => "Vision-Language Model",
34 Self::Whisper => "Audio Transcription",
35 Self::Custom => "Custom Model",
36 }
37 }
38
39 pub fn capability_type(&self) -> &'static str {
41 match self {
42 Self::Detector => "OBJECT_DETECTION",
43 Self::Llm => "LLM_INFERENCE",
44 Self::Tracker => "OBJECT_TRACKING",
45 Self::Embedder => "FEATURE_EMBEDDING",
46 Self::Vlm => "VISION_LANGUAGE",
47 Self::Whisper => "AUDIO_TRANSCRIPTION",
48 Self::Custom => "CUSTOM",
49 }
50 }
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
55#[serde(rename_all = "lowercase")]
56pub enum ModelFormat {
57 Onnx,
59 Gguf,
61 TensorRT,
63 PyTorch,
65 SafeTensors,
67}
68
69impl ModelFormat {
70 pub fn extension(&self) -> &'static str {
72 match self {
73 Self::Onnx => "onnx",
74 Self::Gguf => "gguf",
75 Self::TensorRT => "engine",
76 Self::PyTorch => "pt",
77 Self::SafeTensors => "safetensors",
78 }
79 }
80
81 pub fn mime_type(&self) -> &'static str {
83 match self {
84 Self::Onnx => "application/onnx",
85 Self::Gguf => "application/octet-stream",
86 Self::TensorRT => "application/octet-stream",
87 Self::PyTorch => "application/octet-stream",
88 Self::SafeTensors => "application/octet-stream",
89 }
90 }
91}
92
93#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
95#[allow(non_camel_case_types)] pub enum Quantization {
97 F32,
99 F16,
101 BF16,
103 INT8,
105 Q4_0,
107 Q4_K_S,
109 Q4_K_M,
111 Q5_K_S,
113 Q5_K_M,
115 Q6_K,
117 Q8_0,
119}
120
121impl Quantization {
122 pub fn as_str(&self) -> &'static str {
124 match self {
125 Self::F32 => "F32",
126 Self::F16 => "F16",
127 Self::BF16 => "BF16",
128 Self::INT8 => "INT8",
129 Self::Q4_0 => "Q4_0",
130 Self::Q4_K_S => "Q4_K_S",
131 Self::Q4_K_M => "Q4_K_M",
132 Self::Q5_K_S => "Q5_K_S",
133 Self::Q5_K_M => "Q5_K_M",
134 Self::Q6_K => "Q6_K",
135 Self::Q8_0 => "Q8_0",
136 }
137 }
138
139 pub fn memory_factor(&self) -> f32 {
141 match self {
142 Self::F32 => 2.0,
143 Self::F16 | Self::BF16 => 1.0,
144 Self::INT8 | Self::Q8_0 => 0.5,
145 Self::Q6_K => 0.41,
146 Self::Q5_K_S | Self::Q5_K_M => 0.35,
147 Self::Q4_0 | Self::Q4_K_S | Self::Q4_K_M => 0.28,
148 }
149 }
150
151 pub fn from_filename(s: &str) -> Option<Self> {
153 let lower = s.to_lowercase();
154 match lower.as_str() {
155 "f32" | "fp32" => Some(Self::F32),
156 "f16" | "fp16" => Some(Self::F16),
157 "bf16" => Some(Self::BF16),
158 "int8" => Some(Self::INT8),
159 "q4_0" => Some(Self::Q4_0),
160 "q4_k_s" => Some(Self::Q4_K_S),
161 "q4_k_m" => Some(Self::Q4_K_M),
162 "q5_k_s" => Some(Self::Q5_K_S),
163 "q5_k_m" => Some(Self::Q5_K_M),
164 "q6_k" => Some(Self::Q6_K),
165 "q8_0" => Some(Self::Q8_0),
166 _ => None,
167 }
168 }
169}
170
171impl std::fmt::Display for Quantization {
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 write!(f, "{}", self.as_str())
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 #[test]
182 fn test_model_type_capability() {
183 assert_eq!(ModelType::Detector.capability_type(), "OBJECT_DETECTION");
184 assert_eq!(ModelType::Llm.capability_type(), "LLM_INFERENCE");
185 }
186
187 #[test]
188 fn test_model_format_extension() {
189 assert_eq!(ModelFormat::Onnx.extension(), "onnx");
190 assert_eq!(ModelFormat::Gguf.extension(), "gguf");
191 assert_eq!(ModelFormat::TensorRT.extension(), "engine");
192 }
193
194 #[test]
195 fn test_quantization_memory_factor() {
196 assert!(Quantization::Q4_K_M.memory_factor() < Quantization::F16.memory_factor());
197 assert!(Quantization::Q8_0.memory_factor() < Quantization::F16.memory_factor());
198 assert!(Quantization::Q4_K_M.memory_factor() < Quantization::Q8_0.memory_factor());
199 }
200
201 #[test]
202 fn test_quantization_from_filename() {
203 assert_eq!(
204 Quantization::from_filename("q4_k_m"),
205 Some(Quantization::Q4_K_M)
206 );
207 assert_eq!(
208 Quantization::from_filename("Q4_K_M"),
209 Some(Quantization::Q4_K_M)
210 );
211 assert_eq!(Quantization::from_filename("fp16"), Some(Quantization::F16));
212 assert_eq!(Quantization::from_filename("unknown"), None);
213 }
214}