ferrum_cli/commands/
serve.rs1use crate::config::CliConfig;
4use clap::Args;
5use colored::*;
6use ferrum_interfaces::InferenceEngine;
7use ferrum_models::source::ModelFormat;
8use ferrum_server::{AxumServer, HttpServer, ServerConfig};
9use ferrum_types::Result;
10use std::path::PathBuf;
11use std::sync::Arc;
12use tokio::signal;
13
14#[derive(Args)]
15pub struct ServeCommand {
16 #[arg(value_name = "MODEL")]
18 pub model: Option<String>,
19
20 #[arg(
22 short = 'm',
23 long = "model",
24 value_name = "MODEL",
25 conflicts_with = "model"
26 )]
27 pub model_option: Option<String>,
28
29 #[arg(long)]
31 pub host: Option<String>,
32
33 #[arg(short, long)]
35 pub port: Option<u16>,
36}
37
38pub async fn execute(cmd: ServeCommand, config: CliConfig) -> Result<()> {
39 let ServeCommand {
40 model,
41 model_option,
42 host,
43 port,
44 } = cmd;
45
46 print_banner();
48
49 let model_name = model
51 .or(model_option)
52 .or(config.models.default_model.clone())
53 .unwrap_or_else(|| "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string());
54
55 let model_id = resolve_model_alias(&model_name);
56 println!("{} {}", "Model:".dimmed(), model_id.cyan());
57
58 let host = host.unwrap_or_else(|| config.server.host.clone());
59 let port = port.unwrap_or(config.server.port);
60
61 let cache_dir = get_hf_cache_dir(&config);
63 let source = match find_cached_model(&cache_dir, &model_id) {
64 Some(source) => {
65 println!("{} {}", "Path:".dimmed(), source.local_path.display());
66 source
67 }
68 None => {
69 eprintln!(
70 "{} Model '{}' not found. Run: ferrum pull {}",
71 "Error:".red().bold(),
72 model_id,
73 model_name
74 );
75 return Err(ferrum_types::FerrumError::model("Model not found"));
76 }
77 };
78
79 std::env::set_var(
81 "FERRUM_MODEL_PATH",
82 source.local_path.to_string_lossy().to_string(),
83 );
84
85 let device = select_device();
87 println!("{} {:?}", "Device:".dimmed(), device);
88
89 println!();
91 let mut config_manager = ferrum_models::ConfigManager::new();
92 let model_def = config_manager.load_from_path(&source.local_path).await?;
93
94 let engine: Arc<dyn InferenceEngine + Send + Sync> =
95 if model_def.architecture == ferrum_models::Architecture::Clip {
96 println!("{}", "Initializing CLIP embedding engine...".dimmed());
97 let candle_device = candle_core::Device::Cpu;
98 let executor = ferrum_models::ClipModelExecutor::from_path(
99 &source.local_path.to_string_lossy(),
100 candle_device,
101 candle_core::DType::F32,
102 )?;
103 let tokenizer = crate::commands::embed::load_tokenizer(&source.local_path)?;
105 let engine_config = ferrum_engine::simple_engine_config(model_id.clone(), device);
106 Arc::new(
107 ferrum_engine::embedding_engine::EmbeddingEngine::new(executor, engine_config)
108 .with_tokenizer(tokenizer),
109 )
110 } else {
111 println!(
112 "{}",
113 "Initializing engine (continuous batching)...".dimmed()
114 );
115 let mut engine_config = ferrum_engine::simple_engine_config(model_id.clone(), device);
116 engine_config.scheduler.policy = ferrum_types::SchedulingPolicy::ContinuousBatch;
117 engine_config.kv_cache.cache_type = ferrum_types::KvCacheType::Paged;
118 let engine = ferrum_engine::create_mvp_engine(engine_config).await?;
119 Arc::from(engine)
120 };
121
122 let server_config = ServerConfig {
124 host: host.clone(),
125 port,
126 ..Default::default()
127 };
128
129 let server = AxumServer::new(engine);
131
132 println!();
133 println!(
134 "{} {} {}",
135 "🚀".green(),
136 "Server running at".green().bold(),
137 format!("http://{}:{}", host, port).cyan().bold()
138 );
139 println!();
140 println!("Endpoints:");
141 println!(" POST /v1/chat/completions - OpenAI-compatible chat");
142 println!(" GET /v1/models - List models");
143 println!(" GET /health - Health check");
144 println!();
145 println!("{}", "Press Ctrl+C to stop.".dimmed());
146 println!();
147
148 let pid_file = std::env::temp_dir().join("ferrum.pid");
150 std::fs::write(&pid_file, std::process::id().to_string()).ok();
151
152 tokio::select! {
154 result = server.start(&server_config) => {
155 if let Err(e) = result {
156 eprintln!("{} Server error: {}", "Error:".red().bold(), e);
157 }
158 }
159 _ = signal::ctrl_c() => {
160 println!();
161 println!("{}", "Shutting down...".yellow());
162 }
163 }
164
165 std::fs::remove_file(&pid_file).ok();
167
168 Ok(())
169}
170
171fn print_banner() {
172 println!();
173 println!("{}", " ______ ".bright_red());
174 println!("{}", " | ____| ".bright_red());
175 println!("{}", " | |__ ___ _ __ _ __ _ _ _ __ ___ ".bright_red());
176 println!("{}", " | __/ _ \\ '__| '__| | | | '_ ` _ \\ ".bright_red());
177 println!("{}", " | | | __/ | | | | |_| | | | | | ".bright_red());
178 println!("{}", " |_| \\___|_| |_| \\__,_|_| |_| |_|".bright_red());
179 println!();
180 println!(" {}", "🦀 Rust LLM Inference Server".bright_cyan().bold());
181 println!(
182 " {}",
183 format!("Version {}", env!("CARGO_PKG_VERSION")).dimmed()
184 );
185 println!();
186}
187
188fn resolve_model_alias(name: &str) -> String {
189 match name.to_lowercase().as_str() {
190 "tinyllama" | "tiny" => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
191 "qwen2.5:0.5b" | "qwen:0.5b" => "Qwen/Qwen2.5-0.5B-Instruct".to_string(),
192 "qwen2.5:1.5b" | "qwen:1.5b" => "Qwen/Qwen2.5-1.5B-Instruct".to_string(),
193 "qwen2.5:3b" | "qwen:3b" => "Qwen/Qwen2.5-3B-Instruct".to_string(),
194 "qwen2.5:7b" | "qwen:7b" => "Qwen/Qwen2.5-7B-Instruct".to_string(),
195 "qwen3:0.6b" => "Qwen/Qwen3-0.6B".to_string(),
196 "qwen3:1.7b" => "Qwen/Qwen3-1.7B".to_string(),
197 "qwen3:4b" => "Qwen/Qwen3-4B".to_string(),
198 "llama3.2:1b" => "meta-llama/Llama-3.2-1B-Instruct".to_string(),
199 "llama3.2:3b" => "meta-llama/Llama-3.2-3B-Instruct".to_string(),
200 _ => name.to_string(),
201 }
202}
203
204fn get_hf_cache_dir(config: &CliConfig) -> PathBuf {
205 if let Ok(hf_home) = std::env::var("HF_HOME") {
206 return PathBuf::from(hf_home);
207 }
208 let configured = shellexpand::tilde(&config.models.download.hf_cache_dir).to_string();
209 PathBuf::from(configured)
210}
211
212fn find_cached_model(
213 cache_dir: &PathBuf,
214 model_id: &str,
215) -> Option<ferrum_models::source::ResolvedModelSource> {
216 let repo_dir = cache_dir
217 .join("hub")
218 .join(format!("models--{}", model_id.replace('/', "--")));
219 let snapshots_dir = repo_dir.join("snapshots");
220
221 let ref_main = repo_dir.join("refs").join("main");
223 if let Ok(rev) = std::fs::read_to_string(&ref_main) {
224 let rev = rev.trim();
225 if !rev.is_empty() {
226 let snapshot = snapshots_dir.join(rev);
227 if snapshot.exists() {
228 let format = detect_format(&snapshot);
229 if format != ModelFormat::Unknown {
230 return Some(ferrum_models::source::ResolvedModelSource {
231 original: model_id.to_string(),
232 local_path: snapshot,
233 format,
234 from_cache: true,
235 });
236 }
237 }
238 }
239 }
240
241 if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
243 for entry in entries.flatten() {
244 let path = entry.path();
245 if path.is_dir() {
246 let format = detect_format(&path);
247 if format != ModelFormat::Unknown {
248 return Some(ferrum_models::source::ResolvedModelSource {
249 original: model_id.to_string(),
250 local_path: path,
251 format,
252 from_cache: true,
253 });
254 }
255 }
256 }
257 }
258
259 None
260}
261
262fn detect_format(path: &PathBuf) -> ModelFormat {
263 if path.join("model.safetensors").exists() || path.join("model.safetensors.index.json").exists()
264 {
265 ModelFormat::SafeTensors
266 } else if path.join("pytorch_model.bin").exists() {
267 ModelFormat::PyTorchBin
268 } else {
269 ModelFormat::Unknown
270 }
271}
272
273fn select_device() -> ferrum_types::Device {
274 #[cfg(all(target_os = "macos", feature = "metal"))]
275 {
276 return ferrum_types::Device::Metal;
277 }
278
279 #[cfg(feature = "cuda")]
280 {
281 return ferrum_types::Device::CUDA(0);
282 }
283
284 #[allow(unreachable_code)]
285 ferrum_types::Device::CPU
286}