use crate::errors::{AuthError, Result};
#[cfg(feature = "saml")]
use crate::methods::saml::SamlSignatureValidator;
use crate::security::secure_jwt::{SecureJwtClaims, SecureJwtValidator};
use crate::server::token_exchange::token_exchange_common::{
ServiceComplexityLevel, TokenExchangeCapabilities, TokenExchangeService, TokenValidationResult,
ValidationUtils,
};
use async_trait::async_trait;
use base64::Engine as _;
use chrono::{Duration, Utc};
use jsonwebtoken::{Algorithm, Header, encode};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenExchangeRequest {
pub grant_type: String,
pub subject_token: String,
pub subject_token_type: String,
pub actor_token: Option<String>,
pub actor_token_type: Option<String>,
pub requested_token_type: Option<String>,
pub audience: Option<String>,
pub scope: Option<String>,
pub resource: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenExchangeResponse {
pub access_token: String,
pub token_type: String,
pub expires_in: Option<i64>,
pub refresh_token: Option<String>,
pub scope: Option<String>,
pub issued_token_type: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum TokenType {
#[serde(rename = "urn:ietf:params:oauth:token-type:access_token")]
AccessToken,
#[serde(rename = "urn:ietf:params:oauth:token-type:refresh_token")]
RefreshToken,
#[serde(rename = "urn:ietf:params:oauth:token-type:id_token")]
IdToken,
#[serde(rename = "urn:ietf:params:oauth:token-type:saml2")]
Saml2,
#[serde(rename = "urn:ietf:params:oauth:token-type:saml1")]
Saml1,
#[serde(rename = "urn:ietf:params:oauth:token-type:jwt")]
Jwt,
}
#[derive(Debug, Clone)]
pub struct TokenExchangeContext {
pub subject_claims: SecureJwtClaims,
pub actor_claims: Option<SecureJwtClaims>,
pub client_id: String,
pub audience: Option<String>,
pub scope: Option<Vec<String>>,
pub resource: Option<String>,
}
#[derive(Debug, Clone)]
pub struct TokenExchangePolicy {
pub allowed_subject_token_types: Vec<TokenType>,
pub allowed_actor_token_types: Vec<TokenType>,
pub allowed_scenarios: Vec<ExchangeScenario>,
pub max_token_lifetime: Duration,
pub require_actor_for_delegation: bool,
pub allowed_audiences: Vec<String>,
pub scope_mapping: HashMap<String, Vec<String>>,
}
impl TokenExchangePolicy {
pub fn builder() -> TokenExchangePolicyBuilder {
TokenExchangePolicyBuilder {
inner: Self::default(),
}
}
pub fn jwt_only() -> Self {
Self {
allowed_subject_token_types: vec![TokenType::Jwt, TokenType::AccessToken],
allowed_actor_token_types: vec![TokenType::AccessToken],
allowed_scenarios: vec![
ExchangeScenario::OnBehalfOf,
ExchangeScenario::AudienceRestriction,
],
max_token_lifetime: Duration::hours(1),
require_actor_for_delegation: true,
allowed_audiences: Vec::new(),
scope_mapping: HashMap::new(),
}
}
}
pub struct TokenExchangePolicyBuilder {
inner: TokenExchangePolicy,
}
impl TokenExchangePolicyBuilder {
pub fn subject_token_types(mut self, types: Vec<TokenType>) -> Self {
self.inner.allowed_subject_token_types = types;
self
}
pub fn actor_token_types(mut self, types: Vec<TokenType>) -> Self {
self.inner.allowed_actor_token_types = types;
self
}
pub fn scenarios(mut self, scenarios: Vec<ExchangeScenario>) -> Self {
self.inner.allowed_scenarios = scenarios;
self
}
pub fn max_token_lifetime(mut self, lifetime: Duration) -> Self {
self.inner.max_token_lifetime = lifetime;
self
}
pub fn require_actor_for_delegation(mut self, required: bool) -> Self {
self.inner.require_actor_for_delegation = required;
self
}
pub fn audience(mut self, aud: impl Into<String>) -> Self {
self.inner.allowed_audiences.push(aud.into());
self
}
pub fn audiences(mut self, auds: Vec<String>) -> Self {
self.inner.allowed_audiences = auds;
self
}
pub fn scope_map(mut self, source: impl Into<String>, targets: Vec<String>) -> Self {
self.inner.scope_mapping.insert(source.into(), targets);
self
}
pub fn build(self) -> TokenExchangePolicy {
self.inner
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ExchangeScenario {
ActingAs,
OnBehalfOf,
TokenConversion,
AudienceRestriction,
ScopeReduction,
}
pub struct TokenExchangeManager {
jwt_validator: SecureJwtValidator,
policies: tokio::sync::RwLock<HashMap<String, TokenExchangePolicy>>,
active_exchanges: tokio::sync::RwLock<HashMap<String, TokenExchangeContext>>,
}
impl TokenExchangeManager {
const SUBJECT_TOKEN_TYPES: &'static [&'static str] = &[
"urn:ietf:params:oauth:token-type:jwt",
"urn:ietf:params:oauth:token-type:access_token",
"urn:ietf:params:oauth:token-type:id_token",
"urn:ietf:params:oauth:token-type:saml1",
"urn:ietf:params:oauth:token-type:saml2",
];
const REQUESTED_TOKEN_TYPES: &'static [&'static str] = &[
"urn:ietf:params:oauth:token-type:jwt",
"urn:ietf:params:oauth:token-type:access_token",
"urn:ietf:params:oauth:token-type:refresh_token",
];
pub fn new(jwt_validator: SecureJwtValidator) -> Self {
Self {
jwt_validator,
policies: tokio::sync::RwLock::new(HashMap::new()),
active_exchanges: tokio::sync::RwLock::new(HashMap::new()),
}
}
pub async fn register_policy(&self, client_id: String, policy: TokenExchangePolicy) {
let mut policies = self.policies.write().await;
policies.insert(client_id, policy);
}
pub async fn exchange_token(
&self,
request: TokenExchangeRequest,
client_id: &str,
) -> Result<TokenExchangeResponse> {
if request.grant_type != "urn:ietf:params:oauth:grant-type:token-exchange" {
return Err(AuthError::auth_method(
"token_exchange",
"Invalid grant type for token exchange",
));
}
let policies = self.policies.read().await;
let policy = policies.get(client_id).ok_or_else(|| {
AuthError::auth_method("token_exchange", "No token exchange policy for client")
})?;
let subject_claims = self.validate_subject_token(&request, policy).await?;
let actor_claims = if let Some(ref actor_token) = request.actor_token {
Some(
self.validate_actor_token(actor_token, &request.actor_token_type, policy)
.await?,
)
} else {
None
};
let context = TokenExchangeContext {
subject_claims,
actor_claims,
client_id: client_id.to_string(),
audience: request.audience.clone(),
scope: request
.scope
.as_ref()
.map(|s| s.split(' ').map(String::from).collect()),
resource: request.resource.clone(),
};
let scenario = self.determine_exchange_scenario(&context, policy)?;
self.validate_exchange_scenario(&scenario, &context, policy)?;
let response = self
.generate_exchanged_token(&context, &request, policy)
.await?;
let exchange_id = uuid::Uuid::new_v4().to_string();
let mut exchanges = self.active_exchanges.write().await;
exchanges.insert(exchange_id, context);
Ok(response)
}
async fn validate_subject_token(
&self,
request: &TokenExchangeRequest,
policy: &TokenExchangePolicy,
) -> Result<SecureJwtClaims> {
let token_type = self.parse_token_type(&request.subject_token_type)?;
if !policy.allowed_subject_token_types.contains(&token_type) {
return Err(AuthError::auth_method(
"token_exchange",
"Subject token type not allowed",
));
}
match token_type {
TokenType::AccessToken
| TokenType::RefreshToken
| TokenType::IdToken
| TokenType::Jwt => {
self.validate_jwt_token(&request.subject_token).await
}
TokenType::Saml2 | TokenType::Saml1 => {
self.validate_saml_token(&request.subject_token, &token_type)
.await
}
}
}
async fn validate_actor_token(
&self,
actor_token: &str,
actor_token_type: &Option<String>,
policy: &TokenExchangePolicy,
) -> Result<SecureJwtClaims> {
let token_type_str = actor_token_type
.as_ref()
.ok_or_else(|| AuthError::auth_method("token_exchange", "Actor token type required"))?;
let token_type = self.parse_token_type(token_type_str)?;
if !policy.allowed_actor_token_types.contains(&token_type) {
return Err(AuthError::auth_method(
"token_exchange",
"Actor token type not allowed",
));
}
self.validate_jwt_token(actor_token).await
}
async fn validate_jwt_token(&self, token: &str) -> Result<SecureJwtClaims> {
self.jwt_validator.validate(token).map_err(|e| {
AuthError::auth_method("token_exchange", format!("JWT validation failed: {}", e))
})
}
fn extract_saml_xml(&self, token: &str) -> Result<String> {
if token.trim().is_empty() {
return Err(AuthError::auth_method(
"token_exchange",
"Empty SAML token provided",
));
}
let decoded = if token.trim_start().starts_with('<') {
token.to_string()
} else {
String::from_utf8(
base64::engine::general_purpose::STANDARD
.decode(token)
.map_err(|e| {
AuthError::auth_method(
"token_exchange",
format!("Invalid base64-encoded SAML token: {}", e),
)
})?,
)
.map_err(|e| {
AuthError::auth_method(
"token_exchange",
format!("Invalid UTF-8 in SAML token: {}", e),
)
})?
};
let has_saml_markers = decoded.contains("<saml:")
|| decoded.contains("<saml2:")
|| decoded.contains("<Assertion")
|| decoded.contains("<Response")
|| decoded.contains("urn:oasis:names:tc:SAML");
if !has_saml_markers {
return Err(AuthError::auth_method(
"token_exchange",
"Invalid SAML token format - missing SAML namespace markers",
));
}
Ok(decoded)
}
#[cfg(feature = "saml")]
fn extract_xml_text(&self, xml: &str, local_name: &str) -> Option<String> {
let patterns = [
format!("<{local_name}>"),
format!("<saml:{local_name}>"),
format!("<saml2:{local_name}>"),
format!("<{local_name} "),
format!("<saml:{local_name} "),
format!("<saml2:{local_name} "),
];
for pattern in patterns {
if let Some(start) = xml.find(&pattern) {
let content_start = xml[start..].find('>').map(|index| start + index + 1)?;
let end_patterns = [
format!("</{local_name}>"),
format!("</saml:{local_name}>"),
format!("</saml2:{local_name}>"),
];
for end_pattern in end_patterns {
if let Some(relative_end) = xml[content_start..].find(&end_pattern) {
return Some(
xml[content_start..content_start + relative_end]
.trim()
.to_string(),
);
}
}
}
}
None
}
#[cfg(feature = "saml")]
fn extract_xml_attribute(&self, xml: &str, attribute_name: &str) -> Option<String> {
for pattern in [
format!("{attribute_name}=\""),
format!("{attribute_name}='"),
] {
if let Some(start) = xml.find(&pattern) {
let value_start = start + pattern.len();
if let Some(relative_end) = xml[value_start..].find(['"', '\'']) {
return Some(xml[value_start..value_start + relative_end].to_string());
}
}
}
None
}
#[cfg(feature = "saml")]
fn parse_saml_timestamp(&self, timestamp: &str) -> Result<i64> {
chrono::DateTime::parse_from_rfc3339(timestamp)
.or_else(|_| chrono::DateTime::parse_from_str(timestamp, "%Y-%m-%dT%H:%M:%S%.fZ"))
.or_else(|_| chrono::DateTime::parse_from_str(timestamp, "%Y-%m-%dT%H:%M:%SZ"))
.map(|dt| dt.timestamp())
.map_err(|_| {
AuthError::auth_method(
"token_exchange",
format!("Invalid SAML timestamp: {}", timestamp),
)
})
}
#[cfg(feature = "saml")]
fn extract_saml_timestamp(&self, xml: &str, attribute_name: &str) -> Result<Option<i64>> {
match self.extract_xml_attribute(xml, attribute_name) {
Some(timestamp) => self.parse_saml_timestamp(×tamp).map(Some),
None => Ok(None),
}
}
#[cfg(feature = "saml")]
fn validate_saml_assertion_xml(
&self,
xml: &str,
token_type: &TokenType,
) -> Result<SecureJwtClaims> {
let validator = SamlSignatureValidator;
let certificate = validator.extract_embedded_certificate(xml).map_err(|e| {
AuthError::auth_method(
"token_exchange",
format!(
"SAML assertion is missing a usable embedded certificate: {}",
e
),
)
})?;
let signature_valid = validator
.validate_xml_signature(xml, &certificate)
.map_err(|e| {
AuthError::auth_method(
"token_exchange",
format!("SAML signature validation failed: {}", e),
)
})?;
if !signature_valid {
return Err(AuthError::auth_method(
"token_exchange",
"SAML signature validation failed",
));
}
let subject = self.extract_xml_text(xml, "NameID").ok_or_else(|| {
AuthError::auth_method("token_exchange", "SAML assertion is missing NameID")
})?;
let issuer = self.extract_xml_text(xml, "Issuer").ok_or_else(|| {
AuthError::auth_method("token_exchange", "SAML assertion is missing Issuer")
})?;
let audience = self.extract_xml_text(xml, "Audience");
let issue_instant = self.extract_saml_timestamp(xml, "IssueInstant")?;
let not_before = self.extract_saml_timestamp(xml, "NotBefore")?;
let not_on_or_after = self.extract_saml_timestamp(xml, "NotOnOrAfter")?;
let session_id = self.extract_xml_attribute(xml, "SessionIndex");
let now = chrono::Utc::now().timestamp();
if let Some(value) = issue_instant {
if value > now + 30 {
return Err(AuthError::auth_method(
"token_exchange",
"SAML assertion issue instant is in the future",
));
}
if now - value > 300 {
return Err(AuthError::auth_method(
"token_exchange",
"SAML assertion is too old",
));
}
}
if let Some(value) = not_before
&& now < value
{
return Err(AuthError::auth_method(
"token_exchange",
"SAML assertion is not yet valid",
));
}
if let Some(value) = not_on_or_after
&& now >= value
{
return Err(AuthError::auth_method(
"token_exchange",
"SAML assertion has expired",
));
}
let mut scopes = Vec::new();
if xml.contains("emailaddress") {
scopes.push("email".to_string());
}
if xml.contains("identity/claims/name") || xml.contains("givenname") {
scopes.push("profile".to_string());
}
if xml.contains("claims/groups") || xml.contains("role") || xml.contains("Role") {
scopes.push("groups".to_string());
}
if scopes.is_empty() {
scopes.push("saml_authenticated".to_string());
}
Ok(SecureJwtClaims {
iss: issuer,
sub: subject,
aud: audience.unwrap_or_else(|| "target_audience".to_string()),
exp: not_on_or_after.unwrap_or(now + 3600),
nbf: not_before.unwrap_or(now),
iat: issue_instant.unwrap_or(now),
jti: format!("saml_token_{}", uuid::Uuid::new_v4()),
scope: scopes.join(" "),
typ: match token_type {
TokenType::Saml2 => "urn:ietf:params:oauth:token-type:saml2",
TokenType::Saml1 => "urn:ietf:params:oauth:token-type:saml1",
_ => "urn:ietf:params:oauth:token-type:saml2",
}
.to_string(),
sid: session_id,
client_id: None,
auth_ctx_hash: Some(format!("saml_ctx_{}", uuid::Uuid::new_v4())),
})
}
#[cfg(not(feature = "saml"))]
fn validate_saml_assertion_xml(
&self,
_xml: &str,
_token_type: &TokenType,
) -> Result<SecureJwtClaims> {
Err(AuthError::auth_method(
"token_exchange",
"SAML validation requires the 'saml' feature to be enabled",
))
}
async fn validate_saml_token(
&self,
token: &str,
token_type: &TokenType,
) -> Result<SecureJwtClaims> {
let xml = self.extract_saml_xml(token)?;
let claims = self.validate_saml_assertion_xml(&xml, token_type)?;
tracing::info!(
"SAML token validation completed - parsed subject: {}, issuer: {}, scopes: {}",
claims.sub,
claims.iss,
claims.scope
);
Ok(claims)
}
fn parse_token_type(&self, token_type: &str) -> Result<TokenType> {
match token_type {
"urn:ietf:params:oauth:token-type:access_token" => Ok(TokenType::AccessToken),
"urn:ietf:params:oauth:token-type:refresh_token" => Ok(TokenType::RefreshToken),
"urn:ietf:params:oauth:token-type:id_token" => Ok(TokenType::IdToken),
"urn:ietf:params:oauth:token-type:saml2" => Ok(TokenType::Saml2),
"urn:ietf:params:oauth:token-type:saml1" => Ok(TokenType::Saml1),
"urn:ietf:params:oauth:token-type:jwt" => Ok(TokenType::Jwt),
_ => Err(AuthError::auth_method(
"token_exchange",
"Unknown token type",
)),
}
}
fn determine_exchange_scenario(
&self,
context: &TokenExchangeContext,
_policy: &TokenExchangePolicy,
) -> Result<ExchangeScenario> {
if context.actor_claims.is_some() {
return Ok(ExchangeScenario::OnBehalfOf);
}
if let Some(audience) = context.audience.as_ref() {
if audience != &context.subject_claims.aud {
return Ok(ExchangeScenario::AudienceRestriction);
}
}
if let Some(requested_scope) = &context.scope {
let current_scope: Vec<&str> = context.subject_claims.scope.split(' ').collect();
if requested_scope.len() < current_scope.len() {
return Ok(ExchangeScenario::ScopeReduction);
}
}
Ok(ExchangeScenario::ActingAs)
}
fn validate_exchange_scenario(
&self,
scenario: &ExchangeScenario,
context: &TokenExchangeContext,
policy: &TokenExchangePolicy,
) -> Result<()> {
if !policy.allowed_scenarios.is_empty() && !policy.allowed_scenarios.contains(scenario) {
return Err(AuthError::auth_method(
"token_exchange",
"Exchange scenario is not permitted by policy",
));
}
match scenario {
ExchangeScenario::OnBehalfOf => {
if policy.require_actor_for_delegation && context.actor_claims.is_none() {
return Err(AuthError::auth_method(
"token_exchange",
"Actor token required for delegation",
));
}
}
ExchangeScenario::AudienceRestriction => {
if let Some(audience) = context.audience.as_ref() {
if !policy.allowed_audiences.is_empty()
&& !policy.allowed_audiences.contains(audience)
{
return Err(AuthError::auth_method(
"token_exchange",
"Audience not allowed",
));
}
}
}
_ => {}
}
Ok(())
}
async fn generate_exchanged_token(
&self,
context: &TokenExchangeContext,
request: &TokenExchangeRequest,
policy: &TokenExchangePolicy,
) -> Result<TokenExchangeResponse> {
let now = Utc::now();
let expires_in = policy.max_token_lifetime.num_seconds();
let exp = now + policy.max_token_lifetime;
let mut new_claims = context.subject_claims.clone();
new_claims.exp = exp.timestamp();
new_claims.iat = now.timestamp();
new_claims.jti = uuid::Uuid::new_v4().to_string();
if let Some(ref audience) = request.audience {
new_claims.aud = audience.clone();
}
if let Some(ref requested_scope) = request.scope {
if let Some(mapped_scopes) = policy.scope_mapping.get(requested_scope) {
new_claims.scope = mapped_scopes.join(" ");
} else {
new_claims.scope = requested_scope.clone();
}
}
if let Some(ref actor_claims) = context.actor_claims {
new_claims.client_id = Some(actor_claims.sub.clone());
}
let access_token = encode(
&Header::new(Algorithm::HS256),
&new_claims,
&self.jwt_validator.get_encoding_key(),
)
.map_err(|e| AuthError::internal(format!("Failed to sign exchanged token: {}", e)))?;
let issued_token_type = request
.requested_token_type
.clone()
.unwrap_or_else(|| "urn:ietf:params:oauth:token-type:access_token".to_string());
Ok(TokenExchangeResponse {
access_token,
token_type: "Bearer".to_string(),
expires_in: Some(expires_in),
refresh_token: None, scope: Some(new_claims.scope),
issued_token_type: Some(issued_token_type),
})
}
}
impl Default for TokenExchangePolicy {
fn default() -> Self {
Self {
allowed_subject_token_types: vec![
TokenType::AccessToken,
TokenType::RefreshToken,
TokenType::IdToken,
],
allowed_actor_token_types: vec![TokenType::AccessToken, TokenType::IdToken],
allowed_scenarios: vec![
ExchangeScenario::ActingAs,
ExchangeScenario::OnBehalfOf,
ExchangeScenario::AudienceRestriction,
ExchangeScenario::ScopeReduction,
],
max_token_lifetime: Duration::hours(1),
require_actor_for_delegation: true,
allowed_audiences: Vec::new(), scope_mapping: HashMap::new(),
}
}
}
#[async_trait]
impl TokenExchangeService for TokenExchangeManager {
type Request = (TokenExchangeRequest, String); type Response = TokenExchangeResponse;
type Config = SecureJwtValidator;
async fn exchange_token(&self, request: Self::Request) -> Result<Self::Response> {
let (token_request, client_id) = request;
self.exchange_token(token_request, &client_id).await
}
async fn validate_token(&self, token: &str, token_type: &str) -> Result<TokenValidationResult> {
let supported_types = self.supported_subject_token_types();
ValidationUtils::validate_token_type(token_type, &supported_types)?;
match self.parse_token_type(token_type)? {
TokenType::Jwt | TokenType::AccessToken | TokenType::IdToken => {
match self.jwt_validator.validate(token) {
Ok(claims) => {
use chrono::{TimeZone, Utc};
let expires_at = Utc.timestamp_opt(claims.exp, 0).single();
let audience = if claims.aud.is_empty() {
Vec::new()
} else {
vec![claims.aud.clone()]
};
let scopes = if claims.scope.is_empty() {
Vec::new()
} else {
claims
.scope
.split_whitespace()
.map(|s| s.to_string())
.collect()
};
let mut metadata = HashMap::new();
metadata.insert(
"sub".to_string(),
serde_json::Value::String(claims.sub.clone()),
);
metadata.insert(
"iss".to_string(),
serde_json::Value::String(claims.iss.clone()),
);
metadata.insert(
"aud".to_string(),
serde_json::Value::String(claims.aud.clone()),
);
metadata.insert(
"scope".to_string(),
serde_json::Value::String(claims.scope.clone()),
);
metadata.insert(
"typ".to_string(),
serde_json::Value::String(claims.typ.clone()),
);
if let Some(ref sid) = claims.sid {
metadata
.insert("sid".to_string(), serde_json::Value::String(sid.clone()));
}
if let Some(ref client_id) = claims.client_id {
metadata.insert(
"client_id".to_string(),
serde_json::Value::String(client_id.clone()),
);
}
Ok(TokenValidationResult {
is_valid: true,
subject: Some(claims.sub),
issuer: Some(claims.iss),
audience,
scopes,
expires_at,
metadata,
validation_messages: Vec::new(),
})
}
Err(e) => Ok(TokenValidationResult {
is_valid: false,
subject: None,
issuer: None,
audience: Vec::new(),
scopes: Vec::new(),
expires_at: None,
metadata: HashMap::new(),
validation_messages: vec![format!("JWT validation failed: {}", e)],
}),
}
}
TokenType::Saml2 | TokenType::Saml1 => {
match self
.validate_saml_token(token, &self.parse_token_type(token_type)?)
.await
{
Ok(claims) => {
use chrono::{TimeZone, Utc};
let expires_at = Utc.timestamp_opt(claims.exp, 0).single();
let audience = if claims.aud.is_empty() {
Vec::new()
} else {
vec![claims.aud.clone()]
};
let scopes = if claims.scope.is_empty() {
Vec::new()
} else {
claims
.scope
.split_whitespace()
.map(|scope| scope.to_string())
.collect()
};
let mut metadata = HashMap::new();
metadata.insert(
"sub".to_string(),
serde_json::Value::String(claims.sub.clone()),
);
metadata.insert(
"iss".to_string(),
serde_json::Value::String(claims.iss.clone()),
);
metadata.insert(
"aud".to_string(),
serde_json::Value::String(claims.aud.clone()),
);
metadata.insert(
"scope".to_string(),
serde_json::Value::String(claims.scope.clone()),
);
metadata.insert(
"typ".to_string(),
serde_json::Value::String(claims.typ.clone()),
);
if let Some(ref sid) = claims.sid {
metadata
.insert("sid".to_string(), serde_json::Value::String(sid.clone()));
}
Ok(TokenValidationResult {
is_valid: true,
subject: Some(claims.sub),
issuer: Some(claims.iss),
audience,
scopes,
expires_at,
metadata,
validation_messages: Vec::new(),
})
}
Err(error) => Ok(TokenValidationResult {
is_valid: false,
subject: None,
issuer: None,
audience: Vec::new(),
scopes: Vec::new(),
expires_at: None,
metadata: HashMap::new(),
validation_messages: vec![error.to_string()],
}),
}
}
_ => Err(AuthError::InvalidRequest(format!(
"Token validation not supported for type: {}",
token_type
))),
}
}
fn supported_subject_token_types(&self) -> Vec<String> {
Self::SUBJECT_TOKEN_TYPES
.iter()
.map(|s| s.to_string())
.collect()
}
fn supported_requested_token_types(&self) -> Vec<String> {
Self::REQUESTED_TOKEN_TYPES
.iter()
.map(|s| s.to_string())
.collect()
}
fn capabilities(&self) -> TokenExchangeCapabilities {
TokenExchangeCapabilities {
basic_exchange: true,
multi_party_chains: false,
context_preservation: false,
audit_trail: false,
session_integration: false,
jwt_operations: false,
policy_control: true,
cross_domain_exchange: false,
max_delegation_depth: 3,
complexity_level: ServiceComplexityLevel::Basic,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::security::secure_jwt::SecureJwtConfig;
fn create_test_manager() -> TokenExchangeManager {
let jwt_config = SecureJwtConfig::default();
let jwt_validator = SecureJwtValidator::new(jwt_config).expect("test JWT config");
TokenExchangeManager::new(jwt_validator)
}
fn create_test_request() -> TokenExchangeRequest {
TokenExchangeRequest {
grant_type: "urn:ietf:params:oauth:grant-type:token-exchange".to_string(),
subject_token: "dummy.jwt.token".to_string(),
subject_token_type: "urn:ietf:params:oauth:token-type:access_token".to_string(),
actor_token: None,
actor_token_type: None,
requested_token_type: Some("urn:ietf:params:oauth:token-type:access_token".to_string()),
audience: Some("api.example.com".to_string()),
scope: Some("read write".to_string()),
resource: None,
}
}
#[tokio::test]
async fn test_token_exchange_manager_creation() {
let manager = create_test_manager();
let policy = TokenExchangePolicy::default();
manager
.register_policy("test_client".to_string(), policy)
.await;
}
#[test]
fn test_token_type_parsing() {
let manager = create_test_manager();
assert_eq!(
manager
.parse_token_type("urn:ietf:params:oauth:token-type:access_token")
.unwrap(),
TokenType::AccessToken
);
assert_eq!(
manager
.parse_token_type("urn:ietf:params:oauth:token-type:id_token")
.unwrap(),
TokenType::IdToken
);
assert!(manager.parse_token_type("invalid_token_type").is_err());
}
#[test]
fn test_exchange_scenario_determination() {
let manager = create_test_manager();
let policy = TokenExchangePolicy::default();
let context = TokenExchangeContext {
subject_claims: SecureJwtClaims {
sub: "user123".to_string(),
iss: "auth.example.com".to_string(),
aud: "api.example.com".to_string(),
exp: chrono::Utc::now().timestamp() + 3600,
nbf: chrono::Utc::now().timestamp(),
iat: chrono::Utc::now().timestamp(),
jti: "token123".to_string(),
scope: "read write".to_string(),
typ: "access".to_string(),
sid: None,
client_id: None,
auth_ctx_hash: None,
},
actor_claims: None,
client_id: "test_client".to_string(),
audience: Some("different.api.com".to_string()),
scope: None,
resource: None,
};
let scenario = manager
.determine_exchange_scenario(&context, &policy)
.unwrap();
assert_eq!(scenario, ExchangeScenario::AudienceRestriction);
}
#[tokio::test]
async fn test_invalid_grant_type() {
let manager = create_test_manager();
let policy = TokenExchangePolicy::default();
manager
.register_policy("test_client".to_string(), policy)
.await;
let mut request = create_test_request();
request.grant_type = "invalid_grant_type".to_string();
let result = manager.exchange_token(request, "test_client").await;
assert!(result.is_err());
}
#[test]
fn test_token_exchange_policy_builder() {
let policy = TokenExchangePolicy::builder()
.max_token_lifetime(Duration::minutes(30))
.require_actor_for_delegation(false)
.audience("api.example.com")
.audience("internal.example.com")
.scope_map("admin", vec!["read".into(), "write".into()])
.build();
assert_eq!(policy.max_token_lifetime, Duration::minutes(30));
assert!(!policy.require_actor_for_delegation);
assert_eq!(policy.allowed_audiences.len(), 2);
assert_eq!(policy.scope_mapping.get("admin").unwrap().len(), 2);
assert!(!policy.allowed_subject_token_types.is_empty());
}
#[test]
fn test_token_exchange_policy_jwt_only_preset() {
let policy = TokenExchangePolicy::jwt_only();
assert_eq!(policy.allowed_subject_token_types.len(), 2);
assert!(policy.allowed_subject_token_types.contains(&TokenType::Jwt));
assert!(policy.allowed_subject_token_types.contains(&TokenType::AccessToken));
assert!(policy.require_actor_for_delegation);
}
#[tokio::test]
async fn test_token_exchange_policy_builder_register() {
let manager = create_test_manager();
let policy = TokenExchangePolicy::builder()
.scenarios(vec![ExchangeScenario::OnBehalfOf, ExchangeScenario::AudienceRestriction])
.max_token_lifetime(Duration::minutes(15))
.build();
manager
.register_policy("test_client".to_string(), policy)
.await;
}
}