1use std::io::Write;
2use std::path::{Path, PathBuf};
3
4use serde::Deserialize;
5
6use crate::error::Error;
7pub use crate::learn_prompt::SYSTEM_PROMPT;
8
9const MAX_HINT_LENGTH: usize = 1000;
15
16const MAX_COMMAND_LENGTH: usize = 100;
18
19const MAX_FILENAME_COMPONENT: usize = 50;
21
22#[derive(Deserialize)]
23struct ConfigFile {
24 learn: Option<LearnConfig>,
25}
26
27#[derive(Deserialize, Clone)]
31pub struct LearnConfig {
32 pub provider: String,
34
35 pub model: String,
37
38 pub api_key_env: String,
40}
41
42pub(crate) struct LearnParams<'a> {
44 pub config: &'a LearnConfig,
45 pub api_key: &'a str,
46 pub base_url: &'a str,
47 pub patterns_dir: &'a Path,
48 pub learn_status_path: &'a Path,
49 pub hint: Option<&'a str>,
50}
51
52#[derive(Deserialize)]
54struct LearnData {
55 command: String,
56 output: String,
57 exit_code: i64,
58 hint: Option<String>,
59}
60
61impl Default for LearnConfig {
62 fn default() -> Self {
63 detect_provider()
64 }
65}
66
67fn detect_provider() -> LearnConfig {
69 detect_provider_with(|key| std::env::var(key).ok())
70}
71
72fn detect_provider_with<F: Fn(&str) -> Option<String>>(env_lookup: F) -> LearnConfig {
74 if env_lookup("ANTHROPIC_API_KEY").is_some() {
75 LearnConfig {
76 provider: "anthropic".into(),
77 model: "claude-haiku-4-5".into(),
78 api_key_env: "ANTHROPIC_API_KEY".into(),
79 }
80 } else {
81 LearnConfig {
83 provider: "anthropic".into(),
84 model: "claude-haiku-4-5".into(),
85 api_key_env: "ANTHROPIC_API_KEY".into(),
86 }
87 }
88}
89
90fn config_dir() -> PathBuf {
91 if let Some(test_dir) = std::env::var_os("OO_CONFIG_DIR") {
92 return PathBuf::from(test_dir);
93 }
94 dirs::config_dir()
95 .unwrap_or_else(|| PathBuf::from("/tmp"))
96 .join("oo")
97}
98
99pub fn patterns_dir() -> PathBuf {
103 config_dir().join("patterns")
104}
105
106pub fn learn_status_path() -> PathBuf {
108 config_dir().join("learn-status.log")
109}
110
111pub fn load_learn_config() -> Result<LearnConfig, Error> {
115 let path = config_dir().join("config.toml");
116 if !path.exists() {
117 return Ok(LearnConfig::default());
118 }
119 let content = std::fs::read_to_string(&path)
120 .map_err(|e| Error::Config(format!("{}: {e}", path.display())))?;
121 let cf: ConfigFile =
122 toml::from_str(&content).map_err(|e| Error::Config(format!("{}: {e}", path.display())))?;
123 Ok(cf.learn.unwrap_or_default())
124}
125
126pub(crate) fn run_learn_with_config(
133 params: &LearnParams,
134 command: &str,
135 output: &str,
136 exit_code: i32,
137) -> Result<(), Error> {
138 let hint = match params.hint {
139 Some(h) if h.len() > MAX_HINT_LENGTH => {
140 return Err(Error::Learn(format!(
141 "--hint too long ({} > {} chars)",
142 h.len(),
143 MAX_HINT_LENGTH
144 )));
145 }
146 h => h,
147 };
148
149 let truncated_command = crate::learn_utils::truncate_utf8(command, MAX_COMMAND_LENGTH);
150
151 let user_msg = if let Some(h) = hint {
152 format!(
153 "Command: {truncated_command}\nExit code: {exit_code}\nHint: {h}\nOutput:\n{}",
154 truncate_for_prompt(output)
155 )
156 } else {
157 format!(
158 "Command: {truncated_command}\nExit code: {exit_code}\nOutput:\n{}",
159 truncate_for_prompt(output)
160 )
161 };
162
163 let get_response = |msg: &str| -> Result<String, Error> {
164 match params.config.provider.as_str() {
165 "anthropic" => {
166 call_anthropic(params.base_url, params.api_key, ¶ms.config.model, msg)
167 }
168 other => Err(Error::Learn(format!("unknown provider: {other}"))),
169 }
170 };
171
172 let mut last_err;
174 let toml = get_response(&user_msg)?;
175 let clean = crate::learn_utils::strip_fences(&toml);
176 if validate_pattern_toml_with_limits(&clean).is_ok() {
177 std::fs::create_dir_all(params.patterns_dir).map_err(|e| Error::Learn(e.to_string()))?;
178 let filename = format!("{}.toml", label(command));
179 let path = params.patterns_dir.join(&filename);
180 std::fs::write(&path, &clean).map_err(|e| Error::Learn(e.to_string()))?;
181 let _ =
182 crate::commands::write_learn_status(params.learn_status_path, &label(command), &path);
183 return Ok(());
184 }
185 last_err = "initial TOML validation failed".to_string();
186
187 for _ in 0..2 {
189 let retry_msg = format!(
190 "Your previous TOML was invalid: {last_err}. Here is what you returned:\n{clean}\nOutput ONLY the corrected TOML, nothing else."
191 );
192 let toml = get_response(&retry_msg)?;
193 let clean = crate::learn_utils::strip_fences(&toml);
194 if validate_pattern_toml_with_limits(&clean).is_ok() {
195 std::fs::create_dir_all(params.patterns_dir)
196 .map_err(|e| Error::Learn(e.to_string()))?;
197 let filename = format!("{}.toml", label(command));
198 let path = params.patterns_dir.join(&filename);
199 std::fs::write(&path, &clean).map_err(|e| Error::Learn(e.to_string()))?;
200 let _ = crate::commands::write_learn_status(
201 params.learn_status_path,
202 &label(command),
203 &path,
204 );
205 return Ok(());
206 }
207 last_err = "retry TOML validation failed".to_string();
208 }
209
210 Err(Error::Learn(format!("failed after 3 attempts: {last_err}")))
211}
212
213pub fn run_learn(command: &str, output: &str, exit_code: i32) -> Result<(), Error> {
219 run_learn_with_hint(command, output, exit_code, None)
220}
221
222fn run_learn_with_hint(
226 command: &str,
227 output: &str,
228 exit_code: i32,
229 hint: Option<&str>,
230) -> Result<(), Error> {
231 let config = load_learn_config()?;
232
233 let api_key = std::env::var(&config.api_key_env).map_err(|_| {
234 Error::Learn(format!(
235 "Set {} environment variable to use `oo learn`",
236 config.api_key_env
237 ))
238 })?;
239
240 let base_url = std::env::var("ANTHROPIC_API_URL")
241 .unwrap_or_else(|_| "https://api.anthropic.com/v1/messages".to_string());
242
243 validate_anthropic_url(&base_url)?;
244
245 let params = LearnParams {
246 config: &config,
247 api_key: &api_key,
248 base_url: &base_url,
249 patterns_dir: &patterns_dir(),
250 learn_status_path: &learn_status_path(),
251 hint,
252 };
253
254 run_learn_with_config(¶ms, command, output, exit_code)
255}
256
257pub fn spawn_background(
259 command: &str,
260 output: &str,
261 exit_code: i32,
262 hint: Option<&str>,
263) -> Result<(), Error> {
264 let exe = std::env::current_exe().map_err(|e| Error::Learn(e.to_string()))?;
265
266 let mut tmp = tempfile::NamedTempFile::new().map_err(|e| Error::Learn(e.to_string()))?;
269 let mut data = serde_json::json!({
270 "command": command,
271 "output": output,
272 "exit_code": exit_code,
273 });
274 if let Some(h) = hint {
275 data["hint"] = serde_json::Value::String(h.to_string());
276 }
277 tmp.write_all(data.to_string().as_bytes())
278 .map_err(|e| Error::Learn(e.to_string()))?;
279
280 let tmp_path = tmp.into_temp_path();
283
284 std::process::Command::new(exe)
286 .arg("_learn_bg")
287 .arg(&tmp_path)
288 .stdin(std::process::Stdio::null())
289 .stdout(std::process::Stdio::null())
290 .stderr(std::process::Stdio::null())
291 .spawn()
292 .map_err(|e| Error::Learn(e.to_string()))?;
293
294 tmp_path.keep().map_err(|e| Error::Learn(e.to_string()))?;
299
300 Ok(())
301}
302
303pub fn run_background(data_path: &str) -> Result<(), Error> {
305 let path = Path::new(data_path);
306 let content = std::fs::read_to_string(path).map_err(|e| Error::Learn(e.to_string()))?;
307 let data: LearnData =
308 serde_json::from_str(&content).map_err(|e| Error::Learn(e.to_string()))?;
309
310 let command = &data.command;
311 let output = &data.output;
312
313 let exit_code = i32::try_from(data.exit_code).map_err(|_| {
315 Error::Learn(format!(
316 "exit_code out of range for i32: {}",
317 data.exit_code
318 ))
319 })?;
320
321 let hint = data.hint.as_deref();
322
323 let result = run_learn_with_hint(command, output, exit_code, hint);
324
325 let _ = std::fs::remove_file(path);
327
328 if let Err(ref e) = result {
329 let cmd_label = label(command);
330 let status_path = learn_status_path();
331 let _ =
332 crate::commands::write_learn_status_failure(&status_path, &cmd_label, &e.to_string());
333 }
334
335 result
336}
337
338fn call_anthropic(
343 base_url: &str,
344 api_key: &str,
345 model: &str,
346 user_msg: &str,
347) -> Result<String, Error> {
348 let body = serde_json::json!({
349 "model": model,
350 "max_tokens": 1024,
351 "temperature": 0.0,
352 "system": SYSTEM_PROMPT,
353 "messages": [{"role": "user", "content": user_msg}],
354 });
355
356 use std::time::Duration;
357
358 let agent: ureq::Agent = ureq::Agent::config_builder()
360 .timeout_global(Some(Duration::from_secs(30)))
361 .timeout_connect(Some(Duration::from_secs(10)))
362 .build()
363 .into();
364
365 let response: serde_json::Value = agent
366 .post(base_url)
367 .header("x-api-key", api_key)
368 .header("anthropic-version", "2023-06-01")
369 .header("content-type", "application/json")
370 .send_json(&body)
371 .map_err(|e| Error::Learn(format!("Anthropic API error: {e}")))?
372 .body_mut()
373 .read_json()
374 .map_err(|e| Error::Learn(format!("response parse error: {e}")))?;
375
376 response["content"][0]["text"]
377 .as_str()
378 .map(|s| s.to_string())
379 .ok_or_else(|| Error::Learn("unexpected Anthropic response format".into()))
380}
381
382fn label(command: &str) -> String {
387 let mut words = command.split_whitespace();
388 let first = words
389 .next()
390 .unwrap_or("unknown")
391 .rsplit('/')
392 .next()
393 .unwrap_or("unknown");
394
395 let sanitized_first: String = first
398 .chars()
399 .filter(|c| c.is_ascii_alphanumeric() || *c == '-')
400 .take(MAX_FILENAME_COMPONENT)
401 .collect();
402
403 if sanitized_first.is_empty() {
404 return "unknown".to_string();
405 }
406
407 match words.next() {
409 Some(second) if !second.starts_with('-') => {
410 let sanitized_second: String = second
413 .chars()
414 .filter(|c| c.is_ascii_alphanumeric() || *c == '-')
415 .take(MAX_FILENAME_COMPONENT)
416 .collect();
417 if sanitized_second.is_empty() {
418 sanitized_first
419 } else {
420 format!("{sanitized_first}-{sanitized_second}")
421 }
422 }
423 _ => sanitized_first,
424 }
425}
426
427fn truncate_for_prompt(output: &str) -> &str {
428 crate::learn_utils::truncate_utf8(output, 4000)
429}
430
431fn validate_anthropic_url(url: &str) -> Result<(), Error> {
433 if url.starts_with("https://") {
434 return Ok(());
435 }
436 if let Some(rest) = url.strip_prefix("http://") {
439 let host = rest.split([':', '/']).next().unwrap_or("");
440 if host == "localhost" || host == "127.0.0.1" {
441 return Ok(());
442 }
443 }
444 Err(Error::Learn(format!(
445 "ANTHROPIC_API_URL must use HTTPS (got: {url}). HTTP is only allowed for localhost/127.0.0.1."
446 )))
447}
448
449fn validate_pattern_toml_with_limits(toml_str: &str) -> Result<(), Error> {
453 crate::pattern::validate_pattern_regexes(toml_str)
454 .map_err(|e| Error::Learn(format!("pattern validation: {e}")))
455}
456
457#[cfg(test)]
459#[path = "learn_tests.rs"]
460mod tests;
461
462#[cfg(test)]
463#[path = "learn_prompt_tests.rs"]
464mod prompt_tests;