keyclaw 0.2.1

Local MITM proxy that keeps secrets out of LLM traffic
Documentation
use futures::{SinkExt, Stream, StreamExt};
use hudsucker::{
    WebSocketContext, WebSocketHandler,
    tokio_tungstenite::tungstenite::{self, Message},
};

use crate::errors::{CODE_INVALID_JSON, KeyclawError, code_of};

use super::KeyclawHttpHandler;
use super::common::{log_debug, log_warn, normalize_host_value};

fn is_expected_websocket_close_error(err: &tungstenite::Error) -> bool {
    match err {
        tungstenite::Error::ConnectionClosed | tungstenite::Error::AlreadyClosed => true,
        tungstenite::Error::Protocol(
            tungstenite::error::ProtocolError::ResetWithoutClosingHandshake,
        ) => true,
        tungstenite::Error::Io(io_err) => {
            io_err.kind() == std::io::ErrorKind::UnexpectedEof
                && io_err.to_string().contains("close_notify")
        }
        _ => false,
    }
}

fn uses_input_only_ws_rewrite_path(path: &str) -> bool {
    path.starts_with("/backend-api/codex/responses")
}

fn uses_input_only_ws_rewrite(ctx: &WebSocketContext) -> bool {
    matches!(
        ctx,
        WebSocketContext::ClientToServer { dst, .. }
            if uses_input_only_ws_rewrite_path(dst.path())
    )
}

enum WsRequestRewrite {
    Unchanged,
    Rewritten(String, Vec<crate::placeholder::Replacement>),
    Blocked(KeyclawError),
}

impl WebSocketHandler for KeyclawHttpHandler {
    async fn handle_websocket(
        mut self,
        ctx: WebSocketContext,
        mut stream: impl Stream<Item = Result<Message, tungstenite::Error>> + Unpin + Send + 'static,
        mut sink: impl futures::Sink<Message, Error = tungstenite::Error> + Unpin + Send + 'static,
    ) {
        while let Some(message) = stream.next().await {
            match message {
                Ok(message) => {
                    let Some(message) = self.handle_message(&ctx, message).await else {
                        continue;
                    };
                    let should_close = matches!(message, Message::Close(_));

                    match sink.send(message).await {
                        Err(tungstenite::Error::ConnectionClosed) => {}
                        Err(err) => log_warn(format!("ws send error: {err}")),
                        Ok(()) => {}
                    }
                    if should_close {
                        break;
                    }
                }
                Err(err) => {
                    if !is_expected_websocket_close_error(&err) {
                        log_warn(format!("ws message error: {err}"));
                        match sink.send(Message::Close(None)).await {
                            Err(tungstenite::Error::ConnectionClosed) => {}
                            Err(close_err) => {
                                log_warn(format!("ws close error: {close_err}"));
                            }
                            Ok(()) => {}
                        }
                    }

                    break;
                }
            }
        }
    }

    async fn handle_message(&mut self, ctx: &WebSocketContext, msg: Message) -> Option<Message> {
        if !matches!(ctx, WebSocketContext::ClientToServer { .. }) {
            if let Message::Text(text) = &msg {
                if let Some(value) = self.resolve_ws_response_text(text.as_str()) {
                    log_debug("ws response: resolved placeholders".to_string());
                    return Some(Message::Text(value.into()));
                }
            }
            return Some(msg);
        }

        if let Message::Text(text) = &msg {
            match self.rewrite_ws_request_text(text.as_str(), uses_input_only_ws_rewrite(ctx)) {
                WsRequestRewrite::Rewritten(value, replacements) => {
                    if let Some(host) = websocket_request_host(ctx) {
                        if !self.processor.dry_run {
                            if let Err(err) = crate::audit::append_redactions(
                                self.audit_log_path.as_deref(),
                                &host,
                                &replacements,
                            ) {
                                log_warn(format!("audit log write failed: {err}"));
                            }
                        }
                    }
                    log_debug(format!(
                        "ws request: redacted {} secret(s)",
                        replacements.len()
                    ));
                    return Some(Message::Text(value.into()));
                }
                WsRequestRewrite::Blocked(err) => {
                    let code = code_of(&err).unwrap_or("unknown");
                    log_warn(format!("ws rewrite error ({code}): {err}"));
                    return Some(Message::Close(None));
                }
                WsRequestRewrite::Unchanged => {}
            }
        }
        Some(msg)
    }
}

impl KeyclawHttpHandler {
    fn rewrite_ws_request_text(&self, text: &str, input_only: bool) -> WsRequestRewrite {
        let payload = text.as_bytes().to_vec();
        let rewrite = if input_only {
            self.processor.rewrite_and_evaluate_codex_ws(&payload)
        } else {
            self.processor.rewrite_and_evaluate(&payload)
        };
        match rewrite {
            Ok(result) if result.replacements.is_empty() => WsRequestRewrite::Unchanged,
            Ok(result) => match String::from_utf8(result.body) {
                Ok(value) => WsRequestRewrite::Rewritten(value, result.replacements),
                Err(err) => self.ws_rewrite_failure(KeyclawError::coded_with_source(
                    CODE_INVALID_JSON,
                    "rewrite produced non-utf8 websocket payload",
                    err,
                )),
            },
            Err(err) => self.ws_rewrite_failure(err),
        }
    }

    fn ws_rewrite_failure(&self, err: KeyclawError) -> WsRequestRewrite {
        if self.processor.strict_mode {
            WsRequestRewrite::Blocked(err)
        } else {
            let code = code_of(&err).unwrap_or("unknown");
            log_warn(format!("ws rewrite error ({code}): {err}"));
            WsRequestRewrite::Unchanged
        }
    }

    fn resolve_ws_response_text(&self, text: &str) -> Option<String> {
        if !crate::placeholder::contains_placeholder_prefix(text) {
            return None;
        }

        self.processor
            .resolve_text(text.as_bytes())
            .ok()
            .and_then(|resolved| String::from_utf8(resolved).ok())
    }
}

fn websocket_request_host(ctx: &WebSocketContext) -> Option<String> {
    match ctx {
        WebSocketContext::ClientToServer { dst, .. } => dst
            .authority()
            .map(|authority| normalize_host_value(authority.as_str())),
        _ => None,
    }
}

#[cfg(test)]
mod tests {
    use std::sync::Arc;
    use std::sync::atomic::AtomicI64;
    use std::time::{Duration, SystemTime, UNIX_EPOCH};

    use hudsucker::tokio_tungstenite::tungstenite::{self, error::ProtocolError};

    use super::super::KeyclawHttpHandler;
    use super::{
        WsRequestRewrite, is_expected_websocket_close_error, uses_input_only_ws_rewrite_path,
    };
    use crate::errors::{CODE_INVALID_JSON, code_of};
    use crate::gitleaks_rules::RuleSet;
    use crate::pipeline::Processor;
    use crate::placeholder;
    use crate::vault::Store;

    fn empty_ruleset() -> Arc<RuleSet> {
        Arc::new(RuleSet::from_toml("rules = []").expect("empty ruleset"))
    }

    fn sample_secret_ruleset() -> Arc<RuleSet> {
        let mut rules = RuleSet::from_toml(
            r#"
[[rules]]
id = "generic-api-key"
regex = 'api_key = ([A-Za-z0-9]{30,})'
secretGroup = 1
entropy = 1
keywords = ["api_key"]
"#,
        )
        .expect("sample ruleset");
        rules.entropy_config.enabled = false;
        Arc::new(rules)
    }

    fn test_handler(strict_mode: bool, ruleset: Arc<RuleSet>) -> KeyclawHttpHandler {
        let unique = SystemTime::now()
            .duration_since(UNIX_EPOCH)
            .expect("system time")
            .as_nanos();
        let vault = Arc::new(Store::new(
            std::env::temp_dir().join(format!("keyclaw-ws-{unique}.enc")),
            "test-passphrase".to_string(),
        ));
        let processor = Arc::new(Processor {
            vault: Some(vault),
            ruleset,
            second_pass_scanner: None,
            max_body_size: 1 << 20,
            strict_mode,
            notice_mode: crate::redaction::NoticeMode::Verbose,
            dry_run: false,
        });

        KeyclawHttpHandler {
            allowed_hosts: vec!["api.openai.com".to_string()],
            processor,
            max_body_bytes: 1 << 20,
            body_timeout: Duration::from_secs(1),
            audit_log_path: None,
            intercepted: Arc::new(AtomicI64::new(0)),
        }
    }

    #[test]
    fn tls_unexpected_eof_is_treated_as_expected_ws_close() {
        let err = tungstenite::Error::Io(std::io::Error::new(
            std::io::ErrorKind::UnexpectedEof,
            "peer closed connection without sending TLS close_notify",
        ));

        assert!(is_expected_websocket_close_error(&err));
    }

    #[test]
    fn reset_without_closing_handshake_is_treated_as_expected_ws_close() {
        let err = tungstenite::Error::Protocol(ProtocolError::ResetWithoutClosingHandshake);

        assert!(is_expected_websocket_close_error(&err));
    }

    #[test]
    fn unrelated_io_errors_are_not_treated_as_expected_ws_close() {
        let err = tungstenite::Error::Io(std::io::Error::other("socket exploded"));

        assert!(!is_expected_websocket_close_error(&err));
    }

    #[test]
    fn codex_responses_ws_uses_input_only_scope() {
        assert!(uses_input_only_ws_rewrite_path(
            "/backend-api/codex/responses"
        ));
    }

    #[test]
    fn non_codex_ws_keeps_full_scope() {
        assert!(!uses_input_only_ws_rewrite_path("/backend-api/other"));
    }

    #[test]
    fn ws_request_text_messages_are_redacted() {
        let handler = test_handler(true, sample_secret_ruleset());
        let secret = "aB3dE5fG7hI9jK1lM3nO5pQ7rS9tU1v";
        let rewrite = handler.rewrite_ws_request_text(
            &format!(r#"{{"messages":[{{"content":"api_key = {}"}}]}}"#, secret),
            false,
        );
        let WsRequestRewrite::Rewritten(text, _) = rewrite else {
            panic!("expected rewritten websocket message");
        };

        assert!(!text.contains(secret), "text={text}");
        assert!(
            placeholder::contains_complete_placeholder(&text),
            "text={text}"
        );
    }

    #[test]
    fn ws_request_rewrite_errors_block_in_strict_mode() {
        let handler = test_handler(true, empty_ruleset());

        let rewrite = handler.rewrite_ws_request_text(r#"{"messages":[{"content":"oops"}"#, false);

        let WsRequestRewrite::Blocked(err) = rewrite else {
            panic!("strict mode should block malformed websocket payloads");
        };
        assert_eq!(code_of(&err), Some(CODE_INVALID_JSON));
    }

    #[test]
    fn ws_request_rewrite_errors_pass_through_when_fail_closed_disabled() {
        let handler = test_handler(false, empty_ruleset());

        let rewrite = handler.rewrite_ws_request_text(r#"{"messages":[{"content":"oops"}"#, false);

        assert!(
            matches!(rewrite, WsRequestRewrite::Unchanged),
            "non-strict mode should preserve pass-through behavior"
        );
    }

    #[test]
    fn ws_response_text_messages_are_resolved() {
        let temp = tempfile::tempdir().expect("tempdir");
        let vault = Arc::new(Store::new(
            temp.path().join("vault.enc"),
            "test-passphrase".to_string(),
        ));
        let processor = Arc::new(Processor {
            vault: Some(Arc::clone(&vault)),
            ruleset: empty_ruleset(),
            second_pass_scanner: None,
            max_body_size: 1 << 20,
            strict_mode: true,
            notice_mode: crate::redaction::NoticeMode::Verbose,
            dry_run: false,
        });
        let handler = KeyclawHttpHandler {
            allowed_hosts: vec!["api.openai.com".to_string()],
            processor,
            max_body_bytes: 1 << 20,
            body_timeout: Duration::from_secs(1),
            audit_log_path: None,
            intercepted: Arc::new(AtomicI64::new(0)),
        };
        let request_secret = "api_key = aB3dE5fG7hI9jK1lM3nO5pQ7rS9tU1v";
        let id = vault.store_secret(request_secret).expect("store secret");
        let placeholder = placeholder::make(&id);
        let text = handler
            .resolve_ws_response_text(&format!(r#"{{"content":"{}"}}"#, placeholder))
            .expect("resolved message");

        assert!(text.contains(request_secret), "text={text}");
        assert!(!text.contains(&placeholder), "text={text}");
    }
}