use std::convert::Infallible;
use std::rc::Rc;
use std::time::Duration;
use cache_lib::builder::{CacheBuilder, CacheBuilderInstance};
use cache_lib::Cache;
use pdk_core::classy::extract::context::ConfigureContext;
use pdk_core::classy::extract::{Extract, FromContext};
use pdk_core::classy::hl::{HttpClient, Service};
use pdk_core::classy::{Clock, TimeUnit};
use pdk_core::log::{debug, warn};
use pdk_core::policy_context::api::Metadata;
use thiserror::Error;
use crate::error::{IntrospectionError, ValidationError};
use crate::scopes_validator::ScopesValidator;
use crate::{ExpirableToken, FixedTimeFrame, Object, OneTimeUseToken, ParsedToken};
const DEFAULT_TIMEOUT_MS: u64 = 10000;
const DEFAULT_MAX_CACHE_ENTRIES: usize = 1000;
const TOKEN_FORM_PARAM: &str = "token";
const ACTIVE_FIELD: &str = "active";
const SCOPE_FIELD: &str = "scope";
#[derive(Debug, Clone)]
pub struct IntrospectionResult {
pub token: ParsedToken,
pub access_token: String,
}
impl IntrospectionResult {
pub fn properties(&self) -> &Object {
self.token.properties()
}
pub fn client_id(&self) -> Option<String> {
self.token.client_id()
}
pub fn username(&self) -> Option<String> {
self.token.username()
}
pub fn raw_token_context(&self) -> &str {
self.token.raw_token_context()
}
pub fn scopes(&self) -> &[String] {
self.token.scopes()
}
}
#[non_exhaustive]
#[derive(Error, Debug)]
pub enum TokenValidatorBuildError {
#[error("Service is required but not provided. Call with_service() before build().")]
MissingService,
}
pub struct TokenValidatorBuilder {
http_client: Rc<HttpClient>,
clock: Rc<dyn Clock>,
cache_builder: CacheBuilder,
prefix: String,
}
impl FromContext<ConfigureContext> for TokenValidatorBuilder {
type Error = Infallible;
fn from_context(context: &ConfigureContext) -> Result<Self, Self::Error> {
let http_client: HttpClient = context.extract()?;
let clock: Rc<dyn Clock> = context.extract()?;
let cache_builder: CacheBuilder = context.extract()?;
let metadata: Metadata = context.extract()?;
let prefix = format!(
"token-validator-{}-{}",
metadata.policy_metadata.policy_name, metadata.policy_metadata.policy_namespace
);
Ok(TokenValidatorBuilder {
http_client: Rc::new(http_client),
clock,
cache_builder,
prefix,
})
}
}
impl TokenValidatorBuilder {
#[allow(clippy::new_ret_no_self)]
pub fn new(&self, id: impl Into<String>) -> TokenValidatorBuilderInstance {
TokenValidatorBuilderInstance {
http_client: Rc::clone(&self.http_client),
clock: Rc::clone(&self.clock),
cache_builder: self
.cache_builder
.new(format!("{}-{}", self.prefix, id.into())),
config: TokenValidatorConfig::default(),
scopes_validator: None,
service: None,
max_cache_entries: DEFAULT_MAX_CACHE_ENTRIES,
}
}
}
pub struct TokenValidatorBuilderInstance {
http_client: Rc<HttpClient>,
clock: Rc<dyn Clock>,
cache_builder: CacheBuilderInstance,
config: TokenValidatorConfig,
scopes_validator: Option<ScopesValidator>,
service: Option<Service>,
max_cache_entries: usize,
}
impl TokenValidatorBuilderInstance {
pub fn with_path(mut self, path: impl Into<String>) -> Self {
self.config.path = path.into();
self
}
pub fn with_authorization_value(mut self, value: impl Into<String>) -> Self {
self.config.authorization_value = value.into();
self
}
pub fn with_expires_in_attribute(mut self, attr: impl Into<String>) -> Self {
self.config.expires_in_attribute = attr.into();
self
}
pub fn with_max_token_ttl(mut self, ttl: i64) -> Self {
self.config.max_token_ttl = ttl;
self
}
pub fn with_timeout_ms(mut self, timeout: u64) -> Self {
self.config.timeout_ms = timeout;
self
}
pub fn with_service(mut self, service: Service) -> Self {
self.service = Some(service);
self
}
pub fn with_scopes_validator(mut self, validator: ScopesValidator) -> Self {
self.scopes_validator = Some(validator);
self
}
pub fn with_max_cache_entries(mut self, max_entries: usize) -> Self {
self.max_cache_entries = max_entries;
self
}
pub fn build(self) -> Result<TokenValidator, TokenValidatorBuildError> {
let service = self
.service
.ok_or(TokenValidatorBuildError::MissingService)?;
let cache = self
.cache_builder
.max_entries(self.max_cache_entries)
.build();
Ok(TokenValidator {
config: self.config,
scopes_validator: self.scopes_validator,
http_client: self.http_client,
clock: self.clock,
cache: Box::new(cache),
service,
})
}
}
#[derive(Clone)]
pub struct TokenValidatorConfig {
pub path: String,
pub authorization_value: String,
pub expires_in_attribute: String,
pub max_token_ttl: i64,
pub timeout_ms: u64,
}
impl Default for TokenValidatorConfig {
fn default() -> Self {
Self {
path: "/".to_string(),
authorization_value: String::new(),
expires_in_attribute: "exp".to_string(),
max_token_ttl: -1,
timeout_ms: DEFAULT_TIMEOUT_MS,
}
}
}
pub struct TokenValidator {
config: TokenValidatorConfig,
scopes_validator: Option<ScopesValidator>,
http_client: Rc<HttpClient>,
clock: Rc<dyn Clock>,
cache: Box<dyn Cache>,
service: Service,
}
impl TokenValidator {
pub async fn validate(
&self,
access_token: &str,
) -> Result<IntrospectionResult, IntrospectionError> {
let current_time_ms = self.current_time_ms();
if let Some(result) = self.retrieve_cached_token(access_token, current_time_ms)? {
debug!("Token found in cache and valid");
return Ok(result);
}
debug!("Token not in cache, calling introspection endpoint");
let (status, body) = self.call_introspection(access_token).await?;
if status != 200 {
return Err(IntrospectionError::HttpError { status, body });
}
let parsed_token = self.parse_response(&body, current_time_ms)?;
self.validate_expiration(&parsed_token, current_time_ms)?;
self.validate_scopes(&parsed_token)?;
self.cache_token(access_token, &parsed_token);
Ok(IntrospectionResult {
token: parsed_token,
access_token: access_token.to_string(),
})
}
fn current_time_ms(&self) -> i64 {
self.clock.get_current_time_unit(TimeUnit::Milliseconds) as i64
}
fn retrieve_cached_token(
&self,
access_token: &str,
current_time_ms: i64,
) -> Result<Option<IntrospectionResult>, IntrospectionError> {
let cached_data = match self.cache.get(access_token) {
Some(data) => data,
None => return Ok(None),
};
let parsed_token = match ParsedToken::from_binary(cached_data) {
Ok(token) => token,
Err(e) => {
warn!("Failed to deserialize cached token: {e:?}");
return Ok(None);
}
};
if parsed_token.has_expired(current_time_ms) {
debug!("Cached token expired");
return Ok(None);
}
if let Err(e) = self.validate_scopes(&parsed_token) {
debug!("Cached token has invalid scopes: {e:?}");
self.cache.delete(access_token);
return Ok(None);
}
Ok(Some(IntrospectionResult {
token: parsed_token,
access_token: access_token.to_string(),
}))
}
fn cache_token(&self, access_token: &str, token: &ParsedToken) {
match token.to_binary() {
Ok(data) => {
if let Err(e) = self.cache.save(access_token, data) {
warn!("Failed to cache token: {e:?}");
}
}
Err(e) => {
warn!("Failed to serialize token for caching: {e:?}");
}
}
}
async fn call_introspection(&self, token: &str) -> Result<(u32, String), IntrospectionError> {
let body = serde_urlencoded::to_string([(TOKEN_FORM_PARAM, token)])
.unwrap_or_else(|_| format!("{TOKEN_FORM_PARAM}={token}"));
let headers = vec![
("Content-Type", "application/x-www-form-urlencoded"),
("Authorization", self.config.authorization_value.as_str()),
];
let timeout = Duration::from_millis(self.config.timeout_ms);
let response = self
.http_client
.request(&self.service)
.path(&self.config.path)
.headers(headers)
.body(body.as_bytes())
.timeout(timeout)
.post()
.await
.map_err(|e| IntrospectionError::RequestFailed(format!("{e:?}")))?;
let status = response.status_code();
let response_body = String::from_utf8_lossy(response.body()).to_string();
Ok((status, response_body))
}
fn parse_response(
&self,
body: &str,
current_time_ms: i64,
) -> Result<ParsedToken, IntrospectionError> {
let json: serde_json::Value = serde_json::from_str(body)
.map_err(|e| IntrospectionError::ParseError(e.to_string()))?;
let obj = json
.as_object()
.ok_or_else(|| IntrospectionError::ParseError("Response is not an object".to_string()))?
.clone();
let is_active = obj
.get(ACTIVE_FIELD)
.and_then(|v| v.as_bool())
.unwrap_or(true);
if !is_active {
return Err(IntrospectionError::Validation(
ValidationError::TokenRevoked,
));
}
let scopes = Self::extract_scopes(&obj);
if let Some(exp) = obj.get(&self.config.expires_in_attribute) {
if let Some(exp_secs) = exp.as_i64() {
let expiration_ms = self.calculate_expiration(current_time_ms, exp_secs);
return Ok(ParsedToken::ExpirableToken(ExpirableToken::new(
body.to_string(),
obj,
FixedTimeFrame::new(current_time_ms, expiration_ms),
scopes,
)));
}
}
Ok(ParsedToken::OneTimeUseToken(OneTimeUseToken::new(
body.to_string(),
obj,
scopes,
)))
}
fn extract_scopes(obj: &Object) -> Vec<String> {
match obj.get(SCOPE_FIELD) {
Some(value) => {
if let Some(scope_str) = value.as_str() {
if scope_str.is_empty() {
vec![]
} else {
scope_str.split_whitespace().map(String::from).collect()
}
} else if let Some(scope_arr) = value.as_array() {
scope_arr
.iter()
.filter_map(|v| v.as_str())
.flat_map(|s| s.split_whitespace().map(String::from))
.collect()
} else {
vec![]
}
}
None => vec![],
}
}
fn calculate_expiration(&self, start_time_ms: i64, exp_timestamp_secs: i64) -> i64 {
let exp_timestamp_ms = exp_timestamp_secs * 1000;
if exp_timestamp_ms <= start_time_ms {
return 0;
}
let expiration_ms = exp_timestamp_ms - start_time_ms;
if self.config.max_token_ttl < 0 || self.config.max_token_ttl * 1000 > expiration_ms {
expiration_ms
} else {
self.config.max_token_ttl * 1000
}
}
fn validate_expiration(
&self,
token: &ParsedToken,
current_time_ms: i64,
) -> Result<(), IntrospectionError> {
if let ParsedToken::ExpirableToken(_) = token {
if token.has_expired(current_time_ms) {
return Err(IntrospectionError::Validation(
ValidationError::TokenExpired,
));
}
}
Ok(())
}
fn validate_scopes(&self, token: &ParsedToken) -> Result<(), IntrospectionError> {
if let Some(validator) = &self.scopes_validator {
if !validator.valid_scopes(token.scopes()) {
return Err(IntrospectionError::Validation(
ValidationError::InvalidScopes,
));
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_config() -> TokenValidatorConfig {
TokenValidatorConfig::default()
}
#[test]
fn config_has_correct_defaults() {
let config = create_config();
assert_eq!(config.path, "/");
assert_eq!(config.authorization_value, "");
assert_eq!(config.expires_in_attribute, "exp");
assert_eq!(config.max_token_ttl, -1);
assert_eq!(config.timeout_ms, 10000);
}
#[test]
fn builder_instance_sets_path() {
let config = TokenValidatorConfig {
path: "/introspect".to_string(),
..Default::default()
};
assert_eq!(config.path, "/introspect");
}
#[test]
fn build_error_displays_correctly() {
let err = TokenValidatorBuildError::MissingService;
assert!(err.to_string().contains("Service is required"));
}
#[test]
fn extract_scopes_from_string() {
let mut obj = Object::new();
obj.insert("scope".to_string(), serde_json::json!("read write admin"));
let scopes = TokenValidator::extract_scopes(&obj);
assert_eq!(scopes, vec!["read", "write", "admin"]);
}
#[test]
fn extract_scopes_from_array() {
let mut obj = Object::new();
obj.insert("scope".to_string(), serde_json::json!(["read", "write"]));
let scopes = TokenValidator::extract_scopes(&obj);
assert_eq!(scopes, vec!["read", "write"]);
}
#[test]
fn extract_scopes_empty_when_missing() {
let obj = Object::new();
let scopes = TokenValidator::extract_scopes(&obj);
assert!(scopes.is_empty());
}
#[test]
fn extract_scopes_handles_empty_string() {
let mut obj = Object::new();
obj.insert("scope".to_string(), serde_json::json!(""));
let scopes = TokenValidator::extract_scopes(&obj);
assert!(scopes.is_empty());
}
}