Skip to main content

double_o/
learn.rs

1use std::io::Write;
2use std::path::{Path, PathBuf};
3
4use serde::Deserialize;
5
6use crate::error::Error;
7use crate::pattern::FailureSection;
8
9// ---------------------------------------------------------------------------
10// Config
11// ---------------------------------------------------------------------------
12
13#[derive(Deserialize)]
14struct ConfigFile {
15    learn: Option<LearnConfig>,
16}
17
18#[derive(Deserialize, Clone)]
19pub struct LearnConfig {
20    pub provider: String,
21    pub model: String,
22    pub api_key_env: String,
23}
24
25impl Default for LearnConfig {
26    fn default() -> Self {
27        detect_provider()
28    }
29}
30
31// Auto-detect provider from available API keys (checked in priority order).
32fn detect_provider() -> LearnConfig {
33    detect_provider_with(|key| std::env::var(key).ok())
34}
35
36// Testable variant — accepts a closure for env lookup to avoid env mutation in tests.
37fn detect_provider_with<F: Fn(&str) -> Option<String>>(env_lookup: F) -> LearnConfig {
38    if env_lookup("ANTHROPIC_API_KEY").is_some() {
39        LearnConfig {
40            provider: "anthropic".into(),
41            model: "claude-haiku-4-5-20251001".into(),
42            api_key_env: "ANTHROPIC_API_KEY".into(),
43        }
44    } else if env_lookup("OPENAI_API_KEY").is_some() {
45        LearnConfig {
46            provider: "openai".into(),
47            model: "gpt-4o-mini".into(),
48            api_key_env: "OPENAI_API_KEY".into(),
49        }
50    } else if env_lookup("CEREBRAS_API_KEY").is_some() {
51        LearnConfig {
52            provider: "cerebras".into(),
53            // Cerebras model ID: check https://cloud.cerebras.ai/models for current model catalog
54            model: "zai-glm-4.7".into(),
55            api_key_env: "CEREBRAS_API_KEY".into(),
56        }
57    } else {
58        // Default to anthropic; will fail at runtime if no key is set.
59        LearnConfig {
60            provider: "anthropic".into(),
61            model: "claude-haiku-4-5-20251001".into(),
62            api_key_env: "ANTHROPIC_API_KEY".into(),
63        }
64    }
65}
66
67fn config_dir() -> PathBuf {
68    dirs::config_dir()
69        .unwrap_or_else(|| PathBuf::from("/tmp"))
70        .join("oo")
71}
72
73pub fn patterns_dir() -> PathBuf {
74    config_dir().join("patterns")
75}
76
77/// Path to the one-line status file written by the background learn process.
78pub fn learn_status_path() -> PathBuf {
79    config_dir().join("learn-status.log")
80}
81
82pub fn load_learn_config() -> Result<LearnConfig, Error> {
83    let path = config_dir().join("config.toml");
84    if !path.exists() {
85        return Ok(LearnConfig::default());
86    }
87    let content = std::fs::read_to_string(&path)
88        .map_err(|e| Error::Config(format!("{}: {e}", path.display())))?;
89    let cf: ConfigFile =
90        toml::from_str(&content).map_err(|e| Error::Config(format!("{}: {e}", path.display())))?;
91    Ok(cf.learn.unwrap_or_default())
92}
93
94// ---------------------------------------------------------------------------
95// Background learning
96// ---------------------------------------------------------------------------
97
98const SYSTEM_PROMPT: &str = r#"You generate output classification patterns for `oo`, a shell command runner used by an LLM coding agent.
99
100The agent reads your pattern to decide its next action. Returning nothing is the WORST outcome — an empty summary forces a costly recall cycle that wastes more tokens than a slightly verbose summary would.
101
102## oo's 4-tier system
103
104- Passthrough: output <4 KB passes through unchanged
105- Failure: failed commands get ✗ prefix with filtered error output
106- Success: successful commands get ✓ prefix with a pattern-extracted summary (your patterns target this tier)
107- Large: if your regex fails to match, output falls through to this tier (FTS5 indexed for recall) — not catastrophic
108
109## Output format
110
111Respond with ONLY a TOML block. Fences optional.
112
113    command_match = "^pytest"
114    [success]
115    pattern = '(?P<n>\d+) passed'
116    summary = "{n} passed"
117    [failure]
118    strategy = "grep"
119    grep = "error|Error|FAILED"
120
121## Rules
122
123- For build/test commands: compress aggressively (e.g. "47 passed, 3.2s" or "error: …first error only")
124- For large tabular output (ls, docker ps, git log): omit the success section — let it fall through to Large tier (FTS5 indexed)
125- A regex that's too broad is better than one that matches and returns empty"#;
126
127/// Run the learn flow: call LLM, validate + save pattern.
128pub fn run_learn(command: &str, output: &str, exit_code: i32) -> Result<(), Error> {
129    let config = load_learn_config()?;
130
131    let api_key = std::env::var(&config.api_key_env).map_err(|_| {
132        Error::Learn(format!(
133            "Set {} environment variable to use `oo learn`",
134            config.api_key_env
135        ))
136    })?;
137
138    let user_msg = format!(
139        "Command: {command}\nExit code: {exit_code}\nOutput:\n{}",
140        truncate_for_prompt(output)
141    );
142
143    let toml_response = match config.provider.as_str() {
144        "anthropic" => call_anthropic(
145            "https://api.anthropic.com/v1/messages",
146            &api_key,
147            &config.model,
148            &user_msg,
149        )?,
150        "openai" => call_openai(
151            "https://api.openai.com/v1/chat/completions",
152            &api_key,
153            &config.model,
154            &user_msg,
155        )?,
156        "cerebras" => call_openai(
157            "https://api.cerebras.ai/v1/chat/completions",
158            &api_key,
159            &config.model,
160            &user_msg,
161        )?,
162        other => return Err(Error::Learn(format!("unknown provider: {other}"))),
163    };
164
165    // Strip markdown fences if present
166    let toml_clean = strip_fences(&toml_response);
167
168    // Validate: parse as pattern
169    validate_pattern_toml(&toml_clean)?;
170
171    // Save
172    let dir = patterns_dir();
173    std::fs::create_dir_all(&dir).map_err(|e| Error::Learn(e.to_string()))?;
174    let filename = format!("{}.toml", label(command));
175    let path = dir.join(&filename);
176    std::fs::write(&path, &toml_clean).map_err(|e| Error::Learn(e.to_string()))?;
177
178    // Write status file for the foreground process to display on next invocation
179    let status_path = learn_status_path();
180    let cmd_label = label(command);
181    let _ = crate::commands::write_learn_status(&status_path, &cmd_label, &path);
182
183    Ok(())
184}
185
186/// Spawn the learning process in the background by re-exec'ing ourselves.
187pub fn spawn_background(command: &str, output: &str, exit_code: i32) -> Result<(), Error> {
188    let exe = std::env::current_exe().map_err(|e| Error::Learn(e.to_string()))?;
189
190    // Use a secure named temp file to avoid PID-based predictable filenames
191    // (symlink/TOCTOU attacks). The file is kept alive until the child spawns.
192    let mut tmp = tempfile::NamedTempFile::new().map_err(|e| Error::Learn(e.to_string()))?;
193    let data = serde_json::json!({
194        "command": command,
195        "output": output,
196        "exit_code": exit_code,
197    });
198    tmp.write_all(data.to_string().as_bytes())
199        .map_err(|e| Error::Learn(e.to_string()))?;
200
201    // Convert to TempPath: closes the file handle but keeps the file on disk
202    // until the TempPath is dropped — after the child has been spawned.
203    let tmp_path = tmp.into_temp_path();
204
205    // Spawn detached child
206    std::process::Command::new(exe)
207        .arg("_learn_bg")
208        .arg(&tmp_path)
209        .stdin(std::process::Stdio::null())
210        .stdout(std::process::Stdio::null())
211        .stderr(std::process::Stdio::null())
212        .spawn()
213        .map_err(|e| Error::Learn(e.to_string()))?;
214
215    // Prevent the parent from deleting the temp file on drop. On a loaded
216    // system the child process may not have opened the file yet by the time
217    // the parent exits this function. `keep()` makes the file persist on disk
218    // until the child cleans it up at run_background (line ~218 below).
219    tmp_path.keep().map_err(|e| Error::Learn(e.to_string()))?;
220
221    Ok(())
222}
223
224/// Entry point for the background learn child process.
225pub fn run_background(data_path: &str) -> Result<(), Error> {
226    let path = Path::new(data_path);
227    let content = std::fs::read_to_string(path).map_err(|e| Error::Learn(e.to_string()))?;
228    let data: serde_json::Value =
229        serde_json::from_str(&content).map_err(|e| Error::Learn(e.to_string()))?;
230
231    let command = data["command"].as_str().unwrap_or("");
232    let output = data["output"].as_str().unwrap_or("");
233    let exit_code = data["exit_code"].as_i64().unwrap_or(0) as i32;
234
235    let result = run_learn(command, output, exit_code);
236
237    // Clean up temp file
238    let _ = std::fs::remove_file(path);
239
240    result
241}
242
243// ---------------------------------------------------------------------------
244// LLM API calls
245// ---------------------------------------------------------------------------
246
247fn call_anthropic(
248    base_url: &str,
249    api_key: &str,
250    model: &str,
251    user_msg: &str,
252) -> Result<String, Error> {
253    let body = serde_json::json!({
254        "model": model,
255        "max_tokens": 1024,
256        "system": SYSTEM_PROMPT,
257        "messages": [{"role": "user", "content": user_msg}],
258    });
259
260    let response: serde_json::Value = ureq::post(base_url)
261        .header("x-api-key", api_key)
262        .header("anthropic-version", "2023-06-01")
263        .header("content-type", "application/json")
264        .send_json(&body)
265        .map_err(|e| Error::Learn(format!("Anthropic API error: {e}")))?
266        .body_mut()
267        .read_json()
268        .map_err(|e| Error::Learn(format!("response parse error: {e}")))?;
269
270    response["content"][0]["text"]
271        .as_str()
272        .map(|s| s.to_string())
273        .ok_or_else(|| Error::Learn("unexpected Anthropic response format".into()))
274}
275
276fn call_openai(
277    base_url: &str,
278    api_key: &str,
279    model: &str,
280    user_msg: &str,
281) -> Result<String, Error> {
282    let body = serde_json::json!({
283        "model": model,
284        "messages": [
285            {"role": "system", "content": SYSTEM_PROMPT},
286            {"role": "user", "content": user_msg},
287        ],
288    });
289
290    let response: serde_json::Value = ureq::post(base_url)
291        .header("Authorization", &format!("Bearer {api_key}"))
292        .header("Content-Type", "application/json")
293        .send_json(&body)
294        .map_err(|e| Error::Learn(format!("OpenAI API error: {e}")))?
295        .body_mut()
296        .read_json()
297        .map_err(|e| Error::Learn(format!("response parse error: {e}")))?;
298
299    response["choices"][0]["message"]["content"]
300        .as_str()
301        .map(|s| s.to_string())
302        .ok_or_else(|| Error::Learn("unexpected OpenAI response format".into()))
303}
304
305// ---------------------------------------------------------------------------
306// Helpers
307// ---------------------------------------------------------------------------
308
309fn label(command: &str) -> String {
310    command
311        .split_whitespace()
312        .next()
313        .unwrap_or("unknown")
314        .rsplit('/')
315        .next()
316        .unwrap_or("unknown")
317        .to_string()
318}
319
320fn truncate_for_prompt(output: &str) -> &str {
321    truncate_utf8(output, 4000)
322}
323
324// Truncate at a char boundary to avoid panics on multibyte UTF-8 sequences.
325fn truncate_utf8(s: &str, max_bytes: usize) -> &str {
326    if s.len() <= max_bytes {
327        return s;
328    }
329    let mut end = max_bytes;
330    while end > 0 && !s.is_char_boundary(end) {
331        end -= 1;
332    }
333    &s[..end]
334}
335
336fn strip_fences(s: &str) -> String {
337    let trimmed = s.trim();
338    if let Some(rest) = trimmed.strip_prefix("```toml") {
339        rest.strip_suffix("```").unwrap_or(rest).trim().to_string()
340    } else if let Some(rest) = trimmed.strip_prefix("```") {
341        rest.strip_suffix("```").unwrap_or(rest).trim().to_string()
342    } else {
343        trimmed.to_string()
344    }
345}
346
347fn validate_pattern_toml(toml_str: &str) -> Result<(), Error> {
348    // Try to parse as our pattern format
349    #[derive(Deserialize)]
350    struct Check {
351        command_match: String,
352        // Deserialization target: field must exist for TOML parsing even if not read in code
353        #[allow(dead_code)] // used only for TOML deserialization validation
354        success: Option<SuccessCheck>,
355        failure: Option<FailureSection>,
356    }
357    #[derive(Deserialize)]
358    struct SuccessCheck {
359        pattern: String,
360        // Deserialization target: field must exist for TOML parsing even if not read in code
361        #[allow(dead_code)] // used only for TOML deserialization validation
362        summary: String,
363    }
364
365    let check: Check =
366        toml::from_str(toml_str).map_err(|e| Error::Learn(format!("invalid TOML: {e}")))?;
367
368    // Verify regexes compile
369    regex::Regex::new(&check.command_match)
370        .map_err(|e| Error::Learn(format!("invalid command_match regex: {e}")))?;
371
372    if let Some(s) = &check.success {
373        regex::Regex::new(&s.pattern)
374            .map_err(|e| Error::Learn(format!("invalid success pattern regex: {e}")))?;
375    }
376
377    if let Some(f) = &check.failure {
378        match f.strategy.as_deref().unwrap_or("tail") {
379            "grep" => {
380                let pat = f.grep_pattern.as_deref().ok_or_else(|| {
381                    Error::Learn("failure grep strategy requires a 'grep' field".into())
382                })?;
383                if pat.is_empty() {
384                    return Err(Error::Learn("failure grep regex must not be empty".into()));
385                }
386                regex::Regex::new(pat)
387                    .map_err(|e| Error::Learn(format!("invalid failure grep regex: {e}")))?;
388            }
389            "between" => {
390                let start = f.start.as_deref().ok_or_else(|| {
391                    Error::Learn("between strategy requires 'start' field".into())
392                })?;
393                if start.is_empty() {
394                    return Err(Error::Learn("between 'start' must not be empty".into()));
395                }
396                regex::Regex::new(start)
397                    .map_err(|e| Error::Learn(format!("invalid start regex: {e}")))?;
398                let end = f
399                    .end
400                    .as_deref()
401                    .ok_or_else(|| Error::Learn("between strategy requires 'end' field".into()))?;
402                if end.is_empty() {
403                    return Err(Error::Learn("between 'end' must not be empty".into()));
404                }
405                regex::Regex::new(end)
406                    .map_err(|e| Error::Learn(format!("invalid end regex: {e}")))?;
407            }
408            "tail" | "head" => {} // no regex to validate
409            other => {
410                return Err(Error::Learn(format!("unknown failure strategy: {other}")));
411            }
412        }
413    }
414
415    Ok(())
416}
417
418// Tests live in a separate file to keep this module under 500 lines.
419#[cfg(test)]
420#[path = "learn_tests.rs"]
421mod tests;