Skip to main content

double_o/
learn.rs

1use std::path::{Path, PathBuf};
2
3use serde::Deserialize;
4
5use crate::error::Error;
6
7// ---------------------------------------------------------------------------
8// Config
9// ---------------------------------------------------------------------------
10
11#[derive(Deserialize)]
12struct ConfigFile {
13    learn: Option<LearnConfig>,
14}
15
16#[derive(Deserialize, Clone)]
17pub struct LearnConfig {
18    pub provider: String,
19    pub model: String,
20    pub api_key_env: String,
21}
22
23impl Default for LearnConfig {
24    fn default() -> Self {
25        detect_provider()
26    }
27}
28
29// Auto-detect provider from available API keys (checked in priority order).
30fn detect_provider() -> LearnConfig {
31    detect_provider_with(|key| std::env::var(key).ok())
32}
33
34// Testable variant — accepts a closure for env lookup to avoid env mutation in tests.
35fn detect_provider_with<F: Fn(&str) -> Option<String>>(env_lookup: F) -> LearnConfig {
36    if env_lookup("ANTHROPIC_API_KEY").is_some() {
37        LearnConfig {
38            provider: "anthropic".into(),
39            model: "claude-haiku-4-5-20251001".into(),
40            api_key_env: "ANTHROPIC_API_KEY".into(),
41        }
42    } else if env_lookup("OPENAI_API_KEY").is_some() {
43        LearnConfig {
44            provider: "openai".into(),
45            model: "gpt-4o-mini".into(),
46            api_key_env: "OPENAI_API_KEY".into(),
47        }
48    } else if env_lookup("CEREBRAS_API_KEY").is_some() {
49        LearnConfig {
50            provider: "cerebras".into(),
51            // Cerebras model ID: check https://cloud.cerebras.ai/models for current model catalog
52            model: "zai-glm-4.7".into(),
53            api_key_env: "CEREBRAS_API_KEY".into(),
54        }
55    } else {
56        // Default to anthropic; will fail at runtime if no key is set.
57        LearnConfig {
58            provider: "anthropic".into(),
59            model: "claude-haiku-4-5-20251001".into(),
60            api_key_env: "ANTHROPIC_API_KEY".into(),
61        }
62    }
63}
64
65fn config_dir() -> PathBuf {
66    dirs::config_dir()
67        .unwrap_or_else(|| PathBuf::from("/tmp"))
68        .join("oo")
69}
70
71pub fn patterns_dir() -> PathBuf {
72    config_dir().join("patterns")
73}
74
75pub fn load_learn_config() -> Result<LearnConfig, Error> {
76    let path = config_dir().join("config.toml");
77    if !path.exists() {
78        return Ok(LearnConfig::default());
79    }
80    let content = std::fs::read_to_string(&path)
81        .map_err(|e| Error::Config(format!("{}: {e}", path.display())))?;
82    let cf: ConfigFile =
83        toml::from_str(&content).map_err(|e| Error::Config(format!("{}: {e}", path.display())))?;
84    Ok(cf.learn.unwrap_or_default())
85}
86
87// ---------------------------------------------------------------------------
88// Background learning
89// ---------------------------------------------------------------------------
90
91const SYSTEM_PROMPT: &str = r#"You generate tool output classification patterns for a command runner.
92
93Given a shell command, its stdout, stderr, and exit code, produce a TOML
94pattern file that captures:
95
961. A regex to match this command (command_match)
972. For success (exit 0): a regex with named capture groups to extract a
98   one-line summary, and a summary template using those groups
993. For failure (exit ≠ 0): a strategy to extract the actionable part of
100   the output (tail N lines, head N lines, grep for pattern, or extract
101   between markers)
102
103Be aggressive about compression. A 1000-line passing test suite should
104become "47 passed, 3.2s". A failing build should show only the first
105error and its context, not the full cascade.
106
107Respond with ONLY the TOML block. No explanation, no markdown fences."#;
108
109/// Run the learn flow: call LLM, validate + save pattern.
110pub fn run_learn(command: &str, output: &str, exit_code: i32) -> Result<(), Error> {
111    let config = load_learn_config()?;
112
113    let api_key = std::env::var(&config.api_key_env).map_err(|_| {
114        Error::Learn(format!(
115            "Set {} environment variable to use `oo learn`",
116            config.api_key_env
117        ))
118    })?;
119
120    let user_msg = format!(
121        "Command: {command}\nExit code: {exit_code}\nOutput:\n{}",
122        truncate_for_prompt(output)
123    );
124
125    eprintln!("  [learning pattern for \"{}\"]", label(command));
126
127    let toml_response = match config.provider.as_str() {
128        "anthropic" => call_anthropic(
129            "https://api.anthropic.com/v1/messages",
130            &api_key,
131            &config.model,
132            &user_msg,
133        )?,
134        "openai" => call_openai(
135            "https://api.openai.com/v1/chat/completions",
136            &api_key,
137            &config.model,
138            &user_msg,
139        )?,
140        "cerebras" => call_openai(
141            "https://api.cerebras.ai/v1/chat/completions",
142            &api_key,
143            &config.model,
144            &user_msg,
145        )?,
146        other => return Err(Error::Learn(format!("unknown provider: {other}"))),
147    };
148
149    // Strip markdown fences if present
150    let toml_clean = strip_fences(&toml_response);
151
152    // Validate: parse as pattern
153    validate_pattern_toml(&toml_clean)?;
154
155    // Save
156    let dir = patterns_dir();
157    std::fs::create_dir_all(&dir).map_err(|e| Error::Learn(e.to_string()))?;
158    let filename = format!("{}.toml", label(command));
159    let path = dir.join(&filename);
160    std::fs::write(&path, &toml_clean).map_err(|e| Error::Learn(e.to_string()))?;
161
162    eprintln!("  [saved pattern to {}]", path.display());
163    Ok(())
164}
165
166/// Spawn the learning process in the background by re-exec'ing ourselves.
167pub fn spawn_background(command: &str, output: &str, exit_code: i32) -> Result<(), Error> {
168    let exe = std::env::current_exe().map_err(|e| Error::Learn(e.to_string()))?;
169
170    // Write data to a temp file for the child to read
171    let tmp = std::env::temp_dir().join(format!("oo-learn-{}", std::process::id()));
172    let data = serde_json::json!({
173        "command": command,
174        "output": output,
175        "exit_code": exit_code,
176    });
177    std::fs::write(&tmp, data.to_string()).map_err(|e| Error::Learn(e.to_string()))?;
178
179    // Spawn detached child
180    std::process::Command::new(exe)
181        .arg("_learn_bg")
182        .arg(&tmp)
183        .stdin(std::process::Stdio::null())
184        .stdout(std::process::Stdio::null())
185        .stderr(std::process::Stdio::null())
186        .spawn()
187        .map_err(|e| Error::Learn(e.to_string()))?;
188
189    Ok(())
190}
191
192/// Entry point for the background learn child process.
193pub fn run_background(data_path: &str) -> Result<(), Error> {
194    let path = Path::new(data_path);
195    let content = std::fs::read_to_string(path).map_err(|e| Error::Learn(e.to_string()))?;
196    let data: serde_json::Value =
197        serde_json::from_str(&content).map_err(|e| Error::Learn(e.to_string()))?;
198
199    let command = data["command"].as_str().unwrap_or("");
200    let output = data["output"].as_str().unwrap_or("");
201    let exit_code = data["exit_code"].as_i64().unwrap_or(0) as i32;
202
203    let result = run_learn(command, output, exit_code);
204
205    // Clean up temp file
206    let _ = std::fs::remove_file(path);
207
208    result
209}
210
211// ---------------------------------------------------------------------------
212// LLM API calls
213// ---------------------------------------------------------------------------
214
215fn call_anthropic(
216    base_url: &str,
217    api_key: &str,
218    model: &str,
219    user_msg: &str,
220) -> Result<String, Error> {
221    let body = serde_json::json!({
222        "model": model,
223        "max_tokens": 1024,
224        "system": SYSTEM_PROMPT,
225        "messages": [{"role": "user", "content": user_msg}],
226    });
227
228    let response: serde_json::Value = ureq::post(base_url)
229        .header("x-api-key", api_key)
230        .header("anthropic-version", "2023-06-01")
231        .header("content-type", "application/json")
232        .send_json(&body)
233        .map_err(|e| Error::Learn(format!("Anthropic API error: {e}")))?
234        .body_mut()
235        .read_json()
236        .map_err(|e| Error::Learn(format!("response parse error: {e}")))?;
237
238    response["content"][0]["text"]
239        .as_str()
240        .map(|s| s.to_string())
241        .ok_or_else(|| Error::Learn("unexpected Anthropic response format".into()))
242}
243
244fn call_openai(
245    base_url: &str,
246    api_key: &str,
247    model: &str,
248    user_msg: &str,
249) -> Result<String, Error> {
250    let body = serde_json::json!({
251        "model": model,
252        "messages": [
253            {"role": "system", "content": SYSTEM_PROMPT},
254            {"role": "user", "content": user_msg},
255        ],
256    });
257
258    let response: serde_json::Value = ureq::post(base_url)
259        .header("Authorization", &format!("Bearer {api_key}"))
260        .header("Content-Type", "application/json")
261        .send_json(&body)
262        .map_err(|e| Error::Learn(format!("OpenAI API error: {e}")))?
263        .body_mut()
264        .read_json()
265        .map_err(|e| Error::Learn(format!("response parse error: {e}")))?;
266
267    response["choices"][0]["message"]["content"]
268        .as_str()
269        .map(|s| s.to_string())
270        .ok_or_else(|| Error::Learn("unexpected OpenAI response format".into()))
271}
272
273// ---------------------------------------------------------------------------
274// Helpers
275// ---------------------------------------------------------------------------
276
277fn label(command: &str) -> String {
278    command
279        .split_whitespace()
280        .next()
281        .unwrap_or("unknown")
282        .rsplit('/')
283        .next()
284        .unwrap_or("unknown")
285        .to_string()
286}
287
288fn truncate_for_prompt(output: &str) -> &str {
289    truncate_utf8(output, 4000)
290}
291
292// Truncate at a char boundary to avoid panics on multibyte UTF-8 sequences.
293fn truncate_utf8(s: &str, max_bytes: usize) -> &str {
294    if s.len() <= max_bytes {
295        return s;
296    }
297    let mut end = max_bytes;
298    while end > 0 && !s.is_char_boundary(end) {
299        end -= 1;
300    }
301    &s[..end]
302}
303
304fn strip_fences(s: &str) -> String {
305    let trimmed = s.trim();
306    if let Some(rest) = trimmed.strip_prefix("```toml") {
307        rest.strip_suffix("```").unwrap_or(rest).trim().to_string()
308    } else if let Some(rest) = trimmed.strip_prefix("```") {
309        rest.strip_suffix("```").unwrap_or(rest).trim().to_string()
310    } else {
311        trimmed.to_string()
312    }
313}
314
315fn validate_pattern_toml(toml_str: &str) -> Result<(), Error> {
316    // Try to parse as our pattern format
317    #[derive(Deserialize)]
318    struct Check {
319        command_match: String,
320        #[allow(dead_code)]
321        success: Option<SuccessCheck>,
322        #[allow(dead_code)]
323        failure: Option<serde_json::Value>,
324    }
325    #[derive(Deserialize)]
326    struct SuccessCheck {
327        pattern: String,
328        #[allow(dead_code)]
329        summary: String,
330    }
331
332    let check: Check =
333        toml::from_str(toml_str).map_err(|e| Error::Learn(format!("invalid TOML: {e}")))?;
334
335    // Verify regexes compile
336    regex::Regex::new(&check.command_match)
337        .map_err(|e| Error::Learn(format!("invalid command_match regex: {e}")))?;
338
339    if let Some(s) = &check.success {
340        regex::Regex::new(&s.pattern)
341            .map_err(|e| Error::Learn(format!("invalid success pattern regex: {e}")))?;
342    }
343
344    Ok(())
345}
346
347// Tests live in a separate file to keep this module under 500 lines.
348#[cfg(test)]
349#[path = "learn_tests.rs"]
350mod tests;