oy-cli 0.10.1

Local AI coding CLI for inspecting, editing, running commands, and auditing repositories
Documentation
//! Minimal reqwest-backed webfetch tool.

use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::sync::LazyLock;

use anyhow::{Context, Result, bail};
use regex::Regex;
use reqwest::header::{CONTENT_TYPE, COOKIE, USER_AGENT};
use serde::Serialize;
use serde_json::Value;
use url::Url;

use super::args::{ReturnFormat, WebfetchArgs};
use super::{NetworkAccess, ToolContext};

#[derive(Debug, Serialize)]
pub(super) struct WebfetchOutput {
    pub url: String,
    pub status_code: u16,
    pub content: String,
    pub links: Vec<String>,
}

pub(super) async fn tool_webfetch(ctx: &ToolContext, args: WebfetchArgs) -> Result<Value> {
    if ctx.policy.network != NetworkAccess::Enabled {
        bail!("tool denied by policy: webfetch");
    }

    let url = validate_public_url(&args.url).await?;
    let client = reqwest::Client::builder()
        .redirect(reqwest::redirect::Policy::none())
        .build()
        .context("failed to build webfetch HTTP client")?;
    let mut request = client.get(url.clone());
    if let Some(user_agent) = args
        .user_agent
        .as_deref()
        .filter(|value| !value.trim().is_empty())
    {
        request = request.header(USER_AGENT, user_agent.trim());
    }
    if let Some(cookie) = args
        .cookie
        .as_deref()
        .filter(|value| !value.trim().is_empty())
    {
        request = request.header(COOKIE, cookie.trim());
    }

    let response = request
        .send()
        .await
        .with_context(|| format!("failed to fetch {}", url.as_str()))?;
    let status_code = response.status().as_u16();
    let response_url = response.url().clone();
    let content_type = response
        .headers()
        .get(CONTENT_TYPE)
        .and_then(|value| value.to_str().ok())
        .unwrap_or("")
        .to_string();
    let raw = response.text().await.with_context(|| {
        format!(
            "failed to read response body from {}",
            response_url.as_str()
        )
    })?;
    let links = extract_links(&raw, &content_type, &response_url);
    let content = transform_scraped_content(&raw, &content_type, args.return_format);

    Ok(serde_json::to_value(WebfetchOutput {
        url: response_url.to_string(),
        status_code,
        content,
        links,
    })?)
}

async fn validate_public_url(input: &str) -> Result<Url> {
    let url = Url::parse(&normalize_scrape_url(input)).context("Invalid URL")?;
    if !matches!(url.scheme(), "http" | "https") {
        bail!("webfetch only supports http(s) URLs");
    }
    let host = url.host_str().context("URL must include a host")?;
    validate_public_host(host)?;
    if let Ok(ip) = host.parse::<IpAddr>() {
        validate_public_ip(ip)?;
        return Ok(url);
    }

    let port = url
        .port_or_known_default()
        .context("URL must include a valid port")?;
    let mut resolved_any = false;
    for addr in tokio::net::lookup_host((host, port))
        .await
        .with_context(|| format!("failed to resolve {host}"))?
    {
        resolved_any = true;
        validate_public_ip(addr.ip())?;
    }
    if !resolved_any {
        bail!("failed to resolve {host}");
    }
    Ok(url)
}

fn validate_public_host(host: &str) -> Result<()> {
    let host = host.trim_end_matches('.');
    if host.eq_ignore_ascii_case("localhost") || host.ends_with(".localhost") {
        bail!("webfetch blocks localhost targets");
    }
    Ok(())
}

fn validate_public_ip(ip: IpAddr) -> Result<()> {
    if is_public_ip(ip) {
        Ok(())
    } else {
        bail!("webfetch blocks localhost and private IP targets");
    }
}

fn is_public_ip(ip: IpAddr) -> bool {
    match ip {
        IpAddr::V4(ip) => is_public_ipv4(ip),
        IpAddr::V6(ip) => is_public_ipv6(ip),
    }
}

fn is_public_ipv4(ip: Ipv4Addr) -> bool {
    let [a, b, c, _] = ip.octets();
    !(a == 0
        || a == 10
        || a == 127
        || (a == 100 && (64..=127).contains(&b))
        || (a == 169 && b == 254)
        || (a == 172 && (16..=31).contains(&b))
        || (a == 192 && b == 0 && c == 0)
        || (a == 192 && b == 0 && c == 2)
        || (a == 192 && b == 168)
        || (a == 198 && (b == 18 || b == 19))
        || (a == 198 && b == 51 && c == 100)
        || (a == 203 && b == 0 && c == 113)
        || a >= 224)
}

fn is_public_ipv6(ip: Ipv6Addr) -> bool {
    let segments = ip.segments();
    let first = segments[0];
    !(ip.is_unspecified()
        || ip.is_loopback()
        || (first & 0xfe00) == 0xfc00
        || (first & 0xffc0) == 0xfe80
        || (first & 0xff00) == 0xff00)
}

fn normalize_scrape_url(input: &str) -> String {
    let trimmed = input.trim();
    if trimmed.starts_with("http") {
        trimmed.to_string()
    } else {
        format!("https://{trimmed}")
    }
}

fn is_html_content(content_type: &str, content: &str) -> bool {
    content_type.to_ascii_lowercase().contains("html")
        || content.trim_start().starts_with("<!DOCTYPE html")
        || content.trim_start().starts_with("<html")
}

static HREF_RE: LazyLock<Regex> = LazyLock::new(|| {
    Regex::new(r#"(?is)<a\b[^>]*\bhref\s*=\s*(?:"([^"]*)"|'([^']*)'|([^\s"'<>`]+))"#)
        .expect("valid href regex")
});

fn extract_links(content: &str, content_type: &str, base_url: &Url) -> Vec<String> {
    if !is_html_content(content_type, content) {
        return Vec::new();
    }
    HREF_RE
        .captures_iter(content)
        .filter_map(|captures| {
            captures
                .get(1)
                .or_else(|| captures.get(2))
                .or_else(|| captures.get(3))
                .map(|value| value.as_str().trim())
        })
        .filter(|href| !href.is_empty())
        .filter_map(|href| base_url.join(href).ok())
        .map(|url| url.to_string())
        .collect()
}

fn transform_scraped_content(
    content: &str,
    content_type: &str,
    return_format: ReturnFormat,
) -> String {
    match return_format {
        ReturnFormat::Raw => content.to_string(),
        ReturnFormat::Markdown => {
            if is_html_content(content_type, content) {
                html2md::parse_html(content)
            } else {
                content.to_string()
            }
        }
        ReturnFormat::Text => html_to_text(content, content_type),
        ReturnFormat::Xml => format!(
            "<page><content><![CDATA[{}]]></content></page>",
            content.replace("]]>", "]]]]><![CDATA[>")
        ),
    }
}

fn html_to_text(content: &str, content_type: &str) -> String {
    if is_html_content(content_type, content) {
        html2md::parse_html(content)
            .lines()
            .map(str::trim)
            .filter(|line| !line.is_empty())
            .collect::<Vec<_>>()
            .join("\n")
    } else {
        content.to_string()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn public_ip_filter_blocks_local_and_private_ranges() {
        for ip in [
            IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
            IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)),
            IpAddr::V4(Ipv4Addr::new(172, 16, 0, 1)),
            IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)),
            IpAddr::V4(Ipv4Addr::new(169, 254, 1, 1)),
            IpAddr::V6(Ipv6Addr::LOCALHOST),
            "fc00::1".parse().unwrap(),
            "fe80::1".parse().unwrap(),
        ] {
            assert!(!is_public_ip(ip), "{ip} should be blocked");
        }

        for ip in [
            IpAddr::V4(Ipv4Addr::new(93, 184, 216, 34)),
            "2606:2800:220:1:248:1893:25c8:1946".parse().unwrap(),
        ] {
            assert!(is_public_ip(ip), "{ip} should be allowed");
        }
    }

    #[test]
    fn localhost_hostnames_are_blocked_before_resolution() {
        assert!(validate_public_host("localhost").is_err());
        assert!(validate_public_host("api.localhost").is_err());
        assert!(validate_public_host("example.com").is_ok());
    }

    #[test]
    fn extracts_absolute_and_relative_links_from_html() {
        let base = Url::parse("https://example.com/docs/page.html").unwrap();
        let links = extract_links(
            r#"<html><a href="/root">root</a><a href='next.html'>next</a><a href=https://other.test/>other</a></html>"#,
            "text/html; charset=utf-8",
            &base,
        );
        assert_eq!(
            links,
            vec![
                "https://example.com/root".to_string(),
                "https://example.com/docs/next.html".to_string(),
                "https://other.test/".to_string(),
            ]
        );
    }

    #[test]
    fn text_content_has_no_links() {
        let base = Url::parse("https://example.com/").unwrap();
        assert!(extract_links("<a href='/x'>x</a>", "text/plain", &base).is_empty());
    }
}