use crate::{anthropic_compat, api, openai_compat, util::diag::diag_handler, AppState};
use axum::extract::Request;
use axum::{
extract::State,
http::{HeaderValue, Method},
middleware::{self, Next},
response::Response,
routing::{get, post},
Json, Router,
};
use serde_json::{json, Value};
use std::{net::SocketAddr, sync::Arc};
async fn cors_layer(req: Request, next: Next) -> Response {
let method = req.method().clone();
let mut response = next.run(req).await;
let headers = response.headers_mut();
headers.insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
headers.insert(
"Access-Control-Allow-Methods",
HeaderValue::from_static("GET, POST, OPTIONS"),
);
headers.insert(
"Access-Control-Allow-Headers",
HeaderValue::from_static("Content-Type, Authorization"),
);
headers.insert("Access-Control-Max-Age", HeaderValue::from_static("86400"));
if method == Method::OPTIONS {
*response.status_mut() = axum::http::StatusCode::OK;
}
response
}
async fn health_check(State(state): State<Arc<AppState>>) -> Json<Value> {
let models = state.registry.list_all_available();
let discovered = state.registry.discovered_models.len();
let manual = state.registry.list().len();
Json(json!({
"status": "ok",
"service": "shimmy",
"version": env!("CARGO_PKG_VERSION"),
"models": {
"total": models.len(),
"discovered": discovered,
"manual": manual
},
"endpoints": {
"health": "/health",
"models": "/v1/models",
"chat": "/v1/chat/completions",
"generate": "/api/generate"
},
"compatibility": {
"openai": true,
"cors": true
},
"timestamp": chrono::Utc::now().to_rfc3339(),
"uptime_seconds": std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}))
}
async fn metrics_endpoint(State(state): State<Arc<AppState>>) -> Json<Value> {
let models = state.registry.list_all_available();
let discovered_models = state.registry.discovered_models.clone();
let total_size_mb: u64 = discovered_models
.values()
.map(|m| m.size_bytes / (1024 * 1024))
.sum();
let memory_info = sys_info::mem_info().unwrap_or(sys_info::MemInfo {
total: 0,
free: 0,
avail: 0,
buffers: 0,
cached: 0,
swap_total: 0,
swap_free: 0,
});
let gpu_detected = detect_gpu();
let gpu_vendor = get_gpu_vendor();
Json(json!({
"service": "shimmy",
"version": env!("CARGO_PKG_VERSION"),
"gpu_detected": gpu_detected,
"gpu_vendor": gpu_vendor,
"models": {
"total_count": models.len(),
"total_size_mb": total_size_mb,
"by_type": {
"discovered": state.registry.discovered_models.len(),
"manual": state.registry.list().len()
}
},
"system": {
"memory_total_mb": memory_info.total / 1024,
"memory_free_mb": memory_info.free / 1024,
"memory_available_mb": memory_info.avail / 1024
},
"features": {
"llama": cfg!(feature = "llama"),
"huggingface": cfg!(feature = "huggingface")
},
"endpoints": [
"/health",
"/metrics",
"/v1/chat/completions",
"/v1/models",
"/api/generate",
"/api/models"
],
"timestamp": chrono::Utc::now().to_rfc3339()
}))
}
pub async fn run(addr: SocketAddr, state: Arc<AppState>) -> anyhow::Result<()> {
let listener = tokio::net::TcpListener::bind(addr).await?;
let app = Router::new()
.route("/health", get(health_check))
.route("/metrics", get(metrics_endpoint))
.route("/diag", get(diag_handler))
.route("/api/generate", post(api::generate))
.route("/api/models", get(api::list_models))
.route("/api/models/discover", post(api::discover_models))
.route("/api/models/:name/load", post(api::load_model))
.route("/api/models/:name/unload", post(api::unload_model))
.route("/api/models/:name/status", get(api::model_status))
.route("/api/tools", get(api::list_tools))
.route("/api/tools/:name/execute", post(api::execute_tool))
.route("/api/workflows/execute", post(api::execute_workflow))
.route("/ws/generate", get(api::ws_generate))
.route(
"/v1/chat/completions",
post(openai_compat::chat_completions),
)
.route("/v1/models", get(openai_compat::models))
.route("/v1/messages", post(anthropic_compat::messages))
.layer(middleware::from_fn(cors_layer))
.with_state(state);
axum::serve(listener, app).await?;
Ok(())
}
fn detect_gpu() -> bool {
detect_nvidia() || detect_amd() || detect_intel()
}
fn get_gpu_vendor() -> Option<String> {
if detect_nvidia() {
Some("nvidia".to_string())
} else if detect_amd() {
Some("amd".to_string())
} else if detect_intel() {
Some("intel".to_string())
} else {
None
}
}
fn detect_nvidia() -> bool {
std::process::Command::new("nvidia-smi")
.arg("--version")
.output()
.map(|o| o.status.success())
.unwrap_or(false)
}
fn detect_amd() -> bool {
std::process::Command::new("rocm-smi")
.arg("--version")
.output()
.map(|o| o.status.success())
.unwrap_or(false)
||
std::process::Command::new("wmic")
.args(["path", "win32_VideoController", "get", "name"])
.output()
.map(|o| String::from_utf8_lossy(&o.stdout).to_lowercase().contains("radeon"))
.unwrap_or(false)
}
fn detect_intel() -> bool {
std::process::Command::new("wmic")
.args(["path", "win32_VideoController", "get", "name"])
.output()
.map(|o| {
String::from_utf8_lossy(&o.stdout)
.to_lowercase()
.contains("intel")
})
.unwrap_or(false)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model_registry::Registry;
use tokio::time::{timeout, Duration};
#[test]
fn test_health_response_format() {
let response = serde_json::json!({"status": "ok"});
assert_eq!(response["status"], "ok");
}
#[test]
fn test_app_state_creation() {
let registry = Registry::default();
let engine = Box::new(crate::engine::adapter::InferenceEngineAdapter::new());
let state = Arc::new(AppState::new(engine, registry));
assert_eq!(state.registry.list().len(), 0);
}
#[test]
fn test_socket_addr_parsing() {
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
assert_eq!(addr.port(), 0); assert!(addr.ip().is_loopback());
}
#[test]
fn test_router_route_creation() {
use axum::{routing::get, Router};
let _app: Router<()> = Router::new().route("/health", get(|| async { "ok" }));
}
#[test]
fn test_health_endpoint_response() {
let response = serde_json::json!({"status":"ok"});
assert_eq!(response["status"], "ok");
}
#[tokio::test]
async fn test_tcp_listener_binding() {
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let listener_result = tokio::net::TcpListener::bind(addr).await;
assert!(listener_result.is_ok());
}
#[tokio::test]
async fn test_run_function_creation() {
use crate::engine::adapter::InferenceEngineAdapter;
use crate::model_registry::Registry;
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let registry = Registry::default();
let engine = Box::new(InferenceEngineAdapter::new());
let state = Arc::new(crate::AppState::new(engine, registry));
assert!(addr.port() == 0); assert_eq!(state.registry.list().len(), 0);
}
#[tokio::test]
async fn test_run_function_tcp_binding() {
use crate::engine::adapter::InferenceEngineAdapter;
use crate::model_registry::Registry;
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let registry = Registry::default();
let engine = Box::new(InferenceEngineAdapter::new());
let state = Arc::new(crate::AppState::new(engine, registry));
let result = timeout(Duration::from_millis(100), async { run(addr, state).await }).await;
assert!(result.is_err()); }
#[tokio::test]
async fn test_router_construction_with_all_routes() {
use crate::engine::adapter::InferenceEngineAdapter;
use crate::model_registry::Registry;
use axum::routing::{get, post};
let registry = Registry::default();
let engine = Box::new(InferenceEngineAdapter::new());
let state = Arc::new(crate::AppState::new(engine, registry));
let _router: Router<Arc<AppState>> = Router::new()
.route("/health", get(|| async { "ok" }))
.route("/api/test", post(|| async { "test" }))
.with_state(state);
}
#[tokio::test]
async fn test_run_function_router_and_serve_setup() {
use crate::engine::adapter::InferenceEngineAdapter;
use crate::model_registry::Registry;
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let registry = Registry::default();
let engine = Box::new(InferenceEngineAdapter::new());
let state = Arc::new(crate::AppState::new(engine, registry));
let run_future = run(addr, state);
let result = timeout(Duration::from_millis(50), run_future).await;
assert!(result.is_err()); }
#[tokio::test]
async fn test_run_function_complete_flow() {
use crate::engine::adapter::InferenceEngineAdapter;
use crate::model_registry::Registry;
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let registry = Registry::default();
let engine = Box::new(InferenceEngineAdapter::new());
let state = Arc::new(crate::AppState::new(engine, registry));
let server_handle = tokio::spawn(async move { run(addr, state).await });
tokio::time::sleep(Duration::from_millis(10)).await;
server_handle.abort();
}
#[tokio::test]
async fn test_run_function_direct_call_coverage() {
use crate::engine::adapter::InferenceEngineAdapter;
use crate::model_registry::Registry;
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let registry = Registry::default();
let engine = Box::new(InferenceEngineAdapter::new());
let state = Arc::new(crate::AppState::new(engine, registry));
let run_task = tokio::spawn(run(addr, state));
tokio::time::sleep(Duration::from_millis(20)).await;
run_task.abort();
}
#[test]
fn test_gpu_detection_functions() {
let _gpu_detected = detect_gpu();
let vendor = get_gpu_vendor();
if let Some(vendor_str) = vendor {
assert!(vendor_str == "nvidia" || vendor_str == "amd" || vendor_str == "intel");
}
}
#[test]
fn test_nvidia_detection() {
let _result = detect_nvidia();
}
#[test]
fn test_amd_detection() {
let _result = detect_amd();
}
#[test]
fn test_intel_detection() {
let _result = detect_intel();
}
#[tokio::test]
async fn test_metrics_endpoint_gpu_fields() {
use crate::engine::adapter::InferenceEngineAdapter;
use crate::model_registry::Registry;
let registry = Registry::default();
let engine = Box::new(InferenceEngineAdapter::new());
let state = Arc::new(crate::AppState::new(engine, registry));
let response = metrics_endpoint(State(state)).await;
let json_value = response.0;
assert!(json_value.get("gpu_detected").is_some());
assert!(json_value.get("gpu_vendor").is_some());
assert!(json_value["gpu_detected"].is_boolean());
if !json_value["gpu_vendor"].is_null() {
let vendor = json_value["gpu_vendor"].as_str().unwrap();
assert!(vendor == "nvidia" || vendor == "amd" || vendor == "intel");
}
}
}