use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::JsFuture;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CryptoError {
CryptoUnavailable(String),
RandomError(String),
HashError(String),
SigningError(String),
Base64Error(String),
}
impl std::fmt::Display for CryptoError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::CryptoUnavailable(msg) => write!(f, "Crypto unavailable: {}", msg),
Self::RandomError(msg) => write!(f, "Random error: {}", msg),
Self::HashError(msg) => write!(f, "Hash error: {}", msg),
Self::SigningError(msg) => write!(f, "Signing error: {}", msg),
Self::Base64Error(msg) => write!(f, "Base64 error: {}", msg),
}
}
}
impl std::error::Error for CryptoError {
}
pub type CryptoResult<T> = Result<T, CryptoError>;
pub fn generate_random_bytes(length: usize) -> CryptoResult<Vec<u8>> {
let crypto = web_sys::window()
.ok_or_else(|| CryptoError::CryptoUnavailable("No window object".to_string()))?
.crypto()
.map_err(|_| CryptoError::CryptoUnavailable("No crypto object".to_string()))?;
let mut bytes = vec![0u8; length];
let array = js_sys::Uint8Array::new_with_length(length as u32);
crypto
.get_random_values_with_u8_array(&mut bytes)
.map_err(|e| CryptoError::RandomError(format!("{:?}", e)))?;
array.copy_from(&bytes);
Ok(bytes)
}
pub fn generate_token(byte_length: usize) -> CryptoResult<String> {
let bytes = generate_random_bytes(byte_length)?;
Ok(URL_SAFE_NO_PAD.encode(&bytes))
}
pub fn generate_authorization_code() -> CryptoResult<String> {
generate_token(32)
}
pub fn generate_refresh_token() -> CryptoResult<String> {
generate_token(32)
}
pub fn generate_family_id() -> CryptoResult<String> {
generate_token(16)
}
pub async fn hash_token(token: &str) -> CryptoResult<String> {
let crypto = web_sys::window()
.ok_or_else(|| CryptoError::CryptoUnavailable("No window object".to_string()))?
.crypto()
.map_err(|_| CryptoError::CryptoUnavailable("No crypto object".to_string()))?;
let subtle = crypto.subtle();
let data = js_sys::Uint8Array::from(token.as_bytes());
let promise = subtle
.digest_with_str_and_buffer_source("SHA-256", &data)
.map_err(|e| CryptoError::HashError(format!("{:?}", e)))?;
let result = JsFuture::from(promise)
.await
.map_err(|e| CryptoError::HashError(format!("{:?}", e)))?;
let array = js_sys::Uint8Array::new(&result);
let mut hash = vec![0u8; array.length() as usize];
array.copy_to(&mut hash);
Ok(URL_SAFE_NO_PAD.encode(&hash))
}
pub async fn verify_pkce(
code_verifier: &str,
code_challenge: &str,
method: &str,
) -> CryptoResult<bool> {
match method {
"S256" => {
let verifier_challenge = generate_code_challenge(code_verifier).await?;
Ok(constant_time_compare(&verifier_challenge, code_challenge))
}
"plain" => {
Ok(constant_time_compare(code_verifier, code_challenge))
}
_ => Err(CryptoError::HashError(format!(
"Unknown PKCE method: {}",
method
))),
}
}
pub async fn generate_code_challenge(code_verifier: &str) -> CryptoResult<String> {
let crypto = web_sys::window()
.ok_or_else(|| CryptoError::CryptoUnavailable("No window object".to_string()))?
.crypto()
.map_err(|_| CryptoError::CryptoUnavailable("No crypto object".to_string()))?;
let subtle = crypto.subtle();
let data = js_sys::Uint8Array::from(code_verifier.as_bytes());
let promise = subtle
.digest_with_str_and_buffer_source("SHA-256", &data)
.map_err(|e| CryptoError::HashError(format!("{:?}", e)))?;
let result = JsFuture::from(promise)
.await
.map_err(|e| CryptoError::HashError(format!("{:?}", e)))?;
let array = js_sys::Uint8Array::new(&result);
let mut hash = vec![0u8; array.length() as usize];
array.copy_to(&mut hash);
Ok(URL_SAFE_NO_PAD.encode(&hash))
}
pub fn constant_time_compare(a: &str, b: &str) -> bool {
let a_bytes = a.as_bytes();
let b_bytes = b.as_bytes();
let len_diff = (a_bytes.len() as isize).wrapping_sub(b_bytes.len() as isize);
let len_ne = ((len_diff | len_diff.wrapping_neg()) >> (isize::BITS - 1)) as u8;
let mut result = len_ne & 1;
let max_len = a_bytes.len().max(b_bytes.len());
for i in 0..max_len {
let x = a_bytes.get(i).copied().unwrap_or(0);
let y = b_bytes.get(i).copied().unwrap_or(0);
result |= x ^ y;
}
result == 0
}
pub fn validate_code_verifier(verifier: &str) -> bool {
let len = verifier.len();
if !(43..=128).contains(&len) {
return false;
}
verifier
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == '_' || c == '~')
}
pub fn now_secs() -> u64 {
(js_sys::Date::now() / 1000.0) as u64
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_constant_time_compare() {
assert!(constant_time_compare("hello", "hello"));
assert!(!constant_time_compare("hello", "world"));
assert!(!constant_time_compare("hello", "hell"));
assert!(!constant_time_compare("", "hello"));
}
#[test]
fn test_validate_code_verifier() {
assert!(validate_code_verifier(
"dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
)); assert!(validate_code_verifier(&"a".repeat(43)));
assert!(validate_code_verifier(&"a".repeat(128)));
assert!(!validate_code_verifier(&"a".repeat(42)));
assert!(!validate_code_verifier(&"a".repeat(129)));
assert!(!validate_code_verifier(&format!("{}!", "a".repeat(42))));
assert!(!validate_code_verifier(&format!("{}@", "a".repeat(42))));
assert!(!validate_code_verifier(&format!("{} ", "a".repeat(42))));
assert!(validate_code_verifier(&format!(
"abc-._~{}",
"x".repeat(39)
)));
}
#[test]
fn test_base64_url_encoding() {
let data = vec![0xfb, 0xff, 0xfe]; let encoded = URL_SAFE_NO_PAD.encode(&data);
assert!(!encoded.contains('+'));
assert!(!encoded.contains('/'));
assert!(!encoded.contains('='));
}
}