use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use crate::common::error::JwtError;
pub trait Headers {
fn get(&self, name: &str) -> Option<&str>;
}
impl Headers for HashMap<String, String> {
fn get(&self, name: &str) -> Option<&str> {
self.get(name).map(|s| s.as_str())
}
}
pub trait TokenExtractor: Send + Sync {
fn extract_token(&self, headers: &dyn Headers) -> Result<Option<String>, JwtError>;
fn description(&self) -> &str;
}
#[derive(Clone)]
pub struct HeaderTokenExtractor {
header_name: String,
token_prefix: Option<String>,
description: String,
}
impl fmt::Debug for HeaderTokenExtractor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HeaderTokenExtractor")
.field("header_name", &self.header_name)
.field("token_prefix", &self.token_prefix)
.finish()
}
}
impl HeaderTokenExtractor {
pub fn new(header_name: &str, token_prefix: Option<&str>) -> Self {
let prefix_str = token_prefix.unwrap_or("none");
let description = format!("HeaderTokenExtractor({}, {})", header_name, prefix_str);
Self {
header_name: header_name.to_string(),
token_prefix: token_prefix.map(|s| s.to_string()),
description,
}
}
pub fn bearer(header_name: &str) -> Self {
Self::new(header_name, Some("Bearer "))
}
}
impl TokenExtractor for HeaderTokenExtractor {
fn extract_token(&self, headers: &dyn Headers) -> Result<Option<String>, JwtError> {
let header_str = match headers.get(&self.header_name) {
Some(value) => value,
None => return Ok(None), };
if let Some(prefix) = &self.token_prefix {
if !header_str.starts_with(prefix) {
return Err(JwtError::InvalidToken(format!(
"{} header must start with '{}'",
self.header_name, prefix
)));
}
Ok(Some(header_str[prefix.len()..].to_string()))
} else {
Ok(Some(header_str.to_string()))
}
}
fn description(&self) -> &str {
&self.description
}
}
#[derive(Clone)]
pub struct DebugTokenExtractor(pub Arc<dyn TokenExtractor>);
impl fmt::Debug for DebugTokenExtractor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "TokenExtractor({})", self.0.description())
}
}
#[derive(Clone)]
pub struct ChainedTokenExtractor {
extractors: Vec<DebugTokenExtractor>,
description: String,
}
impl fmt::Debug for ChainedTokenExtractor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ChainedTokenExtractor")
.field("extractors", &self.extractors)
.finish()
}
}
impl ChainedTokenExtractor {
pub fn new(extractors: Vec<Arc<dyn TokenExtractor>>) -> Self {
let debug_extractors = extractors
.into_iter()
.map(DebugTokenExtractor)
.collect::<Vec<_>>();
let description = format!("ChainedTokenExtractor({})", debug_extractors.len());
Self {
extractors: debug_extractors,
description,
}
}
pub fn add_extractor(mut self, extractor: Arc<dyn TokenExtractor>) -> Self {
self.extractors.push(DebugTokenExtractor(extractor));
self.description = format!("ChainedTokenExtractor({})", self.extractors.len());
self
}
}
impl TokenExtractor for ChainedTokenExtractor {
fn extract_token(&self, headers: &dyn Headers) -> Result<Option<String>, JwtError> {
for extractor in &self.extractors {
match extractor.0.extract_token(headers)? {
Some(token) => return Ok(Some(token)),
None => continue,
}
}
Ok(None)
}
fn description(&self) -> &str {
&self.description
}
}
#[derive(Clone)]
pub struct TokenExtractorConfig {
pub access_token_extractor: DebugTokenExtractor,
pub id_token_extractor: DebugTokenExtractor,
pub require_auth: bool,
}
impl fmt::Debug for TokenExtractorConfig {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TokenExtractorConfig")
.field("access_token_extractor", &self.access_token_extractor)
.field("id_token_extractor", &self.id_token_extractor)
.field("require_auth", &self.require_auth)
.finish()
}
}
impl Default for TokenExtractorConfig {
fn default() -> Self {
Self {
access_token_extractor: DebugTokenExtractor(Arc::new(HeaderTokenExtractor::bearer("Authorization"))),
id_token_extractor: DebugTokenExtractor(Arc::new(HeaderTokenExtractor::bearer("Authorization"))),
require_auth: true,
}
}
}
impl TokenExtractorConfig {
pub fn new(
access_token_extractor: Arc<dyn TokenExtractor>,
id_token_extractor: Arc<dyn TokenExtractor>,
require_auth: bool,
) -> Self {
Self {
access_token_extractor: DebugTokenExtractor(access_token_extractor),
id_token_extractor: DebugTokenExtractor(id_token_extractor),
require_auth,
}
}
pub fn with_access_token_extractor(
mut self,
extractor: Arc<dyn TokenExtractor>,
) -> Self {
self.access_token_extractor = DebugTokenExtractor(extractor);
self
}
pub fn with_id_token_extractor(
mut self,
extractor: Arc<dyn TokenExtractor>,
) -> Self {
self.id_token_extractor = DebugTokenExtractor(extractor);
self
}
pub fn with_require_auth(mut self, require_auth: bool) -> Self {
self.require_auth = require_auth;
self
}
pub fn access_token_extractor(&self) -> &dyn TokenExtractor {
&*self.access_token_extractor.0
}
pub fn id_token_extractor(&self) -> &dyn TokenExtractor {
&*self.id_token_extractor.0
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_header_token_extractor_with_prefix() {
let extractor = HeaderTokenExtractor::bearer("Authorization");
let mut headers = HashMap::new();
headers.insert("Authorization".to_string(), "Bearer test-token".to_string());
let result = extractor.extract_token(&headers).unwrap();
assert_eq!(result, Some("test-token".to_string()));
headers.insert("Authorization".to_string(), "Basic test-token".to_string());
let result = extractor.extract_token(&headers);
assert!(result.is_err());
let empty_headers: HashMap<String, String> = HashMap::new();
let result = extractor.extract_token(&empty_headers).unwrap();
assert_eq!(result, None);
}
#[test]
fn test_header_token_extractor_without_prefix() {
let extractor = HeaderTokenExtractor::new("X-Token", None);
let mut headers = HashMap::new();
headers.insert("X-Token".to_string(), "test-token".to_string());
let result = extractor.extract_token(&headers).unwrap();
assert_eq!(result, Some("test-token".to_string()));
let empty_headers: HashMap<String, String> = HashMap::new();
let result = extractor.extract_token(&empty_headers).unwrap();
assert_eq!(result, None);
}
#[test]
fn test_chained_token_extractor() {
let extractor1: Arc<dyn TokenExtractor> =
Arc::new(HeaderTokenExtractor::bearer("Authorization"));
let extractor2: Arc<dyn TokenExtractor> =
Arc::new(HeaderTokenExtractor::new("X-Token", None));
let chained = ChainedTokenExtractor::new(vec![extractor1, extractor2]);
let mut headers = HashMap::new();
headers.insert("Authorization".to_string(), "Bearer test-token-1".to_string());
let result = chained.extract_token(&headers).unwrap();
assert_eq!(result, Some("test-token-1".to_string()));
let mut headers = HashMap::new();
headers.insert("X-Token".to_string(), "test-token-2".to_string());
let result = chained.extract_token(&headers).unwrap();
assert_eq!(result, Some("test-token-2".to_string()));
let empty_headers: HashMap<String, String> = HashMap::new();
let result = chained.extract_token(&empty_headers).unwrap();
assert_eq!(result, None);
}
}