use crate::errors::{AuthError, Result};
use crate::security::secure_jwt::{SecureJwtClaims, SecureJwtValidator};
use crate::server::token_exchange::token_exchange_common::{
ServiceComplexityLevel, TokenExchangeCapabilities, TokenExchangeService, TokenValidationResult,
ValidationUtils,
};
use async_trait::async_trait;
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use chrono::{Duration, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenExchangeRequest {
pub grant_type: String,
pub subject_token: String,
pub subject_token_type: String,
pub actor_token: Option<String>,
pub actor_token_type: Option<String>,
pub requested_token_type: Option<String>,
pub audience: Option<String>,
pub scope: Option<String>,
pub resource: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenExchangeResponse {
pub access_token: String,
pub token_type: String,
pub expires_in: Option<i64>,
pub refresh_token: Option<String>,
pub scope: Option<String>,
pub issued_token_type: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum TokenType {
#[serde(rename = "urn:ietf:params:oauth:token-type:access_token")]
AccessToken,
#[serde(rename = "urn:ietf:params:oauth:token-type:refresh_token")]
RefreshToken,
#[serde(rename = "urn:ietf:params:oauth:token-type:id_token")]
IdToken,
#[serde(rename = "urn:ietf:params:oauth:token-type:saml2")]
Saml2,
#[serde(rename = "urn:ietf:params:oauth:token-type:saml1")]
Saml1,
#[serde(rename = "urn:ietf:params:oauth:token-type:jwt")]
Jwt,
}
#[derive(Debug, Clone)]
pub struct TokenExchangeContext {
pub subject_claims: SecureJwtClaims,
pub actor_claims: Option<SecureJwtClaims>,
pub client_id: String,
pub audience: Option<String>,
pub scope: Option<Vec<String>>,
pub resource: Option<String>,
}
#[derive(Debug, Clone)]
pub struct TokenExchangePolicy {
pub allowed_subject_token_types: Vec<TokenType>,
pub allowed_actor_token_types: Vec<TokenType>,
pub allowed_scenarios: Vec<ExchangeScenario>,
pub max_token_lifetime: Duration,
pub require_actor_for_delegation: bool,
pub allowed_audiences: Vec<String>,
pub scope_mapping: HashMap<String, Vec<String>>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ExchangeScenario {
ActingAs,
OnBehalfOf,
TokenConversion,
AudienceRestriction,
ScopeReduction,
}
#[derive(Debug, Clone)]
struct SamlClaims {
pub subject: String,
pub issuer: String,
pub audience: Option<String>,
pub expiry: Option<i64>,
pub not_before: Option<i64>,
pub session_id: Option<String>,
pub scopes: Vec<String>,
}
pub struct TokenExchangeManager {
jwt_validator: SecureJwtValidator,
policies: tokio::sync::RwLock<HashMap<String, TokenExchangePolicy>>,
active_exchanges: tokio::sync::RwLock<HashMap<String, TokenExchangeContext>>,
}
impl TokenExchangeManager {
const SUBJECT_TOKEN_TYPES: &'static [&'static str] = &[
"urn:ietf:params:oauth:token-type:jwt",
"urn:ietf:params:oauth:token-type:access_token",
"urn:ietf:params:oauth:token-type:id_token",
"urn:ietf:params:oauth:token-type:saml2",
];
const REQUESTED_TOKEN_TYPES: &'static [&'static str] = &[
"urn:ietf:params:oauth:token-type:jwt",
"urn:ietf:params:oauth:token-type:access_token",
"urn:ietf:params:oauth:token-type:refresh_token",
];
pub fn new(jwt_validator: SecureJwtValidator) -> Self {
Self {
jwt_validator,
policies: tokio::sync::RwLock::new(HashMap::new()),
active_exchanges: tokio::sync::RwLock::new(HashMap::new()),
}
}
pub async fn register_policy(&self, client_id: String, policy: TokenExchangePolicy) {
let mut policies = self.policies.write().await;
policies.insert(client_id, policy);
}
pub async fn exchange_token(
&self,
request: TokenExchangeRequest,
client_id: &str,
) -> Result<TokenExchangeResponse> {
if request.grant_type != "urn:ietf:params:oauth:grant-type:token-exchange" {
return Err(AuthError::auth_method(
"token_exchange",
"Invalid grant type for token exchange",
));
}
let policies = self.policies.read().await;
let policy = policies.get(client_id).ok_or_else(|| {
AuthError::auth_method("token_exchange", "No token exchange policy for client")
})?;
let subject_claims = self.validate_subject_token(&request, policy).await?;
let actor_claims = if let Some(ref actor_token) = request.actor_token {
Some(
self.validate_actor_token(actor_token, &request.actor_token_type, policy)
.await?,
)
} else {
None
};
let context = TokenExchangeContext {
subject_claims,
actor_claims,
client_id: client_id.to_string(),
audience: request.audience.clone(),
scope: request
.scope
.as_ref()
.map(|s| s.split(' ').map(String::from).collect()),
resource: request.resource.clone(),
};
let scenario = self.determine_exchange_scenario(&context, policy)?;
self.validate_exchange_scenario(&scenario, &context, policy)?;
let response = self
.generate_exchanged_token(&context, &request, policy)
.await?;
let exchange_id = uuid::Uuid::new_v4().to_string();
let mut exchanges = self.active_exchanges.write().await;
exchanges.insert(exchange_id, context);
Ok(response)
}
async fn validate_subject_token(
&self,
request: &TokenExchangeRequest,
policy: &TokenExchangePolicy,
) -> Result<SecureJwtClaims> {
let token_type = self.parse_token_type(&request.subject_token_type)?;
if !policy.allowed_subject_token_types.contains(&token_type) {
return Err(AuthError::auth_method(
"token_exchange",
"Subject token type not allowed",
));
}
match token_type {
TokenType::AccessToken
| TokenType::RefreshToken
| TokenType::IdToken
| TokenType::Jwt => {
self.validate_jwt_token(&request.subject_token).await
}
TokenType::Saml2 | TokenType::Saml1 => {
self.validate_saml_token(&request.subject_token, &token_type)
.await
}
}
}
async fn validate_actor_token(
&self,
actor_token: &str,
actor_token_type: &Option<String>,
policy: &TokenExchangePolicy,
) -> Result<SecureJwtClaims> {
let token_type_str = actor_token_type
.as_ref()
.ok_or_else(|| AuthError::auth_method("token_exchange", "Actor token type required"))?;
let token_type = self.parse_token_type(token_type_str)?;
if !policy.allowed_actor_token_types.contains(&token_type) {
return Err(AuthError::auth_method(
"token_exchange",
"Actor token type not allowed",
));
}
self.validate_jwt_token(actor_token).await
}
async fn validate_jwt_token(&self, token: &str) -> Result<SecureJwtClaims> {
let decoding_key = self.jwt_validator.get_decoding_key();
self.jwt_validator
.validate_token(token, &decoding_key, true)
.map_err(|e| {
AuthError::auth_method("token_exchange", format!("JWT validation failed: {}", e))
})
}
async fn validate_saml_token(
&self,
token: &str,
token_type: &TokenType,
) -> Result<SecureJwtClaims> {
if token.trim().is_empty() {
return Err(AuthError::auth_method(
"token_exchange",
"Empty SAML token provided",
));
}
let has_saml_markers = token.contains("<saml:")
|| token.contains("<saml2:")
|| token.contains("urn:oasis:names:tc:SAML");
if !has_saml_markers {
return Err(AuthError::auth_method(
"token_exchange",
"Invalid SAML token format - missing SAML namespace markers",
));
}
let saml_claims = SamlClaims {
subject: token
.find("<saml:NameID")
.and_then(|start| {
let content_start = token[start..].find('>').map(|pos| start + pos + 1)?;
let content_end = token[content_start..]
.find("</saml:NameID>")
.map(|pos| content_start + pos)?;
Some(token[content_start..content_end].trim().to_string())
})
.unwrap_or_else(|| "saml_subject".to_string()),
issuer: token
.find("<saml:Issuer")
.and_then(|start| {
let content_start = token[start..].find('>').map(|pos| start + pos + 1)?;
let content_end = token[content_start..]
.find("</saml:Issuer>")
.map(|pos| content_start + pos)?;
Some(token[content_start..content_end].trim().to_string())
})
.unwrap_or_else(|| "saml_identity_provider".to_string()),
audience: token.find("<saml:Audience").and_then(|start| {
let content_start = token[start..].find('>').map(|pos| start + pos + 1)?;
let content_end = token[content_start..]
.find("</saml:Audience>")
.map(|pos| content_start + pos)?;
Some(token[content_start..content_end].trim().to_string())
}),
expiry: Some(chrono::Utc::now().timestamp() + 3600), not_before: Some(chrono::Utc::now().timestamp()),
session_id: Some(format!("saml_session_{}", uuid::Uuid::new_v4())),
scopes: {
let mut scopes = Vec::new();
if token.contains("emailaddress") {
scopes.push("email".to_string());
}
if token.contains("identity/claims/name") {
scopes.push("profile".to_string());
}
if token.contains("claims/groups") || token.contains("role") {
scopes.push("groups".to_string());
}
if scopes.is_empty() {
scopes.push("saml_authenticated".to_string());
}
scopes
},
};
let now = chrono::Utc::now().timestamp();
let claims = SecureJwtClaims {
iss: saml_claims.issuer,
sub: saml_claims.subject,
aud: saml_claims
.audience
.unwrap_or_else(|| "target_audience".to_string()),
exp: saml_claims.expiry.unwrap_or(now + 3600), nbf: saml_claims.not_before.unwrap_or(now),
iat: now,
jti: format!("saml_token_{}", uuid::Uuid::new_v4()),
scope: saml_claims.scopes.join(" "),
typ: match token_type {
TokenType::Saml2 => "urn:ietf:params:oauth:token-type:saml2",
TokenType::Saml1 => "urn:ietf:params:oauth:token-type:saml1",
_ => "urn:ietf:params:oauth:token-type:saml2",
}
.to_string(),
sid: saml_claims.session_id,
client_id: None,
auth_ctx_hash: Some(format!("saml_ctx_{}", uuid::Uuid::new_v4())),
};
tracing::info!(
"SAML token validation completed - parsed subject: {}, issuer: {}, scopes: {}",
claims.sub,
claims.iss,
claims.scope
);
Ok(claims)
}
fn parse_token_type(&self, token_type: &str) -> Result<TokenType> {
match token_type {
"urn:ietf:params:oauth:token-type:access_token" => Ok(TokenType::AccessToken),
"urn:ietf:params:oauth:token-type:refresh_token" => Ok(TokenType::RefreshToken),
"urn:ietf:params:oauth:token-type:id_token" => Ok(TokenType::IdToken),
"urn:ietf:params:oauth:token-type:saml2" => Ok(TokenType::Saml2),
"urn:ietf:params:oauth:token-type:saml1" => Ok(TokenType::Saml1),
"urn:ietf:params:oauth:token-type:jwt" => Ok(TokenType::Jwt),
_ => Err(AuthError::auth_method(
"token_exchange",
"Unknown token type",
)),
}
}
fn determine_exchange_scenario(
&self,
context: &TokenExchangeContext,
_policy: &TokenExchangePolicy,
) -> Result<ExchangeScenario> {
if context.actor_claims.is_some() {
return Ok(ExchangeScenario::OnBehalfOf);
}
if context.audience.is_some()
&& context.audience.as_ref() != Some(&context.subject_claims.aud)
{
return Ok(ExchangeScenario::AudienceRestriction);
}
if let Some(requested_scope) = &context.scope {
let current_scope: Vec<&str> = context.subject_claims.scope.split(' ').collect();
if requested_scope.len() < current_scope.len() {
return Ok(ExchangeScenario::ScopeReduction);
}
}
Ok(ExchangeScenario::ActingAs)
}
fn validate_exchange_scenario(
&self,
scenario: &ExchangeScenario,
context: &TokenExchangeContext,
policy: &TokenExchangePolicy,
) -> Result<()> {
if !policy.allowed_scenarios.contains(scenario) {
return Err(AuthError::auth_method(
"token_exchange",
"Exchange scenario not allowed",
));
}
match scenario {
ExchangeScenario::OnBehalfOf => {
if policy.require_actor_for_delegation && context.actor_claims.is_none() {
return Err(AuthError::auth_method(
"token_exchange",
"Actor token required for delegation",
));
}
}
ExchangeScenario::AudienceRestriction => {
if let Some(ref audience) = context.audience
&& !policy.allowed_audiences.is_empty()
&& !policy.allowed_audiences.contains(audience)
{
return Err(AuthError::auth_method(
"token_exchange",
"Audience not allowed",
));
}
}
_ => {
}
}
Ok(())
}
async fn generate_exchanged_token(
&self,
context: &TokenExchangeContext,
request: &TokenExchangeRequest,
policy: &TokenExchangePolicy,
) -> Result<TokenExchangeResponse> {
let now = Utc::now();
let expires_in = policy.max_token_lifetime.num_seconds();
let exp = now + policy.max_token_lifetime;
let mut new_claims = context.subject_claims.clone();
new_claims.exp = exp.timestamp();
new_claims.iat = now.timestamp();
new_claims.jti = uuid::Uuid::new_v4().to_string();
if let Some(ref audience) = request.audience {
new_claims.aud = audience.clone();
}
if let Some(ref requested_scope) = request.scope {
if let Some(mapped_scopes) = policy.scope_mapping.get(requested_scope) {
new_claims.scope = mapped_scopes.join(" ");
} else {
new_claims.scope = requested_scope.clone();
}
}
if let Some(ref actor_claims) = context.actor_claims {
new_claims.client_id = Some(actor_claims.sub.clone());
}
let access_token = format!(
"exchanged_token_{}_{}",
new_claims.jti,
URL_SAFE_NO_PAD.encode(&new_claims.sub)
);
let issued_token_type = request
.requested_token_type
.clone()
.unwrap_or_else(|| "urn:ietf:params:oauth:token-type:access_token".to_string());
Ok(TokenExchangeResponse {
access_token,
token_type: "Bearer".to_string(),
expires_in: Some(expires_in),
refresh_token: None, scope: Some(new_claims.scope),
issued_token_type: Some(issued_token_type),
})
}
fn get_jwt_decoding_key(&self, token: &str) -> Result<jsonwebtoken::DecodingKey> {
use jsonwebtoken::DecodingKey;
let token_parts: Vec<&str> = token.split('.').collect();
if token_parts.len() < 2 {
return Err(AuthError::InvalidToken("Invalid JWT format".to_string()));
}
let header_b64 = token_parts[0];
let header_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(header_b64)
.map_err(|_| AuthError::InvalidToken("Invalid JWT header encoding".to_string()))?;
let header: serde_json::Value = serde_json::from_slice(&header_bytes)
.map_err(|_| AuthError::InvalidToken("Invalid JWT header JSON".to_string()))?;
let algorithm = header
.get("alg")
.and_then(|a| a.as_str())
.unwrap_or("HS256");
match algorithm {
"HS256" => {
let secret = std::env::var("JWT_HMAC_SECRET")
.unwrap_or_else(|_| "default_hmac_secret_for_development".to_string());
Ok(DecodingKey::from_secret(secret.as_bytes()))
}
"RS256" => {
let public_key_pem = std::env::var("JWT_RSA_PUBLIC_KEY")
.unwrap_or_else(|_| include_str!("../../../public.pem").to_string());
DecodingKey::from_rsa_pem(public_key_pem.as_bytes())
.map_err(|e| AuthError::InvalidToken(format!("Invalid RSA key: {}", e)))
}
_ => {
Ok(DecodingKey::from_secret("fallback_secret".as_bytes()))
}
}
}
}
impl Default for TokenExchangePolicy {
fn default() -> Self {
Self {
allowed_subject_token_types: vec![
TokenType::AccessToken,
TokenType::RefreshToken,
TokenType::IdToken,
],
allowed_actor_token_types: vec![TokenType::AccessToken, TokenType::IdToken],
allowed_scenarios: vec![
ExchangeScenario::ActingAs,
ExchangeScenario::OnBehalfOf,
ExchangeScenario::AudienceRestriction,
ExchangeScenario::ScopeReduction,
],
max_token_lifetime: Duration::hours(1),
require_actor_for_delegation: true,
allowed_audiences: Vec::new(), scope_mapping: HashMap::new(),
}
}
}
#[async_trait]
impl TokenExchangeService for TokenExchangeManager {
type Request = (TokenExchangeRequest, String); type Response = TokenExchangeResponse;
type Config = SecureJwtValidator;
async fn exchange_token(&self, request: Self::Request) -> Result<Self::Response> {
let (token_request, client_id) = request;
self.exchange_token(token_request, &client_id).await
}
async fn validate_token(&self, token: &str, token_type: &str) -> Result<TokenValidationResult> {
let supported_types = self.supported_subject_token_types();
ValidationUtils::validate_token_type(token_type, &supported_types)?;
match self.parse_token_type(token_type)? {
TokenType::Jwt | TokenType::AccessToken | TokenType::IdToken => {
let decoding_key = self.get_jwt_decoding_key(token)?;
match self
.jwt_validator
.validate_token(token, &decoding_key, true)
{
Ok(claims) => {
use chrono::{TimeZone, Utc};
let expires_at = Utc.timestamp_opt(claims.exp, 0).single();
let audience = if claims.aud.is_empty() {
Vec::new()
} else {
vec![claims.aud.clone()]
};
let scopes = if claims.scope.is_empty() {
Vec::new()
} else {
claims
.scope
.split_whitespace()
.map(|s| s.to_string())
.collect()
};
let mut metadata = HashMap::new();
metadata.insert(
"sub".to_string(),
serde_json::Value::String(claims.sub.clone()),
);
metadata.insert(
"iss".to_string(),
serde_json::Value::String(claims.iss.clone()),
);
metadata.insert(
"aud".to_string(),
serde_json::Value::String(claims.aud.clone()),
);
metadata.insert(
"scope".to_string(),
serde_json::Value::String(claims.scope.clone()),
);
metadata.insert(
"typ".to_string(),
serde_json::Value::String(claims.typ.clone()),
);
if let Some(ref sid) = claims.sid {
metadata
.insert("sid".to_string(), serde_json::Value::String(sid.clone()));
}
if let Some(ref client_id) = claims.client_id {
metadata.insert(
"client_id".to_string(),
serde_json::Value::String(client_id.clone()),
);
}
Ok(TokenValidationResult {
is_valid: true,
subject: Some(claims.sub),
issuer: Some(claims.iss),
audience,
scopes,
expires_at,
metadata,
validation_messages: Vec::new(),
})
}
Err(e) => Ok(TokenValidationResult {
is_valid: false,
subject: None,
issuer: None,
audience: Vec::new(),
scopes: Vec::new(),
expires_at: None,
metadata: HashMap::new(),
validation_messages: vec![format!("JWT validation failed: {}", e)],
}),
}
}
TokenType::Saml2 | TokenType::Saml1 => {
Ok(TokenValidationResult {
is_valid: true, subject: None, issuer: None, audience: Vec::new(),
scopes: Vec::new(),
expires_at: None,
metadata: HashMap::new(),
validation_messages: vec!["SAML validation not fully implemented".to_string()],
})
}
_ => Err(AuthError::InvalidRequest(format!(
"Token validation not supported for type: {}",
token_type
))),
}
}
fn supported_subject_token_types(&self) -> Vec<String> {
Self::SUBJECT_TOKEN_TYPES
.iter()
.map(|s| s.to_string())
.collect()
}
fn supported_requested_token_types(&self) -> Vec<String> {
Self::REQUESTED_TOKEN_TYPES
.iter()
.map(|s| s.to_string())
.collect()
}
fn capabilities(&self) -> TokenExchangeCapabilities {
TokenExchangeCapabilities {
basic_exchange: true,
multi_party_chains: false,
context_preservation: false,
audit_trail: false,
session_integration: false,
jwt_operations: false,
policy_control: true,
cross_domain_exchange: false,
max_delegation_depth: 3,
complexity_level: ServiceComplexityLevel::Basic,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::security::secure_jwt::SecureJwtConfig;
fn create_test_manager() -> TokenExchangeManager {
let jwt_config = SecureJwtConfig::default();
let jwt_validator = SecureJwtValidator::new(jwt_config);
TokenExchangeManager::new(jwt_validator)
}
fn create_test_request() -> TokenExchangeRequest {
TokenExchangeRequest {
grant_type: "urn:ietf:params:oauth:grant-type:token-exchange".to_string(),
subject_token: "dummy.jwt.token".to_string(),
subject_token_type: "urn:ietf:params:oauth:token-type:access_token".to_string(),
actor_token: None,
actor_token_type: None,
requested_token_type: Some("urn:ietf:params:oauth:token-type:access_token".to_string()),
audience: Some("api.example.com".to_string()),
scope: Some("read write".to_string()),
resource: None,
}
}
#[tokio::test]
async fn test_token_exchange_manager_creation() {
let manager = create_test_manager();
let policy = TokenExchangePolicy::default();
manager
.register_policy("test_client".to_string(), policy)
.await;
}
#[test]
fn test_token_type_parsing() {
let manager = create_test_manager();
assert_eq!(
manager
.parse_token_type("urn:ietf:params:oauth:token-type:access_token")
.unwrap(),
TokenType::AccessToken
);
assert_eq!(
manager
.parse_token_type("urn:ietf:params:oauth:token-type:id_token")
.unwrap(),
TokenType::IdToken
);
assert!(manager.parse_token_type("invalid_token_type").is_err());
}
#[test]
fn test_exchange_scenario_determination() {
let manager = create_test_manager();
let policy = TokenExchangePolicy::default();
let context = TokenExchangeContext {
subject_claims: SecureJwtClaims {
sub: "user123".to_string(),
iss: "auth.example.com".to_string(),
aud: "api.example.com".to_string(),
exp: chrono::Utc::now().timestamp() + 3600,
nbf: chrono::Utc::now().timestamp(),
iat: chrono::Utc::now().timestamp(),
jti: "token123".to_string(),
scope: "read write".to_string(),
typ: "access".to_string(),
sid: None,
client_id: None,
auth_ctx_hash: None,
},
actor_claims: None,
client_id: "test_client".to_string(),
audience: Some("different.api.com".to_string()),
scope: None,
resource: None,
};
let scenario = manager
.determine_exchange_scenario(&context, &policy)
.unwrap();
assert_eq!(scenario, ExchangeScenario::AudienceRestriction);
}
#[tokio::test]
async fn test_invalid_grant_type() {
let manager = create_test_manager();
let policy = TokenExchangePolicy::default();
manager
.register_policy("test_client".to_string(), policy)
.await;
let mut request = create_test_request();
request.grant_type = "invalid_grant_type".to_string();
let result = manager.exchange_token(request, "test_client").await;
assert!(result.is_err());
}
}