Skip to main content

lc/utils/
cli_utils.rs

1//! CLI utility functions used throughout the application and tests
2
3use anyhow::{anyhow, Result};
4use std::collections::HashSet;
5use std::fs;
6use std::path::Path;
7use std::sync::atomic::{AtomicBool, Ordering};
8
9use crate::config::Config;
10
11/// Global debug mode flag
12static DEBUG_MODE: AtomicBool = AtomicBool::new(false);
13
14/// Set the global debug mode
15pub fn set_debug_mode(enabled: bool) {
16    DEBUG_MODE.store(enabled, Ordering::Relaxed);
17}
18
19/// Check if debug mode is enabled
20pub fn is_debug_mode() -> bool {
21    DEBUG_MODE.load(Ordering::Relaxed)
22}
23
24/// Determine if a file extension represents a code file
25pub fn is_code_file(ext: &str) -> bool {
26    let code_extensions: HashSet<&str> = [
27        "rs",
28        "py",
29        "js",
30        "ts",
31        "java",
32        "cpp",
33        "c",
34        "h",
35        "hpp",
36        "go",
37        "rb",
38        "php",
39        "swift",
40        "kt",
41        "scala",
42        "sh",
43        "bash",
44        "zsh",
45        "fish",
46        "ps1",
47        "bat",
48        "cmd",
49        "html",
50        "css",
51        "scss",
52        "sass",
53        "less",
54        "xml",
55        "json",
56        "yaml",
57        "yml",
58        "toml",
59        "ini",
60        "cfg",
61        "conf",
62        "sql",
63        "r",
64        "m",
65        "mm",
66        "pl",
67        "pm",
68        "lua",
69        "vim",
70        "dockerfile",
71        "makefile",
72        "cmake",
73        "gradle",
74        "maven",
75    ]
76    .iter()
77    .cloned()
78    .collect();
79
80    code_extensions.contains(&ext.to_lowercase().as_str())
81}
82
83/// Read and format attachment files for inclusion in prompts
84pub fn read_and_format_attachments(attachments: &[String]) -> Result<String> {
85    if attachments.is_empty() {
86        return Ok(String::new());
87    }
88
89    let mut result = String::new();
90
91    for attachment_path in attachments {
92        let path = Path::new(attachment_path);
93        let filename = path
94            .file_name()
95            .and_then(|n| n.to_str())
96            .unwrap_or("unknown");
97
98        // Read file content
99        let content = fs::read_to_string(path)
100            .map_err(|e| anyhow!("Failed to read file '{}': {}", attachment_path, e))?;
101
102        // Add file header
103        result.push_str(&format!("=== File: {} ===\n", filename));
104
105        // Check if this is a code file based on extension
106        if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
107            if is_code_file(ext) {
108                result.push_str(&format!("```{}\n{}\n```\n", ext.to_lowercase(), content));
109            } else {
110                result.push_str(&content);
111            }
112        } else {
113            result.push_str(&content);
114        }
115
116        result.push('\n');
117    }
118
119    Ok(result)
120}
121
122/// Resolve model and provider from configuration and CLI overrides
123pub fn resolve_model_and_provider(
124    config: &Config,
125    provider_override: Option<String>,
126    model_override: Option<String>,
127) -> Result<(String, String)> {
128    // Store whether we have explicit provider override to avoid borrow issues
129    let has_provider_override = provider_override.is_some();
130
131    let provider = match provider_override {
132        Some(p) => {
133            if !config.providers.contains_key(&p) {
134                return Err(anyhow!("Provider '{}' not found in configuration", p));
135            }
136            p
137        }
138        None => config
139            .default_provider
140            .clone()
141            .ok_or_else(|| anyhow!("No default provider configured and none specified"))?,
142    };
143
144    let model = match model_override {
145        Some(m) => {
146            // Check if model is in format "provider:model"
147            if m.contains(':') && !has_provider_override {
148                let parts: Vec<&str> = m.splitn(2, ':').collect();
149                if parts.len() == 2 {
150                    let alias_provider = parts[0].to_string();
151                    let alias_model = parts[1].to_string();
152
153                    // Verify provider exists
154                    if !config.providers.contains_key(&alias_provider) {
155                        return Err(anyhow!(
156                            "Provider '{}' not found in configuration",
157                            alias_provider
158                        ));
159                    }
160
161                    return Ok((alias_provider, alias_model));
162                }
163            }
164
165            // Check if it's an alias (only if provider is not explicitly set)
166            if !has_provider_override {
167                if let Some(alias_target) = config.aliases.get(&m) {
168                    let parts: Vec<&str> = alias_target.splitn(2, ':').collect();
169                    if parts.len() == 2 {
170                        let alias_provider = parts[0].to_string();
171                        let alias_model = parts[1].to_string();
172
173                        // Verify provider exists
174                        if !config.providers.contains_key(&alias_provider) {
175                            return Err(anyhow!(
176                                "Provider '{}' from alias not found in configuration",
177                                alias_provider
178                            ));
179                        }
180
181                        return Ok((alias_provider, alias_model));
182                    } else {
183                        return Err(anyhow!(
184                            "Invalid alias target format: '{}'. Expected 'provider:model'",
185                            alias_target
186                        ));
187                    }
188                }
189            }
190
191            m
192        }
193        None => config
194            .default_model
195            .clone()
196            .ok_or_else(|| anyhow!("No default model configured and none specified"))?,
197    };
198
199    Ok((provider, model))
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use crate::config::ProviderConfig;
206    use std::collections::HashMap;
207
208    #[test]
209    fn test_is_code_file() {
210        assert!(is_code_file("rs"));
211        assert!(is_code_file("py"));
212        assert!(is_code_file("js"));
213        assert!(is_code_file("json"));
214
215        assert!(!is_code_file("txt"));
216        assert!(!is_code_file("md"));
217        assert!(!is_code_file("pdf"));
218    }
219
220    #[test]
221    fn test_debug_mode() {
222        set_debug_mode(true);
223        assert!(is_debug_mode());
224
225        set_debug_mode(false);
226        assert!(!is_debug_mode());
227    }
228
229    #[test]
230    fn test_resolve_model_basic() {
231        let mut config = Config {
232            providers: HashMap::new(),
233            default_provider: Some("openai".to_string()),
234            default_model: Some("gpt-4".to_string()),
235            aliases: HashMap::new(),
236            system_prompt: None,
237            templates: HashMap::new(),
238            max_tokens: None,
239            temperature: None,
240            stream: None,
241        };
242
243        config.providers.insert(
244            "openai".to_string(),
245            ProviderConfig {
246                endpoint: "https://api.openai.com".to_string(),
247                models_path: "/v1/models".to_string(),
248                chat_path: "/v1/chat/completions".to_string(),
249                images_path: Some("/images/generations".to_string()),
250                embeddings_path: Some("/embeddings".to_string()),
251                api_key: Some("key".to_string()),
252                models: Vec::new(),
253                headers: HashMap::new(),
254                token_url: None,
255                cached_token: None,
256                auth_type: None,
257                vars: HashMap::new(),
258                chat_templates: None,
259                images_templates: None,
260                embeddings_templates: None,
261                models_templates: None,
262                audio_path: None,
263                speech_path: None,
264                audio_templates: None,
265                speech_templates: None,
266            },
267        );
268
269        let (provider, model) = resolve_model_and_provider(&config, None, None).unwrap();
270        assert_eq!(provider, "openai");
271        assert_eq!(model, "gpt-4");
272    }
273}