Skip to main content

double_o/
learn.rs

1use std::io::Write;
2use std::path::{Path, PathBuf};
3
4use serde::Deserialize;
5
6use crate::error::Error;
7
8// ---------------------------------------------------------------------------
9// Config
10// ---------------------------------------------------------------------------
11
12#[derive(Deserialize)]
13struct ConfigFile {
14    learn: Option<LearnConfig>,
15}
16
17#[derive(Deserialize, Clone)]
18pub struct LearnConfig {
19    pub provider: String,
20    pub model: String,
21    pub api_key_env: String,
22}
23
24impl Default for LearnConfig {
25    fn default() -> Self {
26        detect_provider()
27    }
28}
29
30// Auto-detect provider from available API keys (checked in priority order).
31fn detect_provider() -> LearnConfig {
32    detect_provider_with(|key| std::env::var(key).ok())
33}
34
35// Testable variant — accepts a closure for env lookup to avoid env mutation in tests.
36fn detect_provider_with<F: Fn(&str) -> Option<String>>(env_lookup: F) -> LearnConfig {
37    if env_lookup("ANTHROPIC_API_KEY").is_some() {
38        LearnConfig {
39            provider: "anthropic".into(),
40            model: "claude-haiku-4-5-20251001".into(),
41            api_key_env: "ANTHROPIC_API_KEY".into(),
42        }
43    } else if env_lookup("OPENAI_API_KEY").is_some() {
44        LearnConfig {
45            provider: "openai".into(),
46            model: "gpt-4o-mini".into(),
47            api_key_env: "OPENAI_API_KEY".into(),
48        }
49    } else if env_lookup("CEREBRAS_API_KEY").is_some() {
50        LearnConfig {
51            provider: "cerebras".into(),
52            // Cerebras model ID: check https://cloud.cerebras.ai/models for current model catalog
53            model: "zai-glm-4.7".into(),
54            api_key_env: "CEREBRAS_API_KEY".into(),
55        }
56    } else {
57        // Default to anthropic; will fail at runtime if no key is set.
58        LearnConfig {
59            provider: "anthropic".into(),
60            model: "claude-haiku-4-5-20251001".into(),
61            api_key_env: "ANTHROPIC_API_KEY".into(),
62        }
63    }
64}
65
66fn config_dir() -> PathBuf {
67    dirs::config_dir()
68        .unwrap_or_else(|| PathBuf::from("/tmp"))
69        .join("oo")
70}
71
72pub fn patterns_dir() -> PathBuf {
73    config_dir().join("patterns")
74}
75
76pub fn load_learn_config() -> Result<LearnConfig, Error> {
77    let path = config_dir().join("config.toml");
78    if !path.exists() {
79        return Ok(LearnConfig::default());
80    }
81    let content = std::fs::read_to_string(&path)
82        .map_err(|e| Error::Config(format!("{}: {e}", path.display())))?;
83    let cf: ConfigFile =
84        toml::from_str(&content).map_err(|e| Error::Config(format!("{}: {e}", path.display())))?;
85    Ok(cf.learn.unwrap_or_default())
86}
87
88// ---------------------------------------------------------------------------
89// Background learning
90// ---------------------------------------------------------------------------
91
92const SYSTEM_PROMPT: &str = r#"You generate output classification patterns for `oo`, a shell command runner used by an LLM coding agent.
93
94The agent reads your pattern to decide its next action. Returning nothing is the WORST outcome — an empty summary forces a costly recall cycle that wastes more tokens than a slightly verbose summary would.
95
96## oo's 4-tier system
97
98- Passthrough: output <4 KB passes through unchanged
99- Failure: failed commands get ✗ prefix with filtered error output
100- Success: successful commands get ✓ prefix with a pattern-extracted summary (your patterns target this tier)
101- Large: if your regex fails to match, output falls through to this tier (FTS5 indexed for recall) — not catastrophic
102
103## Output format
104
105Respond with ONLY a TOML block. Fences optional.
106
107    command_match = "^pytest"
108    [success]
109    pattern = '(?P<n>\d+) passed'
110    summary = "{n} passed"
111    [failure]
112    strategy = "grep"
113    grep = "error|Error|FAILED"
114
115## Rules
116
117- For build/test commands: compress aggressively (e.g. "47 passed, 3.2s" or "error: …first error only")
118- For large tabular output (ls, docker ps, git log): omit the success section — let it fall through to Large tier (FTS5 indexed)
119- A regex that's too broad is better than one that matches and returns empty"#;
120
121/// Run the learn flow: call LLM, validate + save pattern.
122pub fn run_learn(command: &str, output: &str, exit_code: i32) -> Result<(), Error> {
123    let config = load_learn_config()?;
124
125    let api_key = std::env::var(&config.api_key_env).map_err(|_| {
126        Error::Learn(format!(
127            "Set {} environment variable to use `oo learn`",
128            config.api_key_env
129        ))
130    })?;
131
132    let user_msg = format!(
133        "Command: {command}\nExit code: {exit_code}\nOutput:\n{}",
134        truncate_for_prompt(output)
135    );
136
137    eprintln!("  [learning pattern for \"{}\"]", label(command));
138
139    let toml_response = match config.provider.as_str() {
140        "anthropic" => call_anthropic(
141            "https://api.anthropic.com/v1/messages",
142            &api_key,
143            &config.model,
144            &user_msg,
145        )?,
146        "openai" => call_openai(
147            "https://api.openai.com/v1/chat/completions",
148            &api_key,
149            &config.model,
150            &user_msg,
151        )?,
152        "cerebras" => call_openai(
153            "https://api.cerebras.ai/v1/chat/completions",
154            &api_key,
155            &config.model,
156            &user_msg,
157        )?,
158        other => return Err(Error::Learn(format!("unknown provider: {other}"))),
159    };
160
161    // Strip markdown fences if present
162    let toml_clean = strip_fences(&toml_response);
163
164    // Validate: parse as pattern
165    validate_pattern_toml(&toml_clean)?;
166
167    // Save
168    let dir = patterns_dir();
169    std::fs::create_dir_all(&dir).map_err(|e| Error::Learn(e.to_string()))?;
170    let filename = format!("{}.toml", label(command));
171    let path = dir.join(&filename);
172    std::fs::write(&path, &toml_clean).map_err(|e| Error::Learn(e.to_string()))?;
173
174    eprintln!("  [saved pattern to {}]", path.display());
175    Ok(())
176}
177
178/// Spawn the learning process in the background by re-exec'ing ourselves.
179pub fn spawn_background(command: &str, output: &str, exit_code: i32) -> Result<(), Error> {
180    let exe = std::env::current_exe().map_err(|e| Error::Learn(e.to_string()))?;
181
182    // Use a secure named temp file to avoid PID-based predictable filenames
183    // (symlink/TOCTOU attacks). The file is kept alive until the child spawns.
184    let mut tmp = tempfile::NamedTempFile::new().map_err(|e| Error::Learn(e.to_string()))?;
185    let data = serde_json::json!({
186        "command": command,
187        "output": output,
188        "exit_code": exit_code,
189    });
190    tmp.write_all(data.to_string().as_bytes())
191        .map_err(|e| Error::Learn(e.to_string()))?;
192
193    // Convert to TempPath: closes the file handle but keeps the file on disk
194    // until the TempPath is dropped — after the child has been spawned.
195    let tmp_path = tmp.into_temp_path();
196
197    // Spawn detached child
198    std::process::Command::new(exe)
199        .arg("_learn_bg")
200        .arg(&tmp_path)
201        .stdin(std::process::Stdio::null())
202        .stdout(std::process::Stdio::null())
203        .stderr(std::process::Stdio::null())
204        .spawn()
205        .map_err(|e| Error::Learn(e.to_string()))?;
206
207    // Prevent the parent from deleting the temp file on drop. On a loaded
208    // system the child process may not have opened the file yet by the time
209    // the parent exits this function. `keep()` makes the file persist on disk
210    // until the child cleans it up at run_background (line ~218 below).
211    tmp_path.keep().map_err(|e| Error::Learn(e.to_string()))?;
212
213    Ok(())
214}
215
216/// Entry point for the background learn child process.
217pub fn run_background(data_path: &str) -> Result<(), Error> {
218    let path = Path::new(data_path);
219    let content = std::fs::read_to_string(path).map_err(|e| Error::Learn(e.to_string()))?;
220    let data: serde_json::Value =
221        serde_json::from_str(&content).map_err(|e| Error::Learn(e.to_string()))?;
222
223    let command = data["command"].as_str().unwrap_or("");
224    let output = data["output"].as_str().unwrap_or("");
225    let exit_code = data["exit_code"].as_i64().unwrap_or(0) as i32;
226
227    let result = run_learn(command, output, exit_code);
228
229    // Clean up temp file
230    let _ = std::fs::remove_file(path);
231
232    result
233}
234
235// ---------------------------------------------------------------------------
236// LLM API calls
237// ---------------------------------------------------------------------------
238
239fn call_anthropic(
240    base_url: &str,
241    api_key: &str,
242    model: &str,
243    user_msg: &str,
244) -> Result<String, Error> {
245    let body = serde_json::json!({
246        "model": model,
247        "max_tokens": 1024,
248        "system": SYSTEM_PROMPT,
249        "messages": [{"role": "user", "content": user_msg}],
250    });
251
252    let response: serde_json::Value = ureq::post(base_url)
253        .header("x-api-key", api_key)
254        .header("anthropic-version", "2023-06-01")
255        .header("content-type", "application/json")
256        .send_json(&body)
257        .map_err(|e| Error::Learn(format!("Anthropic API error: {e}")))?
258        .body_mut()
259        .read_json()
260        .map_err(|e| Error::Learn(format!("response parse error: {e}")))?;
261
262    response["content"][0]["text"]
263        .as_str()
264        .map(|s| s.to_string())
265        .ok_or_else(|| Error::Learn("unexpected Anthropic response format".into()))
266}
267
268fn call_openai(
269    base_url: &str,
270    api_key: &str,
271    model: &str,
272    user_msg: &str,
273) -> Result<String, Error> {
274    let body = serde_json::json!({
275        "model": model,
276        "messages": [
277            {"role": "system", "content": SYSTEM_PROMPT},
278            {"role": "user", "content": user_msg},
279        ],
280    });
281
282    let response: serde_json::Value = ureq::post(base_url)
283        .header("Authorization", &format!("Bearer {api_key}"))
284        .header("Content-Type", "application/json")
285        .send_json(&body)
286        .map_err(|e| Error::Learn(format!("OpenAI API error: {e}")))?
287        .body_mut()
288        .read_json()
289        .map_err(|e| Error::Learn(format!("response parse error: {e}")))?;
290
291    response["choices"][0]["message"]["content"]
292        .as_str()
293        .map(|s| s.to_string())
294        .ok_or_else(|| Error::Learn("unexpected OpenAI response format".into()))
295}
296
297// ---------------------------------------------------------------------------
298// Helpers
299// ---------------------------------------------------------------------------
300
301fn label(command: &str) -> String {
302    command
303        .split_whitespace()
304        .next()
305        .unwrap_or("unknown")
306        .rsplit('/')
307        .next()
308        .unwrap_or("unknown")
309        .to_string()
310}
311
312fn truncate_for_prompt(output: &str) -> &str {
313    truncate_utf8(output, 4000)
314}
315
316// Truncate at a char boundary to avoid panics on multibyte UTF-8 sequences.
317fn truncate_utf8(s: &str, max_bytes: usize) -> &str {
318    if s.len() <= max_bytes {
319        return s;
320    }
321    let mut end = max_bytes;
322    while end > 0 && !s.is_char_boundary(end) {
323        end -= 1;
324    }
325    &s[..end]
326}
327
328fn strip_fences(s: &str) -> String {
329    let trimmed = s.trim();
330    if let Some(rest) = trimmed.strip_prefix("```toml") {
331        rest.strip_suffix("```").unwrap_or(rest).trim().to_string()
332    } else if let Some(rest) = trimmed.strip_prefix("```") {
333        rest.strip_suffix("```").unwrap_or(rest).trim().to_string()
334    } else {
335        trimmed.to_string()
336    }
337}
338
339fn validate_pattern_toml(toml_str: &str) -> Result<(), Error> {
340    // Try to parse as our pattern format
341    #[derive(Deserialize)]
342    struct Check {
343        command_match: String,
344        // Deserialization target: field must exist for TOML parsing even if not read in code
345        #[allow(dead_code)]
346        success: Option<SuccessCheck>,
347        // Deserialization target: field must exist for TOML parsing even if not read in code
348        #[allow(dead_code)]
349        failure: Option<serde_json::Value>,
350    }
351    #[derive(Deserialize)]
352    struct SuccessCheck {
353        pattern: String,
354        // Deserialization target: field must exist for TOML parsing even if not read in code
355        #[allow(dead_code)]
356        summary: String,
357    }
358
359    let check: Check =
360        toml::from_str(toml_str).map_err(|e| Error::Learn(format!("invalid TOML: {e}")))?;
361
362    // Verify regexes compile
363    regex::Regex::new(&check.command_match)
364        .map_err(|e| Error::Learn(format!("invalid command_match regex: {e}")))?;
365
366    if let Some(s) = &check.success {
367        regex::Regex::new(&s.pattern)
368            .map_err(|e| Error::Learn(format!("invalid success pattern regex: {e}")))?;
369    }
370
371    Ok(())
372}
373
374// Tests live in a separate file to keep this module under 500 lines.
375#[cfg(test)]
376#[path = "learn_tests.rs"]
377mod tests;