use base64ct::{Base64UrlUnpadded, Encoding};
use moka::sync::Cache;
use rand::RngCore;
use sha2::{Digest, Sha256};
use std::sync::Arc;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct OAuth2State {
pub code_verifier: String,
pub redirect_url: Option<String>,
pub project_name: Option<String>,
pub custom_domain_base_url: Option<String>,
}
#[derive(Debug, Clone)]
pub struct CompletedAuthSession {
pub rise_jwt: String,
pub max_age: u64,
pub redirect_url: String,
pub project_name: String,
}
pub trait TokenStore: Send + Sync {
fn save(&self, state: String, data: OAuth2State);
fn get(&self, state: &str) -> Option<OAuth2State>;
fn save_completed_session(&self, token: String, session: CompletedAuthSession);
fn get_completed_session(&self, token: &str) -> Option<CompletedAuthSession>;
}
pub struct InMemoryTokenStore {
cache: Arc<Cache<String, OAuth2State>>,
completed_sessions: Arc<Cache<String, CompletedAuthSession>>,
}
impl InMemoryTokenStore {
pub fn new(ttl: Duration) -> Self {
let cache = Cache::builder()
.time_to_live(ttl)
.max_capacity(10_000) .build();
let completed_sessions = Cache::builder()
.time_to_live(Duration::from_secs(300))
.max_capacity(10_000)
.build();
Self {
cache: Arc::new(cache),
completed_sessions: Arc::new(completed_sessions),
}
}
}
impl TokenStore for InMemoryTokenStore {
fn save(&self, state: String, data: OAuth2State) {
self.cache.insert(state, data);
}
fn get(&self, state: &str) -> Option<OAuth2State> {
self.cache.get(state)
}
fn save_completed_session(&self, token: String, session: CompletedAuthSession) {
self.completed_sessions.insert(token, session);
}
fn get_completed_session(&self, token: &str) -> Option<CompletedAuthSession> {
self.completed_sessions.remove(token)
}
}
pub fn generate_code_verifier() -> String {
let mut random_bytes = [0u8; 48];
rand::thread_rng().fill_bytes(&mut random_bytes);
Base64UrlUnpadded::encode_string(&random_bytes)
}
pub fn generate_code_challenge(verifier: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let hash = hasher.finalize();
Base64UrlUnpadded::encode_string(&hash)
}
pub fn generate_state_token() -> String {
let mut random_bytes = [0u8; 32];
rand::thread_rng().fill_bytes(&mut random_bytes);
Base64UrlUnpadded::encode_string(&random_bytes)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_code_verifier_length() {
let verifier = generate_code_verifier();
assert_eq!(verifier.len(), 64);
assert!(verifier
.chars()
.all(|c| c.is_alphanumeric() || c == '-' || c == '_'));
}
#[test]
fn test_code_challenge_deterministic() {
let verifier = "test_verifier_123";
let challenge1 = generate_code_challenge(verifier);
let challenge2 = generate_code_challenge(verifier);
assert_eq!(challenge1, challenge2);
}
#[test]
fn test_code_challenge_unique() {
let verifier1 = "verifier1";
let verifier2 = "verifier2";
let challenge1 = generate_code_challenge(verifier1);
let challenge2 = generate_code_challenge(verifier2);
assert_ne!(challenge1, challenge2);
}
#[test]
fn test_state_token_length() {
let state = generate_state_token();
assert_eq!(state.len(), 43);
}
#[test]
fn test_state_token_randomness() {
let state1 = generate_state_token();
let state2 = generate_state_token();
assert_ne!(state1, state2);
}
#[test]
fn test_token_store_save_and_get() {
let store = InMemoryTokenStore::new(Duration::from_secs(60));
let state = "test_state";
let data = OAuth2State {
code_verifier: "test_verifier".to_string(),
redirect_url: Some("https://example.com".to_string()),
project_name: None,
custom_domain_base_url: None,
};
store.save(state.to_string(), data.clone());
let retrieved = store.get(state);
assert!(retrieved.is_some());
let retrieved = retrieved.unwrap();
assert_eq!(retrieved.code_verifier, data.code_verifier);
assert_eq!(retrieved.redirect_url, data.redirect_url);
}
#[test]
fn test_token_store_get_nonexistent() {
let store = InMemoryTokenStore::new(Duration::from_secs(60));
let retrieved = store.get("nonexistent");
assert!(retrieved.is_none());
}
#[test]
fn test_token_store_ttl() {
let store = InMemoryTokenStore::new(Duration::from_millis(100));
let state = "test_state";
let data = OAuth2State {
code_verifier: "test_verifier".to_string(),
redirect_url: None,
project_name: None,
custom_domain_base_url: None,
};
store.save(state.to_string(), data);
assert!(store.get(state).is_some());
std::thread::sleep(Duration::from_millis(150));
assert!(store.get(state).is_none());
}
}