1use anyhow::{anyhow, Result};
4use std::collections::HashSet;
5use std::fs;
6use std::path::Path;
7use std::sync::atomic::{AtomicBool, Ordering};
8
9use crate::config::Config;
10
11static DEBUG_MODE: AtomicBool = AtomicBool::new(false);
13
14pub fn set_debug_mode(enabled: bool) {
16 DEBUG_MODE.store(enabled, Ordering::Relaxed);
17}
18
19pub fn is_debug_mode() -> bool {
21 DEBUG_MODE.load(Ordering::Relaxed)
22}
23
24pub fn is_code_file(ext: &str) -> bool {
26 let code_extensions: HashSet<&str> = [
27 "rs",
28 "py",
29 "js",
30 "ts",
31 "java",
32 "cpp",
33 "c",
34 "h",
35 "hpp",
36 "go",
37 "rb",
38 "php",
39 "swift",
40 "kt",
41 "scala",
42 "sh",
43 "bash",
44 "zsh",
45 "fish",
46 "ps1",
47 "bat",
48 "cmd",
49 "html",
50 "css",
51 "scss",
52 "sass",
53 "less",
54 "xml",
55 "json",
56 "yaml",
57 "yml",
58 "toml",
59 "ini",
60 "cfg",
61 "conf",
62 "sql",
63 "r",
64 "m",
65 "mm",
66 "pl",
67 "pm",
68 "lua",
69 "vim",
70 "dockerfile",
71 "makefile",
72 "cmake",
73 "gradle",
74 "maven",
75 ]
76 .iter()
77 .cloned()
78 .collect();
79
80 code_extensions.contains(&ext.to_lowercase().as_str())
81}
82
83pub fn read_and_format_attachments(attachments: &[String]) -> Result<String> {
85 if attachments.is_empty() {
86 return Ok(String::new());
87 }
88
89 let mut result = String::new();
90
91 for attachment_path in attachments {
92 let path = Path::new(attachment_path);
93 let filename = path
94 .file_name()
95 .and_then(|n| n.to_str())
96 .unwrap_or("unknown");
97
98 let content = fs::read_to_string(path)
100 .map_err(|e| anyhow!("Failed to read file '{}': {}", attachment_path, e))?;
101
102 result.push_str(&format!("=== File: {} ===\n", filename));
104
105 if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
107 if is_code_file(ext) {
108 result.push_str(&format!("```{}\n{}\n```\n", ext.to_lowercase(), content));
109 } else {
110 result.push_str(&content);
111 }
112 } else {
113 result.push_str(&content);
114 }
115
116 result.push('\n');
117 }
118
119 Ok(result)
120}
121
122pub fn resolve_model_and_provider(
124 config: &Config,
125 provider_override: Option<String>,
126 model_override: Option<String>,
127) -> Result<(String, String)> {
128 let has_provider_override = provider_override.is_some();
130
131 let provider = match provider_override {
132 Some(p) => {
133 if !config.providers.contains_key(&p) {
134 return Err(anyhow!("Provider '{}' not found in configuration", p));
135 }
136 p
137 }
138 None => config
139 .default_provider
140 .clone()
141 .ok_or_else(|| anyhow!("No default provider configured and none specified"))?,
142 };
143
144 let model = match model_override {
145 Some(m) => {
146 if m.contains(':') && !has_provider_override {
148 let parts: Vec<&str> = m.splitn(2, ':').collect();
149 if parts.len() == 2 {
150 let alias_provider = parts[0].to_string();
151 let alias_model = parts[1].to_string();
152
153 if !config.providers.contains_key(&alias_provider) {
155 return Err(anyhow!(
156 "Provider '{}' not found in configuration",
157 alias_provider
158 ));
159 }
160
161 return Ok((alias_provider, alias_model));
162 }
163 }
164
165 if !has_provider_override {
167 if let Some(alias_target) = config.aliases.get(&m) {
168 let parts: Vec<&str> = alias_target.splitn(2, ':').collect();
169 if parts.len() == 2 {
170 let alias_provider = parts[0].to_string();
171 let alias_model = parts[1].to_string();
172
173 if !config.providers.contains_key(&alias_provider) {
175 return Err(anyhow!(
176 "Provider '{}' from alias not found in configuration",
177 alias_provider
178 ));
179 }
180
181 return Ok((alias_provider, alias_model));
182 } else {
183 return Err(anyhow!(
184 "Invalid alias target format: '{}'. Expected 'provider:model'",
185 alias_target
186 ));
187 }
188 }
189 }
190
191 m
192 }
193 None => config
194 .default_model
195 .clone()
196 .ok_or_else(|| anyhow!("No default model configured and none specified"))?,
197 };
198
199 Ok((provider, model))
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use crate::config::ProviderConfig;
206 use std::collections::HashMap;
207
208 #[test]
209 fn test_is_code_file() {
210 assert!(is_code_file("rs"));
211 assert!(is_code_file("py"));
212 assert!(is_code_file("js"));
213 assert!(is_code_file("json"));
214
215 assert!(!is_code_file("txt"));
216 assert!(!is_code_file("md"));
217 assert!(!is_code_file("pdf"));
218 }
219
220 #[test]
221 fn test_debug_mode() {
222 set_debug_mode(true);
223 assert!(is_debug_mode());
224
225 set_debug_mode(false);
226 assert!(!is_debug_mode());
227 }
228
229 #[test]
230 fn test_resolve_model_basic() {
231 let mut config = Config {
232 providers: HashMap::new(),
233 default_provider: Some("openai".to_string()),
234 default_model: Some("gpt-4".to_string()),
235 aliases: HashMap::new(),
236 system_prompt: None,
237 templates: HashMap::new(),
238 max_tokens: None,
239 temperature: None,
240 stream: None,
241 };
242
243 config.providers.insert(
244 "openai".to_string(),
245 ProviderConfig {
246 endpoint: "https://api.openai.com".to_string(),
247 models_path: "/v1/models".to_string(),
248 chat_path: "/v1/chat/completions".to_string(),
249 images_path: Some("/images/generations".to_string()),
250 embeddings_path: Some("/embeddings".to_string()),
251 api_key: Some("key".to_string()),
252 models: Vec::new(),
253 headers: HashMap::new(),
254 token_url: None,
255 cached_token: None,
256 auth_type: None,
257 vars: HashMap::new(),
258 chat_templates: None,
259 images_templates: None,
260 embeddings_templates: None,
261 models_templates: None,
262 audio_path: None,
263 speech_path: None,
264 audio_templates: None,
265 speech_templates: None,
266 },
267 );
268
269 let (provider, model) = resolve_model_and_provider(&config, None, None).unwrap();
270 assert_eq!(provider, "openai");
271 assert_eq!(model, "gpt-4");
272 }
273}