dirge-agent 0.13.7

Minimalistic coding agent written in Rust, optimized for memory footprint and performance
use anyhow::Context;
use base64::Engine;
use sha2::{Digest, Sha256};
use std::io::{Read, Write};
use std::net::TcpListener;

pub(crate) fn verifier() -> String {
    format!(
        "{}{}{}",
        uuid::Uuid::new_v4().simple(),
        uuid::Uuid::new_v4().simple(),
        uuid::Uuid::new_v4().simple()
    )
}

pub(crate) fn challenge(verifier: &str) -> String {
    let digest = Sha256::digest(verifier.as_bytes());
    base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest)
}

pub(crate) fn wait_for_callback(
    listener: TcpListener,
    options: &CallbackOptions<'_>,
) -> anyhow::Result<(String, String)> {
    let (mut stream, _) = listener.accept()?;
    let mut buf = [0_u8; 8192];
    let len = stream.read(&mut buf)?;
    let request = String::from_utf8_lossy(&buf[..len]);
    let result =
        parse_callback_request_with_state(&request, options.error_context, options.expected_state);
    let (status, body) = match &result {
        Ok(_) => ("200 OK", options.success_body),
        Err(_) => ("400 Bad Request", options.failure_body),
    };
    write!(
        stream,
        "HTTP/1.1 {status}\r\nContent-Type: text/plain; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
        body.len(),
        body
    )?;
    result
}

pub(crate) struct CallbackOptions<'a> {
    pub(crate) success_body: &'a str,
    pub(crate) failure_body: &'a str,
    pub(crate) error_context: &'a str,
    pub(crate) expected_state: Option<&'a str>,
}

pub(crate) fn parse_callback_request(
    request: &str,
    error_context: &str,
) -> anyhow::Result<(String, String)> {
    let line = request
        .lines()
        .next()
        .with_context(|| format!("empty {error_context} callback request"))?;
    let target = line
        .split_whitespace()
        .nth(1)
        .with_context(|| format!("malformed {error_context} callback request"))?;
    let url = url::Url::parse(&format!("http://localhost{target}"))?;
    let code = url
        .query_pairs()
        .find(|(key, _)| key == "code")
        .map(|(_, value)| value.into_owned())
        .with_context(|| format!("{error_context} callback missing code"))?;
    let state = url
        .query_pairs()
        .find(|(key, _)| key == "state")
        .map(|(_, value)| value.into_owned())
        .with_context(|| format!("{error_context} callback missing state"))?;
    Ok((code, state))
}

pub(crate) fn parse_callback_request_with_state(
    request: &str,
    error_context: &str,
    expected_state: Option<&str>,
) -> anyhow::Result<(String, String)> {
    let (code, state) = parse_callback_request(request, error_context)?;
    if let Some(expected_state) = expected_state
        && state != expected_state
    {
        anyhow::bail!("{error_context} state mismatch");
    }
    Ok((code, state))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn pkce_challenge_uses_s256_url_safe_no_pad() {
        assert_eq!(
            challenge("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"),
            "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
        );
    }

    #[test]
    fn parse_callback_request_validates_expected_state() {
        let request =
            "GET /auth/callback?code=AUTH-CODE&state=STATE HTTP/1.1\r\nHost: localhost\r\n\r\n";

        assert_eq!(
            parse_callback_request_with_state(request, "OAuth", Some("STATE")).unwrap(),
            ("AUTH-CODE".to_string(), "STATE".to_string())
        );
        let err = parse_callback_request_with_state(request, "OAuth", Some("OTHER")).unwrap_err();
        assert!(err.to_string().contains("state mismatch"));
    }

    #[test]
    fn wait_for_callback_renders_failure_for_state_mismatch() {
        use std::io::{Read, Write};
        use std::net::TcpStream;

        let listener = TcpListener::bind(("127.0.0.1", 0)).unwrap();
        let addr = listener.local_addr().unwrap();
        let client = std::thread::spawn(move || {
            let mut stream = TcpStream::connect(addr).unwrap();
            write!(
                stream,
                "GET /auth/callback?code=AUTH-CODE&state=BAD HTTP/1.1\r\nHost: localhost\r\n\r\n"
            )
            .unwrap();
            let mut response = String::new();
            stream.read_to_string(&mut response).unwrap();
            response
        });

        let err = wait_for_callback(
            listener,
            &CallbackOptions {
                success_body: "success",
                failure_body: "failure",
                error_context: "OAuth",
                expected_state: Some("GOOD"),
            },
        )
        .unwrap_err();
        let response = client.join().unwrap();

        assert!(err.to_string().contains("state mismatch"));
        assert!(response.starts_with("HTTP/1.1 400 Bad Request"));
        assert!(response.ends_with("failure"));
    }
}