keyclaw 0.2.1

Local MITM proxy that keeps secrets out of LLM traffic
Documentation
use std::sync::Arc;
use std::sync::atomic::Ordering;

use http_body_util::BodyExt;
use hudsucker::{
    Body, HttpContext, HttpHandler, RequestOrResponse,
    hyper::{
        Method, Request, Response, StatusCode,
        header::{CONTENT_LENGTH, CONTENT_TYPE, HeaderMap, TRANSFER_ENCODING},
    },
};

use crate::errors::{CODE_BODY_TOO_LARGE, CODE_INVALID_JSON, CODE_REQUEST_TIMEOUT, code_of};

use super::KeyclawHttpHandler;
use super::common::{
    allowed, body_from_vec, header_value, is_json, is_json_payload, json_error_response, log_debug,
    log_replacements, log_warn, request_host, response_is_sse,
};
use super::streaming::SseStreamResolver;

impl HttpHandler for KeyclawHttpHandler {
    async fn handle_request(
        &mut self,
        _ctx: &HttpContext,
        req: Request<Body>,
    ) -> RequestOrResponse {
        let host = match request_host(&req) {
            Some(host) => host,
            None => {
                return json_error_response(
                    StatusCode::BAD_REQUEST,
                    CODE_INVALID_JSON,
                    "unable to determine target host",
                )
                .into();
            }
        };

        if !allowed(&self.allowed_hosts, &host) {
            return req.into();
        }
        self.intercepted.fetch_add(1, Ordering::SeqCst);
        log_debug(format!(
            "intercept {} {} (host={})",
            req.method(),
            req.uri().path(),
            host
        ));

        if req.method() == Method::CONNECT {
            return req.into();
        }

        if req
            .headers()
            .get("upgrade")
            .and_then(|value| value.to_str().ok())
            .map(|value| value.eq_ignore_ascii_case("websocket"))
            .unwrap_or(false)
        {
            let (mut parts, body) = req.into_parts();
            parts.headers.remove("sec-websocket-extensions");
            log_debug(
                "ws upgrade: stripped sec-websocket-extensions to disable compression".to_string(),
            );
            return Request::from_parts(parts, body).into();
        }

        let content_type = header_value(&req, CONTENT_TYPE.as_str()).unwrap_or_default();
        let content_type_is_json = is_json(&content_type);

        if self.max_body_bytes > 0 {
            if let Some(content_len) =
                header_value(&req, CONTENT_LENGTH.as_str()).and_then(|v| v.parse::<i64>().ok())
            {
                if content_len > self.max_body_bytes {
                    return json_error_response(
                        StatusCode::PAYLOAD_TOO_LARGE,
                        CODE_BODY_TOO_LARGE,
                        "request body exceeded maximum size",
                    )
                    .into();
                }
            }
        }

        let (parts, body) = req.into_parts();
        let collected = match tokio::time::timeout(self.body_timeout, body.collect()).await {
            Ok(Ok(collected)) => collected,
            Ok(Err(_)) => {
                return json_error_response(
                    StatusCode::BAD_REQUEST,
                    CODE_INVALID_JSON,
                    "cannot read request body",
                )
                .into();
            }
            Err(_) => {
                log_warn("body read timeout - returning timeout error".to_string());
                return json_error_response(
                    StatusCode::REQUEST_TIMEOUT,
                    CODE_REQUEST_TIMEOUT,
                    "request body read timed out",
                )
                .into();
            }
        };

        let body_bytes = collected.to_bytes();
        if self.max_body_bytes > 0 && (body_bytes.len() as i64) > self.max_body_bytes {
            return json_error_response(
                StatusCode::PAYLOAD_TOO_LARGE,
                CODE_BODY_TOO_LARGE,
                "request body exceeded maximum size",
            )
            .into();
        }

        let original_payload = body_bytes.to_vec();
        if original_payload.is_empty()
            || (!content_type_is_json && !is_json_payload(&original_payload))
        {
            return Request::from_parts(parts, body_from_vec(original_payload)).into();
        }

        let processor = Arc::clone(&self.processor);
        let payload = original_payload.clone();
        let rewritten = match tokio::time::timeout(
            self.body_timeout,
            tokio::task::spawn_blocking(move || processor.rewrite_and_evaluate(&payload)),
        )
        .await
        {
            Ok(Ok(Ok(result))) => result,
            Ok(Ok(Err(err))) => {
                let code = code_of(&err).unwrap_or("unknown");
                log_warn(format!("rewrite error ({code}): {err}"));
                if self.processor.strict_mode {
                    return request_rewrite_error_response(&err).into();
                }
                return Request::from_parts(parts, body_from_vec(original_payload)).into();
            }
            Ok(Err(err)) => {
                log_warn(format!("request processing failed: {err}"));
                if self.processor.strict_mode {
                    return json_error_response(
                        StatusCode::BAD_GATEWAY,
                        "rewrite_failed",
                        "request processing failed before forwarding upstream",
                    )
                    .into();
                }
                return Request::from_parts(parts, body_from_vec(original_payload)).into();
            }
            Err(_) => {
                log_warn("rewrite timeout".to_string());
                if self.processor.strict_mode {
                    return json_error_response(
                        StatusCode::REQUEST_TIMEOUT,
                        CODE_REQUEST_TIMEOUT,
                        "request rewrite timed out",
                    )
                    .into();
                }
                return Request::from_parts(parts, body_from_vec(original_payload)).into();
            }
        };

        let request_had_secrets = !rewritten.replacements.is_empty();
        if request_had_secrets {
            log_debug(format!(
                "request rewritten for host {host}: {}",
                self.processor.replacement_summary(&rewritten.replacements)
            ));
            log_replacements(&host, &original_payload, &rewritten.replacements);
            if !self.processor.dry_run {
                if let Err(err) = crate::audit::append_redactions(
                    self.audit_log_path.as_deref(),
                    &host,
                    &rewritten.replacements,
                ) {
                    log_warn(format!("audit log write failed: {err}"));
                }
            }
        }

        let rewritten_len = rewritten.body.len();
        let mut rewritten_req = Request::from_parts(parts, body_from_vec(rewritten.body));
        set_fixed_body_headers(rewritten_req.headers_mut(), rewritten_len);
        if !self.processor.dry_run && !rewritten.replacements.is_empty() {
            if let Ok(value) = crate::placeholder::CONTRACT_MARKER_VALUE.parse() {
                rewritten_req
                    .headers_mut()
                    .insert(crate::placeholder::CONTRACT_MARKER_KEY, value);
            }
        }
        rewritten_req.headers_mut().remove("accept-encoding");

        rewritten_req.into()
    }

    async fn handle_response(&mut self, _ctx: &HttpContext, res: Response<Body>) -> Response<Body> {
        let is_sse = response_is_sse(&res);

        if res.status() == StatusCode::SWITCHING_PROTOCOLS || res.status().is_informational() {
            return res;
        }

        if is_sse {
            let processor = Arc::clone(&self.processor);
            let (mut parts, body) = res.into_parts();
            let mut sse_resolver = SseStreamResolver::new(processor);
            parts.headers.remove(CONTENT_LENGTH);
            let new_body = body
                .map_frame(move |frame| match frame.into_data() {
                    Ok(data) => {
                        let rewritten = sse_resolver.process_frame(&data);
                        hudsucker::hyper::body::Frame::data(hudsucker::hyper::body::Bytes::from(
                            rewritten,
                        ))
                    }
                    Err(frame) => frame,
                })
                .boxed();
            return Response::from_parts(parts, new_body.into());
        }

        let (mut parts, body) = res.into_parts();
        let collected = match body.collect().await {
            Ok(collected) => collected,
            Err(_) => return Response::from_parts(parts, Body::empty()),
        };
        let mut body_bytes = collected.to_bytes().to_vec();

        let text = String::from_utf8_lossy(&body_bytes);
        if crate::placeholder::contains_placeholder_prefix(&text) {
            let processor = Arc::clone(&self.processor);
            let payload = body_bytes.clone();
            if let Ok(Ok(resolved)) =
                tokio::task::spawn_blocking(move || processor.resolve_text(&payload)).await
            {
                if resolved != body_bytes {
                    log_debug("response: resolved placeholders back to secrets".to_string());
                    body_bytes = resolved;
                }
            }
        }

        parts.headers.remove(TRANSFER_ENCODING);
        let body_len = body_bytes.len();
        let mut resp = Response::from_parts(parts, body_from_vec(body_bytes));
        set_fixed_body_headers(resp.headers_mut(), body_len);
        resp
    }

    async fn should_intercept(&mut self, _ctx: &HttpContext, req: &Request<Body>) -> bool {
        request_host(req)
            .map(|host| allowed(&self.allowed_hosts, &host))
            .unwrap_or(false)
    }
}

fn request_rewrite_error_response(err: &crate::errors::KeyclawError) -> Response<Body> {
    let code = code_of(err).unwrap_or("rewrite_failed");
    let message = err.display_without_code();

    let status = match code {
        CODE_BODY_TOO_LARGE => StatusCode::PAYLOAD_TOO_LARGE,
        CODE_REQUEST_TIMEOUT => StatusCode::REQUEST_TIMEOUT,
        CODE_INVALID_JSON => StatusCode::BAD_REQUEST,
        _ => StatusCode::BAD_GATEWAY,
    };

    json_error_response(status, code, &message)
}

fn set_fixed_body_headers(headers: &mut HeaderMap, len: usize) {
    headers.remove(TRANSFER_ENCODING);
    if let Ok(value) = len.to_string().parse() {
        headers.insert(CONTENT_LENGTH, value);
    }
}

#[cfg(test)]
mod tests {
    use hudsucker::hyper::header::{CONTENT_LENGTH, HeaderMap, HeaderValue, TRANSFER_ENCODING};

    use super::set_fixed_body_headers;

    #[test]
    fn set_fixed_body_headers_removes_transfer_encoding_and_sets_content_length() {
        let mut headers = HeaderMap::new();
        headers.insert(TRANSFER_ENCODING, HeaderValue::from_static("chunked"));

        set_fixed_body_headers(&mut headers, 42);

        assert!(!headers.contains_key(TRANSFER_ENCODING));
        assert_eq!(
            headers
                .get(CONTENT_LENGTH)
                .and_then(|value| value.to_str().ok()),
            Some("42")
        );
    }
}