use std::sync::Arc;
use reqwest::header::AUTHORIZATION;
use serde::Deserialize;
use url::Url;
use base64::Engine as _;
use base64::engine::general_purpose::STANDARD as B64_STANDARD;
use crate::delegated::error::DelegatedError;
use axess_factors::ZeroizedString;
pub mod token_types {
pub const ACCESS_TOKEN: &str = "urn:ietf:params:oauth:token-type:access_token";
pub const JWT: &str = "urn:ietf:params:oauth:token-type:jwt";
pub const SAML2: &str = "urn:ietf:params:oauth:token-type:saml2";
pub const REFRESH_TOKEN: &str = "urn:ietf:params:oauth:token-type:refresh_token";
pub const ID_TOKEN: &str = "urn:ietf:params:oauth:token-type:id_token";
}
const GRANT_TYPE_TOKEN_EXCHANGE: &str = "urn:ietf:params:oauth:grant-type:token-exchange";
#[derive(Debug, Clone)]
pub struct TokenExchangeRequest {
pub subject_token: String,
pub subject_token_type: String,
pub actor_token: Option<String>,
pub actor_token_type: Option<String>,
pub audience: Option<String>,
pub resource: Option<Url>,
pub scopes: Vec<String>,
pub requested_token_type: Option<String>,
}
impl TokenExchangeRequest {
pub fn new(subject_token: impl Into<String>, subject_token_type: impl Into<String>) -> Self {
Self {
subject_token: subject_token.into(),
subject_token_type: subject_token_type.into(),
actor_token: None,
actor_token_type: None,
audience: None,
resource: None,
scopes: Vec::new(),
requested_token_type: None,
}
}
pub fn with_actor_token(
mut self,
actor_token: impl Into<String>,
actor_token_type: impl Into<String>,
) -> Self {
self.actor_token = Some(actor_token.into());
self.actor_token_type = Some(actor_token_type.into());
self
}
pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
self.audience = Some(audience.into());
self
}
pub fn with_resource(mut self, resource: Url) -> Self {
self.resource = Some(resource);
self
}
pub fn with_scopes<I, S>(mut self, scopes: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.scopes = scopes.into_iter().map(Into::into).collect();
self
}
pub fn with_requested_token_type(mut self, t: impl Into<String>) -> Self {
self.requested_token_type = Some(t.into());
self
}
}
#[derive(Debug)]
pub struct TokenExchangeResponse {
pub access_token: ZeroizedString,
pub issued_token_type: String,
pub token_type: String,
pub expires_in: Option<u64>,
pub scopes: Vec<String>,
pub refresh_token: Option<ZeroizedString>,
}
#[derive(Debug, Clone)]
pub struct TokenExchangeClient {
token_endpoint: Url,
client_id: String,
client_secret: Option<Arc<ZeroizedString>>,
http: reqwest::Client,
}
impl TokenExchangeClient {
pub fn new(
token_endpoint: Url,
client_id: impl Into<String>,
client_secret: Option<ZeroizedString>,
) -> Self {
Self {
token_endpoint,
client_id: client_id.into(),
client_secret: client_secret.map(Arc::new),
http: reqwest::Client::new(),
}
}
pub fn with_http_client(mut self, http: reqwest::Client) -> Self {
self.http = http;
self
}
pub async fn exchange(
&self,
request: &TokenExchangeRequest,
) -> Result<TokenExchangeResponse, DelegatedError> {
let mut form: Vec<(&str, String)> = vec![
("grant_type", GRANT_TYPE_TOKEN_EXCHANGE.to_string()),
("subject_token", request.subject_token.clone()),
("subject_token_type", request.subject_token_type.clone()),
("client_id", self.client_id.clone()),
];
if let Some(actor_token) = &request.actor_token {
form.push(("actor_token", actor_token.clone()));
}
if let Some(actor_token_type) = &request.actor_token_type {
form.push(("actor_token_type", actor_token_type.clone()));
}
if let Some(audience) = &request.audience {
form.push(("audience", audience.clone()));
}
if let Some(resource) = &request.resource {
form.push(("resource", resource.to_string()));
}
if !request.scopes.is_empty() {
form.push(("scope", request.scopes.join(" ")));
}
if let Some(t) = &request.requested_token_type {
form.push(("requested_token_type", t.clone()));
}
let mut req = self.http.post(self.token_endpoint.clone());
if let Some(secret) = &self.client_secret {
let creds = format!("{}:{}", self.client_id, &***secret);
let encoded = B64_STANDARD.encode(creds);
req = req.header(AUTHORIZATION, format!("Basic {encoded}"));
}
let response = req
.form(&form)
.send()
.await
.map_err(|e| DelegatedError::Transport(e.to_string()))?;
let status = response.status();
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
return Err(DelegatedError::TokenEndpoint {
status: status.as_u16(),
body,
});
}
let parsed: TokenExchangeResponseBody = response
.json()
.await
.map_err(|e| DelegatedError::MalformedResponse(e.to_string()))?;
if parsed.access_token.is_empty() {
return Err(DelegatedError::MalformedResponse(
"access_token field is empty".into(),
));
}
let scopes: Vec<String> = parsed
.scope
.map(|s| s.split_whitespace().map(str::to_string).collect())
.unwrap_or_default();
Ok(TokenExchangeResponse {
access_token: ZeroizedString::from(parsed.access_token),
issued_token_type: parsed
.issued_token_type
.unwrap_or_else(|| token_types::ACCESS_TOKEN.to_string()),
token_type: parsed.token_type.unwrap_or_else(|| "Bearer".to_string()),
expires_in: parsed.expires_in,
scopes,
refresh_token: parsed.refresh_token.map(ZeroizedString::from),
})
}
}
#[derive(Debug, Deserialize)]
struct TokenExchangeResponseBody {
access_token: String,
#[serde(default)]
issued_token_type: Option<String>,
#[serde(default)]
token_type: Option<String>,
#[serde(default)]
expires_in: Option<u64>,
#[serde(default)]
scope: Option<String>,
#[serde(default)]
refresh_token: Option<String>,
}
#[cfg(test)]
mod tests;