use super::types::{ModelFormat, ModelType, Quantization};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct HardwareRequirements {
pub min_vram_mb: u32,
pub recommended_vram_mb: u32,
pub min_ram_mb: u32,
pub execution_providers: Vec<String>,
pub architectures: Vec<String>,
}
impl Default for HardwareRequirements {
fn default() -> Self {
Self {
min_vram_mb: 0,
recommended_vram_mb: 0,
min_ram_mb: 2048,
execution_providers: vec!["cpu".to_string()],
architectures: vec!["aarch64".to_string(), "x86_64".to_string()],
}
}
}
impl HardwareRequirements {
pub fn small_llm() -> Self {
Self {
min_vram_mb: 2048,
recommended_vram_mb: 4096,
min_ram_mb: 4096,
execution_providers: vec!["cuda".to_string(), "cpu".to_string()],
architectures: vec!["aarch64".to_string(), "x86_64".to_string()],
}
}
pub fn medium_llm() -> Self {
Self {
min_vram_mb: 4096,
recommended_vram_mb: 8192,
min_ram_mb: 8192,
execution_providers: vec!["cuda".to_string(), "cpu".to_string()],
architectures: vec!["aarch64".to_string(), "x86_64".to_string()],
}
}
pub fn yolo_nano() -> Self {
Self {
min_vram_mb: 512,
recommended_vram_mb: 1024,
min_ram_mb: 1024,
execution_providers: vec![
"tensorrt".to_string(),
"cuda".to_string(),
"cpu".to_string(),
],
architectures: vec!["aarch64".to_string(), "x86_64".to_string()],
}
}
pub fn can_run_on(&self, available_vram_mb: u32, available_ram_mb: u32, arch: &str) -> bool {
let vram_ok = available_vram_mb >= self.min_vram_mb || self.min_vram_mb == 0;
let ram_ok = available_ram_mb >= self.min_ram_mb;
let arch_ok = self.architectures.iter().any(|a| a == arch);
vram_ok && ram_ok && arch_ok
}
pub fn supports_provider(&self, provider: &str) -> bool {
self.execution_providers.iter().any(|p| p == provider)
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ModelManifest {
pub model_id: String,
pub name: String,
pub model_type: ModelType,
pub format: ModelFormat,
pub version: String,
pub quantization: Quantization,
pub size_bytes: u64,
pub sha256: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub blob_hash: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub download_url: Option<String>,
pub requirements: HardwareRequirements,
pub features: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub params_billions: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub context_length: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub classes: Option<Vec<String>>,
pub license: String,
pub source: String,
pub created_at: DateTime<Utc>,
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
pub metadata: HashMap<String, serde_json::Value>,
}
impl ModelManifest {
pub fn new(
model_id: impl Into<String>,
name: impl Into<String>,
model_type: ModelType,
) -> Self {
Self {
model_id: model_id.into(),
name: name.into(),
model_type,
format: ModelFormat::Gguf,
version: "1.0.0".to_string(),
quantization: Quantization::Q4_K_M,
size_bytes: 0,
sha256: String::new(),
blob_hash: None,
download_url: None,
requirements: HardwareRequirements::default(),
features: Vec::new(),
params_billions: None,
context_length: None,
classes: None,
license: "Apache-2.0".to_string(),
source: String::new(),
created_at: Utc::now(),
metadata: HashMap::new(),
}
}
pub fn ministral_3b(quantization: Quantization) -> Self {
Self::new("ministral-3b", "Ministral 3B Instruct", ModelType::Llm)
.with_version("25.12")
.with_format(ModelFormat::Gguf)
.with_quantization(quantization)
.with_params(3.0)
.with_context_length(256_000)
.with_requirements(HardwareRequirements::small_llm())
.with_source("Mistral AI")
.with_license("Apache-2.0")
.with_feature("chat")
.with_feature("function_calling")
.with_feature("vision")
}
pub fn ministral_8b(quantization: Quantization) -> Self {
Self::new("ministral-8b", "Ministral 8B Instruct", ModelType::Llm)
.with_version("25.12")
.with_format(ModelFormat::Gguf)
.with_quantization(quantization)
.with_params(8.0)
.with_context_length(256_000)
.with_requirements(HardwareRequirements::medium_llm())
.with_source("Mistral AI")
.with_license("Apache-2.0")
.with_feature("chat")
.with_feature("function_calling")
.with_feature("vision")
}
pub fn yolov8n() -> Self {
Self::new("yolov8n", "YOLOv8 Nano", ModelType::Detector)
.with_version("8.0.0")
.with_format(ModelFormat::Onnx)
.with_quantization(Quantization::F16)
.with_requirements(HardwareRequirements::yolo_nano())
.with_source("Ultralytics")
.with_license("AGPL-3.0")
.with_feature("coco_80")
}
pub fn with_version(mut self, version: impl Into<String>) -> Self {
self.version = version.into();
self
}
pub fn with_format(mut self, format: ModelFormat) -> Self {
self.format = format;
self
}
pub fn with_quantization(mut self, quantization: Quantization) -> Self {
self.quantization = quantization;
self
}
pub fn with_size_bytes(mut self, size: u64) -> Self {
self.size_bytes = size;
self
}
pub fn with_sha256(mut self, hash: impl Into<String>) -> Self {
self.sha256 = hash.into();
self
}
pub fn with_blob_hash(mut self, hash: impl Into<String>) -> Self {
self.blob_hash = Some(hash.into());
self
}
pub fn with_download_url(mut self, url: impl Into<String>) -> Self {
self.download_url = Some(url.into());
self
}
pub fn with_requirements(mut self, requirements: HardwareRequirements) -> Self {
self.requirements = requirements;
self
}
pub fn with_feature(mut self, feature: impl Into<String>) -> Self {
self.features.push(feature.into());
self
}
pub fn with_params(mut self, billions: f32) -> Self {
self.params_billions = Some(billions);
self
}
pub fn with_context_length(mut self, length: u32) -> Self {
self.context_length = Some(length);
self
}
pub fn with_classes(mut self, classes: Vec<String>) -> Self {
self.classes = Some(classes);
self
}
pub fn with_license(mut self, license: impl Into<String>) -> Self {
self.license = license.into();
self
}
pub fn with_source(mut self, source: impl Into<String>) -> Self {
self.source = source.into();
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
pub fn estimated_vram_mb(&self) -> u32 {
if let Some(params) = self.params_billions {
let base_mb = (params * 2.0 * 1024.0) as u32;
(base_mb as f32 * self.quantization.memory_factor() * 1.2) as u32
} else {
self.requirements.recommended_vram_mb
}
}
pub fn filename(&self) -> String {
format!(
"{}-{}-{}.{}",
self.model_id,
self.version.replace('.', "_"),
self.quantization.as_str().to_lowercase(),
self.format.extension()
)
}
pub fn can_run_on(&self, vram_mb: u32, ram_mb: u32, arch: &str) -> bool {
self.requirements.can_run_on(vram_mb, ram_mb, arch)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ModelStatus {
Available,
Downloading,
Ready,
Loaded,
Failed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LocalModelState {
pub model_id: String,
pub status: ModelStatus,
pub local_path: Option<PathBuf>,
pub download_progress: f32,
pub verified_at: Option<DateTime<Utc>>,
pub error: Option<String>,
}
impl LocalModelState {
pub fn available(model_id: impl Into<String>) -> Self {
Self {
model_id: model_id.into(),
status: ModelStatus::Available,
local_path: None,
download_progress: 0.0,
verified_at: None,
error: None,
}
}
pub fn downloading(mut self, progress: f32) -> Self {
self.status = ModelStatus::Downloading;
self.download_progress = progress.clamp(0.0, 1.0);
self
}
pub fn ready(mut self, path: PathBuf) -> Self {
self.status = ModelStatus::Ready;
self.local_path = Some(path);
self.download_progress = 1.0;
self.verified_at = Some(Utc::now());
self
}
pub fn loaded(mut self) -> Self {
self.status = ModelStatus::Loaded;
self
}
pub fn failed(mut self, error: impl Into<String>) -> Self {
self.status = ModelStatus::Failed;
self.error = Some(error.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelUpdateCommand {
pub command_id: String,
pub manifest: ModelManifest,
pub target_nodes: Vec<String>,
pub priority: u8,
pub auto_load: bool,
pub rollback_model_id: Option<String>,
pub timestamp: DateTime<Utc>,
}
impl ModelUpdateCommand {
pub fn new(manifest: ModelManifest) -> Self {
Self {
command_id: uuid::Uuid::new_v4().to_string(),
manifest,
target_nodes: Vec::new(),
priority: 3,
auto_load: true,
rollback_model_id: None,
timestamp: Utc::now(),
}
}
pub fn with_targets(mut self, nodes: Vec<String>) -> Self {
self.target_nodes = nodes;
self
}
pub fn with_priority(mut self, priority: u8) -> Self {
self.priority = priority.clamp(1, 5);
self
}
pub fn with_rollback(mut self, model_id: impl Into<String>) -> Self {
self.rollback_model_id = Some(model_id.into());
self
}
pub fn without_auto_load(mut self) -> Self {
self.auto_load = false;
self
}
pub fn targets_node(&self, node_id: &str) -> bool {
self.target_nodes.is_empty() || self.target_nodes.iter().any(|n| n == node_id)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_manifest_ministral() {
let manifest = ModelManifest::ministral_3b(Quantization::Q4_K_M);
assert_eq!(manifest.model_id, "ministral-3b");
assert_eq!(manifest.model_type, ModelType::Llm);
assert_eq!(manifest.format, ModelFormat::Gguf);
assert_eq!(manifest.quantization, Quantization::Q4_K_M);
assert_eq!(manifest.context_length, Some(256_000));
assert!(manifest.features.contains(&"chat".to_string()));
}
#[test]
fn test_manifest_yolo() {
let manifest = ModelManifest::yolov8n();
assert_eq!(manifest.model_id, "yolov8n");
assert_eq!(manifest.model_type, ModelType::Detector);
assert_eq!(manifest.format, ModelFormat::Onnx);
}
#[test]
fn test_filename_generation() {
let manifest = ModelManifest::ministral_3b(Quantization::Q4_K_M).with_version("25.12");
assert_eq!(manifest.filename(), "ministral-3b-25_12-q4_k_m.gguf");
}
#[test]
fn test_hardware_requirements() {
let reqs = HardwareRequirements::small_llm();
assert!(reqs.can_run_on(4096, 8192, "aarch64"));
assert!(!reqs.can_run_on(512, 2048, "aarch64"));
assert!(!reqs.can_run_on(4096, 8192, "armv7"));
}
#[test]
fn test_update_command_targeting() {
let manifest = ModelManifest::ministral_3b(Quantization::Q4_K_M);
let cmd = ModelUpdateCommand::new(manifest);
assert!(cmd.targets_node("any-node"));
let cmd = cmd.with_targets(vec!["node-1".to_string(), "node-2".to_string()]);
assert!(cmd.targets_node("node-1"));
assert!(cmd.targets_node("node-2"));
assert!(!cmd.targets_node("node-3"));
}
#[test]
fn test_local_model_state_transitions() {
let state = LocalModelState::available("ministral-3b");
assert_eq!(state.status, ModelStatus::Available);
let state = state.downloading(0.5);
assert_eq!(state.status, ModelStatus::Downloading);
assert_eq!(state.download_progress, 0.5);
let state = state.ready(PathBuf::from("/models/ministral-3b.gguf"));
assert_eq!(state.status, ModelStatus::Ready);
assert!(state.verified_at.is_some());
let state = state.loaded();
assert_eq!(state.status, ModelStatus::Loaded);
}
}