keyclaw 0.2.1

Local MITM proxy that keeps secrets out of LLM traffic
Documentation
use std::fs::{File, OpenOptions};
use std::io::Write;
use std::net::SocketAddr;
use std::sync::Mutex;
use std::sync::atomic::{AtomicBool, Ordering};

use http_body_util::Full;
use hudsucker::{
    Body,
    certificate_authority::RcgenAuthority,
    hyper::{
        Request, Response, StatusCode,
        header::{CONTENT_TYPE, HOST, HeaderValue},
    },
    rcgen::{Issuer, KeyPair},
    rustls::crypto::aws_lc_rs,
};

use crate::errors::KeyclawError;

static UNSAFE_LOG: AtomicBool = AtomicBool::new(false);
static LOG_FILE: Mutex<Option<File>> = Mutex::new(None);

pub fn set_unsafe_log(enabled: bool) {
    UNSAFE_LOG.store(enabled, Ordering::SeqCst);
}

pub fn set_log_file(path: &std::path::Path) -> std::io::Result<()> {
    let file = OpenOptions::new().create(true).append(true).open(path)?;
    if let Ok(mut guard) = LOG_FILE.lock() {
        *guard = Some(file);
    }
    Ok(())
}

pub(super) fn build_ca_authority(
    cert_pem: &str,
    key_pem: &str,
) -> Result<RcgenAuthority, KeyclawError> {
    let key_pair = KeyPair::from_pem(key_pem)
        .map_err(|e| KeyclawError::uncoded(format!("parse CA private key failed: {e}")))?;
    let issuer = Issuer::from_ca_cert_pem(cert_pem, key_pair)
        .map_err(|e| KeyclawError::uncoded(format!("parse CA certificate failed: {e}")))?;

    Ok(RcgenAuthority::new(
        issuer,
        1_000,
        aws_lc_rs::default_provider(),
    ))
}

pub(super) fn body_from_vec(bytes: Vec<u8>) -> Body {
    Full::new(hudsucker::hyper::body::Bytes::from(bytes)).into()
}

pub(super) fn response_is_sse(res: &Response<Body>) -> bool {
    res.headers()
        .get(CONTENT_TYPE)
        .and_then(|v| v.to_str().ok())
        .and_then(|ct| ct.split(';').next())
        .map(|ct| ct.trim().eq_ignore_ascii_case("text/event-stream"))
        .unwrap_or(false)
}

pub(super) fn request_host(req: &Request<Body>) -> Option<String> {
    if let Some(authority) = req.uri().authority() {
        return Some(normalize_host(authority.as_str()));
    }

    header_value(req, HOST.as_str()).map(|v| normalize_host(&v))
}

pub(super) fn header_value(req: &Request<Body>, name: &str) -> Option<String> {
    req.headers()
        .get(name)
        .and_then(|v| v.to_str().ok())
        .map(|v| v.to_string())
}

pub(super) fn normalize_hosts(hosts: &[String]) -> Vec<String> {
    hosts
        .iter()
        .map(|host| normalize_allowed_entry(host))
        .filter(|host| !host.is_empty())
        .collect()
}

pub(super) fn normalize_host_value(host: &str) -> String {
    normalize_host(host)
}

pub(super) fn allowed(allowed_hosts: &[String], host: &str) -> bool {
    if allowed_hosts.is_empty() {
        return true;
    }
    let host = normalize_host(host);
    allowed_hosts.iter().any(|allowed| {
        if contains_glob_pattern(allowed) {
            glob_matches(allowed, &host)
        } else {
            host == *allowed || host.ends_with(&format!(".{allowed}"))
        }
    })
}

pub(super) fn is_json(content_type: &str) -> bool {
    let content_type = content_type.trim().to_lowercase();
    content_type.is_empty()
        || content_type.contains("application/json")
        || content_type.contains("+json")
}

pub(super) fn is_json_payload(payload: &[u8]) -> bool {
    serde_json::from_slice::<serde_json::Value>(payload).is_ok()
}

pub(super) fn json_error_response(status: StatusCode, code: &str, msg: &str) -> Response<Body> {
    let payload = serde_json::json!({"error": {"code": code, "message": msg}});
    let body = serde_json::to_vec(&payload).unwrap_or_else(|_| b"{}".to_vec());
    let mut response = Response::new(body_from_vec(body));
    *response.status_mut() = status;
    response
        .headers_mut()
        .insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
    response
}

pub(super) fn log_replacements(
    host: &str,
    original: &[u8],
    replacements: &[crate::placeholder::Replacement],
) {
    if !unsafe_log_enabled() || replacements.is_empty() {
        return;
    }
    let use_file = LOG_FILE
        .lock()
        .ok()
        .as_ref()
        .is_some_and(|guard| guard.is_some());
    macro_rules! log_out {
        ($($arg:tt)*) => {
            if use_file {
                if let Ok(mut guard) = LOG_FILE.lock() {
                    if let Some(ref mut file) = *guard {
                        let _ = writeln!(file, $($arg)*);
                    }
                }
            } else {
                eprintln!($($arg)*);
            }
        }
    }

    let text = String::from_utf8_lossy(original);
    log_out!(
        "keyclaw [UNSAFE] INTERCEPTIONS for {host} ({} found):",
        replacements.len()
    );
    for replacement in replacements {
        if let Some(pos) = text.find(&replacement.secret) {
            let ctx_start = pos.saturating_sub(100);
            let secret_end = pos + replacement.secret.len();
            let before = truncate_utf8(&text[ctx_start..pos], 100);
            let after_end = std::cmp::min(secret_end + 100, text.len());
            let after = truncate_utf8(&text[secret_end..after_end], 100);
            log_out!(
                "  ...{}[SECRET:{} -> {}]{}...",
                before,
                &replacement.secret[..std::cmp::min(8, replacement.secret.len())],
                replacement.placeholder,
                after
            );
        } else {
            log_out!(
                "  {} -> {}",
                &replacement.secret[..std::cmp::min(8, replacement.secret.len())],
                replacement.placeholder
            );
        }
    }
    log_out!("---");
}

pub(super) fn log_debug(line: String) {
    log_with_level(crate::logging::LogLevel::Debug, line);
}

pub(super) fn log_warn(line: String) {
    log_with_level(crate::logging::LogLevel::Warn, line);
}

fn unsafe_log_enabled() -> bool {
    UNSAFE_LOG.load(Ordering::Relaxed)
}

fn truncate_utf8(s: &str, max: usize) -> &str {
    if s.len() <= max {
        return s;
    }
    let mut end = max;
    while end > 0 && !s.is_char_boundary(end) {
        end -= 1;
    }
    &s[..end]
}

fn normalize_host(host: &str) -> String {
    let trimmed = host.trim().trim_matches('.').to_lowercase();
    if let Ok(addr) = trimmed.parse::<SocketAddr>() {
        return addr.ip().to_string();
    }

    if let Some((base, _)) = trimmed.rsplit_once(':') {
        if base.contains('.')
            || base.contains('[')
            || base == "localhost"
            || base.parse::<std::net::IpAddr>().is_ok()
        {
            return base.trim_matches('[').trim_matches(']').to_string();
        }
    }

    trimmed.trim_matches('[').trim_matches(']').to_string()
}

fn contains_glob_pattern(host: &str) -> bool {
    host.contains('*') || host.contains('?')
}

fn normalize_allowed_entry(host: &str) -> String {
    let trimmed = host.trim().to_lowercase();
    if contains_glob_pattern(&trimmed) {
        trimmed
    } else {
        normalize_host(&trimmed)
    }
}

fn glob_matches(pattern: &str, host: &str) -> bool {
    let pattern = pattern.as_bytes();
    let host = host.as_bytes();
    let (mut pattern_idx, mut host_idx) = (0usize, 0usize);
    let mut star_idx = None;
    let mut match_idx = 0usize;

    while host_idx < host.len() {
        if pattern_idx < pattern.len()
            && (pattern[pattern_idx] == b'?' || pattern[pattern_idx] == host[host_idx])
        {
            pattern_idx += 1;
            host_idx += 1;
        } else if pattern_idx < pattern.len() && pattern[pattern_idx] == b'*' {
            star_idx = Some(pattern_idx);
            match_idx = host_idx;
            pattern_idx += 1;
        } else if let Some(star) = star_idx {
            pattern_idx = star + 1;
            match_idx += 1;
            host_idx = match_idx;
        } else {
            return false;
        }
    }

    while pattern_idx < pattern.len() && pattern[pattern_idx] == b'*' {
        pattern_idx += 1;
    }

    pattern_idx == pattern.len()
}

fn log_with_level(level: crate::logging::LogLevel, line: String) {
    if !crate::logging::enabled(level) {
        return;
    }
    let msg = crate::logging::render(level, &line);
    if let Ok(mut guard) = LOG_FILE.lock() {
        if let Some(ref mut file) = *guard {
            let _ = writeln!(file, "{}", msg);
            return;
        }
    }
    eprintln!("{}", msg);
}

#[cfg(test)]
mod tests {
    use hudsucker::Body;
    use hudsucker::hyper::{Request, StatusCode, Uri, header::CONTENT_TYPE, header::HOST};

    use super::{allowed, json_error_response, normalize_hosts, request_host};

    #[test]
    fn json_error_response_sets_status_and_json_content_type() {
        let response = json_error_response(StatusCode::BAD_REQUEST, "invalid_json", "bad input");

        assert_eq!(response.status(), StatusCode::BAD_REQUEST);
        assert_eq!(
            response
                .headers()
                .get(CONTENT_TYPE)
                .and_then(|value| value.to_str().ok()),
            Some("application/json")
        );
    }

    #[test]
    fn request_host_prefers_uri_authority_over_host_header() {
        let req = Request::builder()
            .uri("https://Api.OpenAI.com:443/v1/responses")
            .header(HOST, "ignored.example.com")
            .body(Body::empty())
            .expect("request");

        assert_eq!(request_host(&req).as_deref(), Some("api.openai.com"));
    }

    #[test]
    fn allowed_matches_glob_patterns() {
        assert!(allowed(
            &[
                String::from("*openai.com"),
                String::from("api.anthropic.com")
            ],
            "api.openai.com"
        ));
        assert!(allowed(&[String::from("api.groq.?om")], "api.groq.com"));
        assert!(!allowed(&[String::from("*mistral.ai")], "api.openai.com"));
    }

    #[test]
    fn request_host_falls_back_to_host_header_and_normalizes_ipv6() {
        let req = Request::builder()
            .uri(Uri::from_static("/v1/responses"))
            .header(HOST, " [2001:db8::1]:443 ")
            .body(Body::empty())
            .expect("request");

        assert_eq!(request_host(&req).as_deref(), Some("2001:db8::1"));
    }

    #[test]
    fn normalize_hosts_trims_case_ports_and_empty_entries() {
        let hosts = vec![
            " Api.OpenAI.com ".to_string(),
            "localhost:8080".to_string(),
            " [2001:db8::1]:443 ".to_string(),
            "   ".to_string(),
        ];

        assert_eq!(
            normalize_hosts(&hosts),
            vec![
                "api.openai.com".to_string(),
                "localhost".to_string(),
                "2001:db8::1".to_string(),
            ]
        );
    }

    #[test]
    fn allowed_matches_exact_suffix_localhost_and_ipv6_hosts() {
        let allowed_hosts = normalize_hosts(&[
            "api.openai.com".to_string(),
            "localhost".to_string(),
            "[2001:db8::1]:443".to_string(),
        ]);

        assert!(allowed(&allowed_hosts, "api.openai.com"));
        assert!(allowed(&allowed_hosts, "chat.api.openai.com"));
        assert!(allowed(&allowed_hosts, "LOCALHOST:8877"));
        assert!(allowed(&allowed_hosts, "[2001:db8::1]:443"));
        assert!(!allowed(&allowed_hosts, "badopenai.com"));
    }
}