Skip to main content

ferrum_cli/commands/
serve.rs

1//! Serve command - Start the HTTP inference server
2
3use crate::config::CliConfig;
4use clap::Args;
5use colored::*;
6use ferrum_interfaces::InferenceEngine;
7use ferrum_models::source::ModelFormat;
8use ferrum_server::{AxumServer, HttpServer, ServerConfig};
9use ferrum_types::Result;
10use std::path::PathBuf;
11use std::sync::Arc;
12use tokio::signal;
13
14#[derive(Args)]
15pub struct ServeCommand {
16    /// Model to serve (default: from config)
17    #[arg(value_name = "MODEL")]
18    pub model: Option<String>,
19
20    /// Model to serve (default: from config)
21    #[arg(
22        short = 'm',
23        long = "model",
24        value_name = "MODEL",
25        conflicts_with = "model"
26    )]
27    pub model_option: Option<String>,
28
29    /// Host to bind to
30    #[arg(long)]
31    pub host: Option<String>,
32
33    /// Port to listen on
34    #[arg(short, long)]
35    pub port: Option<u16>,
36}
37
38pub async fn execute(cmd: ServeCommand, config: CliConfig) -> Result<()> {
39    let ServeCommand {
40        model,
41        model_option,
42        host,
43        port,
44    } = cmd;
45
46    // Print banner
47    print_banner();
48
49    // Resolve model
50    let model_name = model
51        .or(model_option)
52        .or(config.models.default_model.clone())
53        .unwrap_or_else(|| "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string());
54
55    let model_id = resolve_model_alias(&model_name);
56    println!("{} {}", "Model:".dimmed(), model_id.cyan());
57
58    let host = host.unwrap_or_else(|| config.server.host.clone());
59    let port = port.unwrap_or(config.server.port);
60
61    // Find cached model
62    let cache_dir = get_hf_cache_dir(&config);
63    let source = match find_cached_model(&cache_dir, &model_id) {
64        Some(source) => {
65            println!("{} {}", "Path:".dimmed(), source.local_path.display());
66            source
67        }
68        None => {
69            eprintln!(
70                "{} Model '{}' not found. Run: ferrum pull {}",
71                "Error:".red().bold(),
72                model_id,
73                model_name
74            );
75            return Err(ferrum_types::FerrumError::model("Model not found"));
76        }
77    };
78
79    // Set model path for engine
80    std::env::set_var(
81        "FERRUM_MODEL_PATH",
82        source.local_path.to_string_lossy().to_string(),
83    );
84
85    // Select device
86    let device = select_device();
87    println!("{} {:?}", "Device:".dimmed(), device);
88
89    // Create engine with continuous batching for serve mode
90    println!();
91    println!(
92        "{}",
93        "Initializing engine (continuous batching)...".dimmed()
94    );
95    let mut engine_config = ferrum_engine::simple_engine_config(model_id.clone(), device);
96    engine_config.scheduler.policy = ferrum_types::SchedulingPolicy::ContinuousBatch;
97    engine_config.kv_cache.cache_type = ferrum_types::KvCacheType::Paged;
98    let engine = ferrum_engine::create_mvp_engine(engine_config).await?;
99    // Convert Box<dyn InferenceEngine> to Arc<dyn InferenceEngine>
100    let engine: Arc<dyn InferenceEngine + Send + Sync> = Arc::from(engine);
101
102    // Create server config
103    let server_config = ServerConfig {
104        host: host.clone(),
105        port,
106        ..Default::default()
107    };
108
109    // Create server with engine
110    let server = AxumServer::new(engine);
111
112    println!();
113    println!(
114        "{} {} {}",
115        "🚀".green(),
116        "Server running at".green().bold(),
117        format!("http://{}:{}", host, port).cyan().bold()
118    );
119    println!();
120    println!("Endpoints:");
121    println!("  POST /v1/chat/completions  - OpenAI-compatible chat");
122    println!("  GET  /v1/models            - List models");
123    println!("  GET  /health               - Health check");
124    println!();
125    println!("{}", "Press Ctrl+C to stop.".dimmed());
126    println!();
127
128    // Write PID file for stop command
129    let pid_file = std::env::temp_dir().join("ferrum.pid");
130    std::fs::write(&pid_file, std::process::id().to_string()).ok();
131
132    // Start server with graceful shutdown
133    tokio::select! {
134        result = server.start(&server_config) => {
135            if let Err(e) = result {
136                eprintln!("{} Server error: {}", "Error:".red().bold(), e);
137            }
138        }
139        _ = signal::ctrl_c() => {
140            println!();
141            println!("{}", "Shutting down...".yellow());
142        }
143    }
144
145    // Clean up PID file
146    std::fs::remove_file(&pid_file).ok();
147
148    Ok(())
149}
150
151fn print_banner() {
152    println!();
153    println!("{}", "  ______                            ".bright_red());
154    println!("{}", " |  ____|                           ".bright_red());
155    println!("{}", " | |__ ___ _ __ _ __ _   _ _ __ ___  ".bright_red());
156    println!("{}", " |  __/ _ \\ '__| '__| | | | '_ ` _ \\ ".bright_red());
157    println!("{}", " | | |  __/ |  | |  | |_| | | | | | ".bright_red());
158    println!("{}", " |_|  \\___|_|  |_|   \\__,_|_| |_| |_|".bright_red());
159    println!();
160    println!("   {}", "🦀 Rust LLM Inference Server".bright_cyan().bold());
161    println!(
162        "   {}",
163        format!("Version {}", env!("CARGO_PKG_VERSION")).dimmed()
164    );
165    println!();
166}
167
168fn resolve_model_alias(name: &str) -> String {
169    match name.to_lowercase().as_str() {
170        "tinyllama" | "tiny" => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
171        "qwen2.5:0.5b" | "qwen:0.5b" => "Qwen/Qwen2.5-0.5B-Instruct".to_string(),
172        "qwen2.5:1.5b" | "qwen:1.5b" => "Qwen/Qwen2.5-1.5B-Instruct".to_string(),
173        "qwen2.5:3b" | "qwen:3b" => "Qwen/Qwen2.5-3B-Instruct".to_string(),
174        "qwen2.5:7b" | "qwen:7b" => "Qwen/Qwen2.5-7B-Instruct".to_string(),
175        "qwen3:0.6b" => "Qwen/Qwen3-0.6B".to_string(),
176        "qwen3:1.7b" => "Qwen/Qwen3-1.7B".to_string(),
177        "qwen3:4b" => "Qwen/Qwen3-4B".to_string(),
178        "llama3.2:1b" => "meta-llama/Llama-3.2-1B-Instruct".to_string(),
179        "llama3.2:3b" => "meta-llama/Llama-3.2-3B-Instruct".to_string(),
180        _ => name.to_string(),
181    }
182}
183
184fn get_hf_cache_dir(config: &CliConfig) -> PathBuf {
185    if let Ok(hf_home) = std::env::var("HF_HOME") {
186        return PathBuf::from(hf_home);
187    }
188    let configured = shellexpand::tilde(&config.models.download.hf_cache_dir).to_string();
189    PathBuf::from(configured)
190}
191
192fn find_cached_model(
193    cache_dir: &PathBuf,
194    model_id: &str,
195) -> Option<ferrum_models::source::ResolvedModelSource> {
196    let repo_dir = cache_dir
197        .join("hub")
198        .join(format!("models--{}", model_id.replace('/', "--")));
199    let snapshots_dir = repo_dir.join("snapshots");
200
201    // Try refs/main first
202    let ref_main = repo_dir.join("refs").join("main");
203    if let Ok(rev) = std::fs::read_to_string(&ref_main) {
204        let rev = rev.trim();
205        if !rev.is_empty() {
206            let snapshot = snapshots_dir.join(rev);
207            if snapshot.exists() {
208                let format = detect_format(&snapshot);
209                if format != ModelFormat::Unknown {
210                    return Some(ferrum_models::source::ResolvedModelSource {
211                        original: model_id.to_string(),
212                        local_path: snapshot,
213                        format,
214                        from_cache: true,
215                    });
216                }
217            }
218        }
219    }
220
221    // Fallback: first snapshot directory
222    if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
223        for entry in entries.flatten() {
224            let path = entry.path();
225            if path.is_dir() {
226                let format = detect_format(&path);
227                if format != ModelFormat::Unknown {
228                    return Some(ferrum_models::source::ResolvedModelSource {
229                        original: model_id.to_string(),
230                        local_path: path,
231                        format,
232                        from_cache: true,
233                    });
234                }
235            }
236        }
237    }
238
239    None
240}
241
242fn detect_format(path: &PathBuf) -> ModelFormat {
243    if path.join("model.safetensors").exists() || path.join("model.safetensors.index.json").exists()
244    {
245        ModelFormat::SafeTensors
246    } else if path.join("pytorch_model.bin").exists() {
247        ModelFormat::PyTorchBin
248    } else {
249        ModelFormat::Unknown
250    }
251}
252
253fn select_device() -> ferrum_types::Device {
254    #[cfg(all(target_os = "macos", feature = "metal"))]
255    {
256        return ferrum_types::Device::Metal;
257    }
258
259    #[cfg(feature = "cuda")]
260    {
261        return ferrum_types::Device::CUDA(0);
262    }
263
264    #[allow(unreachable_code)]
265    ferrum_types::Device::CPU
266}