1use super::types::{ModelFormat, ModelType, Quantization};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::path::PathBuf;
8
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
11pub struct HardwareRequirements {
12 pub min_vram_mb: u32,
14 pub recommended_vram_mb: u32,
16 pub min_ram_mb: u32,
18 pub execution_providers: Vec<String>,
20 pub architectures: Vec<String>,
22}
23
24impl Default for HardwareRequirements {
25 fn default() -> Self {
26 Self {
27 min_vram_mb: 0,
28 recommended_vram_mb: 0,
29 min_ram_mb: 2048,
30 execution_providers: vec!["cpu".to_string()],
31 architectures: vec!["aarch64".to_string(), "x86_64".to_string()],
32 }
33 }
34}
35
36impl HardwareRequirements {
37 pub fn small_llm() -> Self {
39 Self {
40 min_vram_mb: 2048,
41 recommended_vram_mb: 4096,
42 min_ram_mb: 4096,
43 execution_providers: vec!["cuda".to_string(), "cpu".to_string()],
44 architectures: vec!["aarch64".to_string(), "x86_64".to_string()],
45 }
46 }
47
48 pub fn medium_llm() -> Self {
50 Self {
51 min_vram_mb: 4096,
52 recommended_vram_mb: 8192,
53 min_ram_mb: 8192,
54 execution_providers: vec!["cuda".to_string(), "cpu".to_string()],
55 architectures: vec!["aarch64".to_string(), "x86_64".to_string()],
56 }
57 }
58
59 pub fn yolo_nano() -> Self {
61 Self {
62 min_vram_mb: 512,
63 recommended_vram_mb: 1024,
64 min_ram_mb: 1024,
65 execution_providers: vec![
66 "tensorrt".to_string(),
67 "cuda".to_string(),
68 "cpu".to_string(),
69 ],
70 architectures: vec!["aarch64".to_string(), "x86_64".to_string()],
71 }
72 }
73
74 pub fn can_run_on(&self, available_vram_mb: u32, available_ram_mb: u32, arch: &str) -> bool {
76 let vram_ok = available_vram_mb >= self.min_vram_mb || self.min_vram_mb == 0;
77 let ram_ok = available_ram_mb >= self.min_ram_mb;
78 let arch_ok = self.architectures.iter().any(|a| a == arch);
79 vram_ok && ram_ok && arch_ok
80 }
81
82 pub fn supports_provider(&self, provider: &str) -> bool {
84 self.execution_providers.iter().any(|p| p == provider)
85 }
86}
87
88#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
92pub struct ModelManifest {
93 pub model_id: String,
95
96 pub name: String,
98
99 pub model_type: ModelType,
101
102 pub format: ModelFormat,
104
105 pub version: String,
107
108 pub quantization: Quantization,
110
111 pub size_bytes: u64,
113
114 pub sha256: String,
116
117 #[serde(skip_serializing_if = "Option::is_none")]
119 pub blob_hash: Option<String>,
120
121 #[serde(skip_serializing_if = "Option::is_none")]
123 pub download_url: Option<String>,
124
125 pub requirements: HardwareRequirements,
127
128 pub features: Vec<String>,
130
131 #[serde(skip_serializing_if = "Option::is_none")]
133 pub params_billions: Option<f32>,
134
135 #[serde(skip_serializing_if = "Option::is_none")]
137 pub context_length: Option<u32>,
138
139 #[serde(skip_serializing_if = "Option::is_none")]
141 pub classes: Option<Vec<String>>,
142
143 pub license: String,
145
146 pub source: String,
148
149 pub created_at: DateTime<Utc>,
151
152 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
154 pub metadata: HashMap<String, serde_json::Value>,
155}
156
157impl ModelManifest {
158 pub fn new(
160 model_id: impl Into<String>,
161 name: impl Into<String>,
162 model_type: ModelType,
163 ) -> Self {
164 Self {
165 model_id: model_id.into(),
166 name: name.into(),
167 model_type,
168 format: ModelFormat::Gguf,
169 version: "1.0.0".to_string(),
170 quantization: Quantization::Q4_K_M,
171 size_bytes: 0,
172 sha256: String::new(),
173 blob_hash: None,
174 download_url: None,
175 requirements: HardwareRequirements::default(),
176 features: Vec::new(),
177 params_billions: None,
178 context_length: None,
179 classes: None,
180 license: "Apache-2.0".to_string(),
181 source: String::new(),
182 created_at: Utc::now(),
183 metadata: HashMap::new(),
184 }
185 }
186
187 pub fn ministral_3b(quantization: Quantization) -> Self {
189 Self::new("ministral-3b", "Ministral 3B Instruct", ModelType::Llm)
190 .with_version("25.12")
191 .with_format(ModelFormat::Gguf)
192 .with_quantization(quantization)
193 .with_params(3.0)
194 .with_context_length(256_000)
195 .with_requirements(HardwareRequirements::small_llm())
196 .with_source("Mistral AI")
197 .with_license("Apache-2.0")
198 .with_feature("chat")
199 .with_feature("function_calling")
200 .with_feature("vision")
201 }
202
203 pub fn ministral_8b(quantization: Quantization) -> Self {
205 Self::new("ministral-8b", "Ministral 8B Instruct", ModelType::Llm)
206 .with_version("25.12")
207 .with_format(ModelFormat::Gguf)
208 .with_quantization(quantization)
209 .with_params(8.0)
210 .with_context_length(256_000)
211 .with_requirements(HardwareRequirements::medium_llm())
212 .with_source("Mistral AI")
213 .with_license("Apache-2.0")
214 .with_feature("chat")
215 .with_feature("function_calling")
216 .with_feature("vision")
217 }
218
219 pub fn yolov8n() -> Self {
221 Self::new("yolov8n", "YOLOv8 Nano", ModelType::Detector)
222 .with_version("8.0.0")
223 .with_format(ModelFormat::Onnx)
224 .with_quantization(Quantization::F16)
225 .with_requirements(HardwareRequirements::yolo_nano())
226 .with_source("Ultralytics")
227 .with_license("AGPL-3.0")
228 .with_feature("coco_80")
229 }
230
231 pub fn with_version(mut self, version: impl Into<String>) -> Self {
234 self.version = version.into();
235 self
236 }
237
238 pub fn with_format(mut self, format: ModelFormat) -> Self {
239 self.format = format;
240 self
241 }
242
243 pub fn with_quantization(mut self, quantization: Quantization) -> Self {
244 self.quantization = quantization;
245 self
246 }
247
248 pub fn with_size_bytes(mut self, size: u64) -> Self {
249 self.size_bytes = size;
250 self
251 }
252
253 pub fn with_sha256(mut self, hash: impl Into<String>) -> Self {
254 self.sha256 = hash.into();
255 self
256 }
257
258 pub fn with_blob_hash(mut self, hash: impl Into<String>) -> Self {
259 self.blob_hash = Some(hash.into());
260 self
261 }
262
263 pub fn with_download_url(mut self, url: impl Into<String>) -> Self {
264 self.download_url = Some(url.into());
265 self
266 }
267
268 pub fn with_requirements(mut self, requirements: HardwareRequirements) -> Self {
269 self.requirements = requirements;
270 self
271 }
272
273 pub fn with_feature(mut self, feature: impl Into<String>) -> Self {
274 self.features.push(feature.into());
275 self
276 }
277
278 pub fn with_params(mut self, billions: f32) -> Self {
279 self.params_billions = Some(billions);
280 self
281 }
282
283 pub fn with_context_length(mut self, length: u32) -> Self {
284 self.context_length = Some(length);
285 self
286 }
287
288 pub fn with_classes(mut self, classes: Vec<String>) -> Self {
289 self.classes = Some(classes);
290 self
291 }
292
293 pub fn with_license(mut self, license: impl Into<String>) -> Self {
294 self.license = license.into();
295 self
296 }
297
298 pub fn with_source(mut self, source: impl Into<String>) -> Self {
299 self.source = source.into();
300 self
301 }
302
303 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
304 self.metadata.insert(key.into(), value);
305 self
306 }
307
308 pub fn estimated_vram_mb(&self) -> u32 {
310 if let Some(params) = self.params_billions {
311 let base_mb = (params * 2.0 * 1024.0) as u32;
313 (base_mb as f32 * self.quantization.memory_factor() * 1.2) as u32
314 } else {
315 self.requirements.recommended_vram_mb
316 }
317 }
318
319 pub fn filename(&self) -> String {
321 format!(
322 "{}-{}-{}.{}",
323 self.model_id,
324 self.version.replace('.', "_"),
325 self.quantization.as_str().to_lowercase(),
326 self.format.extension()
327 )
328 }
329
330 pub fn can_run_on(&self, vram_mb: u32, ram_mb: u32, arch: &str) -> bool {
332 self.requirements.can_run_on(vram_mb, ram_mb, arch)
333 }
334}
335
336#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
338pub enum ModelStatus {
339 Available,
341 Downloading,
343 Ready,
345 Loaded,
347 Failed,
349}
350
351#[derive(Debug, Clone, Serialize, Deserialize)]
353pub struct LocalModelState {
354 pub model_id: String,
356 pub status: ModelStatus,
358 pub local_path: Option<PathBuf>,
360 pub download_progress: f32,
362 pub verified_at: Option<DateTime<Utc>>,
364 pub error: Option<String>,
366}
367
368impl LocalModelState {
369 pub fn available(model_id: impl Into<String>) -> Self {
371 Self {
372 model_id: model_id.into(),
373 status: ModelStatus::Available,
374 local_path: None,
375 download_progress: 0.0,
376 verified_at: None,
377 error: None,
378 }
379 }
380
381 pub fn downloading(mut self, progress: f32) -> Self {
383 self.status = ModelStatus::Downloading;
384 self.download_progress = progress.clamp(0.0, 1.0);
385 self
386 }
387
388 pub fn ready(mut self, path: PathBuf) -> Self {
390 self.status = ModelStatus::Ready;
391 self.local_path = Some(path);
392 self.download_progress = 1.0;
393 self.verified_at = Some(Utc::now());
394 self
395 }
396
397 pub fn loaded(mut self) -> Self {
399 self.status = ModelStatus::Loaded;
400 self
401 }
402
403 pub fn failed(mut self, error: impl Into<String>) -> Self {
405 self.status = ModelStatus::Failed;
406 self.error = Some(error.into());
407 self
408 }
409}
410
411#[derive(Debug, Clone, Serialize, Deserialize)]
413pub struct ModelUpdateCommand {
414 pub command_id: String,
416 pub manifest: ModelManifest,
418 pub target_nodes: Vec<String>,
420 pub priority: u8,
422 pub auto_load: bool,
424 pub rollback_model_id: Option<String>,
426 pub timestamp: DateTime<Utc>,
428}
429
430impl ModelUpdateCommand {
431 pub fn new(manifest: ModelManifest) -> Self {
433 Self {
434 command_id: uuid::Uuid::new_v4().to_string(),
435 manifest,
436 target_nodes: Vec::new(),
437 priority: 3,
438 auto_load: true,
439 rollback_model_id: None,
440 timestamp: Utc::now(),
441 }
442 }
443
444 pub fn with_targets(mut self, nodes: Vec<String>) -> Self {
446 self.target_nodes = nodes;
447 self
448 }
449
450 pub fn with_priority(mut self, priority: u8) -> Self {
452 self.priority = priority.clamp(1, 5);
453 self
454 }
455
456 pub fn with_rollback(mut self, model_id: impl Into<String>) -> Self {
458 self.rollback_model_id = Some(model_id.into());
459 self
460 }
461
462 pub fn without_auto_load(mut self) -> Self {
464 self.auto_load = false;
465 self
466 }
467
468 pub fn targets_node(&self, node_id: &str) -> bool {
470 self.target_nodes.is_empty() || self.target_nodes.iter().any(|n| n == node_id)
471 }
472}
473
474#[cfg(test)]
475mod tests {
476 use super::*;
477
478 #[test]
479 fn test_manifest_ministral() {
480 let manifest = ModelManifest::ministral_3b(Quantization::Q4_K_M);
481
482 assert_eq!(manifest.model_id, "ministral-3b");
483 assert_eq!(manifest.model_type, ModelType::Llm);
484 assert_eq!(manifest.format, ModelFormat::Gguf);
485 assert_eq!(manifest.quantization, Quantization::Q4_K_M);
486 assert_eq!(manifest.context_length, Some(256_000));
487 assert!(manifest.features.contains(&"chat".to_string()));
488 }
489
490 #[test]
491 fn test_manifest_yolo() {
492 let manifest = ModelManifest::yolov8n();
493
494 assert_eq!(manifest.model_id, "yolov8n");
495 assert_eq!(manifest.model_type, ModelType::Detector);
496 assert_eq!(manifest.format, ModelFormat::Onnx);
497 }
498
499 #[test]
500 fn test_filename_generation() {
501 let manifest = ModelManifest::ministral_3b(Quantization::Q4_K_M).with_version("25.12");
502
503 assert_eq!(manifest.filename(), "ministral-3b-25_12-q4_k_m.gguf");
504 }
505
506 #[test]
507 fn test_hardware_requirements() {
508 let reqs = HardwareRequirements::small_llm();
509
510 assert!(reqs.can_run_on(4096, 8192, "aarch64"));
512
513 assert!(!reqs.can_run_on(512, 2048, "aarch64"));
515
516 assert!(!reqs.can_run_on(4096, 8192, "armv7"));
518 }
519
520 #[test]
521 fn test_update_command_targeting() {
522 let manifest = ModelManifest::ministral_3b(Quantization::Q4_K_M);
523 let cmd = ModelUpdateCommand::new(manifest);
524
525 assert!(cmd.targets_node("any-node"));
527
528 let cmd = cmd.with_targets(vec!["node-1".to_string(), "node-2".to_string()]);
529 assert!(cmd.targets_node("node-1"));
530 assert!(cmd.targets_node("node-2"));
531 assert!(!cmd.targets_node("node-3"));
532 }
533
534 #[test]
535 fn test_local_model_state_transitions() {
536 let state = LocalModelState::available("ministral-3b");
537 assert_eq!(state.status, ModelStatus::Available);
538
539 let state = state.downloading(0.5);
540 assert_eq!(state.status, ModelStatus::Downloading);
541 assert_eq!(state.download_progress, 0.5);
542
543 let state = state.ready(PathBuf::from("/models/ministral-3b.gguf"));
544 assert_eq!(state.status, ModelStatus::Ready);
545 assert!(state.verified_at.is_some());
546
547 let state = state.loaded();
548 assert_eq!(state.status, ModelStatus::Loaded);
549 }
550}