1use std::path::PathBuf;
4
5use serde::{Deserialize, Serialize};
6
7use crate::types::{ModelId, QuantizationType};
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11#[serde(tag = "type", rename_all = "snake_case")]
12pub enum ModelSource {
13 HuggingFace {
15 repo_id: String,
17 revision: Option<String>,
19 },
20 LocalPath {
22 path: PathBuf,
24 },
25 S3 {
27 bucket: String,
29 key: String,
31 region: Option<String>,
33 },
34 Gguf {
36 path: PathBuf,
38 },
39}
40
41impl ModelSource {
42 #[must_use]
44 pub fn huggingface(repo_id: impl Into<String>) -> Self {
45 Self::HuggingFace {
46 repo_id: repo_id.into(),
47 revision: None,
48 }
49 }
50
51 #[must_use]
53 pub fn huggingface_rev(repo_id: impl Into<String>, revision: impl Into<String>) -> Self {
54 Self::HuggingFace {
55 repo_id: repo_id.into(),
56 revision: Some(revision.into()),
57 }
58 }
59
60 #[must_use]
62 pub fn local(path: impl Into<PathBuf>) -> Self {
63 Self::LocalPath { path: path.into() }
64 }
65
66 #[must_use]
68 pub fn gguf(path: impl Into<PathBuf>) -> Self {
69 Self::Gguf { path: path.into() }
70 }
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
75pub enum LlamaVersion {
76 V2,
78 V3,
80 V3_1,
82 V3_2,
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
88pub enum MistralVariant {
89 Mistral7B,
91 Nemo,
93 Large,
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
99pub enum QwenVersion {
100 V2,
102 V2_5,
104}
105
106#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
108pub enum PhiVersion {
109 V3,
111 V3_5,
113 V4,
115}
116
117#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
119pub enum GemmaVersion {
120 V1,
122 V2,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
128#[serde(tag = "type", rename_all = "snake_case")]
129pub enum ModelArchitecture {
130 Llama {
133 version: LlamaVersion,
135 },
136 Mistral {
138 variant: MistralVariant,
140 },
141 Mixtral {
143 num_experts: u8,
145 },
146 Qwen {
148 version: QwenVersion,
150 },
151 Phi {
153 version: PhiVersion,
155 },
156 Gemma {
158 version: GemmaVersion,
160 },
161 DeepSeek {
163 version: u8,
165 },
166
167 Bert,
170 NomicEmbed,
172 JinaEmbed,
174
175 LlavaNext,
178 Qwen2VL,
180 Pixtral,
182
183 CodeLlama,
186 StarCoder2,
188 DeepSeekCoder {
190 version: u8,
192 },
193}
194
195impl ModelArchitecture {
196 #[must_use]
198 pub fn supports_vision(&self) -> bool {
199 matches!(self, Self::LlavaNext | Self::Qwen2VL | Self::Pixtral)
200 }
201
202 #[must_use]
204 pub fn is_embedding_model(&self) -> bool {
205 matches!(self, Self::Bert | Self::NomicEmbed | Self::JinaEmbed)
206 }
207
208 #[must_use]
210 pub fn is_code_specialized(&self) -> bool {
211 matches!(
212 self,
213 Self::CodeLlama | Self::StarCoder2 | Self::DeepSeekCoder { .. }
214 )
215 }
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize)]
220pub struct ModelMetadata {
221 pub id: ModelId,
223 pub architecture: ModelArchitecture,
225 pub source: ModelSource,
227 pub context_length: u32,
229 pub vocab_size: u32,
231 pub hidden_size: u32,
233 pub num_layers: u32,
235 pub num_attention_heads: u32,
237 pub num_kv_heads: Option<u32>,
239 pub quantization: Option<QuantizationType>,
241 pub size_bytes: Option<u64>,
243 pub description: Option<String>,
245}
246
247impl ModelMetadata {
248 #[must_use]
250 pub fn builder(
251 id: impl Into<ModelId>,
252 architecture: ModelArchitecture,
253 ) -> ModelMetadataBuilder {
254 ModelMetadataBuilder::new(id, architecture)
255 }
256}
257
258#[derive(Debug)]
260pub struct ModelMetadataBuilder {
261 id: ModelId,
262 architecture: ModelArchitecture,
263 source: Option<ModelSource>,
264 context_length: u32,
265 vocab_size: u32,
266 hidden_size: u32,
267 num_layers: u32,
268 num_attention_heads: u32,
269 num_kv_heads: Option<u32>,
270 quantization: Option<QuantizationType>,
271 size_bytes: Option<u64>,
272 description: Option<String>,
273}
274
275impl ModelMetadataBuilder {
276 #[must_use]
278 pub fn new(id: impl Into<ModelId>, architecture: ModelArchitecture) -> Self {
279 Self {
280 id: id.into(),
281 architecture,
282 source: None,
283 context_length: 4096,
284 vocab_size: 32000,
285 hidden_size: 4096,
286 num_layers: 32,
287 num_attention_heads: 32,
288 num_kv_heads: None,
289 quantization: None,
290 size_bytes: None,
291 description: None,
292 }
293 }
294
295 #[must_use]
297 pub fn source(mut self, source: ModelSource) -> Self {
298 self.source = Some(source);
299 self
300 }
301
302 #[must_use]
304 pub fn context_length(mut self, length: u32) -> Self {
305 self.context_length = length;
306 self
307 }
308
309 #[must_use]
311 pub fn vocab_size(mut self, size: u32) -> Self {
312 self.vocab_size = size;
313 self
314 }
315
316 #[must_use]
318 pub fn hidden_size(mut self, size: u32) -> Self {
319 self.hidden_size = size;
320 self
321 }
322
323 #[must_use]
325 pub fn num_layers(mut self, layers: u32) -> Self {
326 self.num_layers = layers;
327 self
328 }
329
330 #[must_use]
332 pub fn num_attention_heads(mut self, heads: u32) -> Self {
333 self.num_attention_heads = heads;
334 self
335 }
336
337 #[must_use]
339 pub fn num_kv_heads(mut self, heads: u32) -> Self {
340 self.num_kv_heads = Some(heads);
341 self
342 }
343
344 #[must_use]
346 pub fn quantization(mut self, quant: QuantizationType) -> Self {
347 self.quantization = Some(quant);
348 self
349 }
350
351 #[must_use]
353 pub fn size_bytes(mut self, size: u64) -> Self {
354 self.size_bytes = Some(size);
355 self
356 }
357
358 #[must_use]
360 pub fn description(mut self, desc: impl Into<String>) -> Self {
361 self.description = Some(desc.into());
362 self
363 }
364
365 #[must_use]
371 pub fn build(self) -> ModelMetadata {
372 ModelMetadata {
373 id: self.id,
374 architecture: self.architecture,
375 source: self.source.expect("source must be set"),
376 context_length: self.context_length,
377 vocab_size: self.vocab_size,
378 hidden_size: self.hidden_size,
379 num_layers: self.num_layers,
380 num_attention_heads: self.num_attention_heads,
381 num_kv_heads: self.num_kv_heads,
382 quantization: self.quantization,
383 size_bytes: self.size_bytes,
384 description: self.description,
385 }
386 }
387}