bucketwarden-auth 0.1.0

BucketWarden local identity, access key, and session credential store.
Documentation
use crate::AuthError;
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct JwtHeader {
    pub alg: String,
    pub typ: String,
    #[serde(default)]
    pub kid: Option<String>,
}

#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct JwtClaims {
    pub iss: String,
    pub sub: String,
    #[serde(default)]
    pub aud: Vec<String>,
    #[serde(default)]
    pub exp: Option<u64>,
    #[serde(default)]
    pub iat: Option<u64>,
}

pub(crate) fn verify_hs256_jwt(
    token: &str,
    secret: &[u8],
) -> Result<(JwtHeader, JwtClaims), AuthError> {
    let mut parts = token.split('.');
    let header_part = parts
        .next()
        .ok_or_else(|| AuthError::InvalidWebIdentityToken("missing header".to_string()))?;
    let claims_part = parts
        .next()
        .ok_or_else(|| AuthError::InvalidWebIdentityToken("missing claims".to_string()))?;
    let signature_part = parts
        .next()
        .ok_or_else(|| AuthError::InvalidWebIdentityToken("missing signature".to_string()))?;
    if parts.next().is_some() {
        return Err(AuthError::InvalidWebIdentityToken(
            "too many JWT segments".to_string(),
        ));
    }
    let header: JwtHeader = serde_json::from_slice(&base64url_decode(header_part)?)
        .map_err(|error| AuthError::InvalidWebIdentityToken(error.to_string()))?;
    if header.alg != "HS256" {
        return Err(AuthError::UnsupportedWebIdentityAlgorithm(header.alg));
    }
    let claims: JwtClaims = serde_json::from_slice(&base64url_decode(claims_part)?)
        .map_err(|error| AuthError::InvalidWebIdentityToken(error.to_string()))?;
    let signing_input = format!("{header_part}.{claims_part}");
    let expected = base64url_encode(&hmac_sha256(secret, signing_input.as_bytes()));
    if !constant_time_eq(signature_part.as_bytes(), expected.as_bytes()) {
        return Err(AuthError::InvalidWebIdentityToken(
            "signature verification failed".to_string(),
        ));
    }
    Ok((header, claims))
}

pub fn sign_hs256_jwt(
    kid: Option<&str>,
    claims: &JwtClaims,
    secret: &[u8],
) -> Result<String, AuthError> {
    let header = JwtHeader {
        alg: "HS256".to_string(),
        typ: "JWT".to_string(),
        kid: kid.map(str::to_string),
    };
    let header = serde_json::to_vec(&header)
        .map_err(|error| AuthError::InvalidWebIdentityToken(error.to_string()))?;
    let claims = serde_json::to_vec(claims)
        .map_err(|error| AuthError::InvalidWebIdentityToken(error.to_string()))?;
    let signing_input = format!(
        "{}.{}",
        base64url_encode(&header),
        base64url_encode(&claims)
    );
    let signature = base64url_encode(&hmac_sha256(secret, signing_input.as_bytes()));
    Ok(format!("{signing_input}.{signature}"))
}

fn base64url_encode(bytes: &[u8]) -> String {
    const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
    let mut out = String::new();
    for chunk in bytes.chunks(3) {
        let b0 = chunk[0];
        let b1 = *chunk.get(1).unwrap_or(&0);
        let b2 = *chunk.get(2).unwrap_or(&0);
        let n = ((b0 as u32) << 16) | ((b1 as u32) << 8) | b2 as u32;
        out.push(TABLE[((n >> 18) & 0x3f) as usize] as char);
        out.push(TABLE[((n >> 12) & 0x3f) as usize] as char);
        if chunk.len() > 1 {
            out.push(TABLE[((n >> 6) & 0x3f) as usize] as char);
        }
        if chunk.len() > 2 {
            out.push(TABLE[(n & 0x3f) as usize] as char);
        }
    }
    out
}

fn base64url_decode(value: &str) -> Result<Vec<u8>, AuthError> {
    let mut input = value.replace('-', "+").replace('_', "/");
    while input.len() % 4 != 0 {
        input.push('=');
    }
    base64_decode(&input)
        .ok_or_else(|| AuthError::InvalidWebIdentityToken("invalid base64url segment".to_string()))
}

fn base64_decode(value: &str) -> Option<Vec<u8>> {
    let bytes = value.as_bytes();
    if bytes.is_empty() || bytes.len() % 4 != 0 {
        return None;
    }
    let mut out = Vec::new();
    for chunk in bytes.chunks(4) {
        let mut n = 0u32;
        let mut padding = 0usize;
        for byte in chunk {
            n <<= 6;
            match *byte {
                b'A'..=b'Z' => n |= (*byte - b'A') as u32,
                b'a'..=b'z' => n |= (*byte - b'a' + 26) as u32,
                b'0'..=b'9' => n |= (*byte - b'0' + 52) as u32,
                b'+' => n |= 62,
                b'/' => n |= 63,
                b'=' => padding += 1,
                _ => return None,
            }
        }
        if padding > 2 {
            return None;
        }
        out.push(((n >> 16) & 0xff) as u8);
        if padding < 2 {
            out.push(((n >> 8) & 0xff) as u8);
        }
        if padding < 1 {
            out.push((n & 0xff) as u8);
        }
    }
    Some(out)
}

fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
    let mut diff = a.len() ^ b.len();
    for index in 0..a.len().max(b.len()) {
        let left = a.get(index).copied().unwrap_or_default();
        let right = b.get(index).copied().unwrap_or_default();
        diff |= usize::from(left ^ right);
    }
    diff == 0
}

fn hmac_sha256(key: &[u8], message: &[u8]) -> [u8; 32] {
    let mut key_block = [0u8; 64];
    if key.len() > 64 {
        key_block[..32].copy_from_slice(&sha256(key));
    } else {
        key_block[..key.len()].copy_from_slice(key);
    }
    let mut outer = [0x5cu8; 64];
    let mut inner = [0x36u8; 64];
    for index in 0..64 {
        outer[index] ^= key_block[index];
        inner[index] ^= key_block[index];
    }
    let mut inner_message = Vec::with_capacity(64 + message.len());
    inner_message.extend_from_slice(&inner);
    inner_message.extend_from_slice(message);
    let inner_hash = sha256(&inner_message);
    let mut outer_message = Vec::with_capacity(96);
    outer_message.extend_from_slice(&outer);
    outer_message.extend_from_slice(&inner_hash);
    sha256(&outer_message)
}

fn sha256(input: &[u8]) -> [u8; 32] {
    const K: [u32; 64] = [
        0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4,
        0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe,
        0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f,
        0x4a7484aa, 0x5cb0a9dc, 0x76f988da, 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7,
        0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc,
        0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b,
        0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, 0x19a4c116,
        0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
        0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7,
        0xc67178f2,
    ];
    let mut h = [
        0x6a09e667u32,
        0xbb67ae85,
        0x3c6ef372,
        0xa54ff53a,
        0x510e527f,
        0x9b05688c,
        0x1f83d9ab,
        0x5be0cd19,
    ];
    let bit_len = (input.len() as u64) * 8;
    let mut message = input.to_vec();
    message.push(0x80);
    while message.len() % 64 != 56 {
        message.push(0);
    }
    message.extend_from_slice(&bit_len.to_be_bytes());
    for chunk in message.chunks_exact(64) {
        let mut w = [0u32; 64];
        for (index, word) in w.iter_mut().take(16).enumerate() {
            let offset = index * 4;
            *word = u32::from_be_bytes([
                chunk[offset],
                chunk[offset + 1],
                chunk[offset + 2],
                chunk[offset + 3],
            ]);
        }
        for index in 16..64 {
            let s0 = w[index - 15].rotate_right(7)
                ^ w[index - 15].rotate_right(18)
                ^ (w[index - 15] >> 3);
            let s1 = w[index - 2].rotate_right(17)
                ^ w[index - 2].rotate_right(19)
                ^ (w[index - 2] >> 10);
            w[index] = w[index - 16]
                .wrapping_add(s0)
                .wrapping_add(w[index - 7])
                .wrapping_add(s1);
        }
        let mut a = h[0];
        let mut b = h[1];
        let mut c = h[2];
        let mut d = h[3];
        let mut e = h[4];
        let mut f = h[5];
        let mut g = h[6];
        let mut hh = h[7];
        for index in 0..64 {
            let s1 = e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25);
            let ch = (e & f) ^ ((!e) & g);
            let temp1 = hh
                .wrapping_add(s1)
                .wrapping_add(ch)
                .wrapping_add(K[index])
                .wrapping_add(w[index]);
            let s0 = a.rotate_right(2) ^ a.rotate_right(13) ^ a.rotate_right(22);
            let maj = (a & b) ^ (a & c) ^ (b & c);
            let temp2 = s0.wrapping_add(maj);
            hh = g;
            g = f;
            f = e;
            e = d.wrapping_add(temp1);
            d = c;
            c = b;
            b = a;
            a = temp1.wrapping_add(temp2);
        }
        h[0] = h[0].wrapping_add(a);
        h[1] = h[1].wrapping_add(b);
        h[2] = h[2].wrapping_add(c);
        h[3] = h[3].wrapping_add(d);
        h[4] = h[4].wrapping_add(e);
        h[5] = h[5].wrapping_add(f);
        h[6] = h[6].wrapping_add(g);
        h[7] = h[7].wrapping_add(hh);
    }
    let mut out = [0u8; 32];
    for (index, word) in h.iter().enumerate() {
        out[index * 4..index * 4 + 4].copy_from_slice(&word.to_be_bytes());
    }
    out
}