use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use chrono::{DateTime, Utc};
use http::header::{HeaderMap, SET_COOKIE};
use ring::rand::SecureRandom;
use thiserror::Error;
pub(crate) fn base64url_decode(input: &str) -> Result<Vec<u8>, UtilError> {
let decoded = URL_SAFE_NO_PAD
.decode(input)
.map_err(|_| UtilError::Format("Failed to decode base64url".to_string()))?;
Ok(decoded)
}
pub(crate) fn base64url_encode(input: Vec<u8>) -> Result<String, UtilError> {
Ok(URL_SAFE_NO_PAD.encode(input))
}
pub(crate) fn gen_random_string(len: usize) -> Result<String, UtilError> {
let rng = ring::rand::SystemRandom::new();
let mut session_id = vec![0u8; len];
rng.fill(&mut session_id)
.map_err(|_| UtilError::Crypto("Failed to generate random string".to_string()))?;
let encoded = base64url_encode(session_id)
.map_err(|_| UtilError::Crypto("Failed to encode random string".to_string()))?;
Ok(encoded)
}
pub(crate) fn gen_random_string_with_entropy_validation(len: usize) -> Result<String, UtilError> {
const MAX_GENERATION_ATTEMPTS: usize = 10;
for attempt in 1..=MAX_GENERATION_ATTEMPTS {
let rng = ring::rand::SystemRandom::new();
let mut random_bytes = vec![0u8; len];
rng.fill(&mut random_bytes)
.map_err(|_| UtilError::Crypto("Failed to generate random bytes".to_string()))?;
if validate_entropy(&random_bytes) {
let encoded = base64url_encode(random_bytes)
.map_err(|_| UtilError::Crypto("Failed to encode random string".to_string()))?;
return Ok(encoded);
}
tracing::warn!(
"Low entropy detected in random generation, attempt {}/{}",
attempt,
MAX_GENERATION_ATTEMPTS
);
}
Err(UtilError::Crypto(
"Failed to generate sufficiently random string after max attempts".to_string(),
))
}
fn validate_entropy(bytes: &[u8]) -> bool {
if bytes.is_empty() {
return false;
}
if bytes.iter().all(|&b| b == 0) {
return false;
}
let first_byte = bytes[0];
if bytes.iter().all(|&b| b == first_byte) {
return false;
}
true
}
pub(crate) fn header_set_cookie<'a>(
headers: &'a mut HeaderMap,
name: String,
value: String,
_expires_at: DateTime<Utc>,
max_age: i64,
domain: Option<&str>,
) -> Result<&'a HeaderMap, UtilError> {
let domain_attr = domain.map(|d| format!("; Domain={d}")).unwrap_or_default();
let cookie = format!(
"{name}={value}; SameSite=Lax; Secure; HttpOnly; Path=/; Max-Age={max_age}{domain_attr}"
);
headers.append(
SET_COOKIE,
cookie
.parse()
.map_err(|_| UtilError::Cookie("Failed to parse cookie".to_string()))?,
);
Ok(headers)
}
fn ensure_ring_provider() {
let _ = rustls::crypto::ring::default_provider().install_default();
}
#[cfg(feature = "bundled-tls")]
fn rustls_config_with_webpki_roots() -> rustls::ClientConfig {
let provider = rustls::crypto::ring::default_provider();
let mut root_store = rustls::RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let mut config = rustls::ClientConfig::builder_with_provider(provider.into())
.with_safe_default_protocol_versions()
.expect("Failed to set TLS protocol versions")
.with_root_certificates(root_store)
.with_no_client_auth();
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
config
}
pub(crate) fn get_client() -> reqwest::Client {
ensure_ring_provider();
let builder = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.pool_idle_timeout(std::time::Duration::from_secs(90))
.pool_max_idle_per_host(32);
#[cfg(feature = "bundled-tls")]
let builder = builder.use_preconfigured_tls(rustls_config_with_webpki_roots());
builder.build().expect("Failed to create reqwest client")
}
#[derive(Debug, Error, Clone)]
pub enum UtilError {
#[error("Crypto error: {0}")]
Crypto(String),
#[error("Cookie error: {0}")]
Cookie(String),
#[error("Invalid format: {0}")]
Format(String),
}
#[cfg(test)]
mod tests;