use crate::schema::{
CostModel, GenerateParam, ModelCapability, ModelSchema, ModelSource, PerformanceEnvelope,
};
use serde::{Deserialize, Serialize};
use tracing::{debug, info, warn};
pub const DEFAULT_ENDPOINT: &str = "http://localhost:8000";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VllmMlxConfig {
#[serde(default = "default_endpoint")]
pub endpoint: String,
#[serde(default = "default_true")]
pub auto_discover: bool,
}
fn default_endpoint() -> String {
DEFAULT_ENDPOINT.to_string()
}
fn default_true() -> bool {
true
}
impl Default for VllmMlxConfig {
fn default() -> Self {
Self {
endpoint: std::env::var("VLLM_MLX_ENDPOINT")
.unwrap_or_else(|_| DEFAULT_ENDPOINT.to_string()),
auto_discover: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerHealth {
pub healthy: bool,
pub endpoint: String,
pub model: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiscoveredModel {
pub id: String,
pub owned_by: Option<String>,
}
pub async fn health_check(endpoint: &str) -> ServerHealth {
let url = format!("{}/health", endpoint.trim_end_matches('/'));
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(3))
.build()
.unwrap_or_default();
match client.get(&url).send().await {
Ok(resp) if resp.status().is_success() => {
debug!(endpoint, "vLLM-MLX server is healthy");
ServerHealth {
healthy: true,
endpoint: endpoint.to_string(),
model: None,
}
}
Ok(resp) => {
debug!(endpoint, status = %resp.status(), "vLLM-MLX server responded but unhealthy");
ServerHealth {
healthy: false,
endpoint: endpoint.to_string(),
model: None,
}
}
Err(e) => {
debug!(endpoint, error = %e, "vLLM-MLX server not reachable");
ServerHealth {
healthy: false,
endpoint: endpoint.to_string(),
model: None,
}
}
}
}
pub async fn discover_models(endpoint: &str) -> Result<Vec<DiscoveredModel>, String> {
let url = format!("{}/v1/models", endpoint.trim_end_matches('/'));
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(5))
.build()
.map_err(|e| format!("http client error: {e}"))?;
let resp = client
.get(&url)
.send()
.await
.map_err(|e| format!("failed to reach vLLM-MLX at {}: {}", endpoint, e))?;
let body: serde_json::Value = resp
.json()
.await
.map_err(|e| format!("invalid JSON from /v1/models: {e}"))?;
let models = body
.get("data")
.and_then(|d| d.as_array())
.cloned()
.unwrap_or_default();
Ok(models
.into_iter()
.filter_map(|m| {
let id = m.get("id").and_then(|v| v.as_str())?.to_string();
let owned_by = m
.get("owned_by")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
Some(DiscoveredModel { id, owned_by })
})
.collect())
}
pub fn to_model_schema(model: &DiscoveredModel, endpoint: &str) -> ModelSchema {
let name_lower = model.id.to_lowercase();
let mut capabilities = vec![ModelCapability::Generate];
if name_lower.contains("embed") || name_lower.contains("bge") || name_lower.contains("minilm") {
capabilities = vec![ModelCapability::Embed];
} else {
capabilities.push(ModelCapability::ToolUse);
}
if name_lower.contains("vl") || name_lower.contains("vision") {
capabilities.push(ModelCapability::Vision);
}
let has_gemma_4_segment = {
let segments = name_lower.split(|c: char| matches!(c, '/' | '_' | '-' | '.' | ':'));
let segs: Vec<&str> = segments.collect();
segs.windows(2).any(|w| w[0] == "gemma" && w[1] == "4")
};
if has_gemma_4_segment {
if !capabilities.contains(&ModelCapability::Vision) {
capabilities.push(ModelCapability::Vision);
}
capabilities.push(ModelCapability::VideoUnderstanding);
capabilities.push(ModelCapability::Reasoning);
if name_lower.contains("e2b") || name_lower.contains("e4b") {
capabilities.push(ModelCapability::AudioUnderstanding);
}
}
let mut context_length = if name_lower.contains("128k") {
131_072
} else if name_lower.contains("32k") {
32_768
} else {
8_192 };
if (name_lower.contains("vl") || name_lower.contains("vision")) && context_length == 8_192 {
context_length = 32_768;
}
if has_gemma_4_segment {
context_length = 131_072;
}
let size_mb: u64 = if name_lower.contains("0.5b") || name_lower.contains("0.6b") {
500
} else if name_lower.contains("1b") || name_lower.contains("1.5b") {
1_500
} else if name_lower.contains("3b") {
3_000
} else if name_lower.contains("4b") {
4_000
} else if name_lower.contains("7b") || name_lower.contains("8b") {
7_000
} else if name_lower.contains("13b") || name_lower.contains("14b") {
13_000
} else if name_lower.contains("30b") || name_lower.contains("32b") {
30_000
} else if name_lower.contains("70b") || name_lower.contains("72b") {
70_000
} else {
4_000 };
let registry_id = format!("vllm-mlx/{}", model.id.replace('/', "_"));
let display_name = model.id.split('/').last().unwrap_or(&model.id).to_string();
let provider = if name_lower.contains("qwen") {
"qwen"
} else if name_lower.contains("llama") {
"meta"
} else if name_lower.contains("mistral") || name_lower.contains("mixtral") {
"mistral"
} else if name_lower.contains("phi") {
"microsoft"
} else if name_lower.contains("gemma") {
"google"
} else {
"vllm-mlx"
};
let family = if name_lower.contains("qwen3") {
"qwen3"
} else if name_lower.contains("gemma-4") {
"gemma-4"
} else if name_lower.contains("gemma-3") {
"gemma-3"
} else if name_lower.contains("gemma") {
"gemma"
} else if name_lower.contains("qwen2.5-vl") || name_lower.contains("qwen2-vl") {
"qwen2.5-vl"
} else if name_lower.contains("qwen2") {
"qwen2"
} else if name_lower.contains("llama-3") || name_lower.contains("llama3") {
"llama-3"
} else if name_lower.contains("mistral") {
"mistral"
} else if name_lower.contains("phi-3") || name_lower.contains("phi3") {
"phi-3"
} else {
"unknown"
};
ModelSchema {
id: registry_id,
name: display_name,
provider: provider.to_string(),
family: family.to_string(),
version: String::new(),
capabilities,
context_length,
param_count: String::new(),
quantization: None,
performance: PerformanceEnvelope {
latency_p50_ms: Some(50),
latency_p99_ms: Some(200),
tokens_per_second: Some(200.0),
},
cost: CostModel {
input_per_mtok: None,
output_per_mtok: None,
size_mb: Some(size_mb),
ram_mb: None,
},
source: ModelSource::VllmMlx {
endpoint: endpoint.to_string(),
model_name: model.id.clone(),
},
supported_params: vec![
GenerateParam::Temperature,
GenerateParam::TopP,
GenerateParam::MaxTokens,
GenerateParam::StopSequences,
],
tags: vec![
"vllm-mlx".to_string(),
"local".to_string(),
"apple-silicon".to_string(),
],
public_benchmarks: vec![],
available: true,
}
}
pub async fn discover_and_register(
config: &VllmMlxConfig,
registry: &mut crate::registry::UnifiedRegistry,
) -> usize {
let health = health_check(&config.endpoint).await;
if !health.healthy {
info!(
endpoint = %config.endpoint,
"vLLM-MLX server not available, skipping discovery"
);
return 0;
}
match discover_models(&config.endpoint).await {
Ok(models) => {
let count = models.len();
for model in &models {
let schema = to_model_schema(model, &config.endpoint);
info!(
id = %schema.id,
name = %schema.name,
endpoint = %config.endpoint,
"discovered vLLM-MLX model"
);
registry.register(schema);
}
info!(
count,
endpoint = %config.endpoint,
"vLLM-MLX model discovery complete"
);
count
}
Err(e) => {
warn!(
endpoint = %config.endpoint,
error = %e,
"vLLM-MLX model discovery failed"
);
0
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn to_model_schema_basic() {
let model = DiscoveredModel {
id: "mlx-community/Qwen3-4B-4bit".to_string(),
owned_by: None,
};
let schema = to_model_schema(&model, "http://localhost:8000");
assert_eq!(schema.id, "vllm-mlx/mlx-community_Qwen3-4B-4bit");
assert_eq!(schema.name, "Qwen3-4B-4bit");
assert_eq!(schema.provider, "qwen");
assert_eq!(schema.family, "qwen3");
assert!(schema.is_vllm_mlx());
assert!(schema.is_local());
assert!(!schema.is_remote());
assert!(schema.has_capability(ModelCapability::Generate));
assert!(schema.has_capability(ModelCapability::ToolUse));
assert_eq!(schema.cost_per_1k_output(), 0.0);
}
#[test]
fn to_model_schema_embedding() {
let model = DiscoveredModel {
id: "mlx-community/bge-large-en-v1.5-4bit".to_string(),
owned_by: None,
};
let schema = to_model_schema(&model, "http://localhost:8000");
assert!(schema.has_capability(ModelCapability::Embed));
assert!(!schema.has_capability(ModelCapability::Generate));
}
#[test]
fn to_model_schema_vision() {
let model = DiscoveredModel {
id: "mlx-community/Qwen2.5-VL-3B-Instruct-4bit".to_string(),
owned_by: None,
};
let schema = to_model_schema(&model, "http://localhost:8000");
assert!(schema.has_capability(ModelCapability::Generate));
assert!(schema.has_capability(ModelCapability::Vision));
assert!(schema.has_capability(ModelCapability::ToolUse));
assert_eq!(schema.family, "qwen2.5-vl");
assert_eq!(schema.context_length, 32_768);
}
#[test]
fn to_model_schema_gemma_family() {
let model = DiscoveredModel {
id: "mlx-community/gemma-4-26B-A4B-it".to_string(),
owned_by: None,
};
let schema = to_model_schema(&model, "http://localhost:8000");
assert_eq!(schema.provider, "google");
assert_eq!(schema.family, "gemma-4");
assert!(schema.has_capability(ModelCapability::Generate));
assert!(schema.has_capability(ModelCapability::ToolUse));
assert!(schema.has_capability(ModelCapability::Vision));
assert!(schema.has_capability(ModelCapability::VideoUnderstanding));
assert!(schema.has_capability(ModelCapability::Reasoning));
assert!(!schema.has_capability(ModelCapability::AudioUnderstanding));
assert_eq!(schema.context_length, 131_072);
}
#[test]
fn gemma_4_small_e_series_gets_audio_understanding() {
for id in ["google/gemma-4-E2B-it", "mlx-community/gemma-4-E4B-it-4bit"] {
let model = DiscoveredModel {
id: id.to_string(),
owned_by: None,
};
let schema = to_model_schema(&model, "http://localhost:8000");
assert!(
schema.has_capability(ModelCapability::AudioUnderstanding),
"{id}: E-series should declare audio understanding"
);
assert_eq!(schema.context_length, 131_072, "{id}");
}
}
#[test]
fn gemma_4_segment_match_is_anchored_not_substring() {
for id in [
"google/gemma-4-E2B-it",
"mlx-community/gemma-4-26B-A4B-it",
"local/gemma-4-E4B-it-4bit-mlx",
] {
let m = DiscoveredModel {
id: id.to_string(),
owned_by: None,
};
let s = to_model_schema(&m, "http://localhost:8000");
assert!(
s.has_capability(ModelCapability::VideoUnderstanding),
"{id} should match gemma-4 segment"
);
}
for id in [
"community/gemma-4bit-q8", "meta/llama-4-scout", "google/gemma-2-9b-it", "google/gemma", ] {
let m = DiscoveredModel {
id: id.to_string(),
owned_by: None,
};
let s = to_model_schema(&m, "http://localhost:8000");
assert!(
!s.has_capability(ModelCapability::VideoUnderstanding),
"{id} must not match gemma-4 segment"
);
}
}
#[test]
fn config_default() {
let config = VllmMlxConfig::default();
assert_eq!(config.endpoint, DEFAULT_ENDPOINT);
assert!(config.auto_discover);
}
#[tokio::test]
#[ignore = "requires running vLLM-MLX server"]
async fn integration_health_check() {
let endpoint =
std::env::var("VLLM_MLX_ENDPOINT").unwrap_or_else(|_| DEFAULT_ENDPOINT.to_string());
let health = health_check(&endpoint).await;
assert!(
health.healthy,
"vLLM-MLX server at {} is not healthy",
endpoint
);
println!("Health: {:?}", health);
}
#[tokio::test]
#[ignore = "requires running vLLM-MLX server"]
async fn integration_discover_models() {
let endpoint =
std::env::var("VLLM_MLX_ENDPOINT").unwrap_or_else(|_| DEFAULT_ENDPOINT.to_string());
let models = discover_models(&endpoint)
.await
.expect("failed to discover models");
assert!(!models.is_empty(), "no models discovered from {}", endpoint);
println!("Discovered {} models:", models.len());
for m in &models {
println!(" - {} (owned_by: {:?})", m.id, m.owned_by);
}
}
#[tokio::test]
#[ignore = "requires running vLLM-MLX server"]
async fn integration_discover_and_register() {
let config = VllmMlxConfig::default();
let models_dir = std::env::temp_dir().join("car-test-models");
let _ = std::fs::create_dir_all(&models_dir);
let mut registry = crate::registry::UnifiedRegistry::new(models_dir);
let count = discover_and_register(&config, &mut registry).await;
assert!(count > 0, "no models registered");
println!("Registered {} models", count);
for m in registry.list() {
if m.is_vllm_mlx() {
println!(" - {} ({}) available={}", m.id, m.name, m.available);
assert!(m.available);
}
}
}
#[tokio::test]
#[ignore = "requires running vLLM-MLX server"]
async fn integration_generate() {
let endpoint =
std::env::var("VLLM_MLX_ENDPOINT").unwrap_or_else(|_| DEFAULT_ENDPOINT.to_string());
let models = discover_models(&endpoint)
.await
.expect("failed to discover models");
assert!(!models.is_empty(), "no models");
let model_name = &models[0].id;
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.unwrap();
let resp = client
.post(format!("{}/v1/chat/completions", endpoint))
.json(&serde_json::json!({
"model": model_name,
"messages": [{"role": "user", "content": "Say hello in exactly 3 words."}],
"max_tokens": 20,
"temperature": 0.0,
}))
.send()
.await
.expect("request failed");
assert!(resp.status().is_success(), "status: {}", resp.status());
let body: serde_json::Value = resp.json().await.expect("invalid json");
let text = body["choices"][0]["message"]["content"]
.as_str()
.expect("no content in response");
assert!(!text.is_empty(), "empty response");
println!("Model: {}", model_name);
println!("Response: {}", text);
}
}