use crate::AuthError;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct JwtHeader {
pub alg: String,
pub typ: String,
#[serde(default)]
pub kid: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
pub struct JwtClaims {
pub iss: String,
pub sub: String,
#[serde(default)]
pub aud: Vec<String>,
#[serde(default)]
pub exp: Option<u64>,
#[serde(default)]
pub iat: Option<u64>,
}
pub(crate) fn verify_hs256_jwt(
token: &str,
secret: &[u8],
) -> Result<(JwtHeader, JwtClaims), AuthError> {
let mut parts = token.split('.');
let header_part = parts
.next()
.ok_or_else(|| AuthError::InvalidWebIdentityToken("missing header".to_string()))?;
let claims_part = parts
.next()
.ok_or_else(|| AuthError::InvalidWebIdentityToken("missing claims".to_string()))?;
let signature_part = parts
.next()
.ok_or_else(|| AuthError::InvalidWebIdentityToken("missing signature".to_string()))?;
if parts.next().is_some() {
return Err(AuthError::InvalidWebIdentityToken(
"too many JWT segments".to_string(),
));
}
let header: JwtHeader = serde_json::from_slice(&base64url_decode(header_part)?)
.map_err(|error| AuthError::InvalidWebIdentityToken(error.to_string()))?;
if header.alg != "HS256" {
return Err(AuthError::UnsupportedWebIdentityAlgorithm(header.alg));
}
let claims: JwtClaims = serde_json::from_slice(&base64url_decode(claims_part)?)
.map_err(|error| AuthError::InvalidWebIdentityToken(error.to_string()))?;
let signing_input = format!("{header_part}.{claims_part}");
let expected = base64url_encode(&hmac_sha256(secret, signing_input.as_bytes()));
if !constant_time_eq(signature_part.as_bytes(), expected.as_bytes()) {
return Err(AuthError::InvalidWebIdentityToken(
"signature verification failed".to_string(),
));
}
Ok((header, claims))
}
pub fn sign_hs256_jwt(
kid: Option<&str>,
claims: &JwtClaims,
secret: &[u8],
) -> Result<String, AuthError> {
let header = JwtHeader {
alg: "HS256".to_string(),
typ: "JWT".to_string(),
kid: kid.map(str::to_string),
};
let header = serde_json::to_vec(&header)
.map_err(|error| AuthError::InvalidWebIdentityToken(error.to_string()))?;
let claims = serde_json::to_vec(claims)
.map_err(|error| AuthError::InvalidWebIdentityToken(error.to_string()))?;
let signing_input = format!(
"{}.{}",
base64url_encode(&header),
base64url_encode(&claims)
);
let signature = base64url_encode(&hmac_sha256(secret, signing_input.as_bytes()));
Ok(format!("{signing_input}.{signature}"))
}
fn base64url_encode(bytes: &[u8]) -> String {
const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
let mut out = String::new();
for chunk in bytes.chunks(3) {
let b0 = chunk[0];
let b1 = *chunk.get(1).unwrap_or(&0);
let b2 = *chunk.get(2).unwrap_or(&0);
let n = ((b0 as u32) << 16) | ((b1 as u32) << 8) | b2 as u32;
out.push(TABLE[((n >> 18) & 0x3f) as usize] as char);
out.push(TABLE[((n >> 12) & 0x3f) as usize] as char);
if chunk.len() > 1 {
out.push(TABLE[((n >> 6) & 0x3f) as usize] as char);
}
if chunk.len() > 2 {
out.push(TABLE[(n & 0x3f) as usize] as char);
}
}
out
}
fn base64url_decode(value: &str) -> Result<Vec<u8>, AuthError> {
let mut input = value.replace('-', "+").replace('_', "/");
while input.len() % 4 != 0 {
input.push('=');
}
base64_decode(&input)
.ok_or_else(|| AuthError::InvalidWebIdentityToken("invalid base64url segment".to_string()))
}
fn base64_decode(value: &str) -> Option<Vec<u8>> {
let bytes = value.as_bytes();
if bytes.is_empty() || bytes.len() % 4 != 0 {
return None;
}
let mut out = Vec::new();
for chunk in bytes.chunks(4) {
let mut n = 0u32;
let mut padding = 0usize;
for byte in chunk {
n <<= 6;
match *byte {
b'A'..=b'Z' => n |= (*byte - b'A') as u32,
b'a'..=b'z' => n |= (*byte - b'a' + 26) as u32,
b'0'..=b'9' => n |= (*byte - b'0' + 52) as u32,
b'+' => n |= 62,
b'/' => n |= 63,
b'=' => padding += 1,
_ => return None,
}
}
if padding > 2 {
return None;
}
out.push(((n >> 16) & 0xff) as u8);
if padding < 2 {
out.push(((n >> 8) & 0xff) as u8);
}
if padding < 1 {
out.push((n & 0xff) as u8);
}
}
Some(out)
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
let mut diff = a.len() ^ b.len();
for index in 0..a.len().max(b.len()) {
let left = a.get(index).copied().unwrap_or_default();
let right = b.get(index).copied().unwrap_or_default();
diff |= usize::from(left ^ right);
}
diff == 0
}
fn hmac_sha256(key: &[u8], message: &[u8]) -> [u8; 32] {
let mut key_block = [0u8; 64];
if key.len() > 64 {
key_block[..32].copy_from_slice(&sha256(key));
} else {
key_block[..key.len()].copy_from_slice(key);
}
let mut outer = [0x5cu8; 64];
let mut inner = [0x36u8; 64];
for index in 0..64 {
outer[index] ^= key_block[index];
inner[index] ^= key_block[index];
}
let mut inner_message = Vec::with_capacity(64 + message.len());
inner_message.extend_from_slice(&inner);
inner_message.extend_from_slice(message);
let inner_hash = sha256(&inner_message);
let mut outer_message = Vec::with_capacity(96);
outer_message.extend_from_slice(&outer);
outer_message.extend_from_slice(&inner_hash);
sha256(&outer_message)
}
fn sha256(input: &[u8]) -> [u8; 32] {
const K: [u32; 64] = [
0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4,
0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe,
0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f,
0x4a7484aa, 0x5cb0a9dc, 0x76f988da, 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7,
0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc,
0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b,
0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, 0x19a4c116,
0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7,
0xc67178f2,
];
let mut h = [
0x6a09e667u32,
0xbb67ae85,
0x3c6ef372,
0xa54ff53a,
0x510e527f,
0x9b05688c,
0x1f83d9ab,
0x5be0cd19,
];
let bit_len = (input.len() as u64) * 8;
let mut message = input.to_vec();
message.push(0x80);
while message.len() % 64 != 56 {
message.push(0);
}
message.extend_from_slice(&bit_len.to_be_bytes());
for chunk in message.chunks_exact(64) {
let mut w = [0u32; 64];
for (index, word) in w.iter_mut().take(16).enumerate() {
let offset = index * 4;
*word = u32::from_be_bytes([
chunk[offset],
chunk[offset + 1],
chunk[offset + 2],
chunk[offset + 3],
]);
}
for index in 16..64 {
let s0 = w[index - 15].rotate_right(7)
^ w[index - 15].rotate_right(18)
^ (w[index - 15] >> 3);
let s1 = w[index - 2].rotate_right(17)
^ w[index - 2].rotate_right(19)
^ (w[index - 2] >> 10);
w[index] = w[index - 16]
.wrapping_add(s0)
.wrapping_add(w[index - 7])
.wrapping_add(s1);
}
let mut a = h[0];
let mut b = h[1];
let mut c = h[2];
let mut d = h[3];
let mut e = h[4];
let mut f = h[5];
let mut g = h[6];
let mut hh = h[7];
for index in 0..64 {
let s1 = e.rotate_right(6) ^ e.rotate_right(11) ^ e.rotate_right(25);
let ch = (e & f) ^ ((!e) & g);
let temp1 = hh
.wrapping_add(s1)
.wrapping_add(ch)
.wrapping_add(K[index])
.wrapping_add(w[index]);
let s0 = a.rotate_right(2) ^ a.rotate_right(13) ^ a.rotate_right(22);
let maj = (a & b) ^ (a & c) ^ (b & c);
let temp2 = s0.wrapping_add(maj);
hh = g;
g = f;
f = e;
e = d.wrapping_add(temp1);
d = c;
c = b;
b = a;
a = temp1.wrapping_add(temp2);
}
h[0] = h[0].wrapping_add(a);
h[1] = h[1].wrapping_add(b);
h[2] = h[2].wrapping_add(c);
h[3] = h[3].wrapping_add(d);
h[4] = h[4].wrapping_add(e);
h[5] = h[5].wrapping_add(f);
h[6] = h[6].wrapping_add(g);
h[7] = h[7].wrapping_add(hh);
}
let mut out = [0u8; 32];
for (index, word) in h.iter().enumerate() {
out[index * 4..index * 4 + 4].copy_from_slice(&word.to_be_bytes());
}
out
}