Skip to main content

ferrum_cli/commands/
run.rs

1//! Run command - Interactive chat with a model (ollama-style)
2
3use crate::config::CliConfig;
4use chrono::Utc;
5use clap::Args;
6use colored::*;
7use ferrum_models::source::{ModelFormat, ResolvedModelSource};
8use ferrum_models::HfDownloader;
9use ferrum_types::{InferenceRequest, Priority, RequestId, Result, SamplingParams};
10use futures::StreamExt;
11use std::collections::HashMap;
12use std::io::{self, BufRead, IsTerminal, Write};
13use std::path::PathBuf;
14use std::sync::atomic::{AtomicBool, Ordering};
15use std::sync::Arc;
16use uuid::Uuid;
17
18#[derive(Args)]
19pub struct RunCommand {
20    /// Model name (e.g., tinyllama, qwen2.5:7b, or full path)
21    #[arg(default_value = "tinyllama")]
22    pub model: String,
23
24    /// System prompt
25    #[arg(long)]
26    pub system: Option<String>,
27
28    /// Maximum tokens to generate
29    #[arg(long, default_value = "512")]
30    pub max_tokens: u32,
31
32    /// Temperature (0.0-2.0)
33    #[arg(long, default_value = "0.7")]
34    pub temperature: f32,
35
36    /// Backend: auto, cpu, metal (default: auto)
37    #[arg(long, default_value = "auto")]
38    pub backend: String,
39}
40
41pub async fn execute(cmd: RunCommand, config: CliConfig) -> Result<()> {
42    // Resolve model
43    let model_id = resolve_model_alias(&cmd.model);
44    eprintln!("{}", format!("Loading {}...", model_id).dimmed());
45
46    // Find cached model or auto-download
47    let cache_dir = get_hf_cache_dir(&config);
48    let source = match find_cached_model(&cache_dir, &model_id) {
49        Some(source) => source,
50        None => {
51            // Model not found, try to download automatically
52            eprintln!(
53                "{} Model '{}' not found locally, downloading...",
54                "📥".cyan(),
55                model_id
56            );
57
58            // Get HF token from environment
59            let token = std::env::var("HF_TOKEN")
60                .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
61                .ok();
62
63            // Create downloader and download
64            let downloader = HfDownloader::new(cache_dir.clone(), token)?;
65            let snapshot_path = downloader.download(&model_id, None).await?;
66
67            // Now find the downloaded model
68            let format = detect_format(&snapshot_path);
69            if format == ModelFormat::Unknown {
70                return Err(ferrum_types::FerrumError::model(
71                    "Downloaded model has unknown format",
72                ));
73            }
74
75            ResolvedModelSource {
76                original: model_id.clone(),
77                local_path: snapshot_path,
78                format,
79                from_cache: false,
80            }
81        }
82    };
83
84    // Set model path for engine
85    // NOTE: std::env::set_var is unsafe on Rust 2024; keep it minimal and explicit.
86    unsafe {
87        std::env::set_var(
88            "FERRUM_MODEL_PATH",
89            source.local_path.to_string_lossy().to_string(),
90        );
91    }
92
93    // Select device
94    let device = select_device(&cmd.backend);
95    eprintln!("{}", format!("Using {:?} backend", device).dimmed());
96
97    // Create engine
98    let engine_config = ferrum_engine::simple_engine_config(model_id.clone(), device);
99    let engine = ferrum_engine::create_mvp_engine(engine_config).await?;
100
101    // Print ready message
102    eprintln!();
103    eprintln!("{}", "Ready. Type your message and press Enter.".green());
104    eprintln!("{}", "Use /bye or Ctrl+D to exit.".dimmed());
105    eprintln!();
106
107    // Interactive loop
108    let mut history: Vec<(String, String)> = Vec::new(); // (role, content)
109    let generating = Arc::new(AtomicBool::new(false));
110
111    // If stdin is not a TTY (piped input), don't print prompts and just consume lines.
112    // This enables: `printf "hi\n/bye\n" | ferrum run ...` for automation/profiling.
113    let stdin_is_tty = io::stdin().is_terminal();
114    let mut stdin = io::stdin().lock();
115
116    loop {
117        if stdin_is_tty {
118            // Show prompt
119            print!("{} ", ">>>".bright_green().bold());
120            io::stdout().flush().unwrap();
121        }
122
123        // Read input
124        let mut input = String::new();
125        match stdin.read_line(&mut input) {
126            Ok(0) => break, // EOF
127            Ok(_) => {
128                let input = input.trim();
129                if input.is_empty() {
130                    continue;
131                }
132                if input == "/bye" || input == "exit" || input == "quit" {
133                    break;
134                }
135
136                // Build prompt with history
137                let prompt = build_chat_prompt(&history, input, cmd.system.as_deref(), &model_id);
138
139                // Model-specific sampling defaults
140                let model_lower = model_id.to_lowercase();
141                let (default_top_p, default_top_k) = if model_lower.contains("qwen3") {
142                    // Qwen3 non-thinking mode: top_p=0.8, top_k=20
143                    (0.8, Some(20))
144                } else {
145                    (0.9, None)
146                };
147
148                // Create request
149                let request = InferenceRequest {
150                    id: RequestId(Uuid::new_v4()),
151                    model_id: ferrum_types::ModelId(model_id.clone()),
152                    prompt,
153                    sampling_params: SamplingParams {
154                        max_tokens: cmd.max_tokens as usize,
155                        temperature: cmd.temperature,
156                        top_p: default_top_p,
157                        top_k: default_top_k,
158                        repetition_penalty: 1.1,
159                        stop_sequences: vec![
160                            "<|im_end|>".to_string(),
161                            "</s>".to_string(),
162                            "<|endoftext|>".to_string(),
163                        ],
164                        ..Default::default()
165                    },
166                    stream: true,
167                    priority: Priority::Normal,
168                    client_id: None,
169                    session_id: None,
170                    created_at: Utc::now(),
171                    metadata: HashMap::new(),
172                };
173
174                // Start generation
175                generating.store(true, Ordering::SeqCst);
176                let mut stream = engine.infer_stream(request).await?;
177                let mut response = String::new();
178                let start = std::time::Instant::now();
179                let mut token_count = 0;
180
181                while let Some(result) = stream.next().await {
182                    match result {
183                        Ok(chunk) => {
184                            if !chunk.text.is_empty() {
185                                print!("{}", chunk.text);
186                                io::stdout().flush().unwrap();
187                                response.push_str(&chunk.text);
188                            }
189                            if chunk.token.is_some() {
190                                token_count += 1;
191                            }
192                            if chunk.finish_reason.is_some() {
193                                break;
194                            }
195                        }
196                        Err(e) => {
197                            eprintln!("\n{} {}", "Error:".red(), e);
198                            break;
199                        }
200                    }
201                }
202
203                generating.store(false, Ordering::SeqCst);
204
205                // Print stats
206                let elapsed = start.elapsed();
207                let tps = if elapsed.as_secs_f64() > 0.0 {
208                    token_count as f64 / elapsed.as_secs_f64()
209                } else {
210                    0.0
211                };
212                println!();
213                eprintln!(
214                    "{}",
215                    format!(
216                        "[{} tokens, {:.1} tok/s, {:.1}s]",
217                        token_count,
218                        tps,
219                        elapsed.as_secs_f64()
220                    )
221                    .dimmed()
222                );
223                eprintln!();
224
225                // In non-interactive mode, don't wait for terminal formatting/spacing.
226                if !stdin_is_tty {
227                    io::stdout().flush().ok();
228                    io::stderr().flush().ok();
229                }
230
231                // Add to history
232                history.push(("user".to_string(), input.to_string()));
233                let clean_response = response.trim().to_string();
234                if !clean_response.is_empty() {
235                    history.push(("assistant".to_string(), clean_response));
236                }
237
238                // Limit history
239                while history.len() > 10 {
240                    history.remove(0);
241                }
242            }
243            Err(e) => {
244                eprintln!("{} {}", "Error reading input:".red(), e);
245                break;
246            }
247        }
248    }
249
250    eprintln!("{}", "Goodbye!".bright_yellow());
251    Ok(())
252}
253
254pub fn resolve_model_alias(name: &str) -> String {
255    match name.to_lowercase().as_str() {
256        "tinyllama" | "tiny" => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
257        "qwen2.5:0.5b" | "qwen:0.5b" => "Qwen/Qwen2.5-0.5B-Instruct".to_string(),
258        "qwen2.5:1.5b" | "qwen:1.5b" => "Qwen/Qwen2.5-1.5B-Instruct".to_string(),
259        "qwen2.5:3b" | "qwen:3b" => "Qwen/Qwen2.5-3B-Instruct".to_string(),
260        "qwen2.5:7b" | "qwen:7b" => "Qwen/Qwen2.5-7B-Instruct".to_string(),
261        "qwen3:0.6b" => "Qwen/Qwen3-0.6B".to_string(),
262        "qwen3:1.7b" => "Qwen/Qwen3-1.7B".to_string(),
263        "qwen3:4b" => "Qwen/Qwen3-4B".to_string(),
264        "llama3.2:1b" => "meta-llama/Llama-3.2-1B-Instruct".to_string(),
265        "llama3.2:3b" => "meta-llama/Llama-3.2-3B-Instruct".to_string(),
266        _ => name.to_string(),
267    }
268}
269
270pub fn get_hf_cache_dir(config: &CliConfig) -> PathBuf {
271    // Check environment variable first
272    if let Ok(hf_home) = std::env::var("HF_HOME") {
273        return PathBuf::from(hf_home);
274    }
275
276    // Use config value
277    let configured = shellexpand::tilde(&config.models.download.hf_cache_dir).to_string();
278    PathBuf::from(configured)
279}
280
281pub fn find_cached_model(cache_dir: &PathBuf, model_id: &str) -> Option<ResolvedModelSource> {
282    let repo_dir = cache_dir
283        .join("hub")
284        .join(format!("models--{}", model_id.replace('/', "--")));
285    let snapshots_dir = repo_dir.join("snapshots");
286
287    // Try refs/main first
288    let ref_main = repo_dir.join("refs").join("main");
289    if let Ok(rev) = std::fs::read_to_string(&ref_main) {
290        let rev = rev.trim();
291        if !rev.is_empty() {
292            let snapshot = snapshots_dir.join(rev);
293            if snapshot.exists() {
294                let format = detect_format(&snapshot);
295                if format != ModelFormat::Unknown {
296                    return Some(ResolvedModelSource {
297                        original: model_id.to_string(),
298                        local_path: snapshot,
299                        format,
300                        from_cache: true,
301                    });
302                }
303            }
304        }
305    }
306
307    // Fallback: first snapshot directory
308    if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
309        for entry in entries.flatten() {
310            let path = entry.path();
311            if path.is_dir() {
312                let format = detect_format(&path);
313                if format != ModelFormat::Unknown {
314                    return Some(ResolvedModelSource {
315                        original: model_id.to_string(),
316                        local_path: path,
317                        format,
318                        from_cache: true,
319                    });
320                }
321            }
322        }
323    }
324
325    None
326}
327
328pub fn detect_format(path: &PathBuf) -> ModelFormat {
329    if path.join("model.safetensors").exists() || path.join("model.safetensors.index.json").exists()
330    {
331        ModelFormat::SafeTensors
332    } else if path.join("pytorch_model.bin").exists() {
333        ModelFormat::PyTorchBin
334    } else {
335        ModelFormat::Unknown
336    }
337}
338
339pub fn select_device(backend: &str) -> ferrum_types::Device {
340    match backend.to_lowercase().as_str() {
341        "cpu" => ferrum_types::Device::CPU,
342        "metal" => {
343            #[cfg(all(target_os = "macos", feature = "metal"))]
344            {
345                return ferrum_types::Device::Metal;
346            }
347            #[allow(unreachable_code)]
348            {
349                eprintln!("Metal not available, falling back to CPU");
350                ferrum_types::Device::CPU
351            }
352        }
353        "cuda" => {
354            #[cfg(feature = "cuda")]
355            {
356                return ferrum_types::Device::CUDA(0);
357            }
358            #[allow(unreachable_code)]
359            {
360                eprintln!("CUDA not available, falling back to CPU");
361                ferrum_types::Device::CPU
362            }
363        }
364        "auto" | _ => {
365            #[cfg(all(target_os = "macos", feature = "metal"))]
366            {
367                return ferrum_types::Device::Metal;
368            }
369            #[cfg(feature = "cuda")]
370            {
371                return ferrum_types::Device::CUDA(0);
372            }
373            #[allow(unreachable_code)]
374            ferrum_types::Device::CPU
375        }
376    }
377}
378
379fn build_chat_prompt(
380    history: &[(String, String)],
381    user_input: &str,
382    system: Option<&str>,
383    model_id: &str,
384) -> String {
385    // Detect model type and use appropriate template
386    let model_lower = model_id.to_lowercase();
387
388    if model_lower.contains("qwen") {
389        // Qwen ChatML format (Qwen2, Qwen2.5, Qwen3)
390        let mut prompt = String::new();
391        if let Some(sys) = system {
392            prompt.push_str(&format!("<|im_start|>system\n{}<|im_end|>\n", sys));
393        }
394        for (role, content) in history {
395            prompt.push_str(&format!("<|im_start|>{}\n{}<|im_end|>\n", role, content));
396        }
397        prompt.push_str(&format!(
398            "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
399            user_input
400        ));
401        // Qwen3: disable thinking mode by inserting empty think block
402        if model_lower.contains("qwen3") {
403            prompt.push_str("<think>\n\n</think>\n\n");
404        }
405        prompt
406    } else if model_lower.contains("llama") && model_lower.contains("3") {
407        // Llama 3 format
408        let mut prompt = String::new();
409        prompt.push_str("<|begin_of_text|>");
410        if let Some(sys) = system {
411            prompt.push_str(&format!(
412                "<|start_header_id|>system<|end_header_id|>\n\n{}<|eot_id|>",
413                sys
414            ));
415        }
416        for (role, content) in history {
417            prompt.push_str(&format!(
418                "<|start_header_id|>{}<|end_header_id|>\n\n{}<|eot_id|>",
419                role, content
420            ));
421        }
422        prompt.push_str(&format!(
423            "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n",
424            user_input
425        ));
426        prompt
427    } else {
428        // TinyLlama / generic chat format
429        let sys = system.unwrap_or("You are a helpful assistant.");
430        let mut prompt = format!("<|system|>\n{}</s>\n", sys);
431        for (role, content) in history {
432            let tag = if role == "user" { "user" } else { "assistant" };
433            prompt.push_str(&format!("<|{}|>\n{}</s>\n", tag, content));
434        }
435        prompt.push_str(&format!("<|user|>\n{}</s>\n<|assistant|>\n", user_input));
436        prompt
437    }
438}