1use std::io::Write;
2use std::path::{Path, PathBuf};
3
4use serde::Deserialize;
5
6use crate::error::Error;
7
8#[derive(Deserialize)]
13struct ConfigFile {
14 learn: Option<LearnConfig>,
15}
16
17#[derive(Deserialize, Clone)]
18pub struct LearnConfig {
19 pub provider: String,
20 pub model: String,
21 pub api_key_env: String,
22}
23
24impl Default for LearnConfig {
25 fn default() -> Self {
26 detect_provider()
27 }
28}
29
30fn detect_provider() -> LearnConfig {
32 detect_provider_with(|key| std::env::var(key).ok())
33}
34
35fn detect_provider_with<F: Fn(&str) -> Option<String>>(env_lookup: F) -> LearnConfig {
37 if env_lookup("ANTHROPIC_API_KEY").is_some() {
38 LearnConfig {
39 provider: "anthropic".into(),
40 model: "claude-haiku-4-5-20251001".into(),
41 api_key_env: "ANTHROPIC_API_KEY".into(),
42 }
43 } else if env_lookup("OPENAI_API_KEY").is_some() {
44 LearnConfig {
45 provider: "openai".into(),
46 model: "gpt-4o-mini".into(),
47 api_key_env: "OPENAI_API_KEY".into(),
48 }
49 } else if env_lookup("CEREBRAS_API_KEY").is_some() {
50 LearnConfig {
51 provider: "cerebras".into(),
52 model: "zai-glm-4.7".into(),
54 api_key_env: "CEREBRAS_API_KEY".into(),
55 }
56 } else {
57 LearnConfig {
59 provider: "anthropic".into(),
60 model: "claude-haiku-4-5-20251001".into(),
61 api_key_env: "ANTHROPIC_API_KEY".into(),
62 }
63 }
64}
65
66fn config_dir() -> PathBuf {
67 dirs::config_dir()
68 .unwrap_or_else(|| PathBuf::from("/tmp"))
69 .join("oo")
70}
71
72pub fn patterns_dir() -> PathBuf {
73 config_dir().join("patterns")
74}
75
76pub fn load_learn_config() -> Result<LearnConfig, Error> {
77 let path = config_dir().join("config.toml");
78 if !path.exists() {
79 return Ok(LearnConfig::default());
80 }
81 let content = std::fs::read_to_string(&path)
82 .map_err(|e| Error::Config(format!("{}: {e}", path.display())))?;
83 let cf: ConfigFile =
84 toml::from_str(&content).map_err(|e| Error::Config(format!("{}: {e}", path.display())))?;
85 Ok(cf.learn.unwrap_or_default())
86}
87
88const SYSTEM_PROMPT: &str = r#"You generate output classification patterns for `oo`, a shell command runner used by an LLM coding agent.
93
94The agent reads your pattern to decide its next action. Returning nothing is the WORST outcome — an empty summary forces a costly recall cycle that wastes more tokens than a slightly verbose summary would.
95
96## oo's 4-tier system
97
98- Passthrough: output <4 KB passes through unchanged
99- Failure: failed commands get ✗ prefix with filtered error output
100- Success: successful commands get ✓ prefix with a pattern-extracted summary (your patterns target this tier)
101- Large: if your regex fails to match, output falls through to this tier (FTS5 indexed for recall) — not catastrophic
102
103## Output format
104
105Respond with ONLY a TOML block. Fences optional.
106
107 command_match = "^pytest"
108 [success]
109 pattern = '(?P<n>\d+) passed'
110 summary = "{n} passed"
111 [failure]
112 strategy = "grep"
113 grep = "error|Error|FAILED"
114
115## Rules
116
117- For build/test commands: compress aggressively (e.g. "47 passed, 3.2s" or "error: …first error only")
118- For large tabular output (ls, docker ps, git log): omit the success section — let it fall through to Large tier (FTS5 indexed)
119- A regex that's too broad is better than one that matches and returns empty"#;
120
121pub fn run_learn(command: &str, output: &str, exit_code: i32) -> Result<(), Error> {
123 let config = load_learn_config()?;
124
125 let api_key = std::env::var(&config.api_key_env).map_err(|_| {
126 Error::Learn(format!(
127 "Set {} environment variable to use `oo learn`",
128 config.api_key_env
129 ))
130 })?;
131
132 let user_msg = format!(
133 "Command: {command}\nExit code: {exit_code}\nOutput:\n{}",
134 truncate_for_prompt(output)
135 );
136
137 eprintln!(" [learning pattern for \"{}\"]", label(command));
138
139 let toml_response = match config.provider.as_str() {
140 "anthropic" => call_anthropic(
141 "https://api.anthropic.com/v1/messages",
142 &api_key,
143 &config.model,
144 &user_msg,
145 )?,
146 "openai" => call_openai(
147 "https://api.openai.com/v1/chat/completions",
148 &api_key,
149 &config.model,
150 &user_msg,
151 )?,
152 "cerebras" => call_openai(
153 "https://api.cerebras.ai/v1/chat/completions",
154 &api_key,
155 &config.model,
156 &user_msg,
157 )?,
158 other => return Err(Error::Learn(format!("unknown provider: {other}"))),
159 };
160
161 let toml_clean = strip_fences(&toml_response);
163
164 validate_pattern_toml(&toml_clean)?;
166
167 let dir = patterns_dir();
169 std::fs::create_dir_all(&dir).map_err(|e| Error::Learn(e.to_string()))?;
170 let filename = format!("{}.toml", label(command));
171 let path = dir.join(&filename);
172 std::fs::write(&path, &toml_clean).map_err(|e| Error::Learn(e.to_string()))?;
173
174 eprintln!(" [saved pattern to {}]", path.display());
175 Ok(())
176}
177
178pub fn spawn_background(command: &str, output: &str, exit_code: i32) -> Result<(), Error> {
180 let exe = std::env::current_exe().map_err(|e| Error::Learn(e.to_string()))?;
181
182 let mut tmp = tempfile::NamedTempFile::new().map_err(|e| Error::Learn(e.to_string()))?;
185 let data = serde_json::json!({
186 "command": command,
187 "output": output,
188 "exit_code": exit_code,
189 });
190 tmp.write_all(data.to_string().as_bytes())
191 .map_err(|e| Error::Learn(e.to_string()))?;
192
193 let tmp_path = tmp.into_temp_path();
196
197 std::process::Command::new(exe)
199 .arg("_learn_bg")
200 .arg(&tmp_path)
201 .stdin(std::process::Stdio::null())
202 .stdout(std::process::Stdio::null())
203 .stderr(std::process::Stdio::null())
204 .spawn()
205 .map_err(|e| Error::Learn(e.to_string()))?;
206
207 tmp_path.keep().map_err(|e| Error::Learn(e.to_string()))?;
212
213 Ok(())
214}
215
216pub fn run_background(data_path: &str) -> Result<(), Error> {
218 let path = Path::new(data_path);
219 let content = std::fs::read_to_string(path).map_err(|e| Error::Learn(e.to_string()))?;
220 let data: serde_json::Value =
221 serde_json::from_str(&content).map_err(|e| Error::Learn(e.to_string()))?;
222
223 let command = data["command"].as_str().unwrap_or("");
224 let output = data["output"].as_str().unwrap_or("");
225 let exit_code = data["exit_code"].as_i64().unwrap_or(0) as i32;
226
227 let result = run_learn(command, output, exit_code);
228
229 let _ = std::fs::remove_file(path);
231
232 result
233}
234
235fn call_anthropic(
240 base_url: &str,
241 api_key: &str,
242 model: &str,
243 user_msg: &str,
244) -> Result<String, Error> {
245 let body = serde_json::json!({
246 "model": model,
247 "max_tokens": 1024,
248 "system": SYSTEM_PROMPT,
249 "messages": [{"role": "user", "content": user_msg}],
250 });
251
252 let response: serde_json::Value = ureq::post(base_url)
253 .header("x-api-key", api_key)
254 .header("anthropic-version", "2023-06-01")
255 .header("content-type", "application/json")
256 .send_json(&body)
257 .map_err(|e| Error::Learn(format!("Anthropic API error: {e}")))?
258 .body_mut()
259 .read_json()
260 .map_err(|e| Error::Learn(format!("response parse error: {e}")))?;
261
262 response["content"][0]["text"]
263 .as_str()
264 .map(|s| s.to_string())
265 .ok_or_else(|| Error::Learn("unexpected Anthropic response format".into()))
266}
267
268fn call_openai(
269 base_url: &str,
270 api_key: &str,
271 model: &str,
272 user_msg: &str,
273) -> Result<String, Error> {
274 let body = serde_json::json!({
275 "model": model,
276 "messages": [
277 {"role": "system", "content": SYSTEM_PROMPT},
278 {"role": "user", "content": user_msg},
279 ],
280 });
281
282 let response: serde_json::Value = ureq::post(base_url)
283 .header("Authorization", &format!("Bearer {api_key}"))
284 .header("Content-Type", "application/json")
285 .send_json(&body)
286 .map_err(|e| Error::Learn(format!("OpenAI API error: {e}")))?
287 .body_mut()
288 .read_json()
289 .map_err(|e| Error::Learn(format!("response parse error: {e}")))?;
290
291 response["choices"][0]["message"]["content"]
292 .as_str()
293 .map(|s| s.to_string())
294 .ok_or_else(|| Error::Learn("unexpected OpenAI response format".into()))
295}
296
297fn label(command: &str) -> String {
302 command
303 .split_whitespace()
304 .next()
305 .unwrap_or("unknown")
306 .rsplit('/')
307 .next()
308 .unwrap_or("unknown")
309 .to_string()
310}
311
312fn truncate_for_prompt(output: &str) -> &str {
313 truncate_utf8(output, 4000)
314}
315
316fn truncate_utf8(s: &str, max_bytes: usize) -> &str {
318 if s.len() <= max_bytes {
319 return s;
320 }
321 let mut end = max_bytes;
322 while end > 0 && !s.is_char_boundary(end) {
323 end -= 1;
324 }
325 &s[..end]
326}
327
328fn strip_fences(s: &str) -> String {
329 let trimmed = s.trim();
330 if let Some(rest) = trimmed.strip_prefix("```toml") {
331 rest.strip_suffix("```").unwrap_or(rest).trim().to_string()
332 } else if let Some(rest) = trimmed.strip_prefix("```") {
333 rest.strip_suffix("```").unwrap_or(rest).trim().to_string()
334 } else {
335 trimmed.to_string()
336 }
337}
338
339fn validate_pattern_toml(toml_str: &str) -> Result<(), Error> {
340 #[derive(Deserialize)]
342 struct Check {
343 command_match: String,
344 #[allow(dead_code)]
346 success: Option<SuccessCheck>,
347 #[allow(dead_code)]
349 failure: Option<serde_json::Value>,
350 }
351 #[derive(Deserialize)]
352 struct SuccessCheck {
353 pattern: String,
354 #[allow(dead_code)]
356 summary: String,
357 }
358
359 let check: Check =
360 toml::from_str(toml_str).map_err(|e| Error::Learn(format!("invalid TOML: {e}")))?;
361
362 regex::Regex::new(&check.command_match)
364 .map_err(|e| Error::Learn(format!("invalid command_match regex: {e}")))?;
365
366 if let Some(s) = &check.success {
367 regex::Regex::new(&s.pattern)
368 .map_err(|e| Error::Learn(format!("invalid success pattern regex: {e}")))?;
369 }
370
371 Ok(())
372}
373
374#[cfg(test)]
376#[path = "learn_tests.rs"]
377mod tests;