1use crate::config::CliConfig;
4use chrono::Utc;
5use clap::Args;
6use colored::*;
7use ferrum_models::source::{ModelFormat, ResolvedModelSource};
8use ferrum_models::HfDownloader;
9use ferrum_types::{InferenceRequest, Priority, RequestId, Result, SamplingParams};
10use futures::StreamExt;
11use std::collections::HashMap;
12use std::io::{self, BufRead, IsTerminal, Write};
13use std::path::PathBuf;
14use std::sync::atomic::{AtomicBool, Ordering};
15use std::sync::Arc;
16use uuid::Uuid;
17
18#[derive(Args)]
19pub struct RunCommand {
20 #[arg(default_value = "tinyllama")]
22 pub model: String,
23
24 #[arg(long)]
26 pub system: Option<String>,
27
28 #[arg(long, default_value = "512")]
30 pub max_tokens: u32,
31
32 #[arg(long, default_value = "0.7")]
34 pub temperature: f32,
35
36 #[arg(long, default_value = "auto")]
38 pub backend: String,
39}
40
41pub async fn execute(cmd: RunCommand, config: CliConfig) -> Result<()> {
42 let model_id = resolve_model_alias(&cmd.model);
44 eprintln!("{}", format!("Loading {}...", model_id).dimmed());
45
46 let cache_dir = get_hf_cache_dir(&config);
48 let source = match find_cached_model(&cache_dir, &model_id) {
49 Some(source) => source,
50 None => {
51 eprintln!(
53 "{} Model '{}' not found locally, downloading...",
54 "📥".cyan(),
55 model_id
56 );
57
58 let token = std::env::var("HF_TOKEN")
60 .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
61 .ok();
62
63 let downloader = HfDownloader::new(cache_dir.clone(), token)?;
65 let snapshot_path = downloader.download(&model_id, None).await?;
66
67 let format = detect_format(&snapshot_path);
69 if format == ModelFormat::Unknown {
70 return Err(ferrum_types::FerrumError::model(
71 "Downloaded model has unknown format",
72 ));
73 }
74
75 ResolvedModelSource {
76 original: model_id.clone(),
77 local_path: snapshot_path,
78 format,
79 from_cache: false,
80 }
81 }
82 };
83
84 unsafe {
87 std::env::set_var(
88 "FERRUM_MODEL_PATH",
89 source.local_path.to_string_lossy().to_string(),
90 );
91 }
92
93 let device = select_device(&cmd.backend);
95 eprintln!("{}", format!("Using {:?} backend", device).dimmed());
96
97 let engine_config = ferrum_engine::simple_engine_config(model_id.clone(), device);
99 let engine = ferrum_engine::create_mvp_engine(engine_config).await?;
100
101 eprintln!();
103 eprintln!("{}", "Ready. Type your message and press Enter.".green());
104 eprintln!("{}", "Use /bye or Ctrl+D to exit.".dimmed());
105 eprintln!();
106
107 let mut history: Vec<(String, String)> = Vec::new(); let generating = Arc::new(AtomicBool::new(false));
110
111 let stdin_is_tty = io::stdin().is_terminal();
114 let mut stdin = io::stdin().lock();
115
116 loop {
117 if stdin_is_tty {
118 print!("{} ", ">>>".bright_green().bold());
120 io::stdout().flush().unwrap();
121 }
122
123 let mut input = String::new();
125 match stdin.read_line(&mut input) {
126 Ok(0) => break, Ok(_) => {
128 let input = input.trim();
129 if input.is_empty() {
130 continue;
131 }
132 if input == "/bye" || input == "exit" || input == "quit" {
133 break;
134 }
135
136 let prompt = build_chat_prompt(&history, input, cmd.system.as_deref(), &model_id);
138
139 let model_lower = model_id.to_lowercase();
141 let (default_top_p, default_top_k) = if model_lower.contains("qwen3") {
142 (0.8, Some(20))
144 } else {
145 (0.9, None)
146 };
147
148 let request = InferenceRequest {
150 id: RequestId(Uuid::new_v4()),
151 model_id: ferrum_types::ModelId(model_id.clone()),
152 prompt,
153 sampling_params: SamplingParams {
154 max_tokens: cmd.max_tokens as usize,
155 temperature: cmd.temperature,
156 top_p: default_top_p,
157 top_k: default_top_k,
158 repetition_penalty: 1.1,
159 stop_sequences: vec![
160 "<|im_end|>".to_string(),
161 "</s>".to_string(),
162 "<|endoftext|>".to_string(),
163 ],
164 ..Default::default()
165 },
166 stream: true,
167 priority: Priority::Normal,
168 client_id: None,
169 session_id: None,
170 created_at: Utc::now(),
171 metadata: HashMap::new(),
172 };
173
174 generating.store(true, Ordering::SeqCst);
176 let mut stream = engine.infer_stream(request).await?;
177 let mut response = String::new();
178 let start = std::time::Instant::now();
179 let mut token_count = 0;
180
181 while let Some(result) = stream.next().await {
182 match result {
183 Ok(chunk) => {
184 if !chunk.text.is_empty() {
185 print!("{}", chunk.text);
186 io::stdout().flush().unwrap();
187 response.push_str(&chunk.text);
188 }
189 if chunk.token.is_some() {
190 token_count += 1;
191 }
192 if chunk.finish_reason.is_some() {
193 break;
194 }
195 }
196 Err(e) => {
197 eprintln!("\n{} {}", "Error:".red(), e);
198 break;
199 }
200 }
201 }
202
203 generating.store(false, Ordering::SeqCst);
204
205 let elapsed = start.elapsed();
207 let tps = if elapsed.as_secs_f64() > 0.0 {
208 token_count as f64 / elapsed.as_secs_f64()
209 } else {
210 0.0
211 };
212 println!();
213 eprintln!(
214 "{}",
215 format!(
216 "[{} tokens, {:.1} tok/s, {:.1}s]",
217 token_count,
218 tps,
219 elapsed.as_secs_f64()
220 )
221 .dimmed()
222 );
223 eprintln!();
224
225 if !stdin_is_tty {
227 io::stdout().flush().ok();
228 io::stderr().flush().ok();
229 }
230
231 history.push(("user".to_string(), input.to_string()));
233 let clean_response = response.trim().to_string();
234 if !clean_response.is_empty() {
235 history.push(("assistant".to_string(), clean_response));
236 }
237
238 while history.len() > 10 {
240 history.remove(0);
241 }
242 }
243 Err(e) => {
244 eprintln!("{} {}", "Error reading input:".red(), e);
245 break;
246 }
247 }
248 }
249
250 eprintln!("{}", "Goodbye!".bright_yellow());
251 Ok(())
252}
253
254pub fn resolve_model_alias(name: &str) -> String {
255 match name.to_lowercase().as_str() {
256 "tinyllama" | "tiny" => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
257 "qwen2.5:0.5b" | "qwen:0.5b" => "Qwen/Qwen2.5-0.5B-Instruct".to_string(),
258 "qwen2.5:1.5b" | "qwen:1.5b" => "Qwen/Qwen2.5-1.5B-Instruct".to_string(),
259 "qwen2.5:3b" | "qwen:3b" => "Qwen/Qwen2.5-3B-Instruct".to_string(),
260 "qwen2.5:7b" | "qwen:7b" => "Qwen/Qwen2.5-7B-Instruct".to_string(),
261 "qwen3:0.6b" => "Qwen/Qwen3-0.6B".to_string(),
262 "qwen3:1.7b" => "Qwen/Qwen3-1.7B".to_string(),
263 "qwen3:4b" => "Qwen/Qwen3-4B".to_string(),
264 "llama3.2:1b" => "meta-llama/Llama-3.2-1B-Instruct".to_string(),
265 "llama3.2:3b" => "meta-llama/Llama-3.2-3B-Instruct".to_string(),
266 _ => name.to_string(),
267 }
268}
269
270pub fn get_hf_cache_dir(config: &CliConfig) -> PathBuf {
271 if let Ok(hf_home) = std::env::var("HF_HOME") {
273 return PathBuf::from(hf_home);
274 }
275
276 let configured = shellexpand::tilde(&config.models.download.hf_cache_dir).to_string();
278 PathBuf::from(configured)
279}
280
281pub fn find_cached_model(cache_dir: &PathBuf, model_id: &str) -> Option<ResolvedModelSource> {
282 let repo_dir = cache_dir
283 .join("hub")
284 .join(format!("models--{}", model_id.replace('/', "--")));
285 let snapshots_dir = repo_dir.join("snapshots");
286
287 let ref_main = repo_dir.join("refs").join("main");
289 if let Ok(rev) = std::fs::read_to_string(&ref_main) {
290 let rev = rev.trim();
291 if !rev.is_empty() {
292 let snapshot = snapshots_dir.join(rev);
293 if snapshot.exists() {
294 let format = detect_format(&snapshot);
295 if format != ModelFormat::Unknown {
296 return Some(ResolvedModelSource {
297 original: model_id.to_string(),
298 local_path: snapshot,
299 format,
300 from_cache: true,
301 });
302 }
303 }
304 }
305 }
306
307 if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
309 for entry in entries.flatten() {
310 let path = entry.path();
311 if path.is_dir() {
312 let format = detect_format(&path);
313 if format != ModelFormat::Unknown {
314 return Some(ResolvedModelSource {
315 original: model_id.to_string(),
316 local_path: path,
317 format,
318 from_cache: true,
319 });
320 }
321 }
322 }
323 }
324
325 None
326}
327
328pub fn detect_format(path: &PathBuf) -> ModelFormat {
329 if path.join("model.safetensors").exists() || path.join("model.safetensors.index.json").exists()
330 {
331 ModelFormat::SafeTensors
332 } else if path.join("pytorch_model.bin").exists() {
333 ModelFormat::PyTorchBin
334 } else {
335 ModelFormat::Unknown
336 }
337}
338
339pub fn select_device(backend: &str) -> ferrum_types::Device {
340 match backend.to_lowercase().as_str() {
341 "cpu" => ferrum_types::Device::CPU,
342 "metal" => {
343 #[cfg(all(target_os = "macos", feature = "metal"))]
344 {
345 return ferrum_types::Device::Metal;
346 }
347 #[allow(unreachable_code)]
348 {
349 eprintln!("Metal not available, falling back to CPU");
350 ferrum_types::Device::CPU
351 }
352 }
353 "cuda" => {
354 #[cfg(feature = "cuda")]
355 {
356 return ferrum_types::Device::CUDA(0);
357 }
358 #[allow(unreachable_code)]
359 {
360 eprintln!("CUDA not available, falling back to CPU");
361 ferrum_types::Device::CPU
362 }
363 }
364 "auto" | _ => {
365 #[cfg(all(target_os = "macos", feature = "metal"))]
366 {
367 return ferrum_types::Device::Metal;
368 }
369 #[cfg(feature = "cuda")]
370 {
371 return ferrum_types::Device::CUDA(0);
372 }
373 #[allow(unreachable_code)]
374 ferrum_types::Device::CPU
375 }
376 }
377}
378
379fn build_chat_prompt(
380 history: &[(String, String)],
381 user_input: &str,
382 system: Option<&str>,
383 model_id: &str,
384) -> String {
385 let model_lower = model_id.to_lowercase();
387
388 if model_lower.contains("qwen") {
389 let mut prompt = String::new();
391 if let Some(sys) = system {
392 prompt.push_str(&format!("<|im_start|>system\n{}<|im_end|>\n", sys));
393 }
394 for (role, content) in history {
395 prompt.push_str(&format!("<|im_start|>{}\n{}<|im_end|>\n", role, content));
396 }
397 prompt.push_str(&format!(
398 "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
399 user_input
400 ));
401 if model_lower.contains("qwen3") {
403 prompt.push_str("<think>\n\n</think>\n\n");
404 }
405 prompt
406 } else if model_lower.contains("llama") && model_lower.contains("3") {
407 let mut prompt = String::new();
409 prompt.push_str("<|begin_of_text|>");
410 if let Some(sys) = system {
411 prompt.push_str(&format!(
412 "<|start_header_id|>system<|end_header_id|>\n\n{}<|eot_id|>",
413 sys
414 ));
415 }
416 for (role, content) in history {
417 prompt.push_str(&format!(
418 "<|start_header_id|>{}<|end_header_id|>\n\n{}<|eot_id|>",
419 role, content
420 ));
421 }
422 prompt.push_str(&format!(
423 "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
424 user_input
425 ));
426 prompt
427 } else {
428 let sys = system.unwrap_or("You are a helpful assistant.");
430 let mut prompt = format!("<|system|>\n{}</s>\n", sys);
431 for (role, content) in history {
432 let tag = if role == "user" { "user" } else { "assistant" };
433 prompt.push_str(&format!("<|{}|>\n{}</s>\n", tag, content));
434 }
435 prompt.push_str(&format!("<|user|>\n{}</s>\n<|assistant|>\n", user_input));
436 prompt
437 }
438}