Skip to main content

ferrum_cli/commands/
serve.rs

1//! Serve command - Start the HTTP inference server
2
3use crate::config::CliConfig;
4use clap::Args;
5use colored::*;
6use ferrum_interfaces::InferenceEngine;
7use ferrum_models::source::ModelFormat;
8use ferrum_server::{AxumServer, HttpServer, ServerConfig};
9use ferrum_types::Result;
10use std::path::PathBuf;
11use std::sync::Arc;
12use tokio::signal;
13
14#[derive(Args)]
15pub struct ServeCommand {
16    /// Model to serve (default: from config)
17    #[arg(value_name = "MODEL")]
18    pub model: Option<String>,
19
20    /// Model to serve (default: from config)
21    #[arg(
22        short = 'm',
23        long = "model",
24        value_name = "MODEL",
25        conflicts_with = "model"
26    )]
27    pub model_option: Option<String>,
28
29    /// Host to bind to
30    #[arg(long)]
31    pub host: Option<String>,
32
33    /// Port to listen on
34    #[arg(short, long)]
35    pub port: Option<u16>,
36}
37
38pub async fn execute(cmd: ServeCommand, config: CliConfig) -> Result<()> {
39    let ServeCommand {
40        model,
41        model_option,
42        host,
43        port,
44    } = cmd;
45
46    // Print banner
47    print_banner();
48
49    // Resolve model
50    let model_name = model
51        .or(model_option)
52        .or(config.models.default_model.clone())
53        .unwrap_or_else(|| "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string());
54
55    let model_id = resolve_model_alias(&model_name);
56    println!("{} {}", "Model:".dimmed(), model_id.cyan());
57
58    let host = host.unwrap_or_else(|| config.server.host.clone());
59    let port = port.unwrap_or(config.server.port);
60
61    // Find cached model
62    let cache_dir = get_hf_cache_dir(&config);
63    let source = match find_cached_model(&cache_dir, &model_id) {
64        Some(source) => {
65            println!("{} {}", "Path:".dimmed(), source.local_path.display());
66            source
67        }
68        None => {
69            eprintln!(
70                "{} Model '{}' not found. Run: ferrum pull {}",
71                "Error:".red().bold(),
72                model_id,
73                model_name
74            );
75            return Err(ferrum_types::FerrumError::model("Model not found"));
76        }
77    };
78
79    // Set model path for engine
80    std::env::set_var(
81        "FERRUM_MODEL_PATH",
82        source.local_path.to_string_lossy().to_string(),
83    );
84
85    // Select device
86    let device = select_device();
87    println!("{} {:?}", "Device:".dimmed(), device);
88
89    // Detect architecture to choose engine type
90    println!();
91    let mut config_manager = ferrum_models::ConfigManager::new();
92    let model_def = config_manager.load_from_path(&source.local_path).await?;
93
94    let engine: Arc<dyn InferenceEngine + Send + Sync> =
95        if model_def.architecture == ferrum_models::Architecture::Clip {
96            println!("{}", "Initializing CLIP embedding engine...".dimmed());
97            let candle_device = candle_core::Device::Cpu;
98            let executor = ferrum_models::ClipModelExecutor::from_path(
99                &source.local_path.to_string_lossy(),
100                candle_device,
101                candle_core::DType::F32,
102            )?;
103            // Load tokenizer for text embedding
104            let tokenizer = crate::commands::embed::load_tokenizer(&source.local_path)?;
105            let engine_config = ferrum_engine::simple_engine_config(model_id.clone(), device);
106            Arc::new(
107                ferrum_engine::embedding_engine::EmbeddingEngine::new(executor, engine_config)
108                    .with_tokenizer(tokenizer),
109            )
110        } else {
111            println!(
112                "{}",
113                "Initializing engine (continuous batching)...".dimmed()
114            );
115            let mut engine_config = ferrum_engine::simple_engine_config(model_id.clone(), device);
116            engine_config.scheduler.policy = ferrum_types::SchedulingPolicy::ContinuousBatch;
117            engine_config.kv_cache.cache_type = ferrum_types::KvCacheType::Paged;
118            let engine = ferrum_engine::create_mvp_engine(engine_config).await?;
119            Arc::from(engine)
120        };
121
122    // Create server config
123    let server_config = ServerConfig {
124        host: host.clone(),
125        port,
126        ..Default::default()
127    };
128
129    // Create server with engine
130    let server = AxumServer::new(engine);
131
132    println!();
133    println!(
134        "{} {} {}",
135        "🚀".green(),
136        "Server running at".green().bold(),
137        format!("http://{}:{}", host, port).cyan().bold()
138    );
139    println!();
140    println!("Endpoints:");
141    println!("  POST /v1/chat/completions  - OpenAI-compatible chat");
142    println!("  GET  /v1/models            - List models");
143    println!("  GET  /health               - Health check");
144    println!();
145    println!("{}", "Press Ctrl+C to stop.".dimmed());
146    println!();
147
148    // Write PID file for stop command
149    let pid_file = std::env::temp_dir().join("ferrum.pid");
150    std::fs::write(&pid_file, std::process::id().to_string()).ok();
151
152    // Start server with graceful shutdown
153    tokio::select! {
154        result = server.start(&server_config) => {
155            if let Err(e) = result {
156                eprintln!("{} Server error: {}", "Error:".red().bold(), e);
157            }
158        }
159        _ = signal::ctrl_c() => {
160            println!();
161            println!("{}", "Shutting down...".yellow());
162        }
163    }
164
165    // Clean up PID file
166    std::fs::remove_file(&pid_file).ok();
167
168    Ok(())
169}
170
171fn print_banner() {
172    println!();
173    println!("{}", "  ______                            ".bright_red());
174    println!("{}", " |  ____|                           ".bright_red());
175    println!("{}", " | |__ ___ _ __ _ __ _   _ _ __ ___  ".bright_red());
176    println!("{}", " |  __/ _ \\ '__| '__| | | | '_ ` _ \\ ".bright_red());
177    println!("{}", " | | |  __/ |  | |  | |_| | | | | | ".bright_red());
178    println!("{}", " |_|  \\___|_|  |_|   \\__,_|_| |_| |_|".bright_red());
179    println!();
180    println!("   {}", "🦀 Rust LLM Inference Server".bright_cyan().bold());
181    println!(
182        "   {}",
183        format!("Version {}", env!("CARGO_PKG_VERSION")).dimmed()
184    );
185    println!();
186}
187
188fn resolve_model_alias(name: &str) -> String {
189    match name.to_lowercase().as_str() {
190        "tinyllama" | "tiny" => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
191        "qwen2.5:0.5b" | "qwen:0.5b" => "Qwen/Qwen2.5-0.5B-Instruct".to_string(),
192        "qwen2.5:1.5b" | "qwen:1.5b" => "Qwen/Qwen2.5-1.5B-Instruct".to_string(),
193        "qwen2.5:3b" | "qwen:3b" => "Qwen/Qwen2.5-3B-Instruct".to_string(),
194        "qwen2.5:7b" | "qwen:7b" => "Qwen/Qwen2.5-7B-Instruct".to_string(),
195        "qwen3:0.6b" => "Qwen/Qwen3-0.6B".to_string(),
196        "qwen3:1.7b" => "Qwen/Qwen3-1.7B".to_string(),
197        "qwen3:4b" => "Qwen/Qwen3-4B".to_string(),
198        "llama3.2:1b" => "meta-llama/Llama-3.2-1B-Instruct".to_string(),
199        "llama3.2:3b" => "meta-llama/Llama-3.2-3B-Instruct".to_string(),
200        _ => name.to_string(),
201    }
202}
203
204fn get_hf_cache_dir(config: &CliConfig) -> PathBuf {
205    if let Ok(hf_home) = std::env::var("HF_HOME") {
206        return PathBuf::from(hf_home);
207    }
208    let configured = shellexpand::tilde(&config.models.download.hf_cache_dir).to_string();
209    PathBuf::from(configured)
210}
211
212fn find_cached_model(
213    cache_dir: &PathBuf,
214    model_id: &str,
215) -> Option<ferrum_models::source::ResolvedModelSource> {
216    let repo_dir = cache_dir
217        .join("hub")
218        .join(format!("models--{}", model_id.replace('/', "--")));
219    let snapshots_dir = repo_dir.join("snapshots");
220
221    // Try refs/main first
222    let ref_main = repo_dir.join("refs").join("main");
223    if let Ok(rev) = std::fs::read_to_string(&ref_main) {
224        let rev = rev.trim();
225        if !rev.is_empty() {
226            let snapshot = snapshots_dir.join(rev);
227            if snapshot.exists() {
228                let format = detect_format(&snapshot);
229                if format != ModelFormat::Unknown {
230                    return Some(ferrum_models::source::ResolvedModelSource {
231                        original: model_id.to_string(),
232                        local_path: snapshot,
233                        format,
234                        from_cache: true,
235                    });
236                }
237            }
238        }
239    }
240
241    // Fallback: first snapshot directory
242    if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
243        for entry in entries.flatten() {
244            let path = entry.path();
245            if path.is_dir() {
246                let format = detect_format(&path);
247                if format != ModelFormat::Unknown {
248                    return Some(ferrum_models::source::ResolvedModelSource {
249                        original: model_id.to_string(),
250                        local_path: path,
251                        format,
252                        from_cache: true,
253                    });
254                }
255            }
256        }
257    }
258
259    None
260}
261
262fn detect_format(path: &PathBuf) -> ModelFormat {
263    if path.join("model.safetensors").exists() || path.join("model.safetensors.index.json").exists()
264    {
265        ModelFormat::SafeTensors
266    } else if path.join("pytorch_model.bin").exists() {
267        ModelFormat::PyTorchBin
268    } else {
269        ModelFormat::Unknown
270    }
271}
272
273fn select_device() -> ferrum_types::Device {
274    #[cfg(all(target_os = "macos", feature = "metal"))]
275    {
276        return ferrum_types::Device::Metal;
277    }
278
279    #[cfg(feature = "cuda")]
280    {
281        return ferrum_types::Device::CUDA(0);
282    }
283
284    #[allow(unreachable_code)]
285    ferrum_types::Device::CPU
286}