infernum_core/
model.rs

1//! Model metadata and architecture types.
2
3use std::path::PathBuf;
4
5use serde::{Deserialize, Serialize};
6
7use crate::types::{ModelId, QuantizationType};
8
9/// Source location for a model.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11#[serde(tag = "type", rename_all = "snake_case")]
12pub enum ModelSource {
13    /// HuggingFace Hub model.
14    HuggingFace {
15        /// Repository ID (e.g., "meta-llama/Llama-3.2-3B-Instruct").
16        repo_id: String,
17        /// Optional revision (branch, tag, or commit).
18        revision: Option<String>,
19    },
20    /// Local filesystem path.
21    LocalPath {
22        /// Path to the model directory or file.
23        path: PathBuf,
24    },
25    /// S3 bucket.
26    S3 {
27        /// Bucket name.
28        bucket: String,
29        /// Object key.
30        key: String,
31        /// Optional region.
32        region: Option<String>,
33    },
34    /// GGUF file format.
35    Gguf {
36        /// Path to the GGUF file.
37        path: PathBuf,
38    },
39}
40
41impl ModelSource {
42    /// Creates a HuggingFace source.
43    #[must_use]
44    pub fn huggingface(repo_id: impl Into<String>) -> Self {
45        Self::HuggingFace {
46            repo_id: repo_id.into(),
47            revision: None,
48        }
49    }
50
51    /// Creates a HuggingFace source with a specific revision.
52    #[must_use]
53    pub fn huggingface_rev(repo_id: impl Into<String>, revision: impl Into<String>) -> Self {
54        Self::HuggingFace {
55            repo_id: repo_id.into(),
56            revision: Some(revision.into()),
57        }
58    }
59
60    /// Creates a local path source.
61    #[must_use]
62    pub fn local(path: impl Into<PathBuf>) -> Self {
63        Self::LocalPath { path: path.into() }
64    }
65
66    /// Creates a GGUF source.
67    #[must_use]
68    pub fn gguf(path: impl Into<PathBuf>) -> Self {
69        Self::Gguf { path: path.into() }
70    }
71}
72
73/// Llama model version.
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
75pub enum LlamaVersion {
76    /// Llama 2.
77    V2,
78    /// Llama 3.
79    V3,
80    /// Llama 3.1.
81    V3_1,
82    /// Llama 3.2.
83    V3_2,
84}
85
86/// Mistral model variant.
87#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
88pub enum MistralVariant {
89    /// Mistral 7B.
90    Mistral7B,
91    /// Mistral Nemo.
92    Nemo,
93    /// Mistral Large.
94    Large,
95}
96
97/// Qwen model version.
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
99pub enum QwenVersion {
100    /// Qwen 2.
101    V2,
102    /// Qwen 2.5.
103    V2_5,
104}
105
106/// Phi model version.
107#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
108pub enum PhiVersion {
109    /// Phi 3.
110    V3,
111    /// Phi 3.5.
112    V3_5,
113    /// Phi 4.
114    V4,
115}
116
117/// Gemma model version.
118#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
119pub enum GemmaVersion {
120    /// Gemma 1.
121    V1,
122    /// Gemma 2.
123    V2,
124}
125
126/// Supported model architectures.
127#[derive(Debug, Clone, Serialize, Deserialize)]
128#[serde(tag = "type", rename_all = "snake_case")]
129pub enum ModelArchitecture {
130    // === Decoder-only (Causal LM) ===
131    /// Llama family models.
132    Llama {
133        /// Model version.
134        version: LlamaVersion,
135    },
136    /// Mistral family models.
137    Mistral {
138        /// Model variant.
139        variant: MistralVariant,
140    },
141    /// Mixtral MoE models.
142    Mixtral {
143        /// Number of experts.
144        num_experts: u8,
145    },
146    /// Qwen family models.
147    Qwen {
148        /// Model version.
149        version: QwenVersion,
150    },
151    /// Phi family models.
152    Phi {
153        /// Model version.
154        version: PhiVersion,
155    },
156    /// Gemma family models.
157    Gemma {
158        /// Model version.
159        version: GemmaVersion,
160    },
161    /// DeepSeek models.
162    DeepSeek {
163        /// Model version.
164        version: u8,
165    },
166
167    // === Encoder-only (Embeddings) ===
168    /// BERT-based models.
169    Bert,
170    /// Nomic Embed models.
171    NomicEmbed,
172    /// Jina Embed models.
173    JinaEmbed,
174
175    // === Vision-Language ===
176    /// LLaVA-Next models.
177    LlavaNext,
178    /// Qwen2-VL models.
179    Qwen2VL,
180    /// Pixtral models.
181    Pixtral,
182
183    // === Code-specialized ===
184    /// CodeLlama models.
185    CodeLlama,
186    /// StarCoder 2 models.
187    StarCoder2,
188    /// DeepSeek Coder models.
189    DeepSeekCoder {
190        /// Model version.
191        version: u8,
192    },
193}
194
195impl ModelArchitecture {
196    /// Returns `true` if this architecture supports vision input.
197    #[must_use]
198    pub fn supports_vision(&self) -> bool {
199        matches!(self, Self::LlavaNext | Self::Qwen2VL | Self::Pixtral)
200    }
201
202    /// Returns `true` if this is an embedding model.
203    #[must_use]
204    pub fn is_embedding_model(&self) -> bool {
205        matches!(self, Self::Bert | Self::NomicEmbed | Self::JinaEmbed)
206    }
207
208    /// Returns `true` if this is specialized for code.
209    #[must_use]
210    pub fn is_code_specialized(&self) -> bool {
211        matches!(
212            self,
213            Self::CodeLlama | Self::StarCoder2 | Self::DeepSeekCoder { .. }
214        )
215    }
216}
217
218/// Model metadata and capabilities.
219#[derive(Debug, Clone, Serialize, Deserialize)]
220pub struct ModelMetadata {
221    /// Unique model identifier.
222    pub id: ModelId,
223    /// Model architecture.
224    pub architecture: ModelArchitecture,
225    /// Model source location.
226    pub source: ModelSource,
227    /// Maximum context length in tokens.
228    pub context_length: u32,
229    /// Vocabulary size.
230    pub vocab_size: u32,
231    /// Hidden dimension size.
232    pub hidden_size: u32,
233    /// Number of layers.
234    pub num_layers: u32,
235    /// Number of attention heads.
236    pub num_attention_heads: u32,
237    /// Number of key-value heads (for GQA).
238    pub num_kv_heads: Option<u32>,
239    /// Quantization applied to the model.
240    pub quantization: Option<QuantizationType>,
241    /// Model size in bytes.
242    pub size_bytes: Option<u64>,
243    /// Human-readable description.
244    pub description: Option<String>,
245}
246
247impl ModelMetadata {
248    /// Creates a new `ModelMetadata` builder.
249    #[must_use]
250    pub fn builder(
251        id: impl Into<ModelId>,
252        architecture: ModelArchitecture,
253    ) -> ModelMetadataBuilder {
254        ModelMetadataBuilder::new(id, architecture)
255    }
256}
257
258/// Builder for `ModelMetadata`.
259#[derive(Debug)]
260pub struct ModelMetadataBuilder {
261    id: ModelId,
262    architecture: ModelArchitecture,
263    source: Option<ModelSource>,
264    context_length: u32,
265    vocab_size: u32,
266    hidden_size: u32,
267    num_layers: u32,
268    num_attention_heads: u32,
269    num_kv_heads: Option<u32>,
270    quantization: Option<QuantizationType>,
271    size_bytes: Option<u64>,
272    description: Option<String>,
273}
274
275impl ModelMetadataBuilder {
276    /// Creates a new builder.
277    #[must_use]
278    pub fn new(id: impl Into<ModelId>, architecture: ModelArchitecture) -> Self {
279        Self {
280            id: id.into(),
281            architecture,
282            source: None,
283            context_length: 4096,
284            vocab_size: 32000,
285            hidden_size: 4096,
286            num_layers: 32,
287            num_attention_heads: 32,
288            num_kv_heads: None,
289            quantization: None,
290            size_bytes: None,
291            description: None,
292        }
293    }
294
295    /// Sets the model source.
296    #[must_use]
297    pub fn source(mut self, source: ModelSource) -> Self {
298        self.source = Some(source);
299        self
300    }
301
302    /// Sets the context length.
303    #[must_use]
304    pub fn context_length(mut self, length: u32) -> Self {
305        self.context_length = length;
306        self
307    }
308
309    /// Sets the vocabulary size.
310    #[must_use]
311    pub fn vocab_size(mut self, size: u32) -> Self {
312        self.vocab_size = size;
313        self
314    }
315
316    /// Sets the hidden size.
317    #[must_use]
318    pub fn hidden_size(mut self, size: u32) -> Self {
319        self.hidden_size = size;
320        self
321    }
322
323    /// Sets the number of layers.
324    #[must_use]
325    pub fn num_layers(mut self, layers: u32) -> Self {
326        self.num_layers = layers;
327        self
328    }
329
330    /// Sets the number of attention heads.
331    #[must_use]
332    pub fn num_attention_heads(mut self, heads: u32) -> Self {
333        self.num_attention_heads = heads;
334        self
335    }
336
337    /// Sets the number of KV heads.
338    #[must_use]
339    pub fn num_kv_heads(mut self, heads: u32) -> Self {
340        self.num_kv_heads = Some(heads);
341        self
342    }
343
344    /// Sets the quantization type.
345    #[must_use]
346    pub fn quantization(mut self, quant: QuantizationType) -> Self {
347        self.quantization = Some(quant);
348        self
349    }
350
351    /// Sets the model size in bytes.
352    #[must_use]
353    pub fn size_bytes(mut self, size: u64) -> Self {
354        self.size_bytes = Some(size);
355        self
356    }
357
358    /// Sets the description.
359    #[must_use]
360    pub fn description(mut self, desc: impl Into<String>) -> Self {
361        self.description = Some(desc.into());
362        self
363    }
364
365    /// Builds the `ModelMetadata`.
366    ///
367    /// # Panics
368    ///
369    /// Panics if source is not set.
370    #[must_use]
371    pub fn build(self) -> ModelMetadata {
372        ModelMetadata {
373            id: self.id,
374            architecture: self.architecture,
375            source: self.source.expect("source must be set"),
376            context_length: self.context_length,
377            vocab_size: self.vocab_size,
378            hidden_size: self.hidden_size,
379            num_layers: self.num_layers,
380            num_attention_heads: self.num_attention_heads,
381            num_kv_heads: self.num_kv_heads,
382            quantization: self.quantization,
383            size_bytes: self.size_bytes,
384            description: self.description,
385        }
386    }
387}