use std::net::SocketAddr;
use std::sync::Arc;
use axum::Router;
use axum::routing::{get, post};
use tokio::sync::{RwLock, Semaphore};
use tower_http::cors::{Any, CorsLayer};
use crate::engine::ChatTemplate;
use crate::gguf::GgufFile;
use crate::model::{ModelConfig, ModelLoader};
use crate::tokenizer::Tokenizer;
use crate::{Backend, Model};
use super::handlers::{self, AppState, RequestQueue};
pub struct ServerConfig {
pub host: String,
pub port: u16,
pub model_path: String,
pub max_concurrent: usize,
pub max_queue_depth: usize,
pub max_context_len: usize,
#[cfg(feature = "rag")]
pub rag_database_url: Option<String>,
}
pub async fn run_server(config: ServerConfig) -> Result<(), Box<dyn std::error::Error>> {
eprintln!("Loading model from: {}", config.model_path);
let gguf = GgufFile::open(&config.model_path)?;
let tokenizer = Tokenizer::from_gguf(&gguf)?;
eprintln!("Tokenizer loaded: {} tokens", tokenizer.vocab_size);
let chat_template = ChatTemplate::detect(&gguf);
eprintln!("Chat template: {:?}", chat_template);
let mut loader = ModelLoader::load(&config.model_path)?;
let ctx_override = if config.max_context_len > 0 {
Some(config.max_context_len)
} else {
std::env::var("CARDOZO_CONTEXT_SIZE")
.ok()
.and_then(|v| v.parse::<usize>().ok())
};
if let Some(ctx_len) = ctx_override {
let native = loader.config().max_seq_len;
let clamped = ctx_len.min(native);
loader.config_mut().max_seq_len = clamped;
eprintln!("Context size: {} (native: {})", clamped, native);
}
let model_config = loader.config().clone();
eprintln!(
"Model config: {} layers, {} heads, {} dim",
model_config.num_layers, model_config.num_heads, model_config.hidden_size
);
let model = loader.build_model()?;
eprintln!("Model loaded successfully");
let (gpu_model, backend) = select_model_and_backend(model, &model_config);
let model = gpu_model;
let model_name = std::path::Path::new(&config.model_path)
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("llama")
.to_string();
let max_concurrent = if config.max_concurrent == 0 {
1
} else {
config.max_concurrent
};
let max_queue_depth = if config.max_queue_depth == 0 {
64
} else {
config.max_queue_depth
};
let app_state = Arc::new(AppState {
model: RwLock::new(model),
tokenizer: RwLock::new(Arc::new(tokenizer)),
config: RwLock::new(model_config),
model_name: RwLock::new(model_name),
model_path: RwLock::new(config.model_path.clone()),
chat_template: RwLock::new(chat_template),
backend: RwLock::new(backend),
inference_semaphore: Arc::new(Semaphore::new(max_concurrent)),
request_queue: RequestQueue::new(max_queue_depth, max_concurrent),
});
#[cfg(unix)]
{
let state_for_signal = app_state.clone();
tokio::spawn(async move {
use tokio::signal::unix::{SignalKind, signal};
let mut stream = match signal(SignalKind::hangup()) {
Ok(s) => s,
Err(e) => {
tracing::warn!("Failed to register SIGHUP handler: {}", e);
return;
}
};
loop {
stream.recv().await;
tracing::info!("SIGHUP received: reloading model...");
let path = state_for_signal.model_path.read().await.clone();
match handlers::reload_model_from_path(&state_for_signal, &path).await {
Ok((name, ctx)) => {
tracing::info!("Model reloaded via SIGHUP: {} (ctx={})", name, ctx);
}
Err(e) => {
tracing::error!("Model reload via SIGHUP failed: {}", e);
}
}
}
});
}
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
let mut app = Router::new()
.route("/v1/chat/completions", post(handlers::chat_completions))
.route("/v1/completions", post(handlers::completions))
.route("/v1/embeddings", post(handlers::embeddings))
.route("/v1/models", get(handlers::list_models))
.route("/v1/models/load", post(handlers::load_model))
.route("/v1/queue/status", get(handlers::queue_status))
.route("/health", get(handlers::health))
.route("/", get(|| async { "llama-gguf server" }))
.with_state(app_state.clone());
#[cfg(feature = "rag")]
let rag_enabled = config.rag_database_url.is_some();
#[cfg(not(feature = "rag"))]
let rag_enabled = false;
#[cfg(feature = "rag")]
if let Some(ref db_url) = config.rag_database_url {
use super::handlers::RagState;
use crate::rag::RagConfig;
eprintln!("RAG enabled with database connection");
let rag_config = RagConfig::new(db_url);
let rag_state = Arc::new(RagState::new(rag_config));
let rag_routes = Router::new()
.route("/knowledgebases", post(handlers::list_knowledge_bases))
.route("/knowledgebases/:kb_id", get(handlers::get_knowledge_base))
.route(
"/knowledgebases/:kb_id",
axum::routing::delete(handlers::delete_knowledge_base),
)
.route("/retrieve", post(handlers::retrieve))
.route("/ingest", post(handlers::ingest))
.with_state(rag_state.clone());
let rag_gen_routes = Router::new()
.route(
"/retrieveAndGenerate",
post(handlers::retrieve_and_generate),
)
.with_state((app_state.clone(), rag_state));
app = app
.nest("/v1/rag", rag_routes)
.nest("/v1/rag", rag_gen_routes);
}
app = app.layer(cors);
let addr = format!("{}:{}", config.host, config.port);
let socket_addr: SocketAddr = addr.parse()?;
eprintln!();
eprintln!("â•────────────────────────────────────────────────────────────────────╮");
eprintln!("│ llama-gguf Server │");
eprintln!("├────────────────────────────────────────────────────────────────────┤");
eprintln!("│ Listening on: http://{:<48}│", addr);
eprintln!("│ Concurrency: {} concurrent, {} max queued{:<27}│", max_concurrent, max_queue_depth, "");
eprintln!("├────────────────────────────────────────────────────────────────────┤");
eprintln!("│ Endpoints: │");
eprintln!("│ POST /v1/chat/completions - Chat completions (OpenAI API) │");
eprintln!("│ POST /v1/completions - Text completions (OpenAI API) │");
eprintln!("│ POST /v1/embeddings - Embeddings (OpenAI API) │");
eprintln!("│ GET /v1/models - List models │");
eprintln!("│ POST /v1/models/load - Hot-swap model │");
eprintln!("│ GET /v1/queue/status - Queue status │");
eprintln!("│ GET /health - Health check │");
if rag_enabled {
eprintln!("├────────────────────────────────────────────────────────────────────┤");
eprintln!("│ RAG / Knowledge Base Endpoints (Bedrock-style): │");
eprintln!("│ POST /v1/rag/retrieve - Retrieve from KB │");
eprintln!("│ POST /v1/rag/retrieveAndGenerate - RAG pipeline │");
eprintln!("│ POST /v1/rag/ingest - Ingest documents │");
eprintln!("│ POST /v1/rag/knowledgebases - List knowledge bases │");
eprintln!("│ GET /v1/rag/knowledgebases/:id - Get KB details │");
eprintln!("│ DEL /v1/rag/knowledgebases/:id - Delete KB │");
}
#[cfg(unix)]
eprintln!("├────────────────────────────────────────────────────────────────────┤");
#[cfg(unix)]
eprintln!("│ Send SIGHUP to reload model from the same path │");
eprintln!("╰────────────────────────────────────────────────────────────────────╯");
eprintln!();
let listener = tokio::net::TcpListener::bind(socket_addr).await?;
axum::serve(listener, app).await?;
Ok(())
}
pub(crate) fn select_model_and_backend(
model: crate::model::LlamaModel,
config: &ModelConfig,
) -> (Arc<dyn Model>, Arc<dyn Backend>) {
let use_gpu = std::env::var("LLAMA_GPU")
.map(|v| matches!(v.to_lowercase().as_str(), "1" | "true" | "yes"))
.unwrap_or(false);
if !use_gpu {
eprintln!("GPU disabled (LLAMA_GPU not set or 0)");
return (
Arc::new(model),
Arc::new(crate::backend::cpu::CpuBackend::new()),
);
}
let architecture = model.architecture();
let max_seq_len = config.max_seq_len;
#[cfg(feature = "cuda")]
{
if cudarc::driver::CudaDevice::new(0).is_ok() {
match crate::backend::cuda::gpu_only::GpuOnlyInference::from_model(
model,
max_seq_len,
) {
Ok(gpu) => {
eprintln!(
"Using full GPU inference (attention + DeltaNet + MoE all on CUDA)"
);
let wrapper = crate::backend::GpuModelWrapper::new(
gpu,
config.clone(),
architecture,
);
return (
Arc::new(wrapper),
Arc::new(crate::backend::cpu::CpuBackend::new()),
);
}
Err(e) => {
eprintln!("FATAL: CUDA GPU inference init failed: {}", e);
eprintln!("Model was consumed during init. Restart without LLAMA_GPU=1.");
std::process::exit(1);
}
}
} else {
eprintln!("LLAMA_GPU=1 but no CUDA device found");
}
}
#[cfg(feature = "vulkan")]
{
if crate::backend::vulkan::VulkanBackend::new().is_ok() {
match crate::backend::vulkan::gpu_only::VulkanGpuInference::from_model(
model,
max_seq_len,
) {
Ok(gpu) => {
eprintln!("Using full GPU inference on Vulkan");
let wrapper = crate::backend::GpuModelWrapper::new(
gpu,
config.clone(),
architecture,
);
return (
Arc::new(wrapper),
Arc::new(crate::backend::cpu::CpuBackend::new()),
);
}
Err(e) => {
eprintln!("FATAL: Vulkan GPU inference init failed: {}", e);
eprintln!("Model was consumed during init. Restart without LLAMA_GPU=1.");
std::process::exit(1);
}
}
}
}
#[cfg(all(feature = "metal", target_os = "macos"))]
{
if crate::backend::metal::MetalBackend::new().is_ok() {
match crate::backend::metal::gpu_only::MetalGpuInference::from_model(
model,
max_seq_len,
) {
Ok(gpu) => {
eprintln!("Using full GPU inference on Metal");
let wrapper = crate::backend::GpuModelWrapper::new(
gpu,
config.clone(),
architecture,
);
return (
Arc::new(wrapper),
Arc::new(crate::backend::cpu::CpuBackend::new()),
);
}
Err(e) => {
eprintln!("FATAL: Metal GPU inference init failed: {}", e);
eprintln!("Model was consumed during init. Restart without LLAMA_GPU=1.");
std::process::exit(1);
}
}
}
}
#[cfg(all(feature = "dx12", target_os = "windows"))]
{
if crate::backend::dx12::Dx12Backend::new().is_ok() {
match crate::backend::dx12::gpu_only::Dx12GpuInference::from_model(
model,
max_seq_len,
) {
Ok(gpu) => {
eprintln!("Using full GPU inference on DX12");
let wrapper = crate::backend::GpuModelWrapper::new(
gpu,
config.clone(),
architecture,
);
return (
Arc::new(wrapper),
Arc::new(crate::backend::cpu::CpuBackend::new()),
);
}
Err(e) => {
eprintln!("FATAL: DX12 GPU inference init failed: {}", e);
eprintln!("Model was consumed during init. Restart without LLAMA_GPU=1.");
std::process::exit(1);
}
}
}
}
#[cfg(not(any(
feature = "cuda",
feature = "vulkan",
all(feature = "metal", target_os = "macos"),
all(feature = "dx12", target_os = "windows")
)))]
{
eprintln!("LLAMA_GPU=1 but no GPU backend compiled. Build with --features cuda/vulkan/metal/dx12");
}
eprintln!("Falling back to CPU");
(
Arc::new(model),
Arc::new(crate::backend::cpu::CpuBackend::new()),
)
}