use anyhow::{Context, Result};
use axum::{
extract::{Json, State},
http::StatusCode,
response::{
sse::{Event, KeepAlive, Sse},
IntoResponse,
},
routing::{get, post},
Router,
};
use colored::Colorize;
use console::style;
use futures::stream::{self, Stream, StreamExt};
use serde::{Deserialize, Serialize};
use std::convert::Infallible;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
use tower_http::cors::{Any, CorsLayer};
use tower_http::trace::TraceLayer;
use crate::models::{resolve_model_id, QuantPreset};
struct ServerState {
model_id: String,
backend: Option<Box<dyn ruvllm::LlmBackend>>,
request_count: u64,
total_tokens: u64,
start_time: Instant,
}
type SharedState = Arc<RwLock<ServerState>>;
pub async fn run(
model: &str,
host: &str,
port: u16,
max_concurrent: usize,
max_context: usize,
quantization: &str,
cache_dir: &str,
) -> Result<()> {
let model_id = resolve_model_id(model);
let quant = QuantPreset::from_str(quantization)
.ok_or_else(|| anyhow::anyhow!("Invalid quantization format: {}", quantization))?;
println!();
println!("{}", style("RuvLLM Inference Server").bold().cyan());
println!();
println!(" {} {}", "Model:".dimmed(), model_id);
println!(" {} {}", "Quantization:".dimmed(), quant);
println!(" {} {}", "Max Concurrent:".dimmed(), max_concurrent);
println!(" {} {}", "Max Context:".dimmed(), max_context);
println!();
println!("{}", "Loading model...".yellow());
let mut backend = ruvllm::create_backend();
let config = ruvllm::ModelConfig {
architecture: detect_architecture(&model_id),
quantization: Some(map_quantization(quant)),
max_sequence_length: max_context,
..Default::default()
};
let model_path = PathBuf::from(cache_dir).join("models").join(&model_id);
let load_result = if model_path.exists() {
backend.load_model(model_path.to_str().unwrap(), config.clone())
} else {
backend.load_model(&model_id, config)
};
match load_result {
Ok(_) => {
if let Some(info) = backend.model_info() {
println!(
"{} Loaded {} ({:.1}B params, {} memory)",
style("Success!").green().bold(),
info.name,
info.num_parameters as f64 / 1e9,
bytesize::ByteSize(info.memory_usage as u64)
);
} else {
println!("{} Model loaded", style("Success!").green().bold());
}
}
Err(e) => {
println!(
"{} Model loading failed: {}. Running in mock mode.",
style("Warning:").yellow().bold(),
e
);
}
}
let state = Arc::new(RwLock::new(ServerState {
model_id: model_id.clone(),
backend: Some(backend),
request_count: 0,
total_tokens: 0,
start_time: Instant::now(),
}));
let app = Router::new()
.route("/v1/chat/completions", post(chat_completions))
.route("/v1/models", get(list_models))
.route("/health", get(health_check))
.route("/metrics", get(metrics))
.route("/", get(root))
.with_state(state)
.layer(
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any),
)
.layer(TraceLayer::new_for_http());
let addr = format!("{}:{}", host, port)
.parse::<SocketAddr>()
.context("Invalid address")?;
println!();
println!("{}", style("Server ready!").bold().green());
println!();
println!(" {} http://{}/v1/chat/completions", "API:".cyan(), addr);
println!(" {} http://{}/health", "Health:".cyan(), addr);
println!(" {} http://{}/metrics", "Metrics:".cyan(), addr);
println!();
println!("{}", "Example curl:".dimmed());
println!(
r#" curl http://{}/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{{"model": "{}", "messages": [{{"role": "user", "content": "Hello!"}}]}}'"#,
addr, model_id
);
println!();
println!("Press Ctrl+C to stop the server.");
println!();
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await
.context("Server error")?;
println!();
println!("{}", "Server stopped.".dimmed());
Ok(())
}
#[derive(Debug, Deserialize)]
struct ChatCompletionRequest {
model: String,
messages: Vec<ChatMessage>,
#[serde(default = "default_max_tokens")]
max_tokens: usize,
#[serde(default = "default_temperature")]
temperature: f32,
#[serde(default)]
top_p: Option<f32>,
#[serde(default)]
stream: bool,
#[serde(default)]
stop: Option<Vec<String>>,
}
fn default_max_tokens() -> usize {
512
}
fn default_temperature() -> f32 {
0.7
}
#[derive(Debug, Serialize, Deserialize)]
struct ChatMessage {
role: String,
content: String,
}
#[derive(Debug, Serialize)]
struct ChatCompletionResponse {
id: String,
object: String,
created: u64,
model: String,
choices: Vec<ChatChoice>,
usage: Usage,
}
#[derive(Debug, Serialize)]
struct ChatChoice {
index: usize,
message: ChatMessage,
finish_reason: String,
}
#[derive(Debug, Serialize)]
struct Usage {
prompt_tokens: usize,
completion_tokens: usize,
total_tokens: usize,
}
#[derive(Debug, Serialize)]
struct ChatCompletionChunk {
id: String,
object: String,
created: u64,
model: String,
choices: Vec<ChunkChoice>,
}
#[derive(Debug, Serialize)]
struct ChunkChoice {
index: usize,
delta: Delta,
finish_reason: Option<String>,
}
#[derive(Debug, Serialize)]
struct Delta {
#[serde(skip_serializing_if = "Option::is_none")]
role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
}
async fn chat_completions(
State(state): State<SharedState>,
Json(request): Json<ChatCompletionRequest>,
) -> axum::response::Response {
if request.stream {
chat_completions_stream(state, request)
.await
.into_response()
} else {
chat_completions_non_stream(state, request)
.await
.into_response()
}
}
async fn chat_completions_non_stream(
state: SharedState,
request: ChatCompletionRequest,
) -> impl IntoResponse {
let start = Instant::now();
let prompt = build_prompt(&request.messages);
let mut state_lock = state.write().await;
state_lock.request_count += 1;
let response_text = if let Some(backend) = &state_lock.backend {
if backend.is_model_loaded() {
let params = ruvllm::GenerateParams {
max_tokens: request.max_tokens,
temperature: request.temperature,
top_p: request.top_p.unwrap_or(0.9),
stop_sequences: request.stop.unwrap_or_default(),
..Default::default()
};
match backend.generate(&prompt, params) {
Ok(text) => text,
Err(e) => format!("Generation error: {}", e),
}
} else {
mock_response(&prompt)
}
} else {
mock_response(&prompt)
};
let prompt_tokens = prompt.split_whitespace().count();
let completion_tokens = response_text.split_whitespace().count();
state_lock.total_tokens += (prompt_tokens + completion_tokens) as u64;
drop(state_lock);
let response = ChatCompletionResponse {
id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
object: "chat.completion".to_string(),
created: chrono::Utc::now().timestamp() as u64,
model: request.model,
choices: vec![ChatChoice {
index: 0,
message: ChatMessage {
role: "assistant".to_string(),
content: response_text,
},
finish_reason: "stop".to_string(),
}],
usage: Usage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
},
};
tracing::info!(
"Chat completion: {} tokens in {:.2}ms",
response.usage.total_tokens,
start.elapsed().as_secs_f64() * 1000.0
);
Json(response)
}
async fn chat_completions_stream(
state: SharedState,
request: ChatCompletionRequest,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let completion_id = format!("chatcmpl-{}", uuid::Uuid::new_v4());
let created = chrono::Utc::now().timestamp() as u64;
let model = request.model.clone();
let prompt = build_prompt(&request.messages);
let state_clone = state.clone();
let params = ruvllm::GenerateParams {
max_tokens: request.max_tokens,
temperature: request.temperature,
top_p: request.top_p.unwrap_or(0.9),
stop_sequences: request.stop.unwrap_or_default(),
..Default::default()
};
let stream = async_stream::stream! {
{
let mut state_lock = state_clone.write().await;
state_lock.request_count += 1;
}
let initial_chunk = ChatCompletionChunk {
id: completion_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
choices: vec![ChunkChoice {
index: 0,
delta: Delta {
role: Some("assistant".to_string()),
content: None,
},
finish_reason: None,
}],
};
yield Ok(Event::default().data(serde_json::to_string(&initial_chunk).unwrap_or_default()));
let state_lock = state_clone.read().await;
let backend_opt = state_lock.backend.as_ref();
if let Some(backend) = backend_opt {
if backend.is_model_loaded() {
match backend.generate_stream_v2(&prompt, params.clone()) {
Ok(token_stream) => {
drop(state_lock);
for event_result in token_stream {
match event_result {
Ok(ruvllm::StreamEvent::Token(token)) => {
let chunk = ChatCompletionChunk {
id: completion_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
choices: vec![ChunkChoice {
index: 0,
delta: Delta {
role: None,
content: Some(token.text),
},
finish_reason: None,
}],
};
yield Ok(Event::default().data(serde_json::to_string(&chunk).unwrap_or_default()));
}
Ok(ruvllm::StreamEvent::Done { total_tokens, .. }) => {
let mut state_lock = state_clone.write().await;
state_lock.total_tokens += total_tokens as u64;
drop(state_lock);
let final_chunk = ChatCompletionChunk {
id: completion_id.clone(),
object: "chat.completion.chunk".to_string(),
created,
model: model.clone(),
choices: vec![ChunkChoice {
index: 0,
delta: Delta {
role: None,
content: None,
},
finish_reason: Some("stop".to_string()),
}],
};
yield Ok(Event::default().data(serde_json::to_string(&final_chunk).unwrap_or_default()));
break;
}
Ok(ruvllm::StreamEvent::Error(msg)) => {
tracing::error!("Stream error: {}", msg);
break;
}
Err(e) => {
tracing::error!("Stream error: {}", e);
break;
}
}
}
}
Err(e) => {
drop(state_lock);
tracing::error!("Failed to create stream: {}", e);
for chunk_data in mock_stream_response(&prompt, &completion_id, created, &model) {
yield Ok(Event::default().data(chunk_data));
}
}
}
} else {
drop(state_lock);
for chunk_data in mock_stream_response(&prompt, &completion_id, created, &model) {
yield Ok(Event::default().data(chunk_data));
}
}
} else {
drop(state_lock);
for chunk_data in mock_stream_response(&prompt, &completion_id, created, &model) {
yield Ok(Event::default().data(chunk_data));
}
}
yield Ok(Event::default().data("[DONE]"));
};
Sse::new(stream).keep_alive(KeepAlive::default())
}
fn mock_stream_response(prompt: &str, id: &str, created: u64, model: &str) -> Vec<String> {
let response_text = mock_response(prompt);
let words: Vec<&str> = response_text.split_whitespace().collect();
let mut chunks = Vec::new();
for (i, word) in words.iter().enumerate() {
let text = if i == 0 {
word.to_string()
} else {
format!(" {}", word)
};
let chunk = ChatCompletionChunk {
id: id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
choices: vec![ChunkChoice {
index: 0,
delta: Delta {
role: None,
content: Some(text),
},
finish_reason: None,
}],
};
chunks.push(serde_json::to_string(&chunk).unwrap_or_default());
}
let final_chunk = ChatCompletionChunk {
id: id.to_string(),
object: "chat.completion.chunk".to_string(),
created,
model: model.to_string(),
choices: vec![ChunkChoice {
index: 0,
delta: Delta {
role: None,
content: None,
},
finish_reason: Some("stop".to_string()),
}],
};
chunks.push(serde_json::to_string(&final_chunk).unwrap_or_default());
chunks
}
fn build_prompt(messages: &[ChatMessage]) -> String {
let mut prompt = String::new();
for msg in messages {
match msg.role.as_str() {
"system" => {
prompt.push_str(&format!("<|system|>\n{}\n", msg.content));
}
"user" => {
prompt.push_str(&format!("<|user|>\n{}\n", msg.content));
}
"assistant" => {
prompt.push_str(&format!("<|assistant|>\n{}\n", msg.content));
}
_ => {
prompt.push_str(&format!("{}: {}\n", msg.role, msg.content));
}
}
}
prompt.push_str("<|assistant|>\n");
prompt
}
fn mock_response(prompt: &str) -> String {
let prompt_lower = prompt.to_lowercase();
if prompt_lower.contains("hello") || prompt_lower.contains("hi") {
"Hello! I'm RuvLLM, a local AI assistant running on your Mac. How can I help you today?"
.to_string()
} else if prompt_lower.contains("code") || prompt_lower.contains("function") {
"Here's an example function:\n\n```rust\nfn hello() {\n println!(\"Hello, world!\");\n}\n```\n\nWould you like me to explain this code?".to_string()
} else {
"I understand your request. To provide real responses, please ensure the model is properly loaded. Currently running in mock mode for development.".to_string()
}
}
async fn list_models(State(state): State<SharedState>) -> impl IntoResponse {
let state_lock = state.read().await;
let models = serde_json::json!({
"object": "list",
"data": [{
"id": state_lock.model_id,
"object": "model",
"owned_by": "ruvllm",
"permission": []
}]
});
Json(models)
}
async fn health_check(State(state): State<SharedState>) -> impl IntoResponse {
let state_lock = state.read().await;
let status = if state_lock
.backend
.as_ref()
.map(|b| b.is_model_loaded())
.unwrap_or(false)
{
"healthy"
} else {
"degraded"
};
let health = serde_json::json!({
"status": status,
"model": state_lock.model_id,
"uptime_seconds": state_lock.start_time.elapsed().as_secs()
});
Json(health)
}
async fn metrics(State(state): State<SharedState>) -> impl IntoResponse {
let state_lock = state.read().await;
let uptime = state_lock.start_time.elapsed();
let metrics = serde_json::json!({
"model": state_lock.model_id,
"requests_total": state_lock.request_count,
"tokens_total": state_lock.total_tokens,
"uptime_seconds": uptime.as_secs(),
"requests_per_second": if uptime.as_secs() > 0 {
state_lock.request_count as f64 / uptime.as_secs() as f64
} else {
0.0
},
"tokens_per_second": if uptime.as_secs() > 0 {
state_lock.total_tokens as f64 / uptime.as_secs() as f64
} else {
0.0
}
});
Json(metrics)
}
async fn root() -> impl IntoResponse {
let info = serde_json::json!({
"name": "RuvLLM Inference Server",
"version": env!("CARGO_PKG_VERSION"),
"endpoints": {
"chat": "/v1/chat/completions",
"models": "/v1/models",
"health": "/health",
"metrics": "/metrics"
}
});
Json(info)
}
async fn shutdown_signal() {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("Failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
.expect("Failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {},
_ = terminate => {},
}
println!();
println!("{}", "Shutting down...".yellow());
}
fn detect_architecture(model_id: &str) -> ruvllm::ModelArchitecture {
let lower = model_id.to_lowercase();
if lower.contains("mistral") {
ruvllm::ModelArchitecture::Mistral
} else if lower.contains("llama") {
ruvllm::ModelArchitecture::Llama
} else if lower.contains("phi") {
ruvllm::ModelArchitecture::Phi
} else if lower.contains("qwen") {
ruvllm::ModelArchitecture::Qwen
} else if lower.contains("gemma") {
ruvllm::ModelArchitecture::Gemma
} else {
ruvllm::ModelArchitecture::Llama }
}
fn map_quantization(quant: QuantPreset) -> ruvllm::Quantization {
match quant {
QuantPreset::Q4K => ruvllm::Quantization::Q4K,
QuantPreset::Q8 => ruvllm::Quantization::Q8,
QuantPreset::F16 => ruvllm::Quantization::F16,
QuantPreset::None => ruvllm::Quantization::None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_build_prompt() {
let messages = vec![
ChatMessage {
role: "system".to_string(),
content: "You are helpful.".to_string(),
},
ChatMessage {
role: "user".to_string(),
content: "Hello!".to_string(),
},
];
let prompt = build_prompt(&messages);
assert!(prompt.contains("You are helpful"));
assert!(prompt.contains("Hello"));
assert!(prompt.ends_with("<|assistant|>\n"));
}
#[test]
fn test_detect_architecture() {
assert_eq!(
detect_architecture("mistralai/Mistral-7B"),
ruvllm::ModelArchitecture::Mistral
);
assert_eq!(
detect_architecture("Qwen/Qwen2.5-14B"),
ruvllm::ModelArchitecture::Qwen
);
}
}