use std::time::{Duration, SystemTime};
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct TokenRefreshConfig {
pub expiration_refresh_ratio: f64,
pub retry_config: RetryConfig,
}
impl TokenRefreshConfig {
pub fn set_expiration_refresh_ratio(mut self, ratio: f64) -> Self {
self.expiration_refresh_ratio = ratio;
self
}
pub fn set_retry_config(mut self, retry_config: RetryConfig) -> Self {
self.retry_config = retry_config;
self
}
}
impl Default for TokenRefreshConfig {
fn default() -> Self {
Self {
expiration_refresh_ratio: 0.8,
retry_config: RetryConfig::default(),
}
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct RetryConfig {
pub(crate) exponent_base: f32,
pub(crate) min_delay: Duration,
pub(crate) max_delay: Option<Duration>,
pub(crate) number_of_retries: usize,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
number_of_retries: 3,
min_delay: Duration::from_millis(100),
max_delay: Some(Duration::from_secs(30)),
exponent_base: 2.0,
}
}
}
impl RetryConfig {
pub fn set_number_of_retries(mut self, number_of_retries: usize) -> Self {
self.number_of_retries = number_of_retries;
self
}
pub fn set_min_delay(mut self, min_delay: Duration) -> Self {
self.min_delay = min_delay;
self
}
pub fn set_max_delay(mut self, max_delay: Duration) -> Self {
self.max_delay = Some(max_delay);
self
}
pub fn set_exponent_base(mut self, exponent_base: f32) -> Self {
self.exponent_base = exponent_base;
self
}
}
pub(crate) mod credentials_management_utils {
use super::*;
#[allow(dead_code)] pub(crate) fn calculate_refresh_threshold(
received_at: SystemTime,
expires_at: SystemTime,
refresh_ratio: f64,
) -> Option<Duration> {
if let Ok(total_lifetime) = expires_at.duration_since(received_at) {
Some(Duration::from_secs_f64(
total_lifetime.as_secs_f64() * refresh_ratio,
))
} else {
None
}
}
#[cfg(all(feature = "token-based-authentication", feature = "entra-id"))]
pub(crate) fn extract_oid_from_jwt(jwt: &str) -> Result<String, String> {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
let parts: Vec<&str> = jwt.split('.').collect();
if parts.len() != 3 {
return Err("Invalid JWT: must have 3 parts".to_string());
}
let payload_bytes = URL_SAFE_NO_PAD
.decode(parts[1])
.map_err(|e| format!("Failed to decode payload: {e}"))?;
let payload_str = String::from_utf8(payload_bytes)
.map_err(|e| format!("Payload is not valid UTF-8: {e}"))?;
if let Some(oid_claim_start_idx) = payload_str.find("\"oid\"") {
if let Some(colon_idx) = payload_str[oid_claim_start_idx..].find(':') {
let oid_value_str = payload_str[oid_claim_start_idx + colon_idx + 1..].trim_start();
if let Some(stripped_oid_value) = oid_value_str.strip_prefix('"') {
if let Some(end_quote) = stripped_oid_value.find('"') {
return Ok(stripped_oid_value[..end_quote].to_string());
}
}
}
}
Err("OID claim not found".to_string())
}
}
#[cfg(all(feature = "token-based-authentication", test))]
mod auth_management_tests {
use super::{TokenRefreshConfig, credentials_management_utils};
use std::sync::LazyLock;
const TOKEN_HEADER: &str = "header";
const TOKEN_PAYLOAD: &str = "eyJvaWQiOiIxMjM0NTY3OC05YWJjLWRlZi0xMjM0LTU2Nzg5YWJjZGVmMCJ9"; const TOKEN_PAYLOAD_NO_OID: &str =
"eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNzM1Njg5NjAwfQ"; const TOKEN_SIGNATURE: &str = "signature";
const OID_CLAIM_VALUE: &str = "12345678-9abc-def-1234-56789abcdef0";
static TOKEN: LazyLock<String> =
LazyLock::new(|| format!("{TOKEN_HEADER}.{TOKEN_PAYLOAD}.{TOKEN_SIGNATURE}"));
static TOKEN_WITH_NO_OID: LazyLock<String> =
LazyLock::new(|| format!("{TOKEN_HEADER}.{TOKEN_PAYLOAD_NO_OID}.{TOKEN_SIGNATURE}"));
static INVALID_TOKEN: LazyLock<String> =
LazyLock::new(|| format!("{TOKEN_HEADER}.{TOKEN_PAYLOAD}"));
#[test]
fn test_token_refresh_config() {
let config = TokenRefreshConfig::default();
assert_eq!(config.expiration_refresh_ratio, 0.8);
let custom_config = TokenRefreshConfig::default().set_expiration_refresh_ratio(0.9);
assert_eq!(custom_config.expiration_refresh_ratio, 0.9);
}
#[test]
fn test_refresh_threshold_calculation() {
use std::time::{Duration, SystemTime};
let config = TokenRefreshConfig::default();
let received_at = SystemTime::now();
let expires_at = received_at + Duration::from_secs(3600);
let threshold = credentials_management_utils::calculate_refresh_threshold(
received_at,
expires_at,
config.expiration_refresh_ratio,
);
assert!(threshold.is_some());
assert_eq!(threshold.unwrap(), Duration::from_secs(2880)); }
#[cfg(all(feature = "token-based-authentication", feature = "entra-id"))]
#[test]
fn test_extract_oid_from_jwt() {
let result = credentials_management_utils::extract_oid_from_jwt(TOKEN.as_str());
assert!(result.is_ok());
assert_eq!(result.unwrap(), OID_CLAIM_VALUE);
}
#[cfg(all(feature = "token-based-authentication", feature = "entra-id"))]
#[test]
fn test_extract_oid_from_jwt_with_invalid_token() {
let result = credentials_management_utils::extract_oid_from_jwt(INVALID_TOKEN.as_str());
assert!(result.is_err());
assert_eq!(result.err().unwrap(), "Invalid JWT: must have 3 parts");
}
#[cfg(all(feature = "token-based-authentication", feature = "entra-id"))]
#[test]
fn test_extract_oid_from_jwt_with_no_oid_claim() {
let result = credentials_management_utils::extract_oid_from_jwt(TOKEN_WITH_NO_OID.as_str());
assert!(result.is_err());
assert_eq!(result.err().unwrap(), "OID claim not found");
}
}