Skip to main content

double_o/
learn.rs

1use std::io::Write;
2use std::path::{Path, PathBuf};
3
4use serde::Deserialize;
5
6use crate::error::Error;
7use crate::pattern::FailureSection;
8
9// ---------------------------------------------------------------------------
10// Config
11// ---------------------------------------------------------------------------
12
13#[derive(Deserialize)]
14struct ConfigFile {
15    learn: Option<LearnConfig>,
16}
17
18/// Configuration for the `oo learn` LLM integration.
19///
20/// Specifies the LLM provider, model, and environment variable for the API key.
21#[derive(Deserialize, Clone)]
22pub struct LearnConfig {
23    /// LLM provider name (currently only "anthropic" is supported).
24    pub provider: String,
25
26    /// Model identifier (e.g., "claude-haiku-4-5").
27    pub model: String,
28
29    /// Environment variable containing the API key.
30    pub api_key_env: String,
31}
32
33/// Testable variant of learn config and paths — avoids env var mutation.
34pub(crate) struct LearnParams<'a> {
35    pub config: &'a LearnConfig,
36    pub api_key: &'a str,
37    pub base_url: &'a str,
38    pub patterns_dir: &'a Path,
39    pub learn_status_path: &'a Path,
40}
41
42impl Default for LearnConfig {
43    fn default() -> Self {
44        detect_provider()
45    }
46}
47
48// Auto-detect provider from available API keys (checked in priority order).
49fn detect_provider() -> LearnConfig {
50    detect_provider_with(|key| std::env::var(key).ok())
51}
52
53// Testable variant — accepts a closure for env lookup to avoid env mutation in tests.
54fn detect_provider_with<F: Fn(&str) -> Option<String>>(env_lookup: F) -> LearnConfig {
55    if env_lookup("ANTHROPIC_API_KEY").is_some() {
56        LearnConfig {
57            provider: "anthropic".into(),
58            model: "claude-haiku-4-5".into(),
59            api_key_env: "ANTHROPIC_API_KEY".into(),
60        }
61    } else {
62        // Default to anthropic; will fail at runtime if no key is set.
63        LearnConfig {
64            provider: "anthropic".into(),
65            model: "claude-haiku-4-5".into(),
66            api_key_env: "ANTHROPIC_API_KEY".into(),
67        }
68    }
69}
70
71fn config_dir() -> PathBuf {
72    if let Some(test_dir) = std::env::var_os("OO_CONFIG_DIR") {
73        return PathBuf::from(test_dir);
74    }
75    dirs::config_dir()
76        .unwrap_or_else(|| PathBuf::from("/tmp"))
77        .join("oo")
78}
79
80/// Get the directory containing user-defined patterns.
81///
82/// Returns `~/.config/oo/patterns` or the overridden `OO_CONFIG_DIR/patterns`.
83pub fn patterns_dir() -> PathBuf {
84    config_dir().join("patterns")
85}
86
87/// Path to the one-line status file written by the background learn process.
88pub fn learn_status_path() -> PathBuf {
89    config_dir().join("learn-status.log")
90}
91
92/// Load learn configuration from `~/.config/oo/config.toml`.
93///
94/// Returns the default configuration if the file doesn't exist.
95pub fn load_learn_config() -> Result<LearnConfig, Error> {
96    let path = config_dir().join("config.toml");
97    if !path.exists() {
98        return Ok(LearnConfig::default());
99    }
100    let content = std::fs::read_to_string(&path)
101        .map_err(|e| Error::Config(format!("{}: {e}", path.display())))?;
102    let cf: ConfigFile =
103        toml::from_str(&content).map_err(|e| Error::Config(format!("{}: {e}", path.display())))?;
104    Ok(cf.learn.unwrap_or_default())
105}
106
107// ---------------------------------------------------------------------------
108// Background learning
109// ---------------------------------------------------------------------------
110
111const SYSTEM_PROMPT: &str = r#"You generate output classification patterns for `oo`, a shell command runner used by an LLM coding agent.
112
113The agent reads your pattern to decide its next action. Returning nothing is the WORST outcome — an empty summary forces a costly recall cycle.
114
115IMPORTANT: Use named capture groups (?P<name>...) only — never numbered groups like (\d+). Summary templates use {name} placeholders matching the named groups.
116
117## oo's 4-tier system
118
119- Passthrough: output <4 KB passes through unchanged
120- Failure: failed commands get ✗ prefix with filtered error output
121- Success: successful commands get ✓ prefix with a pattern-extracted summary
122- Large: if regex fails to match, output is FTS5 indexed for recall
123
124## Examples
125
126Test runner — capture RESULT line, not header; strategy=tail for failures:
127    command_match = "\\bcargo\\s+test\\b"
128    [success]
129    pattern = 'test result: ok\. (?P<passed>\d+) passed.*finished in (?P<time>[\d.]+)s'
130    summary = "{passed} passed, {time}s"
131    [failure]
132    strategy = "tail"
133    lines = 30
134
135Build/lint — quiet on success (only useful when failing); strategy=head for failures:
136    command_match = "\\bcargo\\s+build\\b"
137    [success]
138    pattern = "(?s).*"
139    summary = ""
140    [failure]
141    strategy = "head"
142    lines = 20
143
144## Rules
145
146- Test runners: capture SUMMARY line (e.g. 'test result: ok. 5 passed'), NOT headers (e.g. 'running 5 tests')
147- Build/lint tools: empty summary for success; head/lines=20 for failures
148- Large tabular output (ls, git log): omit success section — falls to Large tier
149
150## Command Categories
151
152Note: oo categorizes commands (Status: tests/builds/lints, Content: git show/diff/cat, Data: git log/ls/gh, Unknown: others). Patterns are most valuable for Status commands. Content commands always pass through regardless of size; Data commands are indexed when large and unpatterned."#;
153
154/// Run the learn flow with explicit config and base URL — testable variant.
155///
156/// This internal function bypasses `load_learn_config()` and env var lookup,
157/// making it suitable for testing without environment mutation.
158pub(crate) fn run_learn_with_config(
159    params: &LearnParams,
160    command: &str,
161    output: &str,
162    exit_code: i32,
163) -> Result<(), Error> {
164    let user_msg = format!(
165        "Command: {command}\nExit code: {exit_code}\nOutput:\n{}",
166        truncate_for_prompt(output)
167    );
168
169    let get_response = |msg: &str| -> Result<String, Error> {
170        match params.config.provider.as_str() {
171            "anthropic" => {
172                call_anthropic(params.base_url, params.api_key, &params.config.model, msg)
173            }
174            other => Err(Error::Learn(format!("unknown provider: {other}"))),
175        }
176    };
177
178    // First attempt
179    let mut last_err;
180    let toml = get_response(&user_msg)?;
181    let clean = strip_fences(&toml);
182    match validate_pattern_toml(&clean) {
183        Ok(()) => {
184            std::fs::create_dir_all(params.patterns_dir)
185                .map_err(|e| Error::Learn(e.to_string()))?;
186            let filename = format!("{}.toml", label(command));
187            let path = params.patterns_dir.join(&filename);
188            std::fs::write(&path, &clean).map_err(|e| Error::Learn(e.to_string()))?;
189            let _ = crate::commands::write_learn_status(
190                params.learn_status_path,
191                &label(command),
192                &path,
193            );
194            return Ok(());
195        }
196        Err(e) => last_err = e,
197    }
198
199    // Up to 2 retries
200    for _ in 0..2 {
201        let retry_msg = format!(
202            "Your previous TOML was invalid: {last_err}. Here is what you returned:\n{clean}\nOutput ONLY the corrected TOML, nothing else."
203        );
204        let toml = get_response(&retry_msg)?;
205        let clean = strip_fences(&toml);
206        match validate_pattern_toml(&clean) {
207            Ok(()) => {
208                std::fs::create_dir_all(params.patterns_dir)
209                    .map_err(|e| Error::Learn(e.to_string()))?;
210                let filename = format!("{}.toml", label(command));
211                let path = params.patterns_dir.join(&filename);
212                std::fs::write(&path, &clean).map_err(|e| Error::Learn(e.to_string()))?;
213                let _ = crate::commands::write_learn_status(
214                    params.learn_status_path,
215                    &label(command),
216                    &path,
217                );
218                return Ok(());
219            }
220            Err(e) => last_err = e,
221        }
222    }
223
224    Err(Error::Learn(format!("failed after 3 attempts: {last_err}")))
225}
226
227/// Run the learn flow: call LLM, validate + save pattern.
228///
229/// Loads configuration from environment, calls the LLM to generate a pattern,
230/// validates the result, and saves the pattern to disk. Retries up to 2 times
231/// if the LLM returns invalid TOML.
232pub fn run_learn(command: &str, output: &str, exit_code: i32) -> Result<(), Error> {
233    let config = load_learn_config()?;
234
235    let api_key = std::env::var(&config.api_key_env).map_err(|_| {
236        Error::Learn(format!(
237            "Set {} environment variable to use `oo learn`",
238            config.api_key_env
239        ))
240    })?;
241
242    let base_url = std::env::var("ANTHROPIC_API_URL")
243        .unwrap_or_else(|_| "https://api.anthropic.com/v1/messages".to_string());
244
245    validate_anthropic_url(&base_url)?;
246
247    let params = LearnParams {
248        config: &config,
249        api_key: &api_key,
250        base_url: &base_url,
251        patterns_dir: &patterns_dir(),
252        learn_status_path: &learn_status_path(),
253    };
254
255    run_learn_with_config(&params, command, output, exit_code)
256}
257
258/// Spawn the learning process in the background by re-exec'ing ourselves.
259pub fn spawn_background(command: &str, output: &str, exit_code: i32) -> Result<(), Error> {
260    let exe = std::env::current_exe().map_err(|e| Error::Learn(e.to_string()))?;
261
262    // Use a secure named temp file to avoid PID-based predictable filenames
263    // (symlink/TOCTOU attacks). The file is kept alive until the child spawns.
264    let mut tmp = tempfile::NamedTempFile::new().map_err(|e| Error::Learn(e.to_string()))?;
265    let data = serde_json::json!({
266        "command": command,
267        "output": output,
268        "exit_code": exit_code,
269    });
270    tmp.write_all(data.to_string().as_bytes())
271        .map_err(|e| Error::Learn(e.to_string()))?;
272
273    // Convert to TempPath: closes the file handle but keeps the file on disk
274    // until the TempPath is dropped — after the child has been spawned.
275    let tmp_path = tmp.into_temp_path();
276
277    // Spawn detached child
278    std::process::Command::new(exe)
279        .arg("_learn_bg")
280        .arg(&tmp_path)
281        .stdin(std::process::Stdio::null())
282        .stdout(std::process::Stdio::null())
283        .stderr(std::process::Stdio::null())
284        .spawn()
285        .map_err(|e| Error::Learn(e.to_string()))?;
286
287    // Prevent the parent from deleting the temp file on drop. On a loaded
288    // system the child process may not have opened the file yet by the time
289    // the parent exits this function. `keep()` makes the file persist on disk
290    // until the child cleans it up at run_background (line ~218 below).
291    tmp_path.keep().map_err(|e| Error::Learn(e.to_string()))?;
292
293    Ok(())
294}
295
296/// Entry point for the background learn child process.
297pub fn run_background(data_path: &str) -> Result<(), Error> {
298    let path = Path::new(data_path);
299    let content = std::fs::read_to_string(path).map_err(|e| Error::Learn(e.to_string()))?;
300    let data: serde_json::Value =
301        serde_json::from_str(&content).map_err(|e| Error::Learn(e.to_string()))?;
302
303    let command = data["command"].as_str().unwrap_or("");
304    let output = data["output"].as_str().unwrap_or("");
305    let exit_code = data["exit_code"].as_i64().unwrap_or(0) as i32;
306
307    let result = run_learn(command, output, exit_code);
308
309    // Clean up temp file
310    let _ = std::fs::remove_file(path);
311
312    if let Err(ref e) = result {
313        let cmd_label = label(command);
314        let status_path = learn_status_path();
315        let _ =
316            crate::commands::write_learn_status_failure(&status_path, &cmd_label, &e.to_string());
317    }
318
319    result
320}
321
322// ---------------------------------------------------------------------------
323// LLM API calls
324// ---------------------------------------------------------------------------
325
326fn call_anthropic(
327    base_url: &str,
328    api_key: &str,
329    model: &str,
330    user_msg: &str,
331) -> Result<String, Error> {
332    let body = serde_json::json!({
333        "model": model,
334        "max_tokens": 1024,
335        "temperature": 0.0,
336        "system": SYSTEM_PROMPT,
337        "messages": [{"role": "user", "content": user_msg}],
338    });
339
340    let response: serde_json::Value = ureq::post(base_url)
341        .header("x-api-key", api_key)
342        .header("anthropic-version", "2023-06-01")
343        .header("content-type", "application/json")
344        .send_json(&body)
345        .map_err(|e| Error::Learn(format!("Anthropic API error: {e}")))?
346        .body_mut()
347        .read_json()
348        .map_err(|e| Error::Learn(format!("response parse error: {e}")))?;
349
350    response["content"][0]["text"]
351        .as_str()
352        .map(|s| s.to_string())
353        .ok_or_else(|| Error::Learn("unexpected Anthropic response format".into()))
354}
355
356// ---------------------------------------------------------------------------
357// Helpers
358// ---------------------------------------------------------------------------
359
360fn label(command: &str) -> String {
361    let mut words = command.split_whitespace();
362    let first = words
363        .next()
364        .unwrap_or("unknown")
365        .rsplit('/')
366        .next()
367        .unwrap_or("unknown");
368    // Include the second word only when it is a subcommand (not a flag).
369    match words.next() {
370        Some(second) if !second.starts_with('-') => {
371            // Sanitize: keep only ASCII alphanumeric and hyphens to ensure
372            // the label is safe as a filename component.
373            let sanitized: String = second
374                .chars()
375                .filter(|c| c.is_ascii_alphanumeric() || *c == '-')
376                .collect();
377            if sanitized.is_empty() {
378                first.to_string()
379            } else {
380                format!("{first}-{sanitized}")
381            }
382        }
383        _ => first.to_string(),
384    }
385}
386
387fn truncate_for_prompt(output: &str) -> &str {
388    truncate_utf8(output, 4000)
389}
390
391// Truncate at a char boundary to avoid panics on multibyte UTF-8 sequences.
392fn truncate_utf8(s: &str, max_bytes: usize) -> &str {
393    if s.len() <= max_bytes {
394        return s;
395    }
396    let mut end = max_bytes;
397    while end > 0 && !s.is_char_boundary(end) {
398        end -= 1;
399    }
400    &s[..end]
401}
402
403/// Validate ANTHROPIC_API_URL uses HTTPS (with localhost exceptions).
404fn validate_anthropic_url(url: &str) -> Result<(), Error> {
405    if url.starts_with("https://") {
406        return Ok(());
407    }
408    // HTTP only allowed for localhost/127.0.0.1
409    // Extract host portion: "http://HOST:port/path" or "http://HOST/path"
410    if let Some(rest) = url.strip_prefix("http://") {
411        let host = rest.split([':', '/']).next().unwrap_or("");
412        if host == "localhost" || host == "127.0.0.1" {
413            return Ok(());
414        }
415    }
416    Err(Error::Learn(format!(
417        "ANTHROPIC_API_URL must use HTTPS (got: {url}). HTTP is only allowed for localhost/127.0.0.1."
418    )))
419}
420
421fn strip_fences(s: &str) -> String {
422    let trimmed = s.trim();
423    if let Some(rest) = trimmed.strip_prefix("```toml") {
424        rest.strip_suffix("```").unwrap_or(rest).trim().to_string()
425    } else if let Some(rest) = trimmed.strip_prefix("```") {
426        rest.strip_suffix("```").unwrap_or(rest).trim().to_string()
427    } else {
428        trimmed.to_string()
429    }
430}
431
432fn validate_pattern_toml(toml_str: &str) -> Result<(), Error> {
433    // Try to parse as our pattern format
434    #[derive(Deserialize)]
435    struct Check {
436        command_match: String,
437        // Deserialization target: field must exist for TOML parsing even if not read in code
438        #[allow(dead_code)] // used only for TOML deserialization validation
439        success: Option<SuccessCheck>,
440        failure: Option<FailureSection>,
441    }
442    #[derive(Deserialize)]
443    struct SuccessCheck {
444        pattern: String,
445        // Deserialization target: field must exist for TOML parsing even if not read in code
446        #[allow(dead_code)] // used only for TOML deserialization validation
447        summary: String,
448    }
449
450    let check: Check =
451        toml::from_str(toml_str).map_err(|e| Error::Learn(format!("invalid TOML: {e}")))?;
452
453    // Verify regexes compile
454    regex::Regex::new(&check.command_match)
455        .map_err(|e| Error::Learn(format!("invalid command_match regex: {e}")))?;
456
457    if let Some(s) = &check.success {
458        regex::Regex::new(&s.pattern)
459            .map_err(|e| Error::Learn(format!("invalid success pattern regex: {e}")))?;
460    }
461
462    if let Some(f) = &check.failure {
463        match f.strategy.as_deref().unwrap_or("tail") {
464            "grep" => {
465                let pat = f.grep_pattern.as_deref().ok_or_else(|| {
466                    Error::Learn("failure grep strategy requires a 'grep' field".into())
467                })?;
468                if pat.is_empty() {
469                    return Err(Error::Learn("failure grep regex must not be empty".into()));
470                }
471                regex::Regex::new(pat)
472                    .map_err(|e| Error::Learn(format!("invalid failure grep regex: {e}")))?;
473            }
474            "between" => {
475                let start = f.start.as_deref().ok_or_else(|| {
476                    Error::Learn("between strategy requires 'start' field".into())
477                })?;
478                if start.is_empty() {
479                    return Err(Error::Learn("between 'start' must not be empty".into()));
480                }
481                regex::Regex::new(start)
482                    .map_err(|e| Error::Learn(format!("invalid start regex: {e}")))?;
483                let end = f
484                    .end
485                    .as_deref()
486                    .ok_or_else(|| Error::Learn("between strategy requires 'end' field".into()))?;
487                if end.is_empty() {
488                    return Err(Error::Learn("between 'end' must not be empty".into()));
489                }
490                regex::Regex::new(end)
491                    .map_err(|e| Error::Learn(format!("invalid end regex: {e}")))?;
492            }
493            "tail" | "head" => {} // no regex to validate
494            other => {
495                return Err(Error::Learn(format!("unknown failure strategy: {other}")));
496            }
497        }
498    }
499
500    Ok(())
501}
502
503// Tests live in separate files to keep this module under 500 lines.
504#[cfg(test)]
505#[path = "learn_tests.rs"]
506mod tests;
507
508#[cfg(test)]
509#[path = "learn_prompt_tests.rs"]
510mod prompt_tests;