cortex-agent 0.2.1

Self-learning AI agent with persistent memory, tools, plugins, and a beautiful terminal UI
use std::sync::Arc;

use crate::tool::{param_string, ToolSpec};

pub fn core_tools() -> Vec<ToolSpec> {
    vec![
        current_time_tool(),
        web_fetch_tool(),
        web_search_tool(),
        execute_python_tool(),
        execute_bash_tool(),
        read_file_tool(),
        git_status_tool(),
        git_diff_tool(),
        write_file_tool(),
    ]
}

fn current_time_tool() -> ToolSpec {
    let (params, _) = crate::tool::required_params(&[("timezone", param_string("Timezone name (e.g. UTC, US/Eastern, Asia/Tokyo)"))]);
    ToolSpec::new("current_time", "Get the current date and time", params,
        Arc::new(|args| {
            let tz_name = args.get("timezone").and_then(|v| v.as_str()).unwrap_or("UTC");
            let now = chrono::Utc::now();
            Ok(format!("{} {}", now.format("%Y-%m-%d %H:%M:%S"), tz_name))
        }),
    )
}

fn web_fetch_tool() -> ToolSpec {
    let (params, _) = crate::tool::required_params(&[("url", param_string("The full URL to fetch (https://...)"))]);
    ToolSpec::new("web_fetch", "Fetch a URL and return its text content (truncated to 10KB)", params,
        Arc::new(|args| {
            let url = args.get("url").and_then(|v| v.as_str()).unwrap_or("").to_string();
            if url.is_empty() { return Err("URL is required".into()); }
            let client = reqwest::blocking::Client::builder()
                .timeout(std::time::Duration::from_secs(15))
                .build()
                .map_err(|e| format!("Failed to create HTTP client: {}", e))?;
            let resp = client.get(&url).send().map_err(|e| format!("Error fetching {}: {}", url, e))?;
            let content_type = resp.headers().get(reqwest::header::CONTENT_TYPE)
                .and_then(|v| v.to_str().ok()).unwrap_or("").to_string();
            let text = if content_type.contains("text") || content_type.contains("json") || content_type.contains("html") {
                resp.text().map_err(|e| format!("Error reading response: {}", e))?
            } else {
                let bytes = resp.bytes().map_err(|e| format!("Error: {}", e))?;
                format!("[Binary content: {}, {} bytes]", content_type, bytes.len())
            };
            let max_len = 10_000;
            if text.len() > max_len {
                Ok(format!("{}...\n[truncated from {} total chars]", &text[..max_len], text.len()))
            } else { Ok(text) }
        }),
    )
}

fn web_search_tool() -> ToolSpec {
    let (params, _) = crate::tool::required_params(&[
        ("query", param_string("Search query")),
        ("limit", serde_json::json!({"type": "integer", "description": "Max results", "default": 5})),
    ]);
    ToolSpec::new("web_search", "Search the web using DuckDuckGo (text-only results)", params,
        Arc::new(|args| {
            let query = args.get("query").and_then(|v| v.as_str()).unwrap_or("").to_string();
            let limit = args.get("limit").and_then(|v| v.as_i64()).unwrap_or(5).min(10);
            if query.is_empty() { return Err("Query is required".into()); }
            let url = format!("https://html.duckduckgo.com/html/?q={}", urlencoding(&query));
            let client = reqwest::blocking::Client::builder()
                .timeout(std::time::Duration::from_secs(10))
                .build().map_err(|e| format!("Client error: {}", e))?;
            let resp = client.get(&url)
                .header("User-Agent", "Mozilla/5.0")
                .send().map_err(|e| format!("Search error: {}", e))?;
            let html = resp.text().map_err(|e| format!("Read error: {}", e))?;
            // Extract results via simple parsing
            let mut results = Vec::new();
            for section in html.split("<a rel=\"nofollow\" href=\"") {
                if let Some(link_end) = section.find('\"') {
                    let link = &section[..link_end];
                    let title_section = if let Some(t) = section.find("class=\"result__a\"") {
                        let after = &section[t..];
                        if let Some(s) = after.find('>') {
                            if let Some(e) = after[s..].find("</a>") {
                                Some(after[s+1..s+e].replace("</b>", "").replace("<b>", ""))
                            } else { None }
                        } else { None }
                    } else { None };
                    let title = if let Some(t) = title_section { t } else { continue; };
                    let snippet = if let Some(s) = section.find("class=\"result__snippet\"") {
                        let after = &section[s..];
                        if let Some(t) = after.find('>') {
                            if let Some(e) = after[t..].find("</a>") {
                                after[t+1..t+e].replace("</b>", "").replace("<b>", "")
                            } else { String::new() }
                        } else { String::new() }
                    } else { String::new() };
                    results.push(format!("• {}: {}", title, link));
                    if !snippet.is_empty() {
                        if let Some(last) = results.last_mut() { last.push_str(&format!("\n  {}", snippet)); }
                    }
                    if results.len() >= limit as usize { break; }
                }
            }
            if results.is_empty() { Ok("No results found.".into()) }
            else { Ok(results.join("\n\n")) }
        }),
    )
}

fn execute_python_tool() -> ToolSpec {
    let (params, _) = crate::tool::required_params(&[
        ("code", param_string("Python code to execute")),
        ("timeout", serde_json::json!({"type": "integer", "description": "Max execution time in seconds", "default": 10})),
    ]);
    ToolSpec::new("execute_python", "Execute Python code in an isolated subprocess", params,
        Arc::new(|args| {
            let code = args.get("code").and_then(|v| v.as_str()).unwrap_or("").to_string();
            let timeout = args.get("timeout").and_then(|v| v.as_u64()).unwrap_or(10);
            if code.is_empty() { return Err("Code is required".into()); }
            let wrapped = format!(
                "import sys\ndef _run():\n    try:\n        {}\n    except Exception as e:\n        import traceback\n        traceback.print_exc()\n_run()",
                code.replace('\n', "\n        ")
            );
            let handle = std::thread::spawn(move || {
                std::process::Command::new("python3").arg("-c").arg(&wrapped).output()
            });
            match handle.join() {
                Ok(Ok(output)) => {
                    let stdout = String::from_utf8_lossy(&output.stdout).to_string();
                    let stderr = String::from_utf8_lossy(&output.stderr).to_string();
                    let mut out = String::new();
                    if !stdout.is_empty() { out.push_str(&stdout); }
                    if !stderr.is_empty() { out.push_str(&format!("\n[stderr]\n{}", stderr)); }
                    if out.is_empty() { out = "(no output)".to_string(); }
                    Ok(out.trim().to_string())
                }
                Ok(Err(e)) => Err(format!("Execution error: {}", e)),
                Err(_) => Err(format!("Timed out after {}s", timeout)),
            }
        }),
    )
}

fn execute_bash_tool() -> ToolSpec {
    let (params, _) = crate::tool::required_params(&[
        ("command", param_string("Shell command to run")),
        ("timeout", serde_json::json!({"type": "integer", "description": "Max execution time in seconds", "default": 30})),
    ]);
    ToolSpec::new("execute_bash", "Run a shell command and return stdout/stderr", params,
        Arc::new(|args| {
            let command = args.get("command").and_then(|v| v.as_str()).unwrap_or("").to_string();
            let timeout = args.get("timeout").and_then(|v| v.as_u64()).unwrap_or(30).min(120);
            if command.is_empty() { return Err("Command is required".into()); }
            let handle = std::thread::spawn(move || {
                std::process::Command::new("bash").arg("-c").arg(&command).output()
            });
            match handle.join() {
                Ok(Ok(output)) => {
                    let stdout = String::from_utf8_lossy(&output.stdout).to_string();
                    let stderr = String::from_utf8_lossy(&output.stderr).to_string();
                    let mut out = String::new();
                    if !stdout.is_empty() { out.push_str(&stdout); }
                    if !stderr.is_empty() { out.push_str(&format!("\n[stderr]\n{}", stderr)); }
                    if out.is_empty() { out = format!("(exit code: {})", output.status.code().unwrap_or(-1)); }
                    Ok(out.trim().to_string())
                }
                Ok(Err(e)) => Err(format!("Execution error: {}", e)),
                Err(_) => Err(format!("Timed out after {}s", timeout)),
            }
        }),
    )
}

fn read_file_tool() -> ToolSpec {
    let (params, _) = crate::tool::required_params(&[
        ("path", param_string("Path to the file")),
        ("offset", serde_json::json!({"type": "integer", "description": "Starting line (1-indexed)", "default": 1})),
        ("limit", serde_json::json!({"type": "integer", "description": "Max lines", "default": 100})),
    ]);
    ToolSpec::new("read_file", "Read lines from a file with line numbers", params,
        Arc::new(|args| {
            let path = args.get("path").and_then(|v| v.as_str()).unwrap_or("").to_string();
            let offset = args.get("offset").and_then(|v| v.as_u64()).unwrap_or(1) as usize;
            let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(100) as usize;
            let path = shellexpand::tilde(&path).to_string();
            let content = std::fs::read_to_string(&path).map_err(|e| format!("Error reading file: {}", e))?;
            let lines: Vec<&str> = content.lines().collect();
            let total = lines.len();
            let start = offset.saturating_sub(1);
            let end = std::cmp::min(start + limit, total);
            let mut result = String::new();
            for (i, line) in lines[start..end].iter().enumerate() {
                result.push_str(&format!("{:6}|{}\n", start + i + 1, line));
            }
            result.push_str(&format!("[Lines {}–{} of {}]", start + 1, end, total));
            Ok(result)
        }),
    )
}

fn write_file_tool() -> ToolSpec {
    let (params, _) = crate::tool::required_params(&[
        ("path", param_string("Path to write to")),
        ("content", param_string("File content to write")),
    ]);
    ToolSpec::new("write_file", "Write content to a file (overwrites existing)", params,
        Arc::new(|args| {
            let path = args.get("path").and_then(|v| v.as_str()).unwrap_or("").to_string();
            let content = args.get("content").and_then(|v| v.as_str()).unwrap_or("").to_string();
            if path.is_empty() { return Err("Path is required".into()); }
            let path = shellexpand::tilde(&path).to_string();
            if let Some(parent) = std::path::Path::new(&path).parent() {
                let _ = std::fs::create_dir_all(parent);
            }
            std::fs::write(&path, &content).map_err(|e| format!("Error writing file: {}", e))?;
            Ok(format!("Wrote {} bytes to {}", content.len(), path))
        }),
    )
}

fn git_status_tool() -> ToolSpec {
    let (params, _) = crate::tool::required_params(&[("path", serde_json::json!({"type": "string", "description": "Git repo path", "default": "."}))]);
    ToolSpec::new("git_status", "Show git status (branch, changes, untracked files)", params,
        Arc::new(|args| {
            let path = args.get("path").and_then(|v| v.as_str()).unwrap_or(".");
            let output = std::process::Command::new("git")
                .args(["-C", path, "status", "--short", "-b"])
                .output().map_err(|e| format!("Git error: {}", e))?;
            let out = String::from_utf8_lossy(&output.stdout).to_string();
            let err = String::from_utf8_lossy(&output.stderr).to_string();
            if !err.is_empty() { return Err(err); }
            Ok(out)
        }),
    )
}

fn git_diff_tool() -> ToolSpec {
    let (params, _) = crate::tool::required_params(&[
        ("path", serde_json::json!({"type": "string", "description": "Git repo path", "default": "."})),
        ("staged", serde_json::json!({"type": "boolean", "description": "Show staged diff", "default": false})),
    ]);
    ToolSpec::new("git_diff", "Show unstaged or staged git diff", params,
        Arc::new(|args| {
            let path = args.get("path").and_then(|v| v.as_str()).unwrap_or(".");
            let staged = args.get("staged").and_then(|v| v.as_bool()).unwrap_or(false);
            let mut cmd = std::process::Command::new("git");
            cmd.args(["-C", path, "diff"]);
            if staged { cmd.arg("--staged"); }
            let output = cmd.output().map_err(|e| format!("Git error: {}", e))?;
            let out = String::from_utf8_lossy(&output.stdout).to_string();
            let err = String::from_utf8_lossy(&output.stderr).to_string();
            if !err.is_empty() { return Err(err); }
            if out.is_empty() { Ok("No changes.".into()) }
            else { Ok(out) }
        }),
    )
}

/// Simple URL encoding for search queries
fn urlencoding(input: &str) -> String {
    let mut result = String::new();
    for byte in input.bytes() {
        match byte {
            b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
                result.push(byte as char);
            }
            b' ' => result.push_str("+"),
            _ => result.push_str(&format!("%{:02X}", byte)),
        }
    }
    result
}