Skip to main content

double_o/
learn.rs

1use std::io::Write;
2use std::path::{Path, PathBuf};
3
4use serde::Deserialize;
5
6use crate::error::Error;
7use crate::pattern::FailureSection;
8
9// ---------------------------------------------------------------------------
10// Config
11// ---------------------------------------------------------------------------
12
13#[derive(Deserialize)]
14struct ConfigFile {
15    learn: Option<LearnConfig>,
16}
17
18#[derive(Deserialize, Clone)]
19pub struct LearnConfig {
20    pub provider: String,
21    pub model: String,
22    pub api_key_env: String,
23}
24
25/// Testable variant of learn config and paths — avoids env var mutation.
26pub(crate) struct LearnParams<'a> {
27    pub config: &'a LearnConfig,
28    pub api_key: &'a str,
29    pub base_url: &'a str,
30    pub patterns_dir: &'a Path,
31    pub learn_status_path: &'a Path,
32}
33
34impl Default for LearnConfig {
35    fn default() -> Self {
36        detect_provider()
37    }
38}
39
40// Auto-detect provider from available API keys (checked in priority order).
41fn detect_provider() -> LearnConfig {
42    detect_provider_with(|key| std::env::var(key).ok())
43}
44
45// Testable variant — accepts a closure for env lookup to avoid env mutation in tests.
46fn detect_provider_with<F: Fn(&str) -> Option<String>>(env_lookup: F) -> LearnConfig {
47    if env_lookup("ANTHROPIC_API_KEY").is_some() {
48        LearnConfig {
49            provider: "anthropic".into(),
50            model: "claude-haiku-4-5".into(),
51            api_key_env: "ANTHROPIC_API_KEY".into(),
52        }
53    } else {
54        // Default to anthropic; will fail at runtime if no key is set.
55        LearnConfig {
56            provider: "anthropic".into(),
57            model: "claude-haiku-4-5".into(),
58            api_key_env: "ANTHROPIC_API_KEY".into(),
59        }
60    }
61}
62
63fn config_dir() -> PathBuf {
64    if let Some(test_dir) = std::env::var_os("OO_CONFIG_DIR") {
65        return PathBuf::from(test_dir);
66    }
67    dirs::config_dir()
68        .unwrap_or_else(|| PathBuf::from("/tmp"))
69        .join("oo")
70}
71
72pub fn patterns_dir() -> PathBuf {
73    config_dir().join("patterns")
74}
75
76/// Path to the one-line status file written by the background learn process.
77pub fn learn_status_path() -> PathBuf {
78    config_dir().join("learn-status.log")
79}
80
81pub fn load_learn_config() -> Result<LearnConfig, Error> {
82    let path = config_dir().join("config.toml");
83    if !path.exists() {
84        return Ok(LearnConfig::default());
85    }
86    let content = std::fs::read_to_string(&path)
87        .map_err(|e| Error::Config(format!("{}: {e}", path.display())))?;
88    let cf: ConfigFile =
89        toml::from_str(&content).map_err(|e| Error::Config(format!("{}: {e}", path.display())))?;
90    Ok(cf.learn.unwrap_or_default())
91}
92
93// ---------------------------------------------------------------------------
94// Background learning
95// ---------------------------------------------------------------------------
96
97const SYSTEM_PROMPT: &str = r#"You generate output classification patterns for `oo`, a shell command runner used by an LLM coding agent.
98
99The agent reads your pattern to decide its next action. Returning nothing is the WORST outcome — an empty summary forces a costly recall cycle.
100
101IMPORTANT: Use named capture groups (?P<name>...) only — never numbered groups like (\d+). Summary templates use {name} placeholders matching the named groups.
102
103## oo's 4-tier system
104
105- Passthrough: output <4 KB passes through unchanged
106- Failure: failed commands get ✗ prefix with filtered error output
107- Success: successful commands get ✓ prefix with a pattern-extracted summary
108- Large: if regex fails to match, output is FTS5 indexed for recall
109
110## Examples
111
112Test runner — capture RESULT line, not header; strategy=tail for failures:
113    command_match = "\\bcargo\\s+test\\b"
114    [success]
115    pattern = 'test result: ok\. (?P<passed>\d+) passed.*finished in (?P<time>[\d.]+)s'
116    summary = "{passed} passed, {time}s"
117    [failure]
118    strategy = "tail"
119    lines = 30
120
121Build/lint — quiet on success (only useful when failing); strategy=head for failures:
122    command_match = "\\bcargo\\s+build\\b"
123    [success]
124    pattern = "(?s).*"
125    summary = ""
126    [failure]
127    strategy = "head"
128    lines = 20
129
130## Rules
131
132- Test runners: capture SUMMARY line (e.g. 'test result: ok. 5 passed'), NOT headers (e.g. 'running 5 tests')
133- Build/lint tools: empty summary for success; head/lines=20 for failures
134- Large tabular output (ls, git log): omit success section — falls to Large tier"#;
135
136/// Run the learn flow with explicit config and base URL — testable variant.
137///
138/// This internal function bypasses `load_learn_config()` and env var lookup,
139/// making it suitable for testing without environment mutation.
140pub(crate) fn run_learn_with_config(
141    params: &LearnParams,
142    command: &str,
143    output: &str,
144    exit_code: i32,
145) -> Result<(), Error> {
146    let user_msg = format!(
147        "Command: {command}\nExit code: {exit_code}\nOutput:\n{}",
148        truncate_for_prompt(output)
149    );
150
151    let get_response = |msg: &str| -> Result<String, Error> {
152        match params.config.provider.as_str() {
153            "anthropic" => {
154                call_anthropic(params.base_url, params.api_key, &params.config.model, msg)
155            }
156            other => Err(Error::Learn(format!("unknown provider: {other}"))),
157        }
158    };
159
160    // First attempt
161    let mut last_err;
162    let toml = get_response(&user_msg)?;
163    let clean = strip_fences(&toml);
164    match validate_pattern_toml(&clean) {
165        Ok(()) => {
166            std::fs::create_dir_all(params.patterns_dir)
167                .map_err(|e| Error::Learn(e.to_string()))?;
168            let filename = format!("{}.toml", label(command));
169            let path = params.patterns_dir.join(&filename);
170            std::fs::write(&path, &clean).map_err(|e| Error::Learn(e.to_string()))?;
171            let _ = crate::commands::write_learn_status(
172                params.learn_status_path,
173                &label(command),
174                &path,
175            );
176            return Ok(());
177        }
178        Err(e) => last_err = e,
179    }
180
181    // Up to 2 retries
182    for _ in 0..2 {
183        let retry_msg = format!(
184            "Your previous TOML was invalid: {last_err}. Here is what you returned:\n{clean}\nOutput ONLY the corrected TOML, nothing else."
185        );
186        let toml = get_response(&retry_msg)?;
187        let clean = strip_fences(&toml);
188        match validate_pattern_toml(&clean) {
189            Ok(()) => {
190                std::fs::create_dir_all(params.patterns_dir)
191                    .map_err(|e| Error::Learn(e.to_string()))?;
192                let filename = format!("{}.toml", label(command));
193                let path = params.patterns_dir.join(&filename);
194                std::fs::write(&path, &clean).map_err(|e| Error::Learn(e.to_string()))?;
195                let _ = crate::commands::write_learn_status(
196                    params.learn_status_path,
197                    &label(command),
198                    &path,
199                );
200                return Ok(());
201            }
202            Err(e) => last_err = e,
203        }
204    }
205
206    Err(Error::Learn(format!("failed after 3 attempts: {last_err}")))
207}
208
209/// Run the learn flow: call LLM, validate + save pattern.
210pub fn run_learn(command: &str, output: &str, exit_code: i32) -> Result<(), Error> {
211    let config = load_learn_config()?;
212
213    let api_key = std::env::var(&config.api_key_env).map_err(|_| {
214        Error::Learn(format!(
215            "Set {} environment variable to use `oo learn`",
216            config.api_key_env
217        ))
218    })?;
219
220    let base_url = std::env::var("ANTHROPIC_API_URL")
221        .unwrap_or_else(|_| "https://api.anthropic.com/v1/messages".to_string());
222
223    let params = LearnParams {
224        config: &config,
225        api_key: &api_key,
226        base_url: &base_url,
227        patterns_dir: &patterns_dir(),
228        learn_status_path: &learn_status_path(),
229    };
230
231    run_learn_with_config(&params, command, output, exit_code)
232}
233
234/// Spawn the learning process in the background by re-exec'ing ourselves.
235pub fn spawn_background(command: &str, output: &str, exit_code: i32) -> Result<(), Error> {
236    let exe = std::env::current_exe().map_err(|e| Error::Learn(e.to_string()))?;
237
238    // Use a secure named temp file to avoid PID-based predictable filenames
239    // (symlink/TOCTOU attacks). The file is kept alive until the child spawns.
240    let mut tmp = tempfile::NamedTempFile::new().map_err(|e| Error::Learn(e.to_string()))?;
241    let data = serde_json::json!({
242        "command": command,
243        "output": output,
244        "exit_code": exit_code,
245    });
246    tmp.write_all(data.to_string().as_bytes())
247        .map_err(|e| Error::Learn(e.to_string()))?;
248
249    // Convert to TempPath: closes the file handle but keeps the file on disk
250    // until the TempPath is dropped — after the child has been spawned.
251    let tmp_path = tmp.into_temp_path();
252
253    // Spawn detached child
254    std::process::Command::new(exe)
255        .arg("_learn_bg")
256        .arg(&tmp_path)
257        .stdin(std::process::Stdio::null())
258        .stdout(std::process::Stdio::null())
259        .stderr(std::process::Stdio::null())
260        .spawn()
261        .map_err(|e| Error::Learn(e.to_string()))?;
262
263    // Prevent the parent from deleting the temp file on drop. On a loaded
264    // system the child process may not have opened the file yet by the time
265    // the parent exits this function. `keep()` makes the file persist on disk
266    // until the child cleans it up at run_background (line ~218 below).
267    tmp_path.keep().map_err(|e| Error::Learn(e.to_string()))?;
268
269    Ok(())
270}
271
272/// Entry point for the background learn child process.
273pub fn run_background(data_path: &str) -> Result<(), Error> {
274    let path = Path::new(data_path);
275    let content = std::fs::read_to_string(path).map_err(|e| Error::Learn(e.to_string()))?;
276    let data: serde_json::Value =
277        serde_json::from_str(&content).map_err(|e| Error::Learn(e.to_string()))?;
278
279    let command = data["command"].as_str().unwrap_or("");
280    let output = data["output"].as_str().unwrap_or("");
281    let exit_code = data["exit_code"].as_i64().unwrap_or(0) as i32;
282
283    let result = run_learn(command, output, exit_code);
284
285    // Clean up temp file
286    let _ = std::fs::remove_file(path);
287
288    if let Err(ref e) = result {
289        let cmd_label = label(command);
290        let status_path = learn_status_path();
291        let _ =
292            crate::commands::write_learn_status_failure(&status_path, &cmd_label, &e.to_string());
293    }
294
295    result
296}
297
298// ---------------------------------------------------------------------------
299// LLM API calls
300// ---------------------------------------------------------------------------
301
302fn call_anthropic(
303    base_url: &str,
304    api_key: &str,
305    model: &str,
306    user_msg: &str,
307) -> Result<String, Error> {
308    let body = serde_json::json!({
309        "model": model,
310        "max_tokens": 1024,
311        "temperature": 0.0,
312        "system": SYSTEM_PROMPT,
313        "messages": [{"role": "user", "content": user_msg}],
314    });
315
316    let response: serde_json::Value = ureq::post(base_url)
317        .header("x-api-key", api_key)
318        .header("anthropic-version", "2023-06-01")
319        .header("content-type", "application/json")
320        .send_json(&body)
321        .map_err(|e| Error::Learn(format!("Anthropic API error: {e}")))?
322        .body_mut()
323        .read_json()
324        .map_err(|e| Error::Learn(format!("response parse error: {e}")))?;
325
326    response["content"][0]["text"]
327        .as_str()
328        .map(|s| s.to_string())
329        .ok_or_else(|| Error::Learn("unexpected Anthropic response format".into()))
330}
331
332// ---------------------------------------------------------------------------
333// Helpers
334// ---------------------------------------------------------------------------
335
336fn label(command: &str) -> String {
337    let mut words = command.split_whitespace();
338    let first = words
339        .next()
340        .unwrap_or("unknown")
341        .rsplit('/')
342        .next()
343        .unwrap_or("unknown");
344    // Include the second word only when it is a subcommand (not a flag).
345    match words.next() {
346        Some(second) if !second.starts_with('-') => {
347            // Sanitize: keep only ASCII alphanumeric and hyphens to ensure
348            // the label is safe as a filename component.
349            let sanitized: String = second
350                .chars()
351                .filter(|c| c.is_ascii_alphanumeric() || *c == '-')
352                .collect();
353            if sanitized.is_empty() {
354                first.to_string()
355            } else {
356                format!("{first}-{sanitized}")
357            }
358        }
359        _ => first.to_string(),
360    }
361}
362
363fn truncate_for_prompt(output: &str) -> &str {
364    truncate_utf8(output, 4000)
365}
366
367// Truncate at a char boundary to avoid panics on multibyte UTF-8 sequences.
368fn truncate_utf8(s: &str, max_bytes: usize) -> &str {
369    if s.len() <= max_bytes {
370        return s;
371    }
372    let mut end = max_bytes;
373    while end > 0 && !s.is_char_boundary(end) {
374        end -= 1;
375    }
376    &s[..end]
377}
378
379fn strip_fences(s: &str) -> String {
380    let trimmed = s.trim();
381    if let Some(rest) = trimmed.strip_prefix("```toml") {
382        rest.strip_suffix("```").unwrap_or(rest).trim().to_string()
383    } else if let Some(rest) = trimmed.strip_prefix("```") {
384        rest.strip_suffix("```").unwrap_or(rest).trim().to_string()
385    } else {
386        trimmed.to_string()
387    }
388}
389
390fn validate_pattern_toml(toml_str: &str) -> Result<(), Error> {
391    // Try to parse as our pattern format
392    #[derive(Deserialize)]
393    struct Check {
394        command_match: String,
395        // Deserialization target: field must exist for TOML parsing even if not read in code
396        #[allow(dead_code)] // used only for TOML deserialization validation
397        success: Option<SuccessCheck>,
398        failure: Option<FailureSection>,
399    }
400    #[derive(Deserialize)]
401    struct SuccessCheck {
402        pattern: String,
403        // Deserialization target: field must exist for TOML parsing even if not read in code
404        #[allow(dead_code)] // used only for TOML deserialization validation
405        summary: String,
406    }
407
408    let check: Check =
409        toml::from_str(toml_str).map_err(|e| Error::Learn(format!("invalid TOML: {e}")))?;
410
411    // Verify regexes compile
412    regex::Regex::new(&check.command_match)
413        .map_err(|e| Error::Learn(format!("invalid command_match regex: {e}")))?;
414
415    if let Some(s) = &check.success {
416        regex::Regex::new(&s.pattern)
417            .map_err(|e| Error::Learn(format!("invalid success pattern regex: {e}")))?;
418    }
419
420    if let Some(f) = &check.failure {
421        match f.strategy.as_deref().unwrap_or("tail") {
422            "grep" => {
423                let pat = f.grep_pattern.as_deref().ok_or_else(|| {
424                    Error::Learn("failure grep strategy requires a 'grep' field".into())
425                })?;
426                if pat.is_empty() {
427                    return Err(Error::Learn("failure grep regex must not be empty".into()));
428                }
429                regex::Regex::new(pat)
430                    .map_err(|e| Error::Learn(format!("invalid failure grep regex: {e}")))?;
431            }
432            "between" => {
433                let start = f.start.as_deref().ok_or_else(|| {
434                    Error::Learn("between strategy requires 'start' field".into())
435                })?;
436                if start.is_empty() {
437                    return Err(Error::Learn("between 'start' must not be empty".into()));
438                }
439                regex::Regex::new(start)
440                    .map_err(|e| Error::Learn(format!("invalid start regex: {e}")))?;
441                let end = f
442                    .end
443                    .as_deref()
444                    .ok_or_else(|| Error::Learn("between strategy requires 'end' field".into()))?;
445                if end.is_empty() {
446                    return Err(Error::Learn("between 'end' must not be empty".into()));
447                }
448                regex::Regex::new(end)
449                    .map_err(|e| Error::Learn(format!("invalid end regex: {e}")))?;
450            }
451            "tail" | "head" => {} // no regex to validate
452            other => {
453                return Err(Error::Learn(format!("unknown failure strategy: {other}")));
454            }
455        }
456    }
457
458    Ok(())
459}
460
461// Tests live in separate files to keep this module under 500 lines.
462#[cfg(test)]
463#[path = "learn_tests.rs"]
464mod tests;
465
466#[cfg(test)]
467#[path = "learn_prompt_tests.rs"]
468mod prompt_tests;