use super::circuit_breaker::CircuitBreaker;
use crate::config::AppleConfig;
use crate::errors::AppError;
use jsonwebtoken::errors::ErrorKind;
use jsonwebtoken::jwk::{Jwk, JwkSet};
use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
use serde::{Deserialize, Deserializer};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
fn deserialize_bool_or_string_opt<'de, D>(deserializer: D) -> Result<Option<bool>, D::Error>
where
D: Deserializer<'de>,
{
match Option::<serde_json::Value>::deserialize(deserializer)? {
None | Some(serde_json::Value::Null) => Ok(None),
Some(serde_json::Value::Bool(b)) => Ok(Some(b)),
Some(serde_json::Value::String(s)) => Ok(Some(s.eq_ignore_ascii_case("true"))),
Some(other) => {
tracing::warn!(value = %other, "Unexpected type for Apple email_verified claim; treating as unverified");
Ok(None)
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct AppleTokenClaims {
pub sub: String,
pub email: Option<String>,
#[serde(default, deserialize_with = "deserialize_bool_or_string_opt")]
pub email_verified: Option<bool>,
pub aud: String,
pub iss: String,
pub exp: i64,
pub real_user_status: Option<i64>,
pub nonce: Option<String>,
}
impl AppleTokenClaims {
pub fn is_email_verified(&self) -> bool {
self.email_verified == Some(true)
}
pub fn is_likely_real(&self) -> bool {
match self.real_user_status {
Some(1) => false, _ => true, }
}
}
const APPLE_API_TIMEOUT_SECS: u64 = 5;
const APPLE_JWKS_URL: &str = "https://appleid.apple.com/auth/keys";
const APPLE_JWKS_CACHE_TTL_SECS: u64 = 3600;
const APPLE_ISSUER: &str = "https://appleid.apple.com";
#[derive(Clone)]
pub struct AppleService {
#[allow(dead_code)] client_id: Option<String>,
#[allow(dead_code)] team_id: Option<String>,
http_client: reqwest::Client,
jwks_cache: Arc<RwLock<Option<JwksCache>>>,
circuit_breaker: Arc<RwLock<CircuitBreaker>>,
}
#[derive(Debug, Clone)]
struct JwksCache {
keys: Arc<JwkSet>,
expires_at: Instant,
fetched_at: Instant,
}
impl AppleService {
pub fn new(config: &AppleConfig) -> Self {
let http_client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(APPLE_API_TIMEOUT_SECS))
.build()
.unwrap_or_else(|e| {
tracing::error!(error = %e, "Failed to build Apple HTTP client; falling back to defaults");
reqwest::Client::new()
});
Self {
client_id: config.client_id.clone(),
team_id: config.team_id.clone(),
http_client,
jwks_cache: Arc::new(RwLock::new(None)),
circuit_breaker: Arc::new(RwLock::new(CircuitBreaker::new("apple_jwks"))),
}
}
async fn fetch_jwks(&self) -> Result<JwkSet, AppError> {
let response = tokio::time::timeout(
std::time::Duration::from_secs(APPLE_API_TIMEOUT_SECS),
self.http_client.get(APPLE_JWKS_URL).send(),
)
.await
.map_err(|_| {
AppError::Internal(anyhow::anyhow!(
"Failed to fetch Apple JWKS: request timed out after {}s",
APPLE_API_TIMEOUT_SECS
))
})?
.map_err(|e| AppError::Internal(anyhow::anyhow!("Failed to fetch Apple JWKS: {}", e)))?;
if !response.status().is_success() {
return Err(AppError::Internal(anyhow::anyhow!(
"Failed to fetch Apple JWKS: {}",
response.status()
)));
}
tokio::time::timeout(
std::time::Duration::from_secs(APPLE_API_TIMEOUT_SECS),
async move { response.json::<JwkSet>().await },
)
.await
.map_err(|_| {
AppError::Internal(anyhow::anyhow!(
"Failed to parse Apple JWKS: request timed out after {}s",
APPLE_API_TIMEOUT_SECS
))
})?
.map_err(|e| AppError::Internal(anyhow::anyhow!("Failed to parse Apple JWKS: {}", e)))
}
async fn get_jwks(&self) -> Result<Arc<JwkSet>, AppError> {
{
let cache = self.jwks_cache.read().await;
if let Some(cached) = cache.as_ref() {
if Instant::now() < cached.expires_at {
return Ok(Arc::clone(&cached.keys));
}
}
}
let mut cb = self.circuit_breaker.write().await;
let should_fetch = cb.should_allow_request();
let stale_cache = {
let cache = self.jwks_cache.read().await;
cache.clone()
};
if !should_fetch {
if let Some(cached) = stale_cache {
if cb.is_fallback_valid(cached.fetched_at) {
tracing::debug!(
service = "apple_jwks",
age_secs = cached.fetched_at.elapsed().as_secs(),
"Serving stale JWKS (circuit open)"
);
return Ok(Arc::clone(&cached.keys));
}
}
return Err(AppError::ServiceUnavailable(
"Apple JWKS service temporarily unavailable".into(),
));
}
drop(cb);
match self.fetch_jwks().await {
Ok(jwks) => {
let jwks = Arc::new(jwks);
let now = Instant::now();
{
let mut cache = self.jwks_cache.write().await;
*cache = Some(JwksCache {
keys: Arc::clone(&jwks),
expires_at: now + Duration::from_secs(APPLE_JWKS_CACHE_TTL_SECS),
fetched_at: now,
});
}
self.circuit_breaker.write().await.record_success();
Ok(jwks)
}
Err(e) => {
self.circuit_breaker.write().await.record_failure();
let cb = self.circuit_breaker.read().await;
if let Some(cached) = stale_cache {
if cb.is_fallback_valid(cached.fetched_at) {
tracing::warn!(
service = "apple_jwks",
error = %e,
age_secs = cached.fetched_at.elapsed().as_secs(),
"JWKS fetch failed, serving stale cache"
);
return Ok(Arc::clone(&cached.keys));
}
}
Err(e)
}
}
}
fn extract_kid(&self, id_token: &str) -> Result<String, AppError> {
let header = decode_header(id_token).map_err(|_| AppError::InvalidToken)?;
header.kid.ok_or(AppError::InvalidToken)
}
fn select_jwk<'a>(&self, jwks: &'a JwkSet, kid: &str) -> Option<&'a Jwk> {
jwks.keys
.iter()
.find(|jwk| jwk.common.key_id.as_deref() == Some(kid))
}
fn select_jwk_with_fallback<'a>(
&self,
cached: &'a JwkSet,
fresh: &'a JwkSet,
kid: &str,
) -> Option<&'a Jwk> {
self.select_jwk(cached, kid)
.or_else(|| self.select_jwk(fresh, kid))
}
pub async fn verify_id_token(
&self,
id_token: &str,
client_id: &str,
) -> Result<AppleTokenClaims, AppError> {
let kid = self.extract_kid(id_token)?;
let jwks = self.get_jwks().await?;
let decoding_key = if let Some(jwk) = self.select_jwk(&jwks, &kid) {
DecodingKey::from_jwk(jwk).map_err(|_| AppError::InvalidToken)?
} else {
let fresh = Arc::new(self.fetch_jwks().await?);
let now = Instant::now();
{
let mut cache = self.jwks_cache.write().await;
*cache = Some(JwksCache {
keys: Arc::clone(&fresh),
expires_at: now + Duration::from_secs(APPLE_JWKS_CACHE_TTL_SECS),
fetched_at: now,
});
}
self.circuit_breaker.write().await.record_success();
let jwk = self
.select_jwk_with_fallback(&jwks, &fresh, &kid)
.ok_or(AppError::InvalidToken)?;
DecodingKey::from_jwk(jwk).map_err(|_| AppError::InvalidToken)?
};
let mut validation = Validation::new(Algorithm::RS256);
validation.set_audience(&[client_id]);
validation.set_issuer(&[APPLE_ISSUER]);
let token_data =
decode::<AppleTokenClaims>(id_token, &decoding_key, &validation).map_err(|err| {
tracing::warn!(error = %err, kind = ?err.kind(), "Apple ID token verification failed");
match err.kind() {
ErrorKind::ExpiredSignature => AppError::TokenExpired,
_ => AppError::InvalidToken,
}
})?;
Ok(token_data.claims)
}
}
#[cfg(test)]
mod tests {
use super::*;
use base64::Engine as _;
#[test]
fn test_apple_service_creation() {
let config = AppleConfig {
enabled: true,
client_id: Some("com.example.service".to_string()),
team_id: Some("ABCD123456".to_string()),
..AppleConfig::default()
};
let service = AppleService::new(&config);
assert!(service.client_id.is_some());
assert!(service.team_id.is_some());
}
#[test]
fn test_apple_service_no_config() {
let config = AppleConfig {
enabled: true,
client_id: None,
team_id: None,
..AppleConfig::default()
};
let service = AppleService::new(&config);
assert!(service.client_id.is_none());
}
#[test]
fn test_apple_claims_email_verified() {
let claims = AppleTokenClaims {
sub: "001234.abc".to_string(),
email: Some("test@example.com".to_string()),
email_verified: Some(true),
aud: "com.example.app".to_string(),
iss: "https://appleid.apple.com".to_string(),
exp: 9999999999,
real_user_status: Some(2),
nonce: None,
};
assert!(claims.is_email_verified());
let claims_not_verified = AppleTokenClaims {
email_verified: Some(false),
..claims.clone()
};
assert!(!claims_not_verified.is_email_verified());
let claims_none = AppleTokenClaims {
email_verified: None,
..claims
};
assert!(!claims_none.is_email_verified());
}
#[test]
fn test_apple_claims_is_likely_real() {
let base_claims = AppleTokenClaims {
sub: "001234.abc".to_string(),
email: Some("test@example.com".to_string()),
email_verified: Some(true),
aud: "com.example.app".to_string(),
iss: "https://appleid.apple.com".to_string(),
exp: 9999999999,
real_user_status: None,
nonce: None,
};
assert!(base_claims.is_likely_real());
let claims_unsupported = AppleTokenClaims {
real_user_status: Some(0),
..base_claims.clone()
};
assert!(claims_unsupported.is_likely_real());
let claims_unknown = AppleTokenClaims {
real_user_status: Some(1),
..base_claims.clone()
};
assert!(!claims_unknown.is_likely_real());
let claims_real = AppleTokenClaims {
real_user_status: Some(2),
..base_claims
};
assert!(claims_real.is_likely_real());
}
#[test]
fn test_extract_kid_requires_header_kid() {
let service = AppleService::new(&AppleConfig {
enabled: true,
client_id: Some("client-id".to_string()),
team_id: Some("team-id".to_string()),
..AppleConfig::default()
});
let header = jsonwebtoken::Header {
alg: Algorithm::RS256,
kid: None,
..Default::default()
};
let header_json = serde_json::to_string(&header).unwrap();
let header_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(header_json);
let payload_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode("{}");
let token = format!("{}.{}.", header_b64, payload_b64);
let result = service.extract_kid(&token);
assert!(result.is_err());
}
#[test]
fn test_select_jwk_by_kid() {
let service = AppleService::new(&AppleConfig {
enabled: true,
client_id: Some("client-id".to_string()),
team_id: Some("team-id".to_string()),
..AppleConfig::default()
});
let jwks_json = r#"{
"keys": [
{
"kty": "RSA",
"kid": "test-kid",
"use": "sig",
"alg": "RS256",
"n": "AQAB",
"e": "AQAB"
}
]
}"#;
let jwks: JwkSet = serde_json::from_str(jwks_json).unwrap();
let jwk = service.select_jwk(&jwks, "test-kid");
assert!(jwk.is_some());
}
#[test]
fn test_email_verified_deserializes_from_bool() {
let json = r#"{
"sub": "001234.abc",
"email": "test@example.com",
"email_verified": true,
"aud": "com.example.app",
"iss": "https://appleid.apple.com",
"exp": 9999999999
}"#;
let claims: AppleTokenClaims = serde_json::from_str(json).unwrap();
assert_eq!(claims.email_verified, Some(true));
assert!(claims.is_email_verified());
}
#[test]
fn test_email_verified_deserializes_from_string() {
let json = r#"{
"sub": "001234.abc",
"email": "test@example.com",
"email_verified": "true",
"aud": "com.example.app",
"iss": "https://appleid.apple.com",
"exp": 9999999999
}"#;
let claims: AppleTokenClaims = serde_json::from_str(json).unwrap();
assert_eq!(claims.email_verified, Some(true));
let json_false = json.replace("\"true\"", "\"false\"");
let claims_false: AppleTokenClaims = serde_json::from_str(&json_false).unwrap();
assert_eq!(claims_false.email_verified, Some(false));
}
#[test]
fn test_email_verified_deserializes_from_missing() {
let json = r#"{
"sub": "001234.abc",
"aud": "com.example.app",
"iss": "https://appleid.apple.com",
"exp": 9999999999
}"#;
let claims: AppleTokenClaims = serde_json::from_str(json).unwrap();
assert_eq!(claims.email_verified, None);
assert!(!claims.is_email_verified());
}
#[test]
fn test_email_verified_unexpected_type_does_not_fail() {
let json = r#"{
"sub": "001234.abc",
"email_verified": 1,
"aud": "com.example.app",
"iss": "https://appleid.apple.com",
"exp": 9999999999
}"#;
let claims: AppleTokenClaims = serde_json::from_str(json).unwrap();
assert_eq!(claims.email_verified, None);
assert!(!claims.is_email_verified());
}
}