ferrum_cli/commands/
serve.rs1use crate::config::CliConfig;
4use clap::Args;
5use colored::*;
6use ferrum_interfaces::InferenceEngine;
7use ferrum_models::source::ModelFormat;
8use ferrum_server::{AxumServer, HttpServer, ServerConfig};
9use ferrum_types::Result;
10use std::path::PathBuf;
11use std::sync::Arc;
12use tokio::signal;
13
14#[derive(Args)]
15pub struct ServeCommand {
16 #[arg(value_name = "MODEL")]
18 pub model: Option<String>,
19
20 #[arg(
22 short = 'm',
23 long = "model",
24 value_name = "MODEL",
25 conflicts_with = "model"
26 )]
27 pub model_option: Option<String>,
28
29 #[arg(long)]
31 pub host: Option<String>,
32
33 #[arg(short, long)]
35 pub port: Option<u16>,
36}
37
38pub async fn execute(cmd: ServeCommand, config: CliConfig) -> Result<()> {
39 let ServeCommand {
40 model,
41 model_option,
42 host,
43 port,
44 } = cmd;
45
46 print_banner();
48
49 let model_name = model
51 .or(model_option)
52 .or(config.models.default_model.clone())
53 .unwrap_or_else(|| "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string());
54
55 let model_id = resolve_model_alias(&model_name);
56 println!("{} {}", "Model:".dimmed(), model_id.cyan());
57
58 let host = host.unwrap_or_else(|| config.server.host.clone());
59 let port = port.unwrap_or(config.server.port);
60
61 let cache_dir = get_hf_cache_dir(&config);
63 let source = match find_cached_model(&cache_dir, &model_id) {
64 Some(source) => {
65 println!("{} {}", "Path:".dimmed(), source.local_path.display());
66 source
67 }
68 None => {
69 eprintln!(
70 "{} Model '{}' not found. Run: ferrum pull {}",
71 "Error:".red().bold(),
72 model_id,
73 model_name
74 );
75 return Err(ferrum_types::FerrumError::model("Model not found"));
76 }
77 };
78
79 std::env::set_var(
81 "FERRUM_MODEL_PATH",
82 source.local_path.to_string_lossy().to_string(),
83 );
84
85 let device = select_device();
87 println!("{} {:?}", "Device:".dimmed(), device);
88
89 println!();
91 println!(
92 "{}",
93 "Initializing engine (continuous batching)...".dimmed()
94 );
95 let mut engine_config = ferrum_engine::simple_engine_config(model_id.clone(), device);
96 engine_config.scheduler.policy = ferrum_types::SchedulingPolicy::ContinuousBatch;
97 engine_config.kv_cache.cache_type = ferrum_types::KvCacheType::Paged;
98 let engine = ferrum_engine::create_mvp_engine(engine_config).await?;
99 let engine: Arc<dyn InferenceEngine + Send + Sync> = Arc::from(engine);
101
102 let server_config = ServerConfig {
104 host: host.clone(),
105 port,
106 ..Default::default()
107 };
108
109 let server = AxumServer::new(engine);
111
112 println!();
113 println!(
114 "{} {} {}",
115 "🚀".green(),
116 "Server running at".green().bold(),
117 format!("http://{}:{}", host, port).cyan().bold()
118 );
119 println!();
120 println!("Endpoints:");
121 println!(" POST /v1/chat/completions - OpenAI-compatible chat");
122 println!(" GET /v1/models - List models");
123 println!(" GET /health - Health check");
124 println!();
125 println!("{}", "Press Ctrl+C to stop.".dimmed());
126 println!();
127
128 let pid_file = std::env::temp_dir().join("ferrum.pid");
130 std::fs::write(&pid_file, std::process::id().to_string()).ok();
131
132 tokio::select! {
134 result = server.start(&server_config) => {
135 if let Err(e) = result {
136 eprintln!("{} Server error: {}", "Error:".red().bold(), e);
137 }
138 }
139 _ = signal::ctrl_c() => {
140 println!();
141 println!("{}", "Shutting down...".yellow());
142 }
143 }
144
145 std::fs::remove_file(&pid_file).ok();
147
148 Ok(())
149}
150
151fn print_banner() {
152 println!();
153 println!("{}", " ______ ".bright_red());
154 println!("{}", " | ____| ".bright_red());
155 println!("{}", " | |__ ___ _ __ _ __ _ _ _ __ ___ ".bright_red());
156 println!("{}", " | __/ _ \\ '__| '__| | | | '_ ` _ \\ ".bright_red());
157 println!("{}", " | | | __/ | | | | |_| | | | | | ".bright_red());
158 println!("{}", " |_| \\___|_| |_| \\__,_|_| |_| |_|".bright_red());
159 println!();
160 println!(" {}", "🦀 Rust LLM Inference Server".bright_cyan().bold());
161 println!(
162 " {}",
163 format!("Version {}", env!("CARGO_PKG_VERSION")).dimmed()
164 );
165 println!();
166}
167
168fn resolve_model_alias(name: &str) -> String {
169 match name.to_lowercase().as_str() {
170 "tinyllama" | "tiny" => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
171 "qwen2.5:0.5b" | "qwen:0.5b" => "Qwen/Qwen2.5-0.5B-Instruct".to_string(),
172 "qwen2.5:1.5b" | "qwen:1.5b" => "Qwen/Qwen2.5-1.5B-Instruct".to_string(),
173 "qwen2.5:3b" | "qwen:3b" => "Qwen/Qwen2.5-3B-Instruct".to_string(),
174 "qwen2.5:7b" | "qwen:7b" => "Qwen/Qwen2.5-7B-Instruct".to_string(),
175 "qwen3:0.6b" => "Qwen/Qwen3-0.6B".to_string(),
176 "qwen3:1.7b" => "Qwen/Qwen3-1.7B".to_string(),
177 "qwen3:4b" => "Qwen/Qwen3-4B".to_string(),
178 "llama3.2:1b" => "meta-llama/Llama-3.2-1B-Instruct".to_string(),
179 "llama3.2:3b" => "meta-llama/Llama-3.2-3B-Instruct".to_string(),
180 _ => name.to_string(),
181 }
182}
183
184fn get_hf_cache_dir(config: &CliConfig) -> PathBuf {
185 if let Ok(hf_home) = std::env::var("HF_HOME") {
186 return PathBuf::from(hf_home);
187 }
188 let configured = shellexpand::tilde(&config.models.download.hf_cache_dir).to_string();
189 PathBuf::from(configured)
190}
191
192fn find_cached_model(
193 cache_dir: &PathBuf,
194 model_id: &str,
195) -> Option<ferrum_models::source::ResolvedModelSource> {
196 let repo_dir = cache_dir
197 .join("hub")
198 .join(format!("models--{}", model_id.replace('/', "--")));
199 let snapshots_dir = repo_dir.join("snapshots");
200
201 let ref_main = repo_dir.join("refs").join("main");
203 if let Ok(rev) = std::fs::read_to_string(&ref_main) {
204 let rev = rev.trim();
205 if !rev.is_empty() {
206 let snapshot = snapshots_dir.join(rev);
207 if snapshot.exists() {
208 let format = detect_format(&snapshot);
209 if format != ModelFormat::Unknown {
210 return Some(ferrum_models::source::ResolvedModelSource {
211 original: model_id.to_string(),
212 local_path: snapshot,
213 format,
214 from_cache: true,
215 });
216 }
217 }
218 }
219 }
220
221 if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
223 for entry in entries.flatten() {
224 let path = entry.path();
225 if path.is_dir() {
226 let format = detect_format(&path);
227 if format != ModelFormat::Unknown {
228 return Some(ferrum_models::source::ResolvedModelSource {
229 original: model_id.to_string(),
230 local_path: path,
231 format,
232 from_cache: true,
233 });
234 }
235 }
236 }
237 }
238
239 None
240}
241
242fn detect_format(path: &PathBuf) -> ModelFormat {
243 if path.join("model.safetensors").exists() || path.join("model.safetensors.index.json").exists()
244 {
245 ModelFormat::SafeTensors
246 } else if path.join("pytorch_model.bin").exists() {
247 ModelFormat::PyTorchBin
248 } else {
249 ModelFormat::Unknown
250 }
251}
252
253fn select_device() -> ferrum_types::Device {
254 #[cfg(all(target_os = "macos", feature = "metal"))]
255 {
256 return ferrum_types::Device::Metal;
257 }
258
259 #[cfg(feature = "cuda")]
260 {
261 return ferrum_types::Device::CUDA(0);
262 }
263
264 #[allow(unreachable_code)]
265 ferrum_types::Device::CPU
266}