aioduct 0.1.0

Async-native HTTP client built directly on hyper 1.x — no hyper-util, no legacy
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};

use http::header::{HeaderValue, WWW_AUTHENTICATE};
use http::{HeaderMap, Method, StatusCode, Uri};

#[derive(Clone)]
pub(crate) struct DigestAuth {
    username: String,
    password: String,
    nonce_count: Arc<AtomicU32>,
}

impl DigestAuth {
    pub(crate) fn new(username: String, password: String) -> Self {
        Self {
            username,
            password,
            nonce_count: Arc::new(AtomicU32::new(1)),
        }
    }

    pub(crate) fn needs_retry(&self, status: StatusCode, headers: &HeaderMap) -> bool {
        status == StatusCode::UNAUTHORIZED && headers.contains_key(WWW_AUTHENTICATE)
    }

    pub(crate) fn authorize(
        &self,
        method: &Method,
        uri: &Uri,
        headers: &HeaderMap,
    ) -> Option<HeaderValue> {
        let challenge = headers.get(WWW_AUTHENTICATE)?.to_str().ok()?;
        if !challenge.to_ascii_lowercase().starts_with("digest ") {
            return None;
        }

        let params = parse_challenge(&challenge[7..]);

        let realm = params.get("realm")?;
        let nonce = params.get("nonce")?;
        let qop = params.get("qop");
        let opaque = params.get("opaque");
        let algorithm = params.get("algorithm").map(|s| s.as_str()).unwrap_or("MD5");

        let path = uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/");

        let nc = self.nonce_count.fetch_add(1, Ordering::Relaxed);
        let nc_str = format!("{nc:08x}");
        let cnonce = format!("{:016x}", rand_u64());

        let ha1 = md5_hex(&format!("{}:{}:{}", self.username, realm, self.password));

        let ha2 = md5_hex(&format!("{}:{}", method.as_str(), path));

        let response = if qop.is_some_and(|q| q.contains("auth")) {
            md5_hex(&format!("{ha1}:{nonce}:{nc_str}:{cnonce}:auth:{ha2}"))
        } else {
            md5_hex(&format!("{ha1}:{nonce}:{ha2}"))
        };

        let mut value = format!(
            "Digest username=\"{}\", realm=\"{}\", nonce=\"{}\", uri=\"{}\", response=\"{}\"",
            self.username, realm, nonce, path, response
        );

        if qop.is_some_and(|q| q.contains("auth")) {
            value.push_str(&format!(", qop=auth, nc={nc_str}, cnonce=\"{cnonce}\""));
        }

        if let Some(opaque) = opaque {
            value.push_str(&format!(", opaque=\"{opaque}\""));
        }

        if algorithm != "MD5" {
            value.push_str(&format!(", algorithm={algorithm}"));
        }

        HeaderValue::from_str(&value).ok()
    }
}

fn parse_challenge(s: &str) -> HashMap<String, String> {
    let mut params = HashMap::new();
    let mut remaining = s.trim();

    while !remaining.is_empty() {
        remaining = remaining.trim_start_matches([',', ' ']);
        if remaining.is_empty() {
            break;
        }

        let eq_pos = match remaining.find('=') {
            Some(p) => p,
            None => break,
        };

        let key = remaining[..eq_pos].trim().to_ascii_lowercase();
        remaining = &remaining[eq_pos + 1..];

        let value = if remaining.starts_with('"') {
            remaining = &remaining[1..];
            match remaining.find('"') {
                Some(end) => {
                    let val = &remaining[..end];
                    remaining = &remaining[end + 1..];
                    val.to_string()
                }
                None => {
                    let val = remaining.to_string();
                    remaining = "";
                    val
                }
            }
        } else {
            let end = remaining.find(',').unwrap_or(remaining.len());
            let val = remaining[..end].trim().to_string();
            remaining = &remaining[end..];
            val
        };

        params.insert(key, value);
    }

    params
}

fn md5_hex(input: &str) -> String {
    use std::fmt::Write;

    let digest = md5_compute(input.as_bytes());
    let mut hex = String::with_capacity(32);
    for byte in &digest {
        write!(hex, "{byte:02x}").unwrap();
    }
    hex
}

fn md5_compute(data: &[u8]) -> [u8; 16] {
    let mut state: [u32; 4] = [0x67452301, 0xefcdab89, 0x98badcfe, 0x10325476];

    let bit_len = (data.len() as u64) * 8;
    let mut padded = data.to_vec();
    padded.push(0x80);
    while padded.len() % 64 != 56 {
        padded.push(0);
    }
    padded.extend_from_slice(&bit_len.to_le_bytes());

    for chunk in padded.chunks_exact(64) {
        let mut m = [0u32; 16];
        for (i, word) in chunk.chunks_exact(4).enumerate() {
            m[i] = u32::from_le_bytes([word[0], word[1], word[2], word[3]]);
        }
        md5_round(&mut state, &m);
    }

    let mut result = [0u8; 16];
    for (i, &word) in state.iter().enumerate() {
        result[i * 4..i * 4 + 4].copy_from_slice(&word.to_le_bytes());
    }
    result
}

fn md5_round(state: &mut [u32; 4], m: &[u32; 16]) {
    let (mut a, mut b, mut c, mut d) = (state[0], state[1], state[2], state[3]);

    const S: [[u32; 4]; 4] = [
        [7, 12, 17, 22],
        [5, 9, 14, 20],
        [4, 11, 16, 23],
        [6, 10, 15, 21],
    ];

    const K: [u32; 64] = [
        0xd76aa478, 0xe8c7b756, 0x242070db, 0xc1bdceee, 0xf57c0faf, 0x4787c62a, 0xa8304613,
        0xfd469501, 0x698098d8, 0x8b44f7af, 0xffff5bb1, 0x895cd7be, 0x6b901122, 0xfd987193,
        0xa679438e, 0x49b40821, 0xf61e2562, 0xc040b340, 0x265e5a51, 0xe9b6c7aa, 0xd62f105d,
        0x02441453, 0xd8a1e681, 0xe7d3fbc8, 0x21e1cde6, 0xc33707d6, 0xf4d50d87, 0x455a14ed,
        0xa9e3e905, 0xfcefa3f8, 0x676f02d9, 0x8d2a4c8a, 0xfffa3942, 0x8771f681, 0x6d9d6122,
        0xfde5380c, 0xa4beea44, 0x4bdecfa9, 0xf6bb4b60, 0xbebfbc70, 0x289b7ec6, 0xeaa127fa,
        0xd4ef3085, 0x04881d05, 0xd9d4d039, 0xe6db99e5, 0x1fa27cf8, 0xc4ac5665, 0xf4292244,
        0x432aff97, 0xab9423a7, 0xfc93a039, 0x655b59c3, 0x8f0ccc92, 0xffeff47d, 0x85845dd1,
        0x6fa87e4f, 0xfe2ce6e0, 0xa3014314, 0x4e0811a1, 0xf7537e82, 0xbd3af235, 0x2ad7d2bb,
        0xeb86d391,
    ];

    for i in 0..64 {
        let (f, g) = match i {
            0..16 => ((b & c) | ((!b) & d), i),
            16..32 => ((d & b) | ((!d) & c), (5 * i + 1) % 16),
            32..48 => (b ^ c ^ d, (3 * i + 5) % 16),
            _ => (c ^ (b | (!d)), (7 * i) % 16),
        };
        let temp = d;
        d = c;
        c = b;
        let round = i / 16;
        let shift = S[round][i % 4];
        b = b.wrapping_add(
            (a.wrapping_add(f).wrapping_add(K[i]).wrapping_add(m[g])).rotate_left(shift),
        );
        a = temp;
    }

    state[0] = state[0].wrapping_add(a);
    state[1] = state[1].wrapping_add(b);
    state[2] = state[2].wrapping_add(c);
    state[3] = state[3].wrapping_add(d);
}

fn rand_u64() -> u64 {
    use std::collections::hash_map::RandomState;
    use std::hash::{BuildHasher, Hasher};
    RandomState::new().build_hasher().finish()
}

impl std::fmt::Debug for DigestAuth {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("DigestAuth")
            .field("username", &self.username)
            .field("password", &"[redacted]")
            .finish()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn md5_rfc1321_test_vectors() {
        assert_eq!(md5_hex(""), "d41d8cd98f00b204e9800998ecf8427e");
        assert_eq!(md5_hex("a"), "0cc175b9c0f1b6a831c399e269772661");
        assert_eq!(md5_hex("abc"), "900150983cd24fb0d6963f7d28e17f72");
        assert_eq!(
            md5_hex("message digest"),
            "f96b697d7cb7938d525a2f31aaf161d0"
        );
    }

    #[test]
    fn parse_challenge_basic() {
        let params = parse_challenge(r#"realm="test", nonce="abc123", qop="auth", opaque="xyz""#);
        assert_eq!(params.get("realm").unwrap(), "test");
        assert_eq!(params.get("nonce").unwrap(), "abc123");
        assert_eq!(params.get("qop").unwrap(), "auth");
        assert_eq!(params.get("opaque").unwrap(), "xyz");
    }

    #[test]
    fn digest_response_generation() {
        let auth = DigestAuth::new("user".into(), "pass".into());
        let uri: Uri = "http://example.com/dir/index.html".parse().unwrap();
        let mut headers = HeaderMap::new();
        headers.insert(
            WWW_AUTHENTICATE,
            HeaderValue::from_static(
                r#"Digest realm="testrealm@host.com", nonce="dcd98b", qop="auth""#,
            ),
        );

        let value = auth.authorize(&Method::GET, &uri, &headers);
        assert!(value.is_some());
        let v = value.unwrap().to_str().unwrap().to_string();
        assert!(v.starts_with("Digest "));
        assert!(v.contains("username=\"user\""));
        assert!(v.contains("realm=\"testrealm@host.com\""));
        assert!(v.contains("qop=auth"));
    }
}