use super::*;
use crate::storage::CacheData;
use http::HeaderValue;
async fn store_token_in_cache(
token_type: &str,
token: &str,
ttl: u64,
expires_at: DateTime<Utc>,
user_agent: Option<String>,
) -> Result<String, OAuth2Error> {
let stored_token = StoredToken {
token: token.to_string(),
expires_at,
user_agent,
ttl,
};
let cache_prefix =
CachePrefix::new(token_type.to_string()).map_err(OAuth2Error::convert_storage_error)?;
Ok(
store_cache_auto::<_, OAuth2Error>(cache_prefix, stored_token, ttl)
.await?
.as_str()
.to_string(),
)
}
#[test]
fn test_encode_decode_state() {
let state_params = StateParams {
csrf_id: "csrf123".to_string(),
nonce_id: "nonce456".to_string(),
pkce_id: "pkce789".to_string(),
misc_id: Some("misc123".to_string()),
mode_id: Some("mode456".to_string()),
provider: "google".to_string(),
};
let encoded = encode_state(state_params).unwrap();
assert!(!encoded.contains('+'));
assert!(!encoded.contains('/'));
assert!(!encoded.contains('='));
let decoded = decode_state(&encoded).unwrap();
assert_eq!(decoded.csrf_id, "csrf123");
assert_eq!(decoded.nonce_id, "nonce456");
assert_eq!(decoded.pkce_id, "pkce789");
assert_eq!(decoded.misc_id, Some("misc123".to_string()));
assert_eq!(decoded.mode_id, Some("mode456".to_string()));
assert_eq!(decoded.provider, "google");
}
#[test]
fn test_encode_decode_state_minimal() {
let state_params = StateParams {
csrf_id: "csrf123".to_string(),
nonce_id: "nonce456".to_string(),
pkce_id: "pkce789".to_string(),
misc_id: None,
mode_id: None,
provider: "google".to_string(),
};
let encoded = encode_state(state_params).unwrap();
let decoded = decode_state(&encoded).unwrap();
assert_eq!(decoded.csrf_id, "csrf123");
assert_eq!(decoded.nonce_id, "nonce456");
assert_eq!(decoded.pkce_id, "pkce789");
assert_eq!(decoded.misc_id, None);
assert_eq!(decoded.mode_id, None);
}
#[test]
fn test_oauth2_state_validation_invalid_base64() {
let result = crate::OAuth2State::new("this is not base64!!!".to_string());
assert!(result.is_err());
match result {
Err(OAuth2Error::DecodeState(_)) => {}
Ok(_) => {
unreachable!("Expected DecodeState error but got Ok");
}
Err(err) => {
unreachable!("Expected DecodeState error, got {:?}", err);
}
}
}
#[test]
fn test_oauth2_state_validation_invalid_json() {
let invalid_json = "not valid json";
let encoded = URL_SAFE_NO_PAD.encode(invalid_json);
let result = crate::OAuth2State::new(encoded);
assert!(result.is_err());
match result {
Err(OAuth2Error::DecodeState(_)) => {}
Ok(_) => {
unreachable!("Expected DecodeState error but got Ok");
}
Err(err) => {
unreachable!("Expected DecodeState error, got {:?}", err);
}
}
}
#[tokio::test]
async fn test_validate_origin_success() {
let mut headers = HeaderMap::new();
headers.insert("Origin", HeaderValue::from_static("https://example.com"));
let result = validate_origin(&headers, "https://example.com/oauth2/callback", &[]).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_validate_origin_with_referer() {
let mut headers = HeaderMap::new();
headers.insert(
"Referer",
HeaderValue::from_static("https://example.com/login"),
);
let result = validate_origin(&headers, "https://example.com/oauth2/callback", &[]).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_validate_origin_null_with_referer() {
let mut headers = HeaderMap::new();
headers.insert("Origin", HeaderValue::from_static("null"));
headers.insert(
"Referer",
HeaderValue::from_static("https://example.com/login"),
);
let result = validate_origin(&headers, "https://example.com/oauth2/callback", &[]).await;
assert!(
result.is_ok(),
"Origin: null should fall back to matching Referer"
);
}
#[tokio::test]
async fn test_validate_origin_null_without_referer() {
let mut headers = HeaderMap::new();
headers.insert("Origin", HeaderValue::from_static("null"));
let result = validate_origin(&headers, "https://example.com/oauth2/callback", &[]).await;
assert!(result.is_err());
match result {
Err(OAuth2Error::InvalidOrigin(_)) => {}
Ok(_) => unreachable!("Expected InvalidOrigin error but got Ok"),
Err(err) => unreachable!("Expected InvalidOrigin error, got {:?}", err),
}
}
#[tokio::test]
async fn test_validate_origin_rejects_subdomain_confusion() {
let mut headers = HeaderMap::new();
headers.insert(
"Origin",
HeaderValue::from_static("https://example.com.attacker.example"),
);
let result = validate_origin(&headers, "https://example.com/oauth2/callback", &[]).await;
assert!(
result.is_err(),
"subdomain-confusion Origin must be rejected"
);
}
#[tokio::test]
async fn test_validate_origin_rejects_subdomain_confusion_via_referer() {
let mut headers = HeaderMap::new();
headers.insert(
"Referer",
HeaderValue::from_static("https://example.com.attacker.example/path"),
);
let result = validate_origin(&headers, "https://example.com/oauth2/callback", &[]).await;
assert!(
result.is_err(),
"subdomain-confusion Referer must be rejected"
);
}
#[tokio::test]
async fn test_validate_origin_rejects_subdomain_confusion_in_additional() {
let mut headers = HeaderMap::new();
headers.insert(
"Origin",
HeaderValue::from_static("https://login.live.com.attacker.example"),
);
let result = validate_origin(
&headers,
"https://example.com/oauth2/callback",
&["https://login.live.com".to_string()],
)
.await;
assert!(
result.is_err(),
"subdomain confusion against additional_allowed_origins must be rejected"
);
}
#[tokio::test]
async fn test_validate_origin_case_insensitive_host() {
let mut headers = HeaderMap::new();
headers.insert("Origin", HeaderValue::from_static("https://EXAMPLE.com"));
let result = validate_origin(&headers, "https://example.com/oauth2/callback", &[]).await;
assert!(result.is_ok(), "case-insensitive host must match");
}
#[tokio::test]
async fn test_validate_origin_default_port_normalization() {
let mut headers = HeaderMap::new();
headers.insert(
"Origin",
HeaderValue::from_static("https://example.com:443"),
);
let result = validate_origin(&headers, "https://example.com/oauth2/callback", &[]).await;
assert!(
result.is_ok(),
"explicit default port :443 must match implicit form"
);
}
#[tokio::test]
async fn test_validate_origin_different_port_rejected() {
let mut headers = HeaderMap::new();
headers.insert(
"Origin",
HeaderValue::from_static("https://example.com:8443"),
);
let result = validate_origin(&headers, "https://example.com/oauth2/callback", &[]).await;
assert!(result.is_err(), "different port must not match");
}
#[tokio::test]
async fn test_validate_origin_invalid_url_rejected() {
let mut headers = HeaderMap::new();
headers.insert("Referer", HeaderValue::from_static("not-a-url"));
let result = validate_origin(&headers, "https://example.com/oauth2/callback", &[]).await;
assert!(result.is_err(), "unparseable candidate must be rejected");
}
#[tokio::test]
async fn test_validate_origin_unparseable_additional_origin_dropped() {
let mut headers = HeaderMap::new();
headers.insert("Origin", HeaderValue::from_static("https://login.live.com"));
let result = validate_origin(
&headers,
"https://example.com/oauth2/callback",
&[
"not-a-valid-url".to_string(),
"https://login.live.com".to_string(),
],
)
.await;
assert!(
result.is_ok(),
"valid sibling allowed origin must still match when an unparseable entry is present"
);
let mut headers2 = HeaderMap::new();
headers2.insert(
"Origin",
HeaderValue::from_static("https://attacker.example"),
);
let result2 = validate_origin(
&headers2,
"https://example.com/oauth2/callback",
&["not-a-valid-url".to_string()],
)
.await;
assert!(
result2.is_err(),
"unparseable additional_allowed_origins entry must not authorize anything"
);
}
#[tokio::test]
async fn test_validate_origin_mismatch() {
let mut headers = HeaderMap::new();
headers.insert("Origin", HeaderValue::from_static("https://attacker.com"));
let result = validate_origin(&headers, "https://example.com/oauth2/callback", &[]).await;
assert!(result.is_err());
match result {
Err(OAuth2Error::InvalidOrigin(_)) => {}
Ok(_) => {
unreachable!("Expected InvalidOrigin error but got Ok");
}
Err(err) => {
unreachable!("Expected InvalidOrigin error, got {:?}", err);
}
}
}
#[tokio::test]
async fn test_validate_origin_missing() {
let headers = HeaderMap::new();
let result = validate_origin(&headers, "https://example.com/oauth2/callback", &[]).await;
assert!(result.is_err());
match result {
Err(OAuth2Error::InvalidOrigin(_)) => {}
Ok(_) => {
unreachable!("Expected InvalidOrigin error but got Ok");
}
Err(err) => {
unreachable!("Expected InvalidOrigin error, got {:?}", err);
}
}
}
#[tokio::test]
async fn test_validate_origin_additional_allowed() {
let mut headers = HeaderMap::new();
headers.insert("Origin", HeaderValue::from_static("null"));
headers.insert(
"Referer",
HeaderValue::from_static("https://login.live.com/oauth20_authorize.srf"),
);
let allowed = vec!["https://login.live.com".to_string()];
let result = validate_origin(
&headers,
"https://login.microsoftonline.com/common/oauth2/v2.0/token",
&allowed,
)
.await;
assert!(
result.is_ok(),
"Referer on an additional_allowed_origins host should be accepted"
);
}
#[tokio::test]
async fn test_store_and_get_token_from_cache() {
use crate::test_utils::init_test_environment;
use chrono::{Duration, Utc};
init_test_environment().await;
let token_type = "test_token";
let token = "test_token_value_12345";
let ttl = 300; let expires_at = Utc::now() + Duration::seconds(ttl as i64);
let user_agent = Some("Test User Agent".to_string());
let result = store_token_in_cache(token_type, token, ttl, expires_at, user_agent.clone()).await;
assert!(result.is_ok(), "Should successfully store token");
let token_id = result.unwrap();
assert!(!token_id.is_empty(), "Token ID should not be empty");
assert_eq!(
token_id.len(),
43,
"Token ID should be 43 characters long (32 bytes base64url encoded)"
);
let cache_prefix = CachePrefix::new(token_type.to_string()).unwrap();
let cache_key = CacheKey::new(token_id.clone()).unwrap();
let retrieved_result = get_data::<StoredToken, OAuth2Error>(cache_prefix, cache_key)
.await
.and_then(|opt| {
opt.ok_or_else(|| {
OAuth2Error::SecurityTokenNotFound("test_type-session not found".to_string())
})
});
assert!(
retrieved_result.is_ok(),
"Should successfully retrieve token"
);
let stored_token = retrieved_result.unwrap();
assert_eq!(stored_token.token, token);
assert_eq!(stored_token.user_agent, user_agent);
assert_eq!(stored_token.ttl, ttl);
let time_diff = (stored_token.expires_at - expires_at).num_seconds();
assert!(time_diff.abs() < 1, "Expiration time should be preserved");
}
#[tokio::test]
async fn test_get_token_from_store_not_found() {
use crate::test_utils::init_test_environment;
init_test_environment().await;
let cache_prefix = CachePrefix::new("test_type".to_string()).unwrap();
let cache_key = CacheKey::new("nonexistent_id".to_string()).unwrap();
let result = get_data::<StoredToken, OAuth2Error>(cache_prefix, cache_key)
.await
.and_then(|opt| {
opt.ok_or_else(|| {
OAuth2Error::SecurityTokenNotFound("test-session not found".to_string())
})
});
assert!(result.is_err());
match result {
Err(OAuth2Error::SecurityTokenNotFound(msg)) => {
assert!(msg.contains("test-session not found"));
}
Ok(_) => {
unreachable!("Expected SecurityTokenNotFound error but got Ok");
}
Err(err) => {
unreachable!("Expected SecurityTokenNotFound error, got {:?}", err);
}
}
}
#[tokio::test]
async fn test_remove_token_from_store() {
use crate::test_utils::init_test_environment;
init_test_environment().await;
let token_type = "test_remove";
let token_value = "test_token_value";
let ttl = 300; let expires_at = Utc::now() + chrono::Duration::seconds(ttl as i64);
let user_agent = Some("test-agent".to_string());
let token_id =
store_token_in_cache(token_type, token_value, ttl, expires_at, user_agent.clone())
.await
.unwrap();
let cache_prefix = CachePrefix::new(token_type.to_string()).unwrap();
let cache_key = CacheKey::new(token_id.clone()).unwrap();
let stored_token = get_data::<StoredToken, OAuth2Error>(cache_prefix, cache_key)
.await
.unwrap()
.unwrap();
assert_eq!(stored_token.token, token_value);
let cache_prefix = CachePrefix::new(token_type.to_string()).unwrap();
let cache_key = CacheKey::new(token_id.clone()).unwrap();
let result = remove_data::<OAuth2Error>(cache_prefix, cache_key).await;
assert!(result.is_ok());
let cache_prefix = CachePrefix::new(token_type.to_string()).unwrap();
let cache_key = CacheKey::new(token_id.clone()).unwrap();
let get_result = get_data::<StoredToken, OAuth2Error>(cache_prefix, cache_key)
.await
.and_then(|opt| {
opt.ok_or_else(|| {
OAuth2Error::SecurityTokenNotFound("test-session not found".to_string())
})
});
assert!(get_result.is_err());
match get_result {
Err(OAuth2Error::SecurityTokenNotFound(_)) => {}
Ok(_) => {
unreachable!("Expected SecurityTokenNotFound error after removal but got Ok");
}
Err(err) => {
unreachable!(
"Expected SecurityTokenNotFound error after removal, got {:?}",
err
);
}
}
}
#[tokio::test]
async fn test_generate_store_token() {
use crate::test_utils::init_test_environment;
init_test_environment().await;
let token_type = TokenType::Csrf;
let ttl = 600; let expires_at = Utc::now() + chrono::Duration::seconds(ttl as i64);
let user_agent = Some("test-generate-agent".to_string());
let result = generate_store_token(token_type, ttl, expires_at, user_agent.clone()).await;
assert!(result.is_ok());
let (token, token_id) = result.unwrap();
assert_eq!(
token.len(),
43,
"Generated token should be 43 characters long"
);
assert_eq!(
token_id.len(),
43,
"Generated token_id should be 43 characters long"
);
assert_ne!(token, token_id, "Token and token_id should be different");
let cache_prefix = CachePrefix::new(token_type.to_string()).unwrap();
let cache_key = CacheKey::new(token_id.clone()).unwrap();
let stored_token = get_data::<StoredToken, OAuth2Error>(cache_prefix, cache_key)
.await
.unwrap()
.unwrap();
assert_eq!(stored_token.token, token);
assert_eq!(stored_token.user_agent, user_agent);
assert_eq!(stored_token.ttl, ttl);
let time_diff = (stored_token.expires_at - expires_at).num_seconds();
assert!(time_diff.abs() < 1, "Expiration time should be preserved");
}
#[tokio::test]
async fn test_generate_store_token_randomness() {
use crate::test_utils::init_test_environment;
init_test_environment().await;
let token_type = TokenType::Nonce;
let ttl = 300;
let expires_at = Utc::now() + chrono::Duration::seconds(ttl as i64);
let user_agent = None;
let (token1, token_id1) = generate_store_token(token_type, ttl, expires_at, user_agent.clone())
.await
.unwrap();
let (token2, token_id2) = generate_store_token(token_type, ttl, expires_at, user_agent.clone())
.await
.unwrap();
assert_ne!(token1, token2, "Generated tokens should be different");
assert_ne!(
token_id1, token_id2,
"Generated token IDs should be different"
);
let cache_prefix1 = CachePrefix::new(token_type.to_string()).unwrap();
let cache_key1 = CacheKey::new(token_id1.clone()).unwrap();
let cache_prefix2 = CachePrefix::new(token_type.to_string()).unwrap();
let cache_key2 = CacheKey::new(token_id2.clone()).unwrap();
let stored_token1 = get_data::<StoredToken, OAuth2Error>(cache_prefix1, cache_key1)
.await
.unwrap()
.unwrap();
let stored_token2 = get_data::<StoredToken, OAuth2Error>(cache_prefix2, cache_key2)
.await
.unwrap()
.unwrap();
assert_eq!(stored_token1.token, token1);
assert_eq!(stored_token2.token, token2);
}
#[tokio::test]
async fn test_get_uid_from_stored_session_no_misc_id() {
use crate::test_utils::init_test_environment;
init_test_environment().await;
let state_params = StateParams {
csrf_id: "csrf123".to_string(),
nonce_id: "nonce456".to_string(),
pkce_id: "pkce789".to_string(),
misc_id: None,
mode_id: None,
provider: "google".to_string(),
};
let result = get_uid_from_stored_session_by_state_param(&state_params).await;
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[tokio::test]
async fn test_get_uid_from_stored_session_token_not_found() {
use crate::test_utils::init_test_environment;
init_test_environment().await;
let state_params = StateParams {
csrf_id: "csrf123".to_string(),
nonce_id: "nonce456".to_string(),
pkce_id: "pkce789".to_string(),
misc_id: Some("nonexistent_misc_id".to_string()),
mode_id: None,
provider: "google".to_string(),
};
let result = get_uid_from_stored_session_by_state_param(&state_params).await;
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[tokio::test]
async fn test_delete_session_and_misc_token_no_misc_id() {
use crate::test_utils::init_test_environment;
init_test_environment().await;
let state_params = StateParams {
csrf_id: "csrf123".to_string(),
nonce_id: "nonce456".to_string(),
pkce_id: "pkce789".to_string(),
misc_id: None,
mode_id: None,
provider: "google".to_string(),
};
let result = delete_session_and_misc_token_from_store(&state_params).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_delete_session_and_misc_token_token_not_found() {
use crate::test_utils::init_test_environment;
init_test_environment().await;
let state_params = StateParams {
csrf_id: "csrf123".to_string(),
nonce_id: "nonce456".to_string(),
pkce_id: "pkce789".to_string(),
misc_id: Some("nonexistent_misc_id".to_string()),
mode_id: None,
provider: "google".to_string(),
};
let result = delete_session_and_misc_token_from_store(&state_params).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_get_mode_from_stored_session_not_found() {
use crate::test_utils::init_test_environment;
init_test_environment().await;
let result = get_mode_from_stored_session("nonexistent_mode_id").await;
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[tokio::test]
async fn test_get_mode_from_stored_session_valid_mode() {
use crate::test_utils::init_test_environment;
init_test_environment().await;
let mode_type = "mode";
let mode = OAuth2Mode::Login;
let mode_value = mode.as_str(); let ttl = 300;
let expires_at = Utc::now() + chrono::Duration::seconds(ttl as i64);
let user_agent = None;
let mode_id = store_token_in_cache(mode_type, mode_value, ttl, expires_at, user_agent)
.await
.unwrap();
let result = get_mode_from_stored_session(&mode_id).await;
assert!(result.is_ok());
let retrieved_mode = result.unwrap();
assert!(retrieved_mode.is_some());
assert_eq!(retrieved_mode.unwrap(), mode);
}
#[tokio::test]
async fn test_get_mode_from_stored_session_invalid_mode() {
use crate::test_utils::init_test_environment;
init_test_environment().await;
let mode_type = "mode";
let invalid_mode_value = "invalid_mode_value"; let ttl = 300;
let expires_at = Utc::now() + chrono::Duration::seconds(ttl as i64);
let user_agent = None;
let mode_id = store_token_in_cache(mode_type, invalid_mode_value, ttl, expires_at, user_agent)
.await
.unwrap();
let result = get_mode_from_stored_session(&mode_id).await;
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
#[tokio::test]
async fn test_cache_token_with_zero_ttl() {
use crate::test_utils::init_test_environment;
init_test_environment().await;
let token_type = "test_zero_ttl";
let token = "test_token_zero_ttl";
let ttl = 0; let expires_at = Utc::now(); let user_agent = Some("test-agent".to_string());
let result = store_token_in_cache(token_type, token, ttl, expires_at, user_agent.clone()).await;
assert!(
result.is_ok(),
"Should successfully store token with zero TTL"
);
let token_id = result.unwrap();
let cache_prefix = CachePrefix::new(token_type.to_string()).unwrap();
let cache_key = CacheKey::new(token_id.clone()).unwrap();
let stored_token = get_data::<StoredToken, OAuth2Error>(cache_prefix, cache_key).await;
assert!(
stored_token.is_ok(),
"Should be able to retrieve token with zero TTL"
);
let token_data = stored_token.unwrap().unwrap();
assert_eq!(token_data.ttl, 0);
assert_eq!(token_data.token, token);
}
#[tokio::test]
async fn test_cache_token_with_max_ttl() {
use crate::test_utils::init_test_environment;
init_test_environment().await;
let token_type = "test_max_ttl";
let token = "test_token_max_ttl";
let ttl = 31_536_000_u64; let expires_at = Utc::now() + chrono::Duration::seconds(ttl as i64);
let user_agent = None;
let result = store_token_in_cache(token_type, token, ttl, expires_at, user_agent).await;
assert!(result.is_ok(), "Should handle realistic large TTL values");
let token_id = result.unwrap();
let cache_prefix = CachePrefix::new(token_type.to_string()).unwrap();
let cache_key = CacheKey::new(token_id.clone()).unwrap();
let stored_token = get_data::<StoredToken, OAuth2Error>(cache_prefix, cache_key).await;
assert!(stored_token.is_ok(), "Should retrieve token with large TTL");
assert_eq!(stored_token.unwrap().unwrap().ttl, ttl);
}
#[tokio::test]
async fn test_concurrent_token_operations() {
use crate::test_utils::init_test_environment;
init_test_environment().await;
let token_type = "test_concurrent";
let ttl = 300;
let expires_at = Utc::now() + chrono::Duration::seconds(ttl as i64);
let handles = (0..10)
.map(|i| {
let user_agent = Some(format!("agent-{i}"));
tokio::spawn(async move {
store_token_in_cache(
token_type,
&format!("token-{i}"),
ttl,
expires_at,
user_agent,
)
.await
})
})
.collect::<Vec<_>>();
let mut results = Vec::new();
for handle in handles {
results.push(handle.await);
}
let mut token_ids = Vec::new();
for result in results {
let token_id = result.unwrap().unwrap();
token_ids.push(token_id);
}
for (i, token_id) in token_ids.iter().enumerate() {
let cache_prefix = CachePrefix::new(token_type.to_string()).unwrap();
let cache_key = CacheKey::new(token_id.clone()).unwrap();
let stored_token = get_data::<StoredToken, OAuth2Error>(cache_prefix, cache_key).await;
assert!(stored_token.is_ok());
let token_data = stored_token.unwrap().unwrap();
assert_eq!(token_data.token, format!("token-{i}"));
assert_eq!(token_data.user_agent, Some(format!("agent-{i}")));
}
let unique_count = token_ids
.iter()
.collect::<std::collections::HashSet<_>>()
.len();
assert_eq!(unique_count, 10, "All token IDs should be unique");
}
#[tokio::test]
async fn test_token_storage_with_different_prefixes() {
use crate::test_utils::init_test_environment;
init_test_environment().await;
let ttl = 300;
let expires_at = Utc::now() + chrono::Duration::seconds(ttl as i64);
let user_agent = Some("test-agent".to_string());
let token_prefixes = ["csrf", "nonce", "pkce", "access", "refresh"];
let same_token_content = "same_token_content";
let mut stored_tokens = Vec::new();
for prefix in &token_prefixes {
let token_id = store_token_in_cache(
prefix,
same_token_content,
ttl,
expires_at,
user_agent.clone(),
)
.await
.unwrap();
stored_tokens.push((prefix, token_id));
}
for (prefix, token_id) in &stored_tokens {
let cache_prefix = CachePrefix::new(prefix.to_string()).unwrap();
let cache_key = CacheKey::new(token_id.clone()).unwrap();
let retrieved = get_data::<StoredToken, OAuth2Error>(cache_prefix, cache_key).await;
assert!(
retrieved.is_ok(),
"Should retrieve token for prefix: {prefix}"
);
let token_data = retrieved.unwrap().unwrap();
assert_eq!(token_data.token, same_token_content);
assert_eq!(token_data.user_agent, user_agent);
}
for (prefix1, token_id1) in &stored_tokens {
for (prefix2, _) in &stored_tokens {
if prefix1 != prefix2 {
let cache_prefix = CachePrefix::new(prefix2.to_string()).unwrap();
let cache_key = CacheKey::new(token_id1.clone()).unwrap();
let wrong_retrieval = get_data::<StoredToken, OAuth2Error>(cache_prefix, cache_key)
.await
.and_then(|opt| {
opt.ok_or_else(|| {
OAuth2Error::SecurityTokenNotFound("token not found".to_string())
})
});
assert!(
wrong_retrieval.is_err(),
"Should not retrieve token for {prefix2} with {prefix1}'s token_id"
);
}
}
}
}
#[tokio::test]
async fn test_token_storage_edge_cases() {
use crate::test_utils::init_test_environment;
init_test_environment().await;
let ttl = 300;
let expires_at = Utc::now() + chrono::Duration::seconds(ttl as i64);
let empty_token_result = store_token_in_cache("test", "", ttl, expires_at, None).await;
assert!(
empty_token_result.is_ok(),
"Should handle empty token content"
);
if let Ok(token_id) = empty_token_result {
let cache_prefix = CachePrefix::new("test".to_string()).unwrap();
let cache_key = CacheKey::new(token_id.clone()).unwrap();
let retrieved = get_data::<StoredToken, OAuth2Error>(cache_prefix, cache_key).await;
assert!(retrieved.is_ok());
assert_eq!(retrieved.unwrap().unwrap().token, "");
}
let long_token = "a".repeat(10000); let long_token_result =
store_token_in_cache("test_long", &long_token, ttl, expires_at, None).await;
assert!(
long_token_result.is_ok(),
"Should handle large token content"
);
if let Ok(token_id) = long_token_result {
let cache_prefix = CachePrefix::new("test_long".to_string()).unwrap();
let cache_key = CacheKey::new(token_id.clone()).unwrap();
let retrieved = get_data::<StoredToken, OAuth2Error>(cache_prefix, cache_key).await;
assert!(retrieved.is_ok());
assert_eq!(retrieved.unwrap().unwrap().token, long_token);
}
let special_token = "token_with_特殊字符_🔐_\n\t\r";
let special_result =
store_token_in_cache("test_special", special_token, ttl, expires_at, None).await;
assert!(
special_result.is_ok(),
"Should handle special characters in token"
);
if let Ok(token_id) = special_result {
let cache_prefix = CachePrefix::new("test_special".to_string()).unwrap();
let cache_key = CacheKey::new(token_id.clone()).unwrap();
let retrieved = get_data::<StoredToken, OAuth2Error>(cache_prefix, cache_key).await;
assert!(retrieved.is_ok());
assert_eq!(retrieved.unwrap().unwrap().token, special_token);
}
}
#[tokio::test]
async fn test_token_overwrite_same_id() {
use crate::test_utils::init_test_environment;
init_test_environment().await;
let token_type = "test_overwrite";
let ttl = 300;
let expires_at = Utc::now() + chrono::Duration::seconds(ttl as i64);
let token1 = "first_token";
let user_agent1 = Some("agent1".to_string());
let token_id1 = store_token_in_cache(token_type, token1, ttl, expires_at, user_agent1.clone())
.await
.unwrap();
let token2 = "second_token";
let user_agent2 = Some("agent2".to_string());
let token_id2 = store_token_in_cache(token_type, token2, ttl, expires_at, user_agent2.clone())
.await
.unwrap();
assert_ne!(
token_id1, token_id2,
"Different tokens should have different IDs"
);
let cache_prefix1 = CachePrefix::new(token_type.to_string()).unwrap();
let cache_key1 = CacheKey::new(token_id1.clone()).unwrap();
let cache_prefix2 = CachePrefix::new(token_type.to_string()).unwrap();
let cache_key2 = CacheKey::new(token_id2.clone()).unwrap();
let retrieved1 = get_data::<StoredToken, OAuth2Error>(cache_prefix1, cache_key1)
.await
.unwrap()
.unwrap();
let retrieved2 = get_data::<StoredToken, OAuth2Error>(cache_prefix2, cache_key2)
.await
.unwrap()
.unwrap();
assert_eq!(retrieved1.token, token1);
assert_eq!(retrieved1.user_agent, user_agent1);
assert_eq!(retrieved2.token, token2);
assert_eq!(retrieved2.user_agent, user_agent2);
}
#[tokio::test]
async fn test_multiple_remove_operations() {
use crate::test_utils::init_test_environment;
init_test_environment().await;
let token_type = "test_multiple_remove";
let ttl = 300;
let expires_at = Utc::now() + chrono::Duration::seconds(ttl as i64);
let token_id = store_token_in_cache(token_type, "test_token", ttl, expires_at, None)
.await
.unwrap();
let cache_prefix = CachePrefix::new(token_type.to_string()).unwrap();
let cache_key = CacheKey::new(token_id.clone()).unwrap();
let retrieved = get_data::<StoredToken, OAuth2Error>(cache_prefix, cache_key).await;
assert!(retrieved.is_ok());
let cache_prefix = CachePrefix::new(token_type.to_string()).unwrap();
let cache_key = CacheKey::new(token_id.clone()).unwrap();
let remove_result1 = remove_data::<OAuth2Error>(cache_prefix, cache_key).await;
assert!(remove_result1.is_ok());
let cache_prefix = CachePrefix::new(token_type.to_string()).unwrap();
let cache_key = CacheKey::new(token_id.clone()).unwrap();
let get_after_remove = get_data::<StoredToken, OAuth2Error>(cache_prefix, cache_key)
.await
.and_then(|opt| {
opt.ok_or_else(|| OAuth2Error::SecurityTokenNotFound("token not found".to_string()))
});
assert!(get_after_remove.is_err());
let cache_prefix2 = CachePrefix::new(token_type.to_string()).unwrap();
let cache_key2 = CacheKey::new(token_id.clone()).unwrap();
let remove_result2 = remove_data::<OAuth2Error>(cache_prefix2, cache_key2).await;
assert!(remove_result2.is_ok(), "Second removal should not fail");
let remove_handles = (0..5)
.map(|_| {
let token_id_clone = token_id.clone();
let token_type_clone = token_type;
tokio::spawn(async move {
let (cache_prefix, cache_key) = (
CachePrefix::new(token_type_clone.to_string()).unwrap(),
CacheKey::new(token_id_clone.clone()).unwrap(),
);
remove_data::<OAuth2Error>(cache_prefix, cache_key).await
})
})
.collect::<Vec<_>>();
let mut remove_results = Vec::new();
for handle in remove_handles {
remove_results.push(handle.await);
}
for result in remove_results {
assert!(
result.unwrap().is_ok(),
"Concurrent removals should not fail"
);
}
}
#[tokio::test]
async fn test_cache_operations_with_past_expiration() {
use crate::test_utils::init_test_environment;
init_test_environment().await;
let token_type = "test_past_expiration";
let ttl = 300;
let expires_at = Utc::now() - chrono::Duration::hours(1);
let token_id = store_token_in_cache(token_type, "expired_token", ttl, expires_at, None)
.await
.unwrap();
let cache_prefix = CachePrefix::new(token_type.to_string()).unwrap();
let cache_key = CacheKey::new(token_id.clone()).unwrap();
let retrieved = get_data::<StoredToken, OAuth2Error>(cache_prefix, cache_key).await;
assert!(retrieved.is_ok());
let token_data = retrieved.unwrap().unwrap();
assert_eq!(token_data.token, "expired_token");
assert!(token_data.expires_at < Utc::now());
}
#[tokio::test]
async fn test_cache_serialization_round_trip() {
use crate::test_utils::init_test_environment;
init_test_environment().await;
let _token_type = "test_serialization";
let ttl = 3600;
let expires_at = Utc::now() + chrono::Duration::seconds(ttl as i64);
let user_agent = Some("Mozilla/5.0 (Test) AppleWebKit/537.36".to_string());
let original_token = StoredToken {
token: "complex_token_12345!@#$%".to_string(),
expires_at,
user_agent: user_agent.clone(),
ttl,
};
let cache_data = CacheData::from(original_token.clone());
let recovered_token = StoredToken::try_from(cache_data);
assert!(recovered_token.is_ok());
let recovered = recovered_token.unwrap();
assert_eq!(recovered.token, original_token.token);
assert_eq!(
recovered.expires_at.timestamp_millis(),
original_token.expires_at.timestamp_millis()
);
assert_eq!(recovered.user_agent, original_token.user_agent);
assert_eq!(recovered.ttl, original_token.ttl);
}
#[tokio::test]
async fn test_generate_store_token_consistency() {
use crate::test_utils::init_test_environment;
init_test_environment().await;
let token_type = TokenType::Pkce;
let ttl = 600;
let expires_at = Utc::now() + chrono::Duration::seconds(ttl as i64);
let user_agent = Some("consistency-test-agent".to_string());
for _i in 0..10 {
let (token, token_id) =
generate_store_token(token_type, ttl, expires_at, user_agent.clone())
.await
.unwrap();
assert_eq!(token.len(), 43, "Generated token should be 43 characters");
assert_eq!(
token_id.len(),
43,
"Generated token ID should be 43 characters"
);
assert_ne!(token, token_id, "Token and token ID should be different");
let cache_prefix = CachePrefix::new(token_type.to_string()).unwrap();
let cache_key = CacheKey::new(token_id.clone()).unwrap();
let retrieved = get_data::<StoredToken, OAuth2Error>(cache_prefix, cache_key)
.await
.unwrap()
.unwrap();
assert_eq!(retrieved.token, token);
assert_eq!(retrieved.user_agent, user_agent);
assert_eq!(retrieved.ttl, ttl);
let time_diff = (retrieved.expires_at - expires_at).num_seconds().abs();
assert!(time_diff <= 1, "Expiration time should be consistent");
}
}