use base64::Engine;
use blueprint_std::rand::{CryptoRng, RngCore};
use core::fmt::Display;
use std::collections::BTreeMap;
use crate::types::ServiceId;
pub const CUSTOM_ENGINE: base64::engine::GeneralPurpose = base64::engine::GeneralPurpose::new(
&base64::alphabet::URL_SAFE,
base64::engine::general_purpose::NO_PAD,
);
pub struct ApiTokenGenerator {
prefix: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct GeneratedApiToken {
plaintext: String,
pub(crate) token: String,
pub(crate) service_id: ServiceId,
expires_at: Option<u64>,
pub(crate) additional_headers: BTreeMap<String, String>,
}
impl Display for GeneratedApiToken {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}", self.token)
}
}
impl Default for ApiTokenGenerator {
fn default() -> Self {
ApiTokenGenerator::new()
}
}
impl ApiTokenGenerator {
pub fn new() -> Self {
ApiTokenGenerator {
prefix: String::new(),
}
}
pub fn with_prefix(prefix: &str) -> Self {
ApiTokenGenerator {
prefix: prefix.to_string(),
}
}
pub fn generate_token<R: RngCore + CryptoRng>(
&self,
service_id: ServiceId,
rng: &mut R,
) -> GeneratedApiToken {
self.generate_token_with_expiration_and_headers(service_id, 0, BTreeMap::new(), rng)
}
pub fn generate_token_with_expiration<R: RngCore + CryptoRng>(
&self,
service_id: ServiceId,
expires_at: u64,
rng: &mut R,
) -> GeneratedApiToken {
self.generate_token_with_expiration_and_headers(
service_id,
expires_at,
BTreeMap::new(),
rng,
)
}
pub fn generate_token_with_expiration_and_headers<R: RngCore + CryptoRng>(
&self,
service_id: ServiceId,
expires_at: u64,
additional_headers: BTreeMap<String, String>,
rng: &mut R,
) -> GeneratedApiToken {
use tiny_keccak::Hasher;
let mut token = vec![0u8; 40];
rng.fill_bytes(&mut token);
let checksum = crc32fast::hash(&token);
token.extend_from_slice(&checksum.to_be_bytes());
let token_str = CUSTOM_ENGINE.encode(&token);
let final_token = format!("{}{}", self.prefix, token_str);
let mut hasher = tiny_keccak::Keccak::v256();
hasher.update(final_token.as_bytes());
let mut output = [0u8; 32];
hasher.finalize(&mut output);
GeneratedApiToken {
plaintext: final_token,
token: CUSTOM_ENGINE.encode(output),
service_id,
expires_at: if expires_at != 0 {
Some(expires_at)
} else {
None
},
additional_headers,
}
}
}
impl GeneratedApiToken {
pub fn plaintext(&self, id: u64) -> String {
format!("{}|{}", id, self.plaintext)
}
pub fn token(&self) -> &str {
&self.token
}
pub fn expires_at(&self) -> Option<u64> {
self.expires_at
}
pub fn additional_headers(&self) -> &BTreeMap<String, String> {
&self.additional_headers
}
}
#[derive(Debug, Clone)]
pub struct ApiToken(pub u64, pub String);
#[derive(Debug, Clone, thiserror::Error)]
pub enum ParseApiTokenError {
#[error("Malformed token; expected format is `id|token`")]
MalformedToken,
#[error("Invalid token ID; expected a number")]
InvalidTokenId,
}
impl ApiToken {
fn new(id: u64, token: impl Into<String>) -> Self {
ApiToken(id, token.into())
}
pub(crate) fn from_str(s: &str) -> Result<ApiToken, ParseApiTokenError> {
if s.len() > 512 {
return Err(ParseApiTokenError::MalformedToken);
}
let separator_count = s.matches('|').count();
if separator_count != 1 {
return Err(ParseApiTokenError::MalformedToken);
}
let mut parts = s.splitn(2, '|');
let id_part = parts.next().ok_or(ParseApiTokenError::MalformedToken)?;
if id_part.is_empty() {
return Err(ParseApiTokenError::InvalidTokenId);
}
let id = id_part
.parse::<u64>()
.map_err(|_| ParseApiTokenError::InvalidTokenId)?;
let token_part = parts.next().ok_or(ParseApiTokenError::MalformedToken)?;
if token_part.is_empty() {
return Err(ParseApiTokenError::MalformedToken);
}
Ok(ApiToken::new(id, token_part))
}
}
impl Display for ApiToken {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "{}|{}", self.0, self.1)
}
}
impl<S> axum::extract::FromRequestParts<S> for ApiToken
where
S: Send + Sync,
{
type Rejection = axum::response::Response;
async fn from_request_parts(
parts: &mut axum::http::request::Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
use axum::http::StatusCode;
use axum::response::IntoResponse;
let header = match parts.headers.get(crate::types::headers::AUTHORIZATION) {
Some(header) => header,
None => {
return Err(
(StatusCode::UNAUTHORIZED, "Missing Authorization header").into_response()
);
}
};
let header_str = match header.to_str() {
Ok(header_str) if header_str.starts_with("Bearer ") => &header_str[7..],
Ok(anything) => {
return Err((
StatusCode::BAD_REQUEST,
format!(
"Invalid Authorization header; expected Bearer <api_token>, got {anything}"
),
)
.into_response());
}
Err(_) => {
return Err((
StatusCode::BAD_REQUEST,
"Invalid Authorization header; not a valid UTF-8 string",
)
.into_response());
}
};
match ApiToken::from_str(header_str) {
Ok(token) => Ok(token),
Err(e) => {
Err((StatusCode::BAD_REQUEST, format!("Invalid API Token: {e}",)).into_response())
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::ServiceId;
use axum::extract::FromRequestParts;
use axum::http::{Request, header::AUTHORIZATION};
use std::time::{SystemTime, UNIX_EPOCH};
#[test]
fn test_api_token_generator_new() {
let generator = ApiTokenGenerator::new();
let token =
generator.generate_token(ServiceId::new(1), &mut blueprint_std::BlueprintRng::new());
assert!(!token.token.is_empty());
}
#[test]
fn test_api_token_generator_with_prefix() {
let prefix = "test-prefix-";
let generator = ApiTokenGenerator::with_prefix(prefix);
let mut rng = blueprint_std::BlueprintRng::new();
let token1 = generator.generate_token(ServiceId::new(1), &mut rng);
let plain_generator = ApiTokenGenerator::new();
let token2 = plain_generator.generate_token(ServiceId::new(1), &mut rng);
assert_ne!(token1.token, token2.token);
}
#[test]
fn test_token_expiration() {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let expiry = now + 3600;
let generator = ApiTokenGenerator::new();
let mut rng = blueprint_std::BlueprintRng::new();
let token_with_expiry =
generator.generate_token_with_expiration(ServiceId::new(1), expiry, &mut rng);
let token_without_expiry = generator.generate_token(ServiceId::new(1), &mut rng);
assert_eq!(token_with_expiry.expires_at(), Some(expiry));
assert_eq!(token_without_expiry.expires_at(), None);
}
#[test]
fn test_plaintext_token() {
let generator = ApiTokenGenerator::new();
let mut rng = blueprint_std::BlueprintRng::new();
let token = generator.generate_token(ServiceId::new(1), &mut rng);
let id = 42;
let plaintext = token.plaintext(id);
assert!(plaintext.starts_with(&format!("{id}|")));
assert!(plaintext.len() > 3); }
#[test]
fn test_api_token_display() {
let token = ApiToken(123, "test-token".to_string());
assert_eq!(token.to_string(), "123|test-token");
}
#[tokio::test]
async fn test_api_token_from_request() {
let req = Request::builder()
.header(
AUTHORIZATION,
"Bearer 123|RmFrZVRva2VuVGhhdElzQmFzZTY0RW5jb2RlZA",
)
.body(())
.unwrap();
let (mut parts, _) = req.into_parts();
let result: Result<ApiToken, _> = ApiToken::from_request_parts(&mut parts, &()).await;
assert!(result.is_ok());
let token = result.unwrap();
assert_eq!(token.0, 123);
assert_eq!(token.1, "RmFrZVRva2VuVGhhdElzQmFzZTY0RW5jb2RlZA");
let req = Request::builder().body(()).unwrap();
let (mut parts, _) = req.into_parts();
let result: Result<ApiToken, _> = ApiToken::from_request_parts(&mut parts, &()).await;
assert!(result.is_err());
let req = Request::builder()
.header(AUTHORIZATION, "Basic 123:password")
.body(())
.unwrap();
let (mut parts, _) = req.into_parts();
let result: Result<ApiToken, _> = ApiToken::from_request_parts(&mut parts, &()).await;
assert!(result.is_err());
}
#[test]
fn test_base64_custom_engine() {
let input = b"This is a test string for base64 encoding";
let encoded = base64::Engine::encode(&CUSTOM_ENGINE, input);
let decoded = base64::Engine::decode(&CUSTOM_ENGINE, &encoded).unwrap();
assert_eq!(decoded, input);
assert!(!encoded.contains('+'));
assert!(!encoded.contains('/'));
assert!(!encoded.contains('='));
}
#[test]
fn test_token_generation_with_headers() {
use std::collections::BTreeMap;
let generator = ApiTokenGenerator::new();
let mut rng = blueprint_std::BlueprintRng::new();
let service_id = ServiceId::new(1);
let mut headers = BTreeMap::new();
headers.insert("X-Tenant-Id".to_string(), "tenant123".to_string());
headers.insert("X-User-Type".to_string(), "premium".to_string());
let token = generator.generate_token_with_expiration_and_headers(
service_id,
0,
headers.clone(),
&mut rng,
);
assert_eq!(token.additional_headers(), &headers);
assert!(!token.token.is_empty());
}
#[test]
fn test_token_generation_without_headers() {
let generator = ApiTokenGenerator::new();
let mut rng = blueprint_std::BlueprintRng::new();
let service_id = ServiceId::new(1);
let token = generator.generate_token(service_id, &mut rng);
assert!(token.additional_headers().is_empty());
}
}