1use std::path::{Path, PathBuf};
2
3use serde::Deserialize;
4
5use crate::error::Error;
6
7#[derive(Deserialize)]
12struct ConfigFile {
13 learn: Option<LearnConfig>,
14}
15
16#[derive(Deserialize, Clone)]
17pub struct LearnConfig {
18 pub provider: String,
19 pub model: String,
20 pub api_key_env: String,
21}
22
23impl Default for LearnConfig {
24 fn default() -> Self {
25 detect_provider()
26 }
27}
28
29fn detect_provider() -> LearnConfig {
31 detect_provider_with(|key| std::env::var(key).ok())
32}
33
34fn detect_provider_with<F: Fn(&str) -> Option<String>>(env_lookup: F) -> LearnConfig {
36 if env_lookup("ANTHROPIC_API_KEY").is_some() {
37 LearnConfig {
38 provider: "anthropic".into(),
39 model: "claude-haiku-4-5-20251001".into(),
40 api_key_env: "ANTHROPIC_API_KEY".into(),
41 }
42 } else if env_lookup("OPENAI_API_KEY").is_some() {
43 LearnConfig {
44 provider: "openai".into(),
45 model: "gpt-4o-mini".into(),
46 api_key_env: "OPENAI_API_KEY".into(),
47 }
48 } else if env_lookup("CEREBRAS_API_KEY").is_some() {
49 LearnConfig {
50 provider: "cerebras".into(),
51 model: "zai-glm-4.7".into(),
53 api_key_env: "CEREBRAS_API_KEY".into(),
54 }
55 } else {
56 LearnConfig {
58 provider: "anthropic".into(),
59 model: "claude-haiku-4-5-20251001".into(),
60 api_key_env: "ANTHROPIC_API_KEY".into(),
61 }
62 }
63}
64
65fn config_dir() -> PathBuf {
66 dirs::config_dir()
67 .unwrap_or_else(|| PathBuf::from("/tmp"))
68 .join("oo")
69}
70
71pub fn patterns_dir() -> PathBuf {
72 config_dir().join("patterns")
73}
74
75pub fn load_learn_config() -> Result<LearnConfig, Error> {
76 let path = config_dir().join("config.toml");
77 if !path.exists() {
78 return Ok(LearnConfig::default());
79 }
80 let content = std::fs::read_to_string(&path)
81 .map_err(|e| Error::Config(format!("{}: {e}", path.display())))?;
82 let cf: ConfigFile =
83 toml::from_str(&content).map_err(|e| Error::Config(format!("{}: {e}", path.display())))?;
84 Ok(cf.learn.unwrap_or_default())
85}
86
87const SYSTEM_PROMPT: &str = r#"You generate tool output classification patterns for a command runner.
92
93Given a shell command, its stdout, stderr, and exit code, produce a TOML
94pattern file that captures:
95
961. A regex to match this command (command_match)
972. For success (exit 0): a regex with named capture groups to extract a
98 one-line summary, and a summary template using those groups
993. For failure (exit ≠ 0): a strategy to extract the actionable part of
100 the output (tail N lines, head N lines, grep for pattern, or extract
101 between markers)
102
103Be aggressive about compression. A 1000-line passing test suite should
104become "47 passed, 3.2s". A failing build should show only the first
105error and its context, not the full cascade.
106
107Respond with ONLY the TOML block. No explanation, no markdown fences."#;
108
109pub fn run_learn(command: &str, output: &str, exit_code: i32) -> Result<(), Error> {
111 let config = load_learn_config()?;
112
113 let api_key = std::env::var(&config.api_key_env).map_err(|_| {
114 Error::Learn(format!(
115 "Set {} environment variable to use `oo learn`",
116 config.api_key_env
117 ))
118 })?;
119
120 let user_msg = format!(
121 "Command: {command}\nExit code: {exit_code}\nOutput:\n{}",
122 truncate_for_prompt(output)
123 );
124
125 eprintln!(" [learning pattern for \"{}\"]", label(command));
126
127 let toml_response = match config.provider.as_str() {
128 "anthropic" => call_anthropic(
129 "https://api.anthropic.com/v1/messages",
130 &api_key,
131 &config.model,
132 &user_msg,
133 )?,
134 "openai" => call_openai(
135 "https://api.openai.com/v1/chat/completions",
136 &api_key,
137 &config.model,
138 &user_msg,
139 )?,
140 "cerebras" => call_openai(
141 "https://api.cerebras.ai/v1/chat/completions",
142 &api_key,
143 &config.model,
144 &user_msg,
145 )?,
146 other => return Err(Error::Learn(format!("unknown provider: {other}"))),
147 };
148
149 let toml_clean = strip_fences(&toml_response);
151
152 validate_pattern_toml(&toml_clean)?;
154
155 let dir = patterns_dir();
157 std::fs::create_dir_all(&dir).map_err(|e| Error::Learn(e.to_string()))?;
158 let filename = format!("{}.toml", label(command));
159 let path = dir.join(&filename);
160 std::fs::write(&path, &toml_clean).map_err(|e| Error::Learn(e.to_string()))?;
161
162 eprintln!(" [saved pattern to {}]", path.display());
163 Ok(())
164}
165
166pub fn spawn_background(command: &str, output: &str, exit_code: i32) -> Result<(), Error> {
168 let exe = std::env::current_exe().map_err(|e| Error::Learn(e.to_string()))?;
169
170 let tmp = std::env::temp_dir().join(format!("oo-learn-{}", std::process::id()));
172 let data = serde_json::json!({
173 "command": command,
174 "output": output,
175 "exit_code": exit_code,
176 });
177 std::fs::write(&tmp, data.to_string()).map_err(|e| Error::Learn(e.to_string()))?;
178
179 std::process::Command::new(exe)
181 .arg("_learn_bg")
182 .arg(&tmp)
183 .stdin(std::process::Stdio::null())
184 .stdout(std::process::Stdio::null())
185 .stderr(std::process::Stdio::null())
186 .spawn()
187 .map_err(|e| Error::Learn(e.to_string()))?;
188
189 Ok(())
190}
191
192pub fn run_background(data_path: &str) -> Result<(), Error> {
194 let path = Path::new(data_path);
195 let content = std::fs::read_to_string(path).map_err(|e| Error::Learn(e.to_string()))?;
196 let data: serde_json::Value =
197 serde_json::from_str(&content).map_err(|e| Error::Learn(e.to_string()))?;
198
199 let command = data["command"].as_str().unwrap_or("");
200 let output = data["output"].as_str().unwrap_or("");
201 let exit_code = data["exit_code"].as_i64().unwrap_or(0) as i32;
202
203 let result = run_learn(command, output, exit_code);
204
205 let _ = std::fs::remove_file(path);
207
208 result
209}
210
211fn call_anthropic(
216 base_url: &str,
217 api_key: &str,
218 model: &str,
219 user_msg: &str,
220) -> Result<String, Error> {
221 let body = serde_json::json!({
222 "model": model,
223 "max_tokens": 1024,
224 "system": SYSTEM_PROMPT,
225 "messages": [{"role": "user", "content": user_msg}],
226 });
227
228 let response: serde_json::Value = ureq::post(base_url)
229 .header("x-api-key", api_key)
230 .header("anthropic-version", "2023-06-01")
231 .header("content-type", "application/json")
232 .send_json(&body)
233 .map_err(|e| Error::Learn(format!("Anthropic API error: {e}")))?
234 .body_mut()
235 .read_json()
236 .map_err(|e| Error::Learn(format!("response parse error: {e}")))?;
237
238 response["content"][0]["text"]
239 .as_str()
240 .map(|s| s.to_string())
241 .ok_or_else(|| Error::Learn("unexpected Anthropic response format".into()))
242}
243
244fn call_openai(
245 base_url: &str,
246 api_key: &str,
247 model: &str,
248 user_msg: &str,
249) -> Result<String, Error> {
250 let body = serde_json::json!({
251 "model": model,
252 "messages": [
253 {"role": "system", "content": SYSTEM_PROMPT},
254 {"role": "user", "content": user_msg},
255 ],
256 });
257
258 let response: serde_json::Value = ureq::post(base_url)
259 .header("Authorization", &format!("Bearer {api_key}"))
260 .header("Content-Type", "application/json")
261 .send_json(&body)
262 .map_err(|e| Error::Learn(format!("OpenAI API error: {e}")))?
263 .body_mut()
264 .read_json()
265 .map_err(|e| Error::Learn(format!("response parse error: {e}")))?;
266
267 response["choices"][0]["message"]["content"]
268 .as_str()
269 .map(|s| s.to_string())
270 .ok_or_else(|| Error::Learn("unexpected OpenAI response format".into()))
271}
272
273fn label(command: &str) -> String {
278 command
279 .split_whitespace()
280 .next()
281 .unwrap_or("unknown")
282 .rsplit('/')
283 .next()
284 .unwrap_or("unknown")
285 .to_string()
286}
287
288fn truncate_for_prompt(output: &str) -> &str {
289 truncate_utf8(output, 4000)
290}
291
292fn truncate_utf8(s: &str, max_bytes: usize) -> &str {
294 if s.len() <= max_bytes {
295 return s;
296 }
297 let mut end = max_bytes;
298 while end > 0 && !s.is_char_boundary(end) {
299 end -= 1;
300 }
301 &s[..end]
302}
303
304fn strip_fences(s: &str) -> String {
305 let trimmed = s.trim();
306 if let Some(rest) = trimmed.strip_prefix("```toml") {
307 rest.strip_suffix("```").unwrap_or(rest).trim().to_string()
308 } else if let Some(rest) = trimmed.strip_prefix("```") {
309 rest.strip_suffix("```").unwrap_or(rest).trim().to_string()
310 } else {
311 trimmed.to_string()
312 }
313}
314
315fn validate_pattern_toml(toml_str: &str) -> Result<(), Error> {
316 #[derive(Deserialize)]
318 struct Check {
319 command_match: String,
320 #[allow(dead_code)]
321 success: Option<SuccessCheck>,
322 #[allow(dead_code)]
323 failure: Option<serde_json::Value>,
324 }
325 #[derive(Deserialize)]
326 struct SuccessCheck {
327 pattern: String,
328 #[allow(dead_code)]
329 summary: String,
330 }
331
332 let check: Check =
333 toml::from_str(toml_str).map_err(|e| Error::Learn(format!("invalid TOML: {e}")))?;
334
335 regex::Regex::new(&check.command_match)
337 .map_err(|e| Error::Learn(format!("invalid command_match regex: {e}")))?;
338
339 if let Some(s) = &check.success {
340 regex::Regex::new(&s.pattern)
341 .map_err(|e| Error::Learn(format!("invalid success pattern regex: {e}")))?;
342 }
343
344 Ok(())
345}
346
347#[cfg(test)]
349#[path = "learn_tests.rs"]
350mod tests;