Skip to main content

double_o/
learn.rs

1use std::io::Write;
2use std::path::{Path, PathBuf};
3
4use serde::Deserialize;
5
6use crate::error::Error;
7pub use crate::learn_prompt::SYSTEM_PROMPT;
8
9// ---------------------------------------------------------------------------
10// Config & Validation limits
11// ---------------------------------------------------------------------------
12
13/// Maximum allowed length for hint text to prevent payload bloat.
14const MAX_HINT_LENGTH: usize = 1000;
15
16/// Maximum allowed length for command text used in LLM prompt.
17const MAX_COMMAND_LENGTH: usize = 100;
18
19/// Maximum allowed length for a single filename component after sanitization.
20const MAX_FILENAME_COMPONENT: usize = 50;
21
22#[derive(Deserialize)]
23struct ConfigFile {
24    learn: Option<LearnConfig>,
25}
26
27/// Configuration for the `oo learn` LLM integration.
28///
29/// Specifies the LLM provider, model, and environment variable for the API key.
30#[derive(Deserialize, Clone)]
31pub struct LearnConfig {
32    /// LLM provider name (currently only "anthropic" is supported).
33    pub provider: String,
34
35    /// Model identifier (e.g., "claude-haiku-4-5").
36    pub model: String,
37
38    /// Environment variable containing the API key.
39    pub api_key_env: String,
40}
41
42/// Testable variant of learn config and paths — avoids env var mutation.
43pub(crate) struct LearnParams<'a> {
44    pub config: &'a LearnConfig,
45    pub api_key: &'a str,
46    pub base_url: &'a str,
47    pub patterns_dir: &'a Path,
48    pub learn_status_path: &'a Path,
49    pub hint: Option<&'a str>,
50}
51
52/// Typed struct for the JSON data passed to the background learn process.
53#[derive(Deserialize)]
54struct LearnData {
55    command: String,
56    output: String,
57    exit_code: i64,
58    hint: Option<String>,
59}
60
61impl Default for LearnConfig {
62    fn default() -> Self {
63        detect_provider()
64    }
65}
66
67// Auto-detect provider from available API keys (checked in priority order).
68fn detect_provider() -> LearnConfig {
69    detect_provider_with(|key| std::env::var(key).ok())
70}
71
72// Testable variant — accepts a closure for env lookup to avoid env mutation in tests.
73fn detect_provider_with<F: Fn(&str) -> Option<String>>(env_lookup: F) -> LearnConfig {
74    if env_lookup("ANTHROPIC_API_KEY").is_some() {
75        LearnConfig {
76            provider: "anthropic".into(),
77            model: "claude-haiku-4-5".into(),
78            api_key_env: "ANTHROPIC_API_KEY".into(),
79        }
80    } else {
81        // Default to anthropic; will fail at runtime if no key is set.
82        LearnConfig {
83            provider: "anthropic".into(),
84            model: "claude-haiku-4-5".into(),
85            api_key_env: "ANTHROPIC_API_KEY".into(),
86        }
87    }
88}
89
90fn config_dir() -> PathBuf {
91    if let Some(test_dir) = std::env::var_os("OO_CONFIG_DIR") {
92        return PathBuf::from(test_dir);
93    }
94    dirs::config_dir()
95        .unwrap_or_else(|| PathBuf::from("/tmp"))
96        .join("oo")
97}
98
99/// Get the directory containing user-defined patterns.
100///
101/// Returns `~/.config/oo/patterns` or the overridden `OO_CONFIG_DIR/patterns`.
102pub fn patterns_dir() -> PathBuf {
103    config_dir().join("patterns")
104}
105
106/// Path to the one-line status file written by the background learn process.
107pub fn learn_status_path() -> PathBuf {
108    config_dir().join("learn-status.log")
109}
110
111/// Load learn configuration from `~/.config/oo/config.toml`.
112///
113/// Returns the default configuration if the file doesn't exist.
114pub fn load_learn_config() -> Result<LearnConfig, Error> {
115    let path = config_dir().join("config.toml");
116    if !path.exists() {
117        return Ok(LearnConfig::default());
118    }
119    let content = std::fs::read_to_string(&path)
120        .map_err(|e| Error::Config(format!("{}: {e}", path.display())))?;
121    let cf: ConfigFile =
122        toml::from_str(&content).map_err(|e| Error::Config(format!("{}: {e}", path.display())))?;
123    Ok(cf.learn.unwrap_or_default())
124}
125
126// ---------------------------------------------------------------------------
127// Background learning
128/// Run the learn flow with explicit config and base URL — testable variant.
129///
130/// This internal function bypasses `load_learn_config()` and env var lookup,
131/// making it suitable for testing without environment mutation.
132pub(crate) fn run_learn_with_config(
133    params: &LearnParams,
134    command: &str,
135    output: &str,
136    exit_code: i32,
137) -> Result<(), Error> {
138    let hint = match params.hint {
139        Some(h) if h.len() > MAX_HINT_LENGTH => {
140            return Err(Error::Learn(format!(
141                "--hint too long ({} > {} chars)",
142                h.len(),
143                MAX_HINT_LENGTH
144            )));
145        }
146        h => h,
147    };
148
149    let truncated_command = crate::learn_utils::truncate_utf8(command, MAX_COMMAND_LENGTH);
150
151    let user_msg = if let Some(h) = hint {
152        format!(
153            "Command: {truncated_command}\nExit code: {exit_code}\nHint: {h}\nOutput:\n{}",
154            truncate_for_prompt(output)
155        )
156    } else {
157        format!(
158            "Command: {truncated_command}\nExit code: {exit_code}\nOutput:\n{}",
159            truncate_for_prompt(output)
160        )
161    };
162
163    let get_response = |msg: &str| -> Result<String, Error> {
164        match params.config.provider.as_str() {
165            "anthropic" => {
166                call_anthropic(params.base_url, params.api_key, &params.config.model, msg)
167            }
168            other => Err(Error::Learn(format!("unknown provider: {other}"))),
169        }
170    };
171
172    // First attempt
173    let mut last_err;
174    let toml = get_response(&user_msg)?;
175    let clean = crate::learn_utils::strip_fences(&toml);
176    if validate_pattern_toml_with_limits(&clean).is_ok() {
177        std::fs::create_dir_all(params.patterns_dir).map_err(|e| Error::Learn(e.to_string()))?;
178        let filename = format!("{}.toml", label(command));
179        let path = params.patterns_dir.join(&filename);
180        std::fs::write(&path, &clean).map_err(|e| Error::Learn(e.to_string()))?;
181        let _ =
182            crate::commands::write_learn_status(params.learn_status_path, &label(command), &path);
183        return Ok(());
184    }
185    last_err = "initial TOML validation failed".to_string();
186
187    // Up to 2 retries
188    for _ in 0..2 {
189        let retry_msg = format!(
190            "Your previous TOML was invalid: {last_err}. Here is what you returned:\n{clean}\nOutput ONLY the corrected TOML, nothing else."
191        );
192        let toml = get_response(&retry_msg)?;
193        let clean = crate::learn_utils::strip_fences(&toml);
194        if validate_pattern_toml_with_limits(&clean).is_ok() {
195            std::fs::create_dir_all(params.patterns_dir)
196                .map_err(|e| Error::Learn(e.to_string()))?;
197            let filename = format!("{}.toml", label(command));
198            let path = params.patterns_dir.join(&filename);
199            std::fs::write(&path, &clean).map_err(|e| Error::Learn(e.to_string()))?;
200            let _ = crate::commands::write_learn_status(
201                params.learn_status_path,
202                &label(command),
203                &path,
204            );
205            return Ok(());
206        }
207        last_err = "retry TOML validation failed".to_string();
208    }
209
210    Err(Error::Learn(format!("failed after 3 attempts: {last_err}")))
211}
212
213/// Run the learn flow: call LLM, validate + save pattern.
214///
215/// Loads configuration from environment, calls the LLM to generate a pattern,
216/// validates the result, and saves the pattern to disk. Retries up to 2 times
217/// if the LLM returns invalid TOML.
218pub fn run_learn(command: &str, output: &str, exit_code: i32) -> Result<(), Error> {
219    run_learn_with_hint(command, output, exit_code, None)
220}
221
222/// Internal variant of run_learn that accepts an optional hint.
223///
224/// Used by run_background to pass through agent-provided hints.
225fn run_learn_with_hint(
226    command: &str,
227    output: &str,
228    exit_code: i32,
229    hint: Option<&str>,
230) -> Result<(), Error> {
231    let config = load_learn_config()?;
232
233    let api_key = std::env::var(&config.api_key_env).map_err(|_| {
234        Error::Learn(format!(
235            "Set {} environment variable to use `oo learn`",
236            config.api_key_env
237        ))
238    })?;
239
240    let base_url = std::env::var("ANTHROPIC_API_URL")
241        .unwrap_or_else(|_| "https://api.anthropic.com/v1/messages".to_string());
242
243    validate_anthropic_url(&base_url)?;
244
245    let params = LearnParams {
246        config: &config,
247        api_key: &api_key,
248        base_url: &base_url,
249        patterns_dir: &patterns_dir(),
250        learn_status_path: &learn_status_path(),
251        hint,
252    };
253
254    run_learn_with_config(&params, command, output, exit_code)
255}
256
257/// Spawn the learning process in the background by re-exec'ing ourselves.
258pub fn spawn_background(
259    command: &str,
260    output: &str,
261    exit_code: i32,
262    hint: Option<&str>,
263) -> Result<(), Error> {
264    let exe = std::env::current_exe().map_err(|e| Error::Learn(e.to_string()))?;
265
266    // Use a secure named temp file to avoid PID-based predictable filenames
267    // (symlink/TOCTOU attacks). The file is kept alive until the child spawns.
268    let mut tmp = tempfile::NamedTempFile::new().map_err(|e| Error::Learn(e.to_string()))?;
269    let mut data = serde_json::json!({
270        "command": command,
271        "output": output,
272        "exit_code": exit_code,
273    });
274    if let Some(h) = hint {
275        data["hint"] = serde_json::Value::String(h.to_string());
276    }
277    tmp.write_all(data.to_string().as_bytes())
278        .map_err(|e| Error::Learn(e.to_string()))?;
279
280    // Convert to TempPath: closes the file handle but keeps the file on disk
281    // until the TempPath is dropped — after the child has been spawned.
282    let tmp_path = tmp.into_temp_path();
283
284    // Spawn detached child
285    std::process::Command::new(exe)
286        .arg("_learn_bg")
287        .arg(&tmp_path)
288        .stdin(std::process::Stdio::null())
289        .stdout(std::process::Stdio::null())
290        .stderr(std::process::Stdio::null())
291        .spawn()
292        .map_err(|e| Error::Learn(e.to_string()))?;
293
294    // Prevent the parent from deleting the temp file on drop. On a loaded
295    // system the child process may not have opened the file yet by the time
296    // the parent exits this function. `keep()` makes the file persist on disk
297    // until the child cleans it up at run_background (line ~218 below).
298    tmp_path.keep().map_err(|e| Error::Learn(e.to_string()))?;
299
300    Ok(())
301}
302
303/// Entry point for the background learn child process.
304pub fn run_background(data_path: &str) -> Result<(), Error> {
305    let path = Path::new(data_path);
306    let content = std::fs::read_to_string(path).map_err(|e| Error::Learn(e.to_string()))?;
307    let data: LearnData =
308        serde_json::from_str(&content).map_err(|e| Error::Learn(e.to_string()))?;
309
310    let command = &data.command;
311    let output = &data.output;
312
313    // Explicit bounds check to avoid silent truncation i64→i32
314    let exit_code = i32::try_from(data.exit_code).map_err(|_| {
315        Error::Learn(format!(
316            "exit_code out of range for i32: {}",
317            data.exit_code
318        ))
319    })?;
320
321    let hint = data.hint.as_deref();
322
323    let result = run_learn_with_hint(command, output, exit_code, hint);
324
325    // Clean up temp file
326    let _ = std::fs::remove_file(path);
327
328    if let Err(ref e) = result {
329        let cmd_label = label(command);
330        let status_path = learn_status_path();
331        let _ =
332            crate::commands::write_learn_status_failure(&status_path, &cmd_label, &e.to_string());
333    }
334
335    result
336}
337
338// ---------------------------------------------------------------------------
339// LLM API calls
340// ---------------------------------------------------------------------------
341
342fn call_anthropic(
343    base_url: &str,
344    api_key: &str,
345    model: &str,
346    user_msg: &str,
347) -> Result<String, Error> {
348    let body = serde_json::json!({
349        "model": model,
350        "max_tokens": 1024,
351        "temperature": 0.0,
352        "system": SYSTEM_PROMPT,
353        "messages": [{"role": "user", "content": user_msg}],
354    });
355
356    use std::time::Duration;
357
358    // Configure Agent with explicit timeout to prevent hanging on API calls
359    let agent: ureq::Agent = ureq::Agent::config_builder()
360        .timeout_global(Some(Duration::from_secs(30)))
361        .timeout_connect(Some(Duration::from_secs(10)))
362        .build()
363        .into();
364
365    let response: serde_json::Value = agent
366        .post(base_url)
367        .header("x-api-key", api_key)
368        .header("anthropic-version", "2023-06-01")
369        .header("content-type", "application/json")
370        .send_json(&body)
371        .map_err(|e| Error::Learn(format!("Anthropic API error: {e}")))?
372        .body_mut()
373        .read_json()
374        .map_err(|e| Error::Learn(format!("response parse error: {e}")))?;
375
376    response["content"][0]["text"]
377        .as_str()
378        .map(|s| s.to_string())
379        .ok_or_else(|| Error::Learn("unexpected Anthropic response format".into()))
380}
381
382// ---------------------------------------------------------------------------
383// Helpers
384// ---------------------------------------------------------------------------
385
386fn label(command: &str) -> String {
387    let mut words = command.split_whitespace();
388    let first = words
389        .next()
390        .unwrap_or("unknown")
391        .rsplit('/')
392        .next()
393        .unwrap_or("unknown");
394
395    // Sanitize first word: keep only ASCII alphanumeric and hyphens,
396    // prevent path traversal, dotfiles, special chars, overlong filenames.
397    let sanitized_first: String = first
398        .chars()
399        .filter(|c| c.is_ascii_alphanumeric() || *c == '-')
400        .take(MAX_FILENAME_COMPONENT)
401        .collect();
402
403    if sanitized_first.is_empty() {
404        return "unknown".to_string();
405    }
406
407    // Include the second word only when it is a subcommand (not a flag).
408    match words.next() {
409        Some(second) if !second.starts_with('-') => {
410            // Sanitize: keep only ASCII alphanumeric and hyphens to ensure
411            // the label is safe as a filename component.
412            let sanitized_second: String = second
413                .chars()
414                .filter(|c| c.is_ascii_alphanumeric() || *c == '-')
415                .take(MAX_FILENAME_COMPONENT)
416                .collect();
417            if sanitized_second.is_empty() {
418                sanitized_first
419            } else {
420                format!("{sanitized_first}-{sanitized_second}")
421            }
422        }
423        _ => sanitized_first,
424    }
425}
426
427fn truncate_for_prompt(output: &str) -> &str {
428    crate::learn_utils::truncate_utf8(output, 4000)
429}
430
431/// Validate ANTHROPIC_API_URL uses HTTPS (with localhost exceptions).
432fn validate_anthropic_url(url: &str) -> Result<(), Error> {
433    if url.starts_with("https://") {
434        return Ok(());
435    }
436    // HTTP only allowed for localhost/127.0.0.1
437    // Extract host portion: "http://HOST:port/path" or "http://HOST/path"
438    if let Some(rest) = url.strip_prefix("http://") {
439        let host = rest.split([':', '/']).next().unwrap_or("");
440        if host == "localhost" || host == "127.0.0.1" {
441            return Ok(());
442        }
443    }
444    Err(Error::Learn(format!(
445        "ANTHROPIC_API_URL must use HTTPS (got: {url}). HTTP is only allowed for localhost/127.0.0.1."
446    )))
447}
448
449/// Validate a TOML pattern string using the same regex limits as TOML loading.
450///
451/// Uses `validate_pattern_regexes` from pattern::toml module for consistency.
452fn validate_pattern_toml_with_limits(toml_str: &str) -> Result<(), Error> {
453    crate::pattern::validate_pattern_regexes(toml_str)
454        .map_err(|e| Error::Learn(format!("pattern validation: {e}")))
455}
456
457// Tests live in separate files to keep this module under 500 lines.
458#[cfg(test)]
459#[path = "learn_tests.rs"]
460mod tests;
461
462#[cfg(test)]
463#[path = "learn_prompt_tests.rs"]
464mod prompt_tests;