use std::sync::Arc;
use axum::{
Json,
extract::State,
http::StatusCode,
response::{
IntoResponse,
sse::{Event, KeepAlive, Sse},
},
};
use crate::budget::TokenPool;
use crate::inference::{InferenceRequest, Message, Role};
use super::AppState;
use super::types::*;
pub(crate) async fn chat_completions(
State(state): State<Arc<AppState>>,
Json(req): Json<ChatRequest>,
) -> impl IntoResponse {
let request_id = uuid::Uuid::new_v4();
tracing::info!(request_id = %request_id, model = %req.model, "chat_completions");
if let Some(resp) = validate_chat_request(&req) {
return resp;
}
#[cfg(feature = "hwaccel")]
let selected_route = {
let hw = state.hardware.read().unwrap_or_else(|e| e.into_inner());
let model_params = estimate_model_params(&req.model);
state
.router
.read()
.unwrap_or_else(|e| e.into_inner())
.select_with_hardware(&req.model, &hw, model_params, state.vram_reserve_bytes)
.cloned()
};
#[cfg(not(feature = "hwaccel"))]
let selected_route = state
.router
.read()
.unwrap_or_else(|e| e.into_inner())
.select(&req.model)
.cloned();
let route = match selected_route {
Some(r) => r,
None => {
return error_response(
StatusCode::NOT_FOUND,
format!(
"No provider configured for model '{}'. Configure provider routes to handle this model.",
req.model
),
)
.into_response();
}
};
let rate_key = format!("{}:{}", route.provider, route.base_url);
if !state.rate_limiter.check(&rate_key) {
return error_response(
StatusCode::TOO_MANY_REQUESTS,
format!(
"Rate limit exceeded for provider '{}'. Please try again later.",
route.provider
),
)
.into_response();
}
let provider = match state.providers.get(route.provider, &route.base_url) {
Some(p) => p,
None => {
return error_response(
StatusCode::SERVICE_UNAVAILABLE,
format!(
"Provider '{}' matched for model '{}' but no backend is registered. \
Is the '{}' feature enabled?",
route.provider, req.model, route.provider
),
)
.into_response();
}
};
let max_tokens = match (req.max_tokens, route.max_tokens_limit) {
(Some(requested), Some(limit)) if requested > limit => {
tracing::warn!(
"clamping max_tokens from {} to provider limit {}",
requested,
limit
);
Some(limit)
}
(requested, _) => requested,
};
let mut inference_req = InferenceRequest {
model: req.model.clone(),
prompt: String::new(),
system: None,
messages: req
.messages
.iter()
.map(|m| Message {
role: match m.role.as_str() {
"system" => Role::System,
"assistant" => Role::Assistant,
"tool" => Role::Tool,
_ => Role::User,
},
content: m.content.clone(),
tool_call_id: m.tool_call_id.clone(),
tool_calls: m.tool_calls.clone(),
})
.collect(),
max_tokens,
temperature: req.temperature,
top_p: req.top_p,
stream: req.stream,
tools: req.tools.clone(),
tool_choice: req.tool_choice.clone(),
};
let counter = crate::context::tokens::ProviderTokenCounter::for_provider(route.provider);
if let Some(result) = state.compactor.compact(
&inference_req.model,
&inference_req.messages,
&state.model_registry,
&counter,
) {
tracing::info!(
request_id = %request_id,
original_tokens = result.original_tokens,
compacted_tokens = result.compacted_tokens,
messages_dropped = result.messages_dropped,
"context compacted"
);
inference_req.messages = result.messages;
}
let input_estimate =
crate::context::tokens::TokenCounter::count_messages(&counter, &inference_req.messages);
let output_budget = max_tokens.unwrap_or(1024);
let estimated_tokens = (input_estimate.saturating_add(output_budget)) as u64;
let pool_name = req.pool.clone();
{
let mut budget = state.budget.lock().unwrap_or_else(|e| e.into_inner());
match budget.get_pool(&pool_name) {
None => {
return error_response(
StatusCode::BAD_REQUEST,
format!("Token pool '{}' does not exist", pool_name),
)
.into_response();
}
Some(pool) if !pool.can_reserve(estimated_tokens) => {
let remaining = pool.available();
return error_response(
StatusCode::TOO_MANY_REQUESTS,
format!(
"Token budget exceeded: pool '{}' has {} tokens remaining, requested {}",
pool_name, remaining, estimated_tokens
),
)
.into_response();
}
_ => {}
}
budget.reserve(&pool_name, estimated_tokens);
}
if req.stream {
return handle_streaming(
state,
inference_req,
&route,
provider,
pool_name,
estimated_tokens,
req.model.clone(),
)
.await;
}
handle_non_streaming(
state,
inference_req,
&route,
provider,
pool_name,
estimated_tokens,
req.model.clone(),
)
.await
}
fn validate_chat_request(req: &ChatRequest) -> Option<axum::response::Response> {
if req.messages.len() > 256 {
return Some(
error_response(
StatusCode::BAD_REQUEST,
format!("Too many messages: {} (max 256)", req.messages.len()),
)
.into_response(),
);
}
if req.messages.is_empty() {
return Some(
error_response(StatusCode::BAD_REQUEST, "messages array is empty").into_response(),
);
}
if req.model.is_empty() {
return Some(
error_response(StatusCode::BAD_REQUEST, "model field is required").into_response(),
);
}
if req.model.len() > 256
|| req
.model
.bytes()
.any(|b| b < 0x20 || b == b'\\' || b == b'"')
{
return Some(error_response(StatusCode::BAD_REQUEST, "invalid model name").into_response());
}
if let Some(temp) = req.temperature
&& !(0.0..=2.0).contains(&temp)
{
return Some(
error_response(
StatusCode::BAD_REQUEST,
format!("temperature must be between 0.0 and 2.0, got {temp}"),
)
.into_response(),
);
}
if let Some(tp) = req.top_p
&& !(0.0..=1.0).contains(&tp)
{
return Some(
error_response(
StatusCode::BAD_REQUEST,
format!("top_p must be between 0.0 and 1.0, got {tp}"),
)
.into_response(),
);
}
None
}
async fn handle_streaming(
state: Arc<AppState>,
inference_req: InferenceRequest,
route: &crate::router::ProviderRoute,
provider: Arc<dyn crate::provider::LlmProvider>,
pool_name: String,
estimated_tokens: u64,
model: String,
) -> axum::response::Response {
let id = format!("chatcmpl-{}", uuid::Uuid::new_v4());
let rx = match provider.infer_stream(inference_req).await {
Ok(rx) => {
if let Some(ref audit) = state.audit {
audit.record(
"inference.request",
"info",
&format!("Streaming inference started for model {}", model),
Some(&route.provider.to_string()),
Some(&model),
None,
);
}
rx
}
Err(e) => {
if let Some(ref audit) = state.audit {
audit.record(
"inference.error",
"error",
&format!("Streaming inference error: {e}"),
Some(&route.provider.to_string()),
Some(&model),
None,
);
}
let mut budget = state.budget.lock().unwrap_or_else(|e| e.into_inner());
budget.report(&pool_name, estimated_tokens, 0);
return error_response(StatusCode::INTERNAL_SERVER_ERROR, {
tracing::error!("inference error: {e}");
"Inference request failed".to_string()
})
.into_response();
}
};
let token_count = Arc::new(std::sync::atomic::AtomicU64::new(0));
let token_count_clone = token_count.clone();
let budget_guard = StreamBudgetGuard {
state: state.clone(),
pool: pool_name,
estimated: estimated_tokens,
actual: token_count_clone,
provider: route.provider.to_string(),
model: model.clone(),
start: std::time::Instant::now(),
};
let s = async_stream::stream! {
let _guard = budget_guard;
let mut rx = rx;
let mut buf = String::with_capacity(256);
while let Some(result) = rx.recv().await {
match result {
Ok(token) => {
token_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
buf.clear();
let escaped = serde_json::to_string(&token).unwrap_or_default();
use std::fmt::Write;
let _ = write!(
buf,
r#"{{"id":"{}","object":"chat.completion.chunk","model":"{}","choices":[{{"index":0,"delta":{{"content":{}}},"finish_reason":null}}]}}"#,
&id, &model, escaped
);
yield Ok::<_, std::convert::Infallible>(
Event::default().data(buf.as_str())
);
}
Err(e) => {
tracing::error!("stream error: {e}");
break;
}
}
}
buf.clear();
use std::fmt::Write;
let _ = write!(
buf,
r#"{{"id":"{}","object":"chat.completion.chunk","model":"{}","choices":[{{"index":0,"delta":{{}},"finish_reason":"stop"}}]}}"#,
&id, &model
);
yield Ok(Event::default().data(buf.as_str()));
yield Ok(Event::default().data("[DONE]"));
};
Sse::new(s)
.keep_alive(KeepAlive::new().interval(std::time::Duration::from_secs(15)))
.into_response()
}
async fn handle_non_streaming(
state: Arc<AppState>,
inference_req: InferenceRequest,
route: &crate::router::ProviderRoute,
provider: Arc<dyn crate::provider::LlmProvider>,
pool_name: String,
estimated_tokens: u64,
req_model: String,
) -> axum::response::Response {
let infer_result = if state.retry_manager.is_enabled() {
let p = provider.clone();
let req = inference_req.clone();
state
.retry_manager
.with_retry(|| {
let p = p.clone();
let r = req.clone();
async move { p.infer(&r).await }
})
.await
} else {
provider.infer(&inference_req).await
};
match infer_result {
Ok(result) => {
state
.router
.read()
.unwrap_or_else(|e| e.into_inner())
.report_latency(route.provider, &route.base_url, result.latency_ms);
let actual = result.usage.total_tokens as u64;
{
let mut budget = state.budget.lock().unwrap_or_else(|e| e.into_inner());
budget.report(&pool_name, estimated_tokens, actual);
}
let cost = state.cost_tracker.record(
route.provider,
&route.base_url,
&result.model,
&result.usage,
);
tracing::debug!(cost_usd = cost, model = %result.model, "request cost");
crate::metrics::record_request(
&result.provider,
&result.model,
"success",
result.latency_ms as f64 / 1000.0,
result.usage.prompt_tokens,
result.usage.completion_tokens,
);
state.event_bus.publish(
crate::events::topics::INFERENCE,
crate::events::ProviderEvent::InferenceCompleted {
provider: result.provider.clone(),
model: result.model.clone(),
latency_ms: result.latency_ms,
tokens: result.usage.total_tokens,
},
);
if let Some(ref audit) = state.audit {
audit.record(
"inference.response",
"info",
&format!("Inference completed for model {}", result.model),
Some(&route.provider.to_string()),
Some(&result.model),
Some(serde_json::json!({
"prompt_tokens": result.usage.prompt_tokens,
"completion_tokens": result.usage.completion_tokens,
"total_tokens": result.usage.total_tokens,
})),
);
}
let resp = ChatCompletionResponse {
id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
object: "chat.completion",
created: chrono::Utc::now().timestamp(),
model: result.model.clone(),
choices: vec![ChatChoice {
index: 0,
message: ChatResponseMessage {
role: "assistant",
content: result.text,
tool_calls: result.tool_calls.clone(),
},
finish_reason: if result.tool_calls.is_empty() {
"stop"
} else {
"tool_calls"
},
}],
usage: ChatUsage {
prompt_tokens: result.usage.prompt_tokens,
completion_tokens: result.usage.completion_tokens,
total_tokens: result.usage.total_tokens,
},
};
(StatusCode::OK, Json(resp)).into_response()
}
Err(e) => {
state.event_bus.publish(
crate::events::topics::ERRORS,
crate::events::ProviderEvent::InferenceFailed {
provider: route.provider.to_string(),
model: req_model.clone(),
error: e.to_string(),
},
);
if let Some(ref audit) = state.audit {
audit.record(
"inference.error",
"error",
&{
tracing::error!("inference error: {e}");
"Inference request failed".to_string()
},
Some(&route.provider.to_string()),
Some(&req_model),
None,
);
}
crate::metrics::record_request(
&route.provider.to_string(),
&req_model,
"error",
0.0,
0,
0,
);
let mut budget = state.budget.lock().unwrap_or_else(|e| e.into_inner());
budget.report(&pool_name, estimated_tokens, 0);
error_response(StatusCode::INTERNAL_SERVER_ERROR, {
tracing::error!("inference error: {e}");
"Inference request failed".to_string()
})
.into_response()
}
}
}
pub(crate) async fn list_models(State(state): State<Arc<AppState>>) -> Json<ModelsResponse> {
let mut models = Vec::new();
let routes: Vec<_> = state
.router
.read()
.unwrap_or_else(|e| e.into_inner())
.routes()
.to_vec();
for route in &routes {
if !route.enabled {
continue;
}
if let Some(provider) = state.providers.get(route.provider, &route.base_url) {
match provider.list_models().await {
Ok(live_models) => {
for m in live_models {
models.push(ModelObject {
id: m.id,
object: "model",
owned_by: m.provider,
});
}
}
Err(e) => {
tracing::warn!("failed to list models from {}: {e}", route.provider);
for pattern in &route.model_patterns {
models.push(ModelObject {
id: pattern.clone(),
object: "model",
owned_by: route.provider.to_string(),
});
}
}
}
} else {
for pattern in &route.model_patterns {
models.push(ModelObject {
id: pattern.clone(),
object: "model",
owned_by: route.provider.to_string(),
});
}
}
}
Json(ModelsResponse {
object: "list",
data: models,
})
}
pub(crate) async fn health(State(state): State<Arc<AppState>>) -> Json<HealthResponse> {
Json(HealthResponse {
status: "ok",
version: env!("CARGO_PKG_VERSION"),
providers_configured: state
.router
.read()
.unwrap_or_else(|e| e.into_inner())
.routes()
.len(),
})
}
pub(crate) async fn health_providers(
State(state): State<Arc<AppState>>,
) -> Json<Vec<ProviderHealth>> {
let mut results = Vec::new();
let routes: Vec<_> = state
.router
.read()
.unwrap_or_else(|e| e.into_inner())
.routes()
.to_vec();
for route in &routes {
let key = (route.provider, route.base_url.clone());
let bg_state = state.health_map.get(&key);
let status = if !route.enabled {
"disabled".to_string()
} else if let Some(ref hs) = bg_state {
if hs.is_healthy {
"healthy".to_string()
} else {
"unhealthy".to_string()
}
} else if let Some(provider) = state.providers.get(route.provider, &route.base_url) {
match provider.health_check().await {
Ok(true) => "healthy".to_string(),
Ok(false) => "unhealthy".to_string(),
Err(e) => {
tracing::warn!("health check error for {}: {e}", route.provider);
"error".to_string()
}
}
} else {
"no_backend".to_string()
};
results.push(ProviderHealth {
provider: route.provider.to_string(),
base_url: route.base_url.clone(),
enabled: route.enabled,
status,
consecutive_failures: bg_state.as_ref().map(|s| s.consecutive_failures),
last_error: bg_state.as_ref().and_then(|s| {
s.last_error
.as_ref()
.map(|_| "health check failed".to_string())
}),
});
}
Json(results)
}
pub(crate) async fn health_heartbeat(State(state): State<Arc<AppState>>) -> impl IntoResponse {
let stats = state.heartbeat.fleet_stats();
(StatusCode::OK, Json(stats)).into_response()
}
pub(crate) async fn tokens_check(
State(state): State<Arc<AppState>>,
Json(req): Json<TokenCheckRequest>,
) -> impl IntoResponse {
let budget = state.budget.lock().unwrap_or_else(|e| e.into_inner());
match budget.get_pool(&req.pool) {
Some(pool) => (
StatusCode::OK,
Json(TokenCheckResponse {
allowed: pool.can_reserve(req.tokens),
available: pool.available(),
}),
)
.into_response(),
None => error_response(
StatusCode::NOT_FOUND,
format!("Token pool '{}' not found", req.pool),
)
.into_response(),
}
}
pub(crate) async fn tokens_reserve(
State(state): State<Arc<AppState>>,
Json(req): Json<TokenReserveRequest>,
) -> impl IntoResponse {
let mut budget = state.budget.lock().unwrap_or_else(|e| e.into_inner());
let reserved = budget.reserve(&req.pool, req.tokens);
match budget.get_pool(&req.pool) {
Some(pool) => (
StatusCode::OK,
Json(TokenReserveResponse {
reserved,
available: pool.available(),
}),
)
.into_response(),
None => error_response(
StatusCode::NOT_FOUND,
format!("Token pool '{}' not found", req.pool),
)
.into_response(),
}
}
pub(crate) async fn tokens_report(
State(state): State<Arc<AppState>>,
Json(req): Json<TokenReportRequest>,
) -> impl IntoResponse {
let mut budget = state.budget.lock().unwrap_or_else(|e| e.into_inner());
budget.report(&req.pool, req.reserved, req.actual);
match budget.get_pool(&req.pool) {
Some(pool) => (
StatusCode::OK,
Json(TokenReportResponse {
used: pool.used,
available: pool.available(),
}),
)
.into_response(),
None => error_response(
StatusCode::NOT_FOUND,
format!("Token pool '{}' not found", req.pool),
)
.into_response(),
}
}
pub(crate) async fn tokens_pools(State(state): State<Arc<AppState>>) -> Json<Vec<TokenPool>> {
let budget = state.budget.lock().unwrap_or_else(|e| e.into_inner());
let pools: Vec<TokenPool> = budget.pools().values().cloned().collect();
Json(pools)
}
pub(crate) async fn costs_get(State(state): State<Arc<AppState>>) -> Json<CostsResponse> {
let (records, total_cost_usd) = state.cost_tracker.all_with_total();
Json(CostsResponse {
records,
total_cost_usd,
})
}
pub(crate) async fn costs_reset(State(state): State<Arc<AppState>>) -> Json<serde_json::Value> {
if let Some(ref audit) = state.audit {
audit.record(
"admin.costs_reset",
"warn",
"Cost tracking counters reset",
None,
None,
None,
);
}
state.cost_tracker.reset();
Json(serde_json::json!({ "status": "ok" }))
}
pub(crate) async fn audit_log(State(state): State<Arc<AppState>>) -> impl IntoResponse {
match &state.audit {
Some(audit) => {
let (entries, total, chain_valid) = audit.snapshot(100);
(
StatusCode::OK,
Json(AuditResponse {
entries,
total,
chain_valid,
}),
)
.into_response()
}
None => error_response(
StatusCode::NOT_FOUND,
"Audit logging is not enabled. Set [audit] enabled = true in config.",
)
.into_response(),
}
}
pub(crate) async fn admin_reload(State(state): State<Arc<AppState>>) -> impl IntoResponse {
if let Some(path) = &state.config_path {
super::reload_config(&state, path);
(
StatusCode::OK,
Json(serde_json::json!({"status": "reloaded"})),
)
.into_response()
} else {
error_response(
StatusCode::BAD_REQUEST,
"no config path configured for reload",
)
.into_response()
}
}
pub(crate) async fn queue_status(State(state): State<Arc<AppState>>) -> impl IntoResponse {
(
StatusCode::OK,
Json(serde_json::json!({
"queued": state.inference_queue.len().await,
})),
)
.into_response()
}
pub(crate) async fn cache_stats(State(state): State<Arc<AppState>>) -> impl IntoResponse {
(StatusCode::OK, Json(serde_json::json!(state.cache.stats()))).into_response()
}
pub(crate) async fn prometheus_metrics() -> impl IntoResponse {
let body = crate::metrics::gather();
(
StatusCode::OK,
[(
axum::http::header::CONTENT_TYPE,
"text/plain; version=0.0.4; charset=utf-8",
)],
body,
)
}
pub(crate) async fn embeddings(
State(state): State<Arc<AppState>>,
Json(req): Json<crate::inference::EmbeddingsRequest>,
) -> impl IntoResponse {
let route = match state
.router
.read()
.unwrap_or_else(|e| e.into_inner())
.select(&req.model)
.cloned()
{
Some(r) => r,
None => {
return error_response(
StatusCode::NOT_FOUND,
format!("No provider configured for model '{}'", req.model),
)
.into_response();
}
};
let provider = match state.providers.get(route.provider, &route.base_url) {
Some(p) => p,
None => {
return error_response(
StatusCode::SERVICE_UNAVAILABLE,
format!("Provider '{}' not available", route.provider),
)
.into_response();
}
};
match provider.embeddings(&req).await {
Ok(result) => (StatusCode::OK, Json(result)).into_response(),
Err(e) => error_response(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Embeddings error: {e}"),
)
.into_response(),
}
}
#[cfg(feature = "tools")]
pub(crate) async fn tools_list(State(state): State<Arc<AppState>>) -> impl IntoResponse {
let result = state.mcp_bridge.list_tools();
(StatusCode::OK, Json(result)).into_response()
}
#[cfg(feature = "tools")]
pub(crate) async fn tools_call(
State(state): State<Arc<AppState>>,
Json(req): Json<super::types::ToolCallRequest>,
) -> impl IntoResponse {
let (result, is_error) = state.mcp_bridge.call_tool(&req.name, req.arguments);
let status = if is_error {
StatusCode::BAD_REQUEST
} else {
StatusCode::OK
};
(status, Json(result)).into_response()
}
#[cfg(feature = "hwaccel")]
fn estimate_model_params(model: &str) -> Option<u64> {
let lower = model.to_ascii_lowercase();
for (i, ch) in lower.char_indices() {
if ch == 'b'
&& i > 0
&& (i + 1 >= lower.len() || !lower.as_bytes()[i + 1].is_ascii_alphanumeric())
{
let before = &lower[..i];
let num_start = before
.rfind(|c: char| !c.is_ascii_digit() && c != '.')
.map(|pos| pos + 1)
.unwrap_or(0);
let num_str = &before[num_start..];
if let Ok(val) = num_str.parse::<f64>()
&& val > 0.0
&& val < 10_000.0
{
return Some((val * 1_000_000_000.0) as u64);
}
}
}
None
}
#[cfg(feature = "hwaccel")]
pub(crate) async fn hardware_info(State(state): State<Arc<AppState>>) -> impl IntoResponse {
let hw = state.hardware.read().unwrap_or_else(|e| e.into_inner());
let vram_reserve = state.vram_reserve_bytes;
let accelerators: Vec<super::types::AcceleratorInfo> = hw
.all_profiles()
.iter()
.filter(|p| p.available && !matches!(p.accelerator, ai_hwaccel::AcceleratorType::Cpu))
.map(|p| {
let fallback = p.accelerator.to_string();
let name = p.device_name.as_deref().unwrap_or(&fallback);
super::types::AcceleratorInfo {
name: name.to_string(),
family: format!("{:?}", p.accelerator.family()),
memory_bytes: p.memory_bytes,
memory_used_bytes: p.memory_used_bytes,
memory_free_bytes: p.memory_free_bytes,
utilization_pct: p.gpu_utilization_percent,
temperature_c: p.temperature_c,
power_watts: p.power_watts,
bandwidth_gbps: p.memory_bandwidth_gbps,
}
})
.collect();
let environment =
hw.runtime_environment()
.map(|env| super::types::EnvironmentInfo {
is_docker: env.is_docker,
is_kubernetes: env.is_kubernetes,
namespace: env.kubernetes_namespace.clone(),
cloud_provider: env.cloud_instance.as_ref().map(|c| c.provider.clone()),
instance_type: env
.cloud_instance
.as_ref()
.and_then(|c| c.instance_type.clone()),
kubernetes_gpu: env.kubernetes_gpu.as_ref().map(|k| {
super::types::KubernetesGpuInfo {
device_ids: k.device_ids.clone(),
gpu_count: k.gpu_count,
source: k.source.clone(),
}
}),
});
let resp = super::types::HardwareResponse {
accelerators,
total_vram_bytes: hw.total_accelerator_memory(),
available_vram_bytes: hw.available_vram(vram_reserve),
vram_reserve_bytes: vram_reserve,
has_fast_interconnect: hw.has_fast_interconnect(),
environment,
};
(StatusCode::OK, Json(resp)).into_response()
}
#[cfg(feature = "hwaccel")]
pub(crate) async fn hardware_placement(
State(state): State<Arc<AppState>>,
Json(req): Json<super::types::PlacementRequest>,
) -> impl IntoResponse {
let hw = state.hardware.read().unwrap_or_else(|e| e.into_inner());
let recommendation = hw.recommend_placement(req.model_params, &req.providers);
let cloud_instances = if !recommendation.fits_in_vram {
hw.recommend_cloud_instance(req.model_params, None)
.into_iter()
.take(5)
.map(|rec| super::types::CloudInstanceInfo {
name: rec.instance.name,
provider: rec.instance.provider,
gpu: rec.instance.gpu,
gpu_count: rec.instance.gpu_count,
total_gpu_memory_gb: rec.instance.total_gpu_memory_gb,
price_per_hour: rec.instance.price_per_hour,
memory_headroom_pct: rec.memory_headroom_pct,
})
.collect()
} else {
Vec::new()
};
let resp = super::types::PlacementResponse {
recommendation,
cloud_alternatives: cloud_instances,
};
(StatusCode::OK, Json(resp)).into_response()
}
#[cfg(feature = "hwaccel")]
pub(crate) async fn hardware_models(
State(state): State<Arc<AppState>>,
Json(req): Json<super::types::ModelCompatRequest>,
) -> impl IntoResponse {
let hw = state.hardware.read().unwrap_or_else(|e| e.into_inner());
let quant = req
.quantization
.as_deref()
.and_then(parse_quantization)
.unwrap_or_else(|| hw.registry().suggest_quantization(7_000_000_000));
if let Some(name) = &req.model {
if let Some(profile) = crate::hardware::HardwareManager::find_model(name) {
let can_run = hw.can_run_model(name, &quant);
let estimated = ai_hwaccel::AcceleratorRegistry::estimate_memory(
(profile.params_billions * 1e9) as u64,
&quant,
);
let total = hw.total_accelerator_memory();
let headroom = if total > 0 {
(((total as f64 - estimated as f64) / total as f64) * 100.0).max(0.0)
} else {
0.0
};
let info = super::types::CompatibleModelInfo {
name: profile.name.clone(),
family: profile.family.clone(),
params_billions: profile.params_billions,
memory_required_bytes: estimated,
headroom_pct: headroom,
};
let resp = serde_json::json!({
"model": info,
"can_run": can_run,
"total_vram_bytes": total,
});
return (StatusCode::OK, Json(resp)).into_response();
}
return super::types::error_response(StatusCode::NOT_FOUND, "model not found in catalogue")
.into_response();
}
let results = hw.compatible_models(&quant);
let compatible: Vec<super::types::CompatibleModelInfo> = results
.iter()
.map(|r| super::types::CompatibleModelInfo {
name: r.model.name.clone(),
family: r.model.family.clone(),
params_billions: r.model.params_billions,
memory_required_bytes: r.memory_required_bytes,
headroom_pct: r.headroom_pct,
})
.collect();
let resp = super::types::ModelCompatResponse {
compatible,
total_vram_bytes: hw.total_accelerator_memory(),
};
(StatusCode::OK, Json(resp)).into_response()
}
#[cfg(feature = "hwaccel")]
pub(crate) async fn hardware_simulate(
State(state): State<Arc<AppState>>,
Json(req): Json<super::types::SimulateRequest>,
) -> impl IntoResponse {
if req.model_params == 0 {
return super::types::error_response(
StatusCode::BAD_REQUEST,
"model_params must be greater than 0",
)
.into_response();
}
if req.add_devices.len() > 64 {
return super::types::error_response(
StatusCode::BAD_REQUEST,
"add_devices limited to 64 entries",
)
.into_response();
}
for d in &req.add_devices {
if d.memory_bytes == 0 {
return super::types::error_response(
StatusCode::BAD_REQUEST,
"device memory_bytes must be greater than 0",
)
.into_response();
}
}
let hw = state.hardware.read().unwrap_or_else(|e| e.into_inner());
let original = super::types::SimulateSnapshot {
device_count: hw.available_profiles().len(),
total_vram_bytes: hw.total_accelerator_memory(),
sharding: hw.plan_sharding(req.model_params),
};
let mut hypothetical = if req.remove_count > 0 {
let mut profiles: Vec<ai_hwaccel::AcceleratorProfile> =
hw.registry().all_profiles().to_vec();
let mut removed = 0usize;
profiles.retain(|p| {
if removed < req.remove_count
&& !matches!(p.accelerator, ai_hwaccel::AcceleratorType::Cpu)
{
removed += 1;
false
} else {
true
}
});
hw.what_if_replace(profiles)
} else {
crate::hardware::HardwareManager::from_registry(hw.registry().clone())
};
if !req.add_devices.is_empty() {
let additional: Vec<ai_hwaccel::AcceleratorProfile> = req
.add_devices
.iter()
.enumerate()
.map(|(i, d)| ai_hwaccel::AcceleratorProfile::cuda(i as u32, d.memory_bytes))
.collect();
hypothetical = hypothetical.what_if_add(&additional);
}
let simulated = super::types::SimulateSnapshot {
device_count: hypothetical.available_profiles().len(),
total_vram_bytes: hypothetical.total_accelerator_memory(),
sharding: hypothetical.plan_sharding(req.model_params),
};
let resp = super::types::SimulateResponse {
original,
simulated,
};
(StatusCode::OK, Json(resp)).into_response()
}
#[cfg(feature = "hwaccel")]
pub(crate) async fn hardware_format(
Json(req): Json<super::types::ModelFormatRequest>,
) -> impl IntoResponse {
let raw = std::path::Path::new(&req.path);
if !raw.is_absolute() {
return super::types::error_response(StatusCode::BAD_REQUEST, "path must be absolute")
.into_response();
}
let path = match raw.canonicalize() {
Ok(p) => p,
Err(_) => {
return super::types::error_response(StatusCode::NOT_FOUND, "file not found")
.into_response();
}
};
let valid_ext = matches!(
path.extension().and_then(|e| e.to_str()),
Some("safetensors" | "gguf" | "onnx" | "pt" | "bin" | "model")
);
if !valid_ext {
return super::types::error_response(
StatusCode::BAD_REQUEST,
"unsupported file extension (expected .safetensors, .gguf, .onnx, .pt, .bin, .model)",
)
.into_response();
}
match crate::hardware::HardwareManager::detect_model_format(&path) {
Some(meta) => {
let resp = super::types::ModelFormatResponse {
format: format!("{:?}", meta.format),
param_count: meta.param_count,
dtype: meta.dtype,
tensor_count: meta.tensor_count,
format_version: meta.format_version,
};
(StatusCode::OK, Json(resp)).into_response()
}
None => super::types::error_response(StatusCode::BAD_REQUEST, "unrecognized model format")
.into_response(),
}
}
#[cfg(feature = "hwaccel")]
fn parse_quantization(s: &str) -> Option<ai_hwaccel::QuantizationLevel> {
match s.to_lowercase().as_str() {
"fp32" | "f32" | "none" => Some(ai_hwaccel::QuantizationLevel::None),
"fp16" | "f16" | "float16" => Some(ai_hwaccel::QuantizationLevel::Float16),
"bf16" | "bfloat16" => Some(ai_hwaccel::QuantizationLevel::BFloat16),
"int8" | "i8" | "q8" => Some(ai_hwaccel::QuantizationLevel::Int8),
"int4" | "i4" | "q4" => Some(ai_hwaccel::QuantizationLevel::Int4),
_ => None,
}
}
#[cfg(all(test, feature = "hwaccel"))]
mod tests {
use super::*;
#[test]
fn estimate_params_standard_patterns() {
assert_eq!(estimate_model_params("llama3-70b"), Some(70_000_000_000));
assert_eq!(
estimate_model_params("mistral-7b-instruct"),
Some(7_000_000_000)
);
assert_eq!(estimate_model_params("qwen2.5:32b"), Some(32_000_000_000));
assert_eq!(estimate_model_params("phi-3.5b"), Some(3_500_000_000));
}
#[test]
fn estimate_params_with_decimals() {
assert_eq!(estimate_model_params("llama-1.5b"), Some(1_500_000_000));
assert_eq!(estimate_model_params("model-0.5b"), Some(500_000_000));
}
#[test]
fn estimate_params_no_match() {
assert_eq!(estimate_model_params("gpt-4o"), None);
assert_eq!(estimate_model_params("claude-sonnet-4"), None);
assert_eq!(estimate_model_params("some-model"), None);
}
#[test]
fn estimate_params_edge_cases() {
assert_eq!(estimate_model_params("bert-base"), None);
}
}