Skip to main content

double_o/
learn.rs

1use std::io::Write;
2use std::path::{Path, PathBuf};
3
4use serde::Deserialize;
5
6use crate::error::Error;
7use crate::pattern::FailureSection;
8
9// ---------------------------------------------------------------------------
10// Config
11// ---------------------------------------------------------------------------
12
13#[derive(Deserialize)]
14struct ConfigFile {
15    learn: Option<LearnConfig>,
16}
17
18#[derive(Deserialize, Clone)]
19pub struct LearnConfig {
20    pub provider: String,
21    pub model: String,
22    pub api_key_env: String,
23}
24
25impl Default for LearnConfig {
26    fn default() -> Self {
27        detect_provider()
28    }
29}
30
31// Auto-detect provider from available API keys (checked in priority order).
32fn detect_provider() -> LearnConfig {
33    detect_provider_with(|key| std::env::var(key).ok())
34}
35
36// Testable variant — accepts a closure for env lookup to avoid env mutation in tests.
37fn detect_provider_with<F: Fn(&str) -> Option<String>>(env_lookup: F) -> LearnConfig {
38    if env_lookup("ANTHROPIC_API_KEY").is_some() {
39        LearnConfig {
40            provider: "anthropic".into(),
41            model: "claude-haiku-4-5".into(),
42            api_key_env: "ANTHROPIC_API_KEY".into(),
43        }
44    } else if env_lookup("OPENAI_API_KEY").is_some() {
45        LearnConfig {
46            provider: "openai".into(),
47            model: "gpt-4o-mini".into(),
48            api_key_env: "OPENAI_API_KEY".into(),
49        }
50    } else if env_lookup("CEREBRAS_API_KEY").is_some() {
51        LearnConfig {
52            provider: "cerebras".into(),
53            // Cerebras model ID: check https://cloud.cerebras.ai/models for current model catalog
54            model: "zai-glm-4.7".into(),
55            api_key_env: "CEREBRAS_API_KEY".into(),
56        }
57    } else {
58        // Default to anthropic; will fail at runtime if no key is set.
59        LearnConfig {
60            provider: "anthropic".into(),
61            model: "claude-haiku-4-5".into(),
62            api_key_env: "ANTHROPIC_API_KEY".into(),
63        }
64    }
65}
66
67fn config_dir() -> PathBuf {
68    dirs::config_dir()
69        .unwrap_or_else(|| PathBuf::from("/tmp"))
70        .join("oo")
71}
72
73pub fn patterns_dir() -> PathBuf {
74    config_dir().join("patterns")
75}
76
77/// Path to the one-line status file written by the background learn process.
78pub fn learn_status_path() -> PathBuf {
79    config_dir().join("learn-status.log")
80}
81
82pub fn load_learn_config() -> Result<LearnConfig, Error> {
83    let path = config_dir().join("config.toml");
84    if !path.exists() {
85        return Ok(LearnConfig::default());
86    }
87    let content = std::fs::read_to_string(&path)
88        .map_err(|e| Error::Config(format!("{}: {e}", path.display())))?;
89    let cf: ConfigFile =
90        toml::from_str(&content).map_err(|e| Error::Config(format!("{}: {e}", path.display())))?;
91    Ok(cf.learn.unwrap_or_default())
92}
93
94// ---------------------------------------------------------------------------
95// Background learning
96// ---------------------------------------------------------------------------
97
98const SYSTEM_PROMPT: &str = r#"You generate output classification patterns for `oo`, a shell command runner used by an LLM coding agent.
99
100The agent reads your pattern to decide its next action. Returning nothing is the WORST outcome — an empty summary forces a costly recall cycle.
101
102IMPORTANT: Use named capture groups (?P<name>...) only — never numbered groups like (\d+). Summary templates use {name} placeholders matching the named groups.
103
104## oo's 4-tier system
105
106- Passthrough: output <4 KB passes through unchanged
107- Failure: failed commands get ✗ prefix with filtered error output
108- Success: successful commands get ✓ prefix with a pattern-extracted summary
109- Large: if regex fails to match, output is FTS5 indexed for recall
110
111## Examples
112
113Test runner — capture RESULT line, not header; strategy=tail for failures:
114    command_match = "\\bcargo\\s+test\\b"
115    [success]
116    pattern = 'test result: ok\. (?P<passed>\d+) passed.*finished in (?P<time>[\d.]+)s'
117    summary = "{passed} passed, {time}s"
118    [failure]
119    strategy = "tail"
120    lines = 30
121
122Build/lint — quiet on success (only useful when failing); strategy=head for failures:
123    command_match = "\\bcargo\\s+build\\b"
124    [success]
125    pattern = "(?s).*"
126    summary = ""
127    [failure]
128    strategy = "head"
129    lines = 20
130
131## Rules
132
133- Test runners: capture SUMMARY line (e.g. 'test result: ok. 5 passed'), NOT headers (e.g. 'running 5 tests')
134- Build/lint tools: empty summary for success; head/lines=20 for failures
135- Large tabular output (ls, git log): omit success section — falls to Large tier"#;
136
137/// Run the learn flow: call LLM, validate + save pattern.
138pub fn run_learn(command: &str, output: &str, exit_code: i32) -> Result<(), Error> {
139    let config = load_learn_config()?;
140
141    let api_key = std::env::var(&config.api_key_env).map_err(|_| {
142        Error::Learn(format!(
143            "Set {} environment variable to use `oo learn`",
144            config.api_key_env
145        ))
146    })?;
147
148    let user_msg = format!(
149        "Command: {command}\nExit code: {exit_code}\nOutput:\n{}",
150        truncate_for_prompt(output)
151    );
152
153    let toml_response = match config.provider.as_str() {
154        "anthropic" => call_anthropic(
155            "https://api.anthropic.com/v1/messages",
156            &api_key,
157            &config.model,
158            &user_msg,
159        )?,
160        "openai" => call_openai(
161            "https://api.openai.com/v1/chat/completions",
162            &api_key,
163            &config.model,
164            &user_msg,
165        )?,
166        "cerebras" => call_openai(
167            "https://api.cerebras.ai/v1/chat/completions",
168            &api_key,
169            &config.model,
170            &user_msg,
171        )?,
172        other => return Err(Error::Learn(format!("unknown provider: {other}"))),
173    };
174
175    // Strip markdown fences if present
176    let toml_clean = strip_fences(&toml_response);
177
178    // Validate: parse as pattern
179    validate_pattern_toml(&toml_clean)?;
180
181    // Save
182    let dir = patterns_dir();
183    std::fs::create_dir_all(&dir).map_err(|e| Error::Learn(e.to_string()))?;
184    let filename = format!("{}.toml", label(command));
185    let path = dir.join(&filename);
186    std::fs::write(&path, &toml_clean).map_err(|e| Error::Learn(e.to_string()))?;
187
188    // Write status file for the foreground process to display on next invocation
189    let status_path = learn_status_path();
190    let cmd_label = label(command);
191    let _ = crate::commands::write_learn_status(&status_path, &cmd_label, &path);
192
193    Ok(())
194}
195
196/// Spawn the learning process in the background by re-exec'ing ourselves.
197pub fn spawn_background(command: &str, output: &str, exit_code: i32) -> Result<(), Error> {
198    let exe = std::env::current_exe().map_err(|e| Error::Learn(e.to_string()))?;
199
200    // Use a secure named temp file to avoid PID-based predictable filenames
201    // (symlink/TOCTOU attacks). The file is kept alive until the child spawns.
202    let mut tmp = tempfile::NamedTempFile::new().map_err(|e| Error::Learn(e.to_string()))?;
203    let data = serde_json::json!({
204        "command": command,
205        "output": output,
206        "exit_code": exit_code,
207    });
208    tmp.write_all(data.to_string().as_bytes())
209        .map_err(|e| Error::Learn(e.to_string()))?;
210
211    // Convert to TempPath: closes the file handle but keeps the file on disk
212    // until the TempPath is dropped — after the child has been spawned.
213    let tmp_path = tmp.into_temp_path();
214
215    // Spawn detached child
216    std::process::Command::new(exe)
217        .arg("_learn_bg")
218        .arg(&tmp_path)
219        .stdin(std::process::Stdio::null())
220        .stdout(std::process::Stdio::null())
221        .stderr(std::process::Stdio::null())
222        .spawn()
223        .map_err(|e| Error::Learn(e.to_string()))?;
224
225    // Prevent the parent from deleting the temp file on drop. On a loaded
226    // system the child process may not have opened the file yet by the time
227    // the parent exits this function. `keep()` makes the file persist on disk
228    // until the child cleans it up at run_background (line ~218 below).
229    tmp_path.keep().map_err(|e| Error::Learn(e.to_string()))?;
230
231    Ok(())
232}
233
234/// Entry point for the background learn child process.
235pub fn run_background(data_path: &str) -> Result<(), Error> {
236    let path = Path::new(data_path);
237    let content = std::fs::read_to_string(path).map_err(|e| Error::Learn(e.to_string()))?;
238    let data: serde_json::Value =
239        serde_json::from_str(&content).map_err(|e| Error::Learn(e.to_string()))?;
240
241    let command = data["command"].as_str().unwrap_or("");
242    let output = data["output"].as_str().unwrap_or("");
243    let exit_code = data["exit_code"].as_i64().unwrap_or(0) as i32;
244
245    let result = run_learn(command, output, exit_code);
246
247    // Clean up temp file
248    let _ = std::fs::remove_file(path);
249
250    if let Err(ref e) = result {
251        let cmd_label = label(command);
252        let status_path = learn_status_path();
253        let _ =
254            crate::commands::write_learn_status_failure(&status_path, &cmd_label, &e.to_string());
255    }
256
257    result
258}
259
260// ---------------------------------------------------------------------------
261// LLM API calls
262// ---------------------------------------------------------------------------
263
264fn call_anthropic(
265    base_url: &str,
266    api_key: &str,
267    model: &str,
268    user_msg: &str,
269) -> Result<String, Error> {
270    let body = serde_json::json!({
271        "model": model,
272        "max_tokens": 1024,
273        "system": SYSTEM_PROMPT,
274        "messages": [{"role": "user", "content": user_msg}],
275    });
276
277    let response: serde_json::Value = ureq::post(base_url)
278        .header("x-api-key", api_key)
279        .header("anthropic-version", "2023-06-01")
280        .header("content-type", "application/json")
281        .send_json(&body)
282        .map_err(|e| Error::Learn(format!("Anthropic API error: {e}")))?
283        .body_mut()
284        .read_json()
285        .map_err(|e| Error::Learn(format!("response parse error: {e}")))?;
286
287    response["content"][0]["text"]
288        .as_str()
289        .map(|s| s.to_string())
290        .ok_or_else(|| Error::Learn("unexpected Anthropic response format".into()))
291}
292
293fn call_openai(
294    base_url: &str,
295    api_key: &str,
296    model: &str,
297    user_msg: &str,
298) -> Result<String, Error> {
299    let body = serde_json::json!({
300        "model": model,
301        "messages": [
302            {"role": "system", "content": SYSTEM_PROMPT},
303            {"role": "user", "content": user_msg},
304        ],
305    });
306
307    let response: serde_json::Value = ureq::post(base_url)
308        .header("Authorization", &format!("Bearer {api_key}"))
309        .header("Content-Type", "application/json")
310        .send_json(&body)
311        .map_err(|e| Error::Learn(format!("OpenAI API error: {e}")))?
312        .body_mut()
313        .read_json()
314        .map_err(|e| Error::Learn(format!("response parse error: {e}")))?;
315
316    response["choices"][0]["message"]["content"]
317        .as_str()
318        .map(|s| s.to_string())
319        .ok_or_else(|| Error::Learn("unexpected OpenAI response format".into()))
320}
321
322// ---------------------------------------------------------------------------
323// Helpers
324// ---------------------------------------------------------------------------
325
326fn label(command: &str) -> String {
327    let mut words = command.split_whitespace();
328    let first = words
329        .next()
330        .unwrap_or("unknown")
331        .rsplit('/')
332        .next()
333        .unwrap_or("unknown");
334    // Include the second word only when it is a subcommand (not a flag).
335    match words.next() {
336        Some(second) if !second.starts_with('-') => {
337            // Sanitize: keep only ASCII alphanumeric and hyphens to ensure
338            // the label is safe as a filename component.
339            let sanitized: String = second
340                .chars()
341                .filter(|c| c.is_ascii_alphanumeric() || *c == '-')
342                .collect();
343            if sanitized.is_empty() {
344                first.to_string()
345            } else {
346                format!("{first}-{sanitized}")
347            }
348        }
349        _ => first.to_string(),
350    }
351}
352
353fn truncate_for_prompt(output: &str) -> &str {
354    truncate_utf8(output, 4000)
355}
356
357// Truncate at a char boundary to avoid panics on multibyte UTF-8 sequences.
358fn truncate_utf8(s: &str, max_bytes: usize) -> &str {
359    if s.len() <= max_bytes {
360        return s;
361    }
362    let mut end = max_bytes;
363    while end > 0 && !s.is_char_boundary(end) {
364        end -= 1;
365    }
366    &s[..end]
367}
368
369fn strip_fences(s: &str) -> String {
370    let trimmed = s.trim();
371    if let Some(rest) = trimmed.strip_prefix("```toml") {
372        rest.strip_suffix("```").unwrap_or(rest).trim().to_string()
373    } else if let Some(rest) = trimmed.strip_prefix("```") {
374        rest.strip_suffix("```").unwrap_or(rest).trim().to_string()
375    } else {
376        trimmed.to_string()
377    }
378}
379
380fn validate_pattern_toml(toml_str: &str) -> Result<(), Error> {
381    // Try to parse as our pattern format
382    #[derive(Deserialize)]
383    struct Check {
384        command_match: String,
385        // Deserialization target: field must exist for TOML parsing even if not read in code
386        #[allow(dead_code)] // used only for TOML deserialization validation
387        success: Option<SuccessCheck>,
388        failure: Option<FailureSection>,
389    }
390    #[derive(Deserialize)]
391    struct SuccessCheck {
392        pattern: String,
393        // Deserialization target: field must exist for TOML parsing even if not read in code
394        #[allow(dead_code)] // used only for TOML deserialization validation
395        summary: String,
396    }
397
398    let check: Check =
399        toml::from_str(toml_str).map_err(|e| Error::Learn(format!("invalid TOML: {e}")))?;
400
401    // Verify regexes compile
402    regex::Regex::new(&check.command_match)
403        .map_err(|e| Error::Learn(format!("invalid command_match regex: {e}")))?;
404
405    if let Some(s) = &check.success {
406        regex::Regex::new(&s.pattern)
407            .map_err(|e| Error::Learn(format!("invalid success pattern regex: {e}")))?;
408    }
409
410    if let Some(f) = &check.failure {
411        match f.strategy.as_deref().unwrap_or("tail") {
412            "grep" => {
413                let pat = f.grep_pattern.as_deref().ok_or_else(|| {
414                    Error::Learn("failure grep strategy requires a 'grep' field".into())
415                })?;
416                if pat.is_empty() {
417                    return Err(Error::Learn("failure grep regex must not be empty".into()));
418                }
419                regex::Regex::new(pat)
420                    .map_err(|e| Error::Learn(format!("invalid failure grep regex: {e}")))?;
421            }
422            "between" => {
423                let start = f.start.as_deref().ok_or_else(|| {
424                    Error::Learn("between strategy requires 'start' field".into())
425                })?;
426                if start.is_empty() {
427                    return Err(Error::Learn("between 'start' must not be empty".into()));
428                }
429                regex::Regex::new(start)
430                    .map_err(|e| Error::Learn(format!("invalid start regex: {e}")))?;
431                let end = f
432                    .end
433                    .as_deref()
434                    .ok_or_else(|| Error::Learn("between strategy requires 'end' field".into()))?;
435                if end.is_empty() {
436                    return Err(Error::Learn("between 'end' must not be empty".into()));
437                }
438                regex::Regex::new(end)
439                    .map_err(|e| Error::Learn(format!("invalid end regex: {e}")))?;
440            }
441            "tail" | "head" => {} // no regex to validate
442            other => {
443                return Err(Error::Learn(format!("unknown failure strategy: {other}")));
444            }
445        }
446    }
447
448    Ok(())
449}
450
451// Tests live in separate files to keep this module under 500 lines.
452#[cfg(test)]
453#[path = "learn_tests.rs"]
454mod tests;
455
456#[cfg(test)]
457#[path = "learn_prompt_tests.rs"]
458mod prompt_tests;