use std::io::Write;
use std::path::{Path, PathBuf};
use serde::Deserialize;
use crate::error::Error;
pub use crate::learn_prompt::SYSTEM_PROMPT;
const MAX_HINT_LENGTH: usize = 1000;
const MAX_COMMAND_LENGTH: usize = 100;
const MAX_FILENAME_COMPONENT: usize = 50;
#[derive(Deserialize)]
struct ConfigFile {
learn: Option<LearnConfig>,
}
#[derive(Deserialize, Clone)]
pub struct LearnConfig {
pub provider: String,
pub model: String,
pub api_key_env: String,
}
pub(crate) struct LearnParams<'a> {
pub config: &'a LearnConfig,
pub api_key: &'a str,
pub base_url: &'a str,
pub patterns_dir: &'a Path,
pub learn_status_path: &'a Path,
pub hint: Option<&'a str>,
}
#[derive(Deserialize)]
struct LearnData {
command: String,
output: String,
exit_code: i64,
hint: Option<String>,
}
impl Default for LearnConfig {
fn default() -> Self {
detect_provider()
}
}
fn detect_provider() -> LearnConfig {
detect_provider_with(|key| std::env::var(key).ok())
}
fn detect_provider_with<F: Fn(&str) -> Option<String>>(env_lookup: F) -> LearnConfig {
if env_lookup("ANTHROPIC_API_KEY").is_some() {
LearnConfig {
provider: "anthropic".into(),
model: "claude-haiku-4-5".into(),
api_key_env: "ANTHROPIC_API_KEY".into(),
}
} else {
LearnConfig {
provider: "anthropic".into(),
model: "claude-haiku-4-5".into(),
api_key_env: "ANTHROPIC_API_KEY".into(),
}
}
}
fn config_dir() -> PathBuf {
if let Some(test_dir) = std::env::var_os("OO_CONFIG_DIR") {
return PathBuf::from(test_dir);
}
dirs::config_dir()
.unwrap_or_else(|| PathBuf::from("/tmp"))
.join("oo")
}
pub fn patterns_dir() -> PathBuf {
config_dir().join("patterns")
}
pub fn learn_status_path() -> PathBuf {
config_dir().join("learn-status.log")
}
pub fn load_learn_config() -> Result<LearnConfig, Error> {
let path = config_dir().join("config.toml");
if !path.exists() {
return Ok(LearnConfig::default());
}
let content = std::fs::read_to_string(&path)
.map_err(|e| Error::Config(format!("{}: {e}", path.display())))?;
let cf: ConfigFile =
toml::from_str(&content).map_err(|e| Error::Config(format!("{}: {e}", path.display())))?;
Ok(cf.learn.unwrap_or_default())
}
pub(crate) fn run_learn_with_config(
params: &LearnParams,
command: &str,
output: &str,
exit_code: i32,
) -> Result<(), Error> {
let hint = match params.hint {
Some(h) if h.len() > MAX_HINT_LENGTH => {
return Err(Error::Learn(format!(
"--hint too long ({} > {} chars)",
h.len(),
MAX_HINT_LENGTH
)));
}
h => h,
};
let truncated_command = crate::learn_utils::truncate_utf8(command, MAX_COMMAND_LENGTH);
let user_msg = if let Some(h) = hint {
format!(
"Command: {truncated_command}\nExit code: {exit_code}\nHint: {h}\nOutput:\n{}",
truncate_for_prompt(output)
)
} else {
format!(
"Command: {truncated_command}\nExit code: {exit_code}\nOutput:\n{}",
truncate_for_prompt(output)
)
};
let get_response = |msg: &str| -> Result<String, Error> {
match params.config.provider.as_str() {
"anthropic" => {
call_anthropic(params.base_url, params.api_key, ¶ms.config.model, msg)
}
other => Err(Error::Learn(format!("unknown provider: {other}"))),
}
};
let mut last_err;
let toml = get_response(&user_msg)?;
let clean = crate::learn_utils::strip_fences(&toml);
if validate_pattern_toml_with_limits(&clean).is_ok() {
std::fs::create_dir_all(params.patterns_dir).map_err(|e| Error::Learn(e.to_string()))?;
let filename = format!("{}.toml", label(command));
let path = params.patterns_dir.join(&filename);
std::fs::write(&path, &clean).map_err(|e| Error::Learn(e.to_string()))?;
let _ =
crate::commands::write_learn_status(params.learn_status_path, &label(command), &path);
return Ok(());
}
last_err = "initial TOML validation failed".to_string();
for _ in 0..2 {
let retry_msg = format!(
"Your previous TOML was invalid: {last_err}. Here is what you returned:\n{clean}\nOutput ONLY the corrected TOML, nothing else."
);
let toml = get_response(&retry_msg)?;
let clean = crate::learn_utils::strip_fences(&toml);
if validate_pattern_toml_with_limits(&clean).is_ok() {
std::fs::create_dir_all(params.patterns_dir)
.map_err(|e| Error::Learn(e.to_string()))?;
let filename = format!("{}.toml", label(command));
let path = params.patterns_dir.join(&filename);
std::fs::write(&path, &clean).map_err(|e| Error::Learn(e.to_string()))?;
let _ = crate::commands::write_learn_status(
params.learn_status_path,
&label(command),
&path,
);
return Ok(());
}
last_err = "retry TOML validation failed".to_string();
}
Err(Error::Learn(format!("failed after 3 attempts: {last_err}")))
}
pub fn run_learn(command: &str, output: &str, exit_code: i32) -> Result<(), Error> {
run_learn_with_hint(command, output, exit_code, None)
}
fn run_learn_with_hint(
command: &str,
output: &str,
exit_code: i32,
hint: Option<&str>,
) -> Result<(), Error> {
let config = load_learn_config()?;
let api_key = std::env::var(&config.api_key_env).map_err(|_| {
Error::Learn(format!(
"Set {} environment variable to use `oo learn`",
config.api_key_env
))
})?;
let base_url = std::env::var("ANTHROPIC_API_URL")
.unwrap_or_else(|_| "https://api.anthropic.com/v1/messages".to_string());
validate_anthropic_url(&base_url)?;
let params = LearnParams {
config: &config,
api_key: &api_key,
base_url: &base_url,
patterns_dir: &patterns_dir(),
learn_status_path: &learn_status_path(),
hint,
};
run_learn_with_config(¶ms, command, output, exit_code)
}
pub fn spawn_background(
command: &str,
output: &str,
exit_code: i32,
hint: Option<&str>,
) -> Result<(), Error> {
let exe = std::env::current_exe().map_err(|e| Error::Learn(e.to_string()))?;
let mut tmp = tempfile::NamedTempFile::new().map_err(|e| Error::Learn(e.to_string()))?;
let mut data = serde_json::json!({
"command": command,
"output": output,
"exit_code": exit_code,
});
if let Some(h) = hint {
data["hint"] = serde_json::Value::String(h.to_string());
}
tmp.write_all(data.to_string().as_bytes())
.map_err(|e| Error::Learn(e.to_string()))?;
let tmp_path = tmp.into_temp_path();
std::process::Command::new(exe)
.arg("_learn_bg")
.arg(&tmp_path)
.stdin(std::process::Stdio::null())
.stdout(std::process::Stdio::null())
.stderr(std::process::Stdio::null())
.spawn()
.map_err(|e| Error::Learn(e.to_string()))?;
tmp_path.keep().map_err(|e| Error::Learn(e.to_string()))?;
Ok(())
}
pub fn run_background(data_path: &str) -> Result<(), Error> {
let path = Path::new(data_path);
let content = std::fs::read_to_string(path).map_err(|e| Error::Learn(e.to_string()))?;
let data: LearnData =
serde_json::from_str(&content).map_err(|e| Error::Learn(e.to_string()))?;
let command = &data.command;
let output = &data.output;
let exit_code = i32::try_from(data.exit_code).map_err(|_| {
Error::Learn(format!(
"exit_code out of range for i32: {}",
data.exit_code
))
})?;
let hint = data.hint.as_deref();
let result = run_learn_with_hint(command, output, exit_code, hint);
let _ = std::fs::remove_file(path);
if let Err(ref e) = result {
let cmd_label = label(command);
let status_path = learn_status_path();
let _ =
crate::commands::write_learn_status_failure(&status_path, &cmd_label, &e.to_string());
}
result
}
fn call_anthropic(
base_url: &str,
api_key: &str,
model: &str,
user_msg: &str,
) -> Result<String, Error> {
let body = serde_json::json!({
"model": model,
"max_tokens": 1024,
"temperature": 0.0,
"system": SYSTEM_PROMPT,
"messages": [{"role": "user", "content": user_msg}],
});
use std::time::Duration;
let agent: ureq::Agent = ureq::Agent::config_builder()
.timeout_global(Some(Duration::from_secs(30)))
.timeout_connect(Some(Duration::from_secs(10)))
.build()
.into();
let response: serde_json::Value = agent
.post(base_url)
.header("x-api-key", api_key)
.header("anthropic-version", "2023-06-01")
.header("content-type", "application/json")
.send_json(&body)
.map_err(|e| Error::Learn(format!("Anthropic API error: {e}")))?
.body_mut()
.read_json()
.map_err(|e| Error::Learn(format!("response parse error: {e}")))?;
response["content"][0]["text"]
.as_str()
.map(|s| s.to_string())
.ok_or_else(|| Error::Learn("unexpected Anthropic response format".into()))
}
fn label(command: &str) -> String {
let mut words = command.split_whitespace();
let first = words
.next()
.unwrap_or("unknown")
.rsplit('/')
.next()
.unwrap_or("unknown");
let sanitized_first: String = first
.chars()
.filter(|c| c.is_ascii_alphanumeric() || *c == '-')
.take(MAX_FILENAME_COMPONENT)
.collect();
if sanitized_first.is_empty() {
return "unknown".to_string();
}
match words.next() {
Some(second) if !second.starts_with('-') => {
let sanitized_second: String = second
.chars()
.filter(|c| c.is_ascii_alphanumeric() || *c == '-')
.take(MAX_FILENAME_COMPONENT)
.collect();
if sanitized_second.is_empty() {
sanitized_first
} else {
format!("{sanitized_first}-{sanitized_second}")
}
}
_ => sanitized_first,
}
}
fn truncate_for_prompt(output: &str) -> &str {
crate::learn_utils::truncate_utf8(output, 4000)
}
fn validate_anthropic_url(url: &str) -> Result<(), Error> {
if url.starts_with("https://") {
return Ok(());
}
if let Some(rest) = url.strip_prefix("http://") {
let host = rest.split([':', '/']).next().unwrap_or("");
if host == "localhost" || host == "127.0.0.1" {
return Ok(());
}
}
Err(Error::Learn(format!(
"ANTHROPIC_API_URL must use HTTPS (got: {url}). HTTP is only allowed for localhost/127.0.0.1."
)))
}
fn validate_pattern_toml_with_limits(toml_str: &str) -> Result<(), Error> {
crate::pattern::validate_pattern_regexes(toml_str)
.map_err(|e| Error::Learn(format!("pattern validation: {e}")))
}
#[cfg(test)]
#[path = "learn_tests.rs"]
mod tests;
#[cfg(test)]
#[path = "learn_prompt_tests.rs"]
mod prompt_tests;