use std::collections::HashMap;
use std::fmt;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct LoadModelRequest {
pub model: String,
#[serde(default)]
pub options: Option<ModelLoadOptions>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
pub struct ModelLoadOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub gpu_layers: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub context_length: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub quantization: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub memory_limit: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub flash_attention: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tensor_split: Option<Vec<f32>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct LoadModelResponse {
pub model: String,
pub status: ModelStatus,
pub load_time_ms: u64,
pub memory_bytes: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub message: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct UnloadModelRequest {
pub model: String,
#[serde(default)]
pub force: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct UnloadModelResponse {
pub model: String,
pub success: bool,
pub memory_freed: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub message: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct WarmupModelRequest {
pub model: String,
#[serde(default = "default_warmup_iterations")]
pub iterations: u32,
#[serde(default = "default_warmup_tokens")]
pub tokens: u32,
}
fn default_warmup_iterations() -> u32 {
3
}
fn default_warmup_tokens() -> u32 {
128
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct WarmupModelResponse {
pub model: String,
pub success: bool,
pub iterations_completed: u32,
pub avg_latency_ms: f64,
#[serde(skip_serializing_if = "Option::is_none")]
pub message: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ModelsStatusResponse {
pub models: Vec<AdminModelInfo>,
pub total_memory_bytes: u64,
pub available_memory_bytes: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct AdminModelInfo {
pub model: String,
pub status: ModelStatus,
pub memory_bytes: u64,
pub loaded_at: u64,
pub requests_total: u64,
pub active_requests: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub options: Option<ModelLoadOptions>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum ModelStatus {
Loading,
#[default]
Ready,
WarmingUp,
Unloading,
Failed,
Idle,
}
impl fmt::Display for ModelStatus {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Loading => write!(f, "loading"),
Self::Ready => write!(f, "ready"),
Self::WarmingUp => write!(f, "warming_up"),
Self::Unloading => write!(f, "unloading"),
Self::Failed => write!(f, "failed"),
Self::Idle => write!(f, "idle"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AdminError {
ModelNotFound(String),
ModelAlreadyLoaded(String),
ModelLoading(String),
ModelBusy(String, u32),
InsufficientMemory {
required: u64,
available: u64,
},
InvalidConfig(String),
Timeout(Duration),
Internal(String),
}
impl std::error::Error for AdminError {}
impl fmt::Display for AdminError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ModelNotFound(model) => write!(f, "model not found: {model}"),
Self::ModelAlreadyLoaded(model) => write!(f, "model already loaded: {model}"),
Self::ModelLoading(model) => write!(f, "model is currently loading: {model}"),
Self::ModelBusy(model, count) => {
write!(f, "model has {count} active requests: {model}")
},
Self::InsufficientMemory {
required,
available,
} => {
write!(
f,
"insufficient memory: required {} bytes, available {} bytes",
required, available
)
},
Self::InvalidConfig(msg) => write!(f, "invalid configuration: {msg}"),
Self::Timeout(duration) => write!(f, "operation timed out after {:?}", duration),
Self::Internal(msg) => write!(f, "internal error: {msg}"),
}
}
}
#[derive(Debug)]
struct LoadedModel {
model: String,
status: RwLock<ModelStatus>,
memory_bytes: u64,
loaded_at: Instant,
loaded_at_unix: u64,
requests_total: AtomicU64,
active_requests: std::sync::atomic::AtomicU32,
options: Option<ModelLoadOptions>,
}
#[allow(dead_code)]
impl LoadedModel {
#[must_use]
fn loaded_at(&self) -> Instant {
self.loaded_at
}
#[must_use]
fn uptime(&self) -> std::time::Duration {
self.loaded_at.elapsed()
}
}
#[derive(Debug)]
pub struct ModelRegistry {
models: RwLock<HashMap<String, Arc<LoadedModel>>>,
total_memory: AtomicU64,
available_memory: u64,
}
impl ModelRegistry {
pub fn new(available_memory: u64) -> Self {
Self {
models: RwLock::new(HashMap::new()),
total_memory: AtomicU64::new(0),
available_memory,
}
}
pub fn load_model(&self, request: &LoadModelRequest) -> Result<LoadModelResponse, AdminError> {
let start = Instant::now();
{
let models = self.models.read();
if let Some(existing) = models.get(&request.model) {
let status = *existing.status.read();
if status == ModelStatus::Loading {
return Err(AdminError::ModelLoading(request.model.clone()));
}
return Err(AdminError::ModelAlreadyLoaded(request.model.clone()));
}
}
let required_memory = self.estimate_memory(&request.options);
let current_usage = self.total_memory.load(Ordering::Acquire);
if current_usage + required_memory > self.available_memory {
return Err(AdminError::InsufficientMemory {
required: required_memory,
available: self.available_memory.saturating_sub(current_usage),
});
}
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let loaded_model = Arc::new(LoadedModel {
model: request.model.clone(),
status: RwLock::new(ModelStatus::Ready),
memory_bytes: required_memory,
loaded_at: Instant::now(),
loaded_at_unix: now,
requests_total: AtomicU64::new(0),
active_requests: std::sync::atomic::AtomicU32::new(0),
options: request.options.clone(),
});
{
let mut models = self.models.write();
models.insert(request.model.clone(), loaded_model);
}
self.total_memory
.fetch_add(required_memory, Ordering::AcqRel);
Ok(LoadModelResponse {
model: request.model.clone(),
status: ModelStatus::Ready,
load_time_ms: start.elapsed().as_millis() as u64,
memory_bytes: required_memory,
message: None,
})
}
pub fn unload_model(
&self,
request: &UnloadModelRequest,
) -> Result<UnloadModelResponse, AdminError> {
let model = {
let models = self.models.read();
models
.get(&request.model)
.cloned()
.ok_or_else(|| AdminError::ModelNotFound(request.model.clone()))?
};
let active = model.active_requests.load(Ordering::Acquire);
if active > 0 && !request.force {
return Err(AdminError::ModelBusy(request.model.clone(), active));
}
*model.status.write() = ModelStatus::Unloading;
let memory_freed = {
let mut models = self.models.write();
if let Some(removed) = models.remove(&request.model) {
removed.memory_bytes
} else {
0
}
};
self.total_memory.fetch_sub(memory_freed, Ordering::AcqRel);
Ok(UnloadModelResponse {
model: request.model.clone(),
success: true,
memory_freed,
message: if request.force && active > 0 {
Some(format!("Force unloaded with {} active requests", active))
} else {
None
},
})
}
pub fn warmup_model(
&self,
request: &WarmupModelRequest,
) -> Result<WarmupModelResponse, AdminError> {
let model = {
let models = self.models.read();
models
.get(&request.model)
.cloned()
.ok_or_else(|| AdminError::ModelNotFound(request.model.clone()))?
};
let original_status = *model.status.read();
*model.status.write() = ModelStatus::WarmingUp;
let start = Instant::now();
let iterations = request.iterations.min(10);
*model.status.write() = original_status;
let total_time = start.elapsed();
let avg_latency = if iterations > 0 {
total_time.as_secs_f64() * 1000.0 / f64::from(iterations)
} else {
0.0
};
Ok(WarmupModelResponse {
model: request.model.clone(),
success: true,
iterations_completed: iterations,
avg_latency_ms: avg_latency,
message: None,
})
}
pub fn get_status(&self) -> ModelsStatusResponse {
let models = self.models.read();
let model_infos: Vec<AdminModelInfo> = models
.values()
.map(|m| AdminModelInfo {
model: m.model.clone(),
status: *m.status.read(),
memory_bytes: m.memory_bytes,
loaded_at: m.loaded_at_unix,
requests_total: m.requests_total.load(Ordering::Acquire),
active_requests: m.active_requests.load(Ordering::Acquire),
options: m.options.clone(),
})
.collect();
ModelsStatusResponse {
models: model_infos,
total_memory_bytes: self.total_memory.load(Ordering::Acquire),
available_memory_bytes: self.available_memory,
}
}
pub fn get_model_info(&self, model_id: &str) -> Option<AdminModelInfo> {
let models = self.models.read();
models.get(model_id).map(|m| AdminModelInfo {
model: m.model.clone(),
status: *m.status.read(),
memory_bytes: m.memory_bytes,
loaded_at: m.loaded_at_unix,
requests_total: m.requests_total.load(Ordering::Acquire),
active_requests: m.active_requests.load(Ordering::Acquire),
options: m.options.clone(),
})
}
pub fn record_request_start(&self, model_id: &str) -> bool {
let models = self.models.read();
if let Some(model) = models.get(model_id) {
model.active_requests.fetch_add(1, Ordering::AcqRel);
model.requests_total.fetch_add(1, Ordering::AcqRel);
true
} else {
false
}
}
pub fn record_request_end(&self, model_id: &str) {
let models = self.models.read();
if let Some(model) = models.get(model_id) {
model.active_requests.fetch_sub(1, Ordering::AcqRel);
}
}
fn estimate_memory(&self, options: &Option<ModelLoadOptions>) -> u64 {
let base = 1024 * 1024 * 1024;
if let Some(opts) = options {
let ctx_factor = opts.context_length.unwrap_or(4096) as u64 / 4096;
let gpu_factor = if opts.gpu_layers.unwrap_or(0) > 0 {
2
} else {
1
};
base * ctx_factor * gpu_factor
} else {
base
}
}
pub fn is_loaded(&self, model_id: &str) -> bool {
self.models.read().contains_key(model_id)
}
pub fn model_count(&self) -> usize {
self.models.read().len()
}
pub fn render_prometheus(&self) -> String {
let status = self.get_status();
let mut output = String::with_capacity(1024);
output.push_str("# HELP infernum_model_memory_bytes Memory used by each model\n");
output.push_str("# TYPE infernum_model_memory_bytes gauge\n");
for model in &status.models {
output.push_str(&format!(
"infernum_model_memory_bytes{{model=\"{}\"}} {}\n",
model.model, model.memory_bytes
));
}
output.push_str("# HELP infernum_model_active_requests Active requests per model\n");
output.push_str("# TYPE infernum_model_active_requests gauge\n");
for model in &status.models {
output.push_str(&format!(
"infernum_model_active_requests{{model=\"{}\"}} {}\n",
model.model, model.active_requests
));
}
output.push_str("# HELP infernum_model_requests_total Total requests per model\n");
output.push_str("# TYPE infernum_model_requests_total counter\n");
for model in &status.models {
output.push_str(&format!(
"infernum_model_requests_total{{model=\"{}\"}} {}\n",
model.model, model.requests_total
));
}
output.push_str("# HELP infernum_model_status Model status (1=ready, 0=other)\n");
output.push_str("# TYPE infernum_model_status gauge\n");
for model in &status.models {
let ready = if model.status == ModelStatus::Ready {
1
} else {
0
};
output.push_str(&format!(
"infernum_model_status{{model=\"{}\",status=\"{}\"}} {}\n",
model.model, model.status, ready
));
}
output.push_str(
"# HELP infernum_models_memory_total_bytes Total memory used by all models\n",
);
output.push_str("# TYPE infernum_models_memory_total_bytes gauge\n");
output.push_str(&format!(
"infernum_models_memory_total_bytes {}\n",
status.total_memory_bytes
));
output.push_str(
"# HELP infernum_models_memory_available_bytes Available memory for models\n",
);
output.push_str("# TYPE infernum_models_memory_available_bytes gauge\n");
output.push_str(&format!(
"infernum_models_memory_available_bytes {}\n",
status.available_memory_bytes
));
output.push_str("# HELP infernum_models_loaded_total Number of loaded models\n");
output.push_str("# TYPE infernum_models_loaded_total gauge\n");
output.push_str(&format!(
"infernum_models_loaded_total {}\n",
status.models.len()
));
output
}
}
impl Default for ModelRegistry {
fn default() -> Self {
Self::new(16 * 1024 * 1024 * 1024)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_load_model_request_serialization() {
let request = LoadModelRequest {
model: "llama-3b".to_string(),
options: Some(ModelLoadOptions {
gpu_layers: Some(32),
context_length: Some(8192),
quantization: Some("q4_k_m".to_string()),
memory_limit: None,
flash_attention: Some(true),
tensor_split: None,
}),
};
let json = serde_json::to_string(&request).expect("serialize");
let parsed: LoadModelRequest = serde_json::from_str(&json).expect("deserialize");
assert_eq!(request, parsed);
}
#[test]
fn test_load_model_request_minimal() {
let json = r#"{"model": "llama-3b"}"#;
let request: LoadModelRequest = serde_json::from_str(json).expect("deserialize");
assert_eq!(request.model, "llama-3b");
assert!(request.options.is_none());
}
#[test]
fn test_unload_model_request_default_force() {
let json = r#"{"model": "llama-3b"}"#;
let request: UnloadModelRequest = serde_json::from_str(json).expect("deserialize");
assert_eq!(request.model, "llama-3b");
assert!(!request.force);
}
#[test]
fn test_warmup_model_request_defaults() {
let json = r#"{"model": "llama-3b"}"#;
let request: WarmupModelRequest = serde_json::from_str(json).expect("deserialize");
assert_eq!(request.model, "llama-3b");
assert_eq!(request.iterations, 3); assert_eq!(request.tokens, 128); }
#[test]
fn test_model_status_display() {
assert_eq!(ModelStatus::Loading.to_string(), "loading");
assert_eq!(ModelStatus::Ready.to_string(), "ready");
assert_eq!(ModelStatus::WarmingUp.to_string(), "warming_up");
assert_eq!(ModelStatus::Unloading.to_string(), "unloading");
assert_eq!(ModelStatus::Failed.to_string(), "failed");
assert_eq!(ModelStatus::Idle.to_string(), "idle");
}
#[test]
fn test_model_status_serialization() {
let json = serde_json::to_string(&ModelStatus::Ready).expect("serialize");
assert_eq!(json, "\"ready\"");
let parsed: ModelStatus = serde_json::from_str("\"warming_up\"").expect("deserialize");
assert_eq!(parsed, ModelStatus::WarmingUp);
}
#[test]
fn test_admin_error_display() {
assert_eq!(
AdminError::ModelNotFound("llama".to_string()).to_string(),
"model not found: llama"
);
assert_eq!(
AdminError::ModelAlreadyLoaded("llama".to_string()).to_string(),
"model already loaded: llama"
);
assert_eq!(
AdminError::ModelBusy("llama".to_string(), 5).to_string(),
"model has 5 active requests: llama"
);
assert_eq!(
AdminError::InsufficientMemory {
required: 1000,
available: 500
}
.to_string(),
"insufficient memory: required 1000 bytes, available 500 bytes"
);
}
#[test]
fn test_registry_load_model() {
let registry = ModelRegistry::new(10 * 1024 * 1024 * 1024);
let request = LoadModelRequest {
model: "llama-3b".to_string(),
options: None,
};
let response = registry.load_model(&request).expect("load should succeed");
assert_eq!(response.model, "llama-3b");
assert_eq!(response.status, ModelStatus::Ready);
assert!(response.memory_bytes > 0);
assert!(registry.is_loaded("llama-3b"));
assert_eq!(registry.model_count(), 1);
}
#[test]
fn test_registry_load_duplicate() {
let registry = ModelRegistry::new(10 * 1024 * 1024 * 1024);
let request = LoadModelRequest {
model: "llama-3b".to_string(),
options: None,
};
let _ = registry.load_model(&request).expect("first load");
let result = registry.load_model(&request);
assert!(matches!(result, Err(AdminError::ModelAlreadyLoaded(_))));
}
#[test]
fn test_registry_insufficient_memory() {
let registry = ModelRegistry::new(1024);
let request = LoadModelRequest {
model: "llama-3b".to_string(),
options: Some(ModelLoadOptions {
context_length: Some(8192),
..Default::default()
}),
};
let result = registry.load_model(&request);
assert!(matches!(result, Err(AdminError::InsufficientMemory { .. })));
}
#[test]
fn test_registry_unload_model() {
let registry = ModelRegistry::new(10 * 1024 * 1024 * 1024);
let load_req = LoadModelRequest {
model: "llama-3b".to_string(),
options: None,
};
let _ = registry.load_model(&load_req).expect("load");
let unload_req = UnloadModelRequest {
model: "llama-3b".to_string(),
force: false,
};
let response = registry.unload_model(&unload_req).expect("unload");
assert!(response.success);
assert!(response.memory_freed > 0);
assert!(!registry.is_loaded("llama-3b"));
assert_eq!(registry.model_count(), 0);
}
#[test]
fn test_registry_unload_not_found() {
let registry = ModelRegistry::new(10 * 1024 * 1024 * 1024);
let request = UnloadModelRequest {
model: "nonexistent".to_string(),
force: false,
};
let result = registry.unload_model(&request);
assert!(matches!(result, Err(AdminError::ModelNotFound(_))));
}
#[test]
fn test_registry_unload_with_active_requests() {
let registry = ModelRegistry::new(10 * 1024 * 1024 * 1024);
let load_req = LoadModelRequest {
model: "llama-3b".to_string(),
options: None,
};
let _ = registry.load_model(&load_req).expect("load");
registry.record_request_start("llama-3b");
let unload_req = UnloadModelRequest {
model: "llama-3b".to_string(),
force: false,
};
let result = registry.unload_model(&unload_req);
assert!(matches!(result, Err(AdminError::ModelBusy(_, 1))));
let unload_req = UnloadModelRequest {
model: "llama-3b".to_string(),
force: true,
};
let response = registry.unload_model(&unload_req).expect("force unload");
assert!(response.success);
assert!(response.message.is_some());
}
#[test]
fn test_registry_warmup_model() {
let registry = ModelRegistry::new(10 * 1024 * 1024 * 1024);
let load_req = LoadModelRequest {
model: "llama-3b".to_string(),
options: None,
};
let _ = registry.load_model(&load_req).expect("load");
let warmup_req = WarmupModelRequest {
model: "llama-3b".to_string(),
iterations: 3,
tokens: 128,
};
let response = registry.warmup_model(&warmup_req).expect("warmup");
assert!(response.success);
assert_eq!(response.iterations_completed, 3);
}
#[test]
fn test_registry_warmup_not_found() {
let registry = ModelRegistry::new(10 * 1024 * 1024 * 1024);
let request = WarmupModelRequest {
model: "nonexistent".to_string(),
iterations: 3,
tokens: 128,
};
let result = registry.warmup_model(&request);
assert!(matches!(result, Err(AdminError::ModelNotFound(_))));
}
#[test]
fn test_registry_get_status() {
let registry = ModelRegistry::new(10 * 1024 * 1024 * 1024);
let _ = registry
.load_model(&LoadModelRequest {
model: "llama-3b".to_string(),
options: None,
})
.expect("load");
let _ = registry
.load_model(&LoadModelRequest {
model: "mistral-7b".to_string(),
options: None,
})
.expect("load");
let status = registry.get_status();
assert_eq!(status.models.len(), 2);
assert!(status.total_memory_bytes > 0);
assert_eq!(status.available_memory_bytes, 10 * 1024 * 1024 * 1024);
}
#[test]
fn test_registry_get_model_info() {
let registry = ModelRegistry::new(10 * 1024 * 1024 * 1024);
let _ = registry
.load_model(&LoadModelRequest {
model: "llama-3b".to_string(),
options: Some(ModelLoadOptions {
gpu_layers: Some(32),
..Default::default()
}),
})
.expect("load");
let info = registry.get_model_info("llama-3b").expect("should exist");
assert_eq!(info.model, "llama-3b");
assert_eq!(info.status, ModelStatus::Ready);
assert!(info.options.is_some());
assert_eq!(info.options.as_ref().expect("options").gpu_layers, Some(32));
}
#[test]
fn test_registry_request_tracking() {
let registry = ModelRegistry::new(10 * 1024 * 1024 * 1024);
let _ = registry
.load_model(&LoadModelRequest {
model: "llama-3b".to_string(),
options: None,
})
.expect("load");
assert!(registry.record_request_start("llama-3b"));
assert!(registry.record_request_start("llama-3b"));
let info = registry.get_model_info("llama-3b").expect("info");
assert_eq!(info.active_requests, 2);
assert_eq!(info.requests_total, 2);
registry.record_request_end("llama-3b");
let info = registry.get_model_info("llama-3b").expect("info");
assert_eq!(info.active_requests, 1);
assert_eq!(info.requests_total, 2); }
#[test]
fn test_registry_request_start_nonexistent() {
let registry = ModelRegistry::new(10 * 1024 * 1024 * 1024);
assert!(!registry.record_request_start("nonexistent"));
}
#[test]
fn test_registry_prometheus_output() {
let registry = ModelRegistry::new(10 * 1024 * 1024 * 1024);
let _ = registry
.load_model(&LoadModelRequest {
model: "llama-3b".to_string(),
options: None,
})
.expect("load");
let output = registry.render_prometheus();
assert!(output.contains("infernum_model_memory_bytes"));
assert!(output.contains("llama-3b"));
assert!(output.contains("infernum_model_status"));
assert!(output.contains("infernum_models_loaded_total 1"));
}
#[test]
fn test_load_response_serialization() {
let response = LoadModelResponse {
model: "llama-3b".to_string(),
status: ModelStatus::Ready,
load_time_ms: 1500,
memory_bytes: 4 * 1024 * 1024 * 1024,
message: Some("Loaded with quantization".to_string()),
};
let json = serde_json::to_string(&response).expect("serialize");
let parsed: LoadModelResponse = serde_json::from_str(&json).expect("deserialize");
assert_eq!(response, parsed);
}
#[test]
fn test_models_status_response_serialization() {
let response = ModelsStatusResponse {
models: vec![AdminModelInfo {
model: "llama-3b".to_string(),
status: ModelStatus::Ready,
memory_bytes: 4 * 1024 * 1024 * 1024,
loaded_at: 1700000000,
requests_total: 100,
active_requests: 5,
options: None,
}],
total_memory_bytes: 4 * 1024 * 1024 * 1024,
available_memory_bytes: 16 * 1024 * 1024 * 1024,
};
let json = serde_json::to_string(&response).expect("serialize");
assert!(json.contains("llama-3b"));
assert!(json.contains("ready"));
}
#[test]
fn test_model_load_options_skip_none_serialization() {
let options = ModelLoadOptions {
gpu_layers: Some(32),
context_length: None,
quantization: None,
memory_limit: None,
flash_attention: Some(true),
tensor_split: None,
};
let json = serde_json::to_string(&options).expect("serialize");
assert!(json.contains("gpu_layers"));
assert!(json.contains("flash_attention"));
assert!(!json.contains("context_length"));
assert!(!json.contains("quantization"));
}
#[test]
fn test_default_model_registry() {
let registry = ModelRegistry::default();
let status = registry.get_status();
assert_eq!(status.available_memory_bytes, 16 * 1024 * 1024 * 1024);
}
#[test]
fn test_model_status_default() {
let status = ModelStatus::default();
assert_eq!(status, ModelStatus::Ready);
}
#[test]
fn test_warmup_capped_iterations() {
let registry = ModelRegistry::new(10 * 1024 * 1024 * 1024);
let _ = registry
.load_model(&LoadModelRequest {
model: "llama-3b".to_string(),
options: None,
})
.expect("load");
let warmup_req = WarmupModelRequest {
model: "llama-3b".to_string(),
iterations: 100,
tokens: 128,
};
let response = registry.warmup_model(&warmup_req).expect("warmup");
assert_eq!(response.iterations_completed, 10);
}
}