axum_keycloak_auth/
extract.rs1use std::{borrow::Cow, sync::Arc};
2
3use axum::extract::Request;
4use nonempty::NonEmpty;
5
6use crate::error::AuthError;
7
8pub type ExtractedToken<'a> = Cow<'a, str>;
11
12pub trait TokenExtractor: Send + Sync + std::fmt::Debug {
20 fn extract<'a>(&self, request: &'a Request) -> Result<ExtractedToken<'a>, AuthError>;
21}
22
23#[derive(Debug, Clone, Default)]
25pub struct AuthHeaderTokenExtractor {}
26
27impl TokenExtractor for AuthHeaderTokenExtractor {
28 fn extract<'a>(&self, request: &'a Request) -> Result<ExtractedToken<'a>, AuthError> {
29 request
30 .headers()
31 .get(http::header::AUTHORIZATION)
32 .ok_or(AuthError::MissingAuthorizationHeader)?
33 .to_str()
34 .map_err(|err| AuthError::InvalidAuthorizationHeader {
35 reason: err.to_string(),
36 })?
37 .strip_prefix("Bearer ")
38 .ok_or(AuthError::MissingBearerToken)
39 .map(Cow::Borrowed)
40 }
41}
42
43#[derive(Debug, Clone)]
49pub struct QueryParamTokenExtractor {
50 pub key: String,
51}
52
53impl QueryParamTokenExtractor {
54 pub fn extracting_key(key: impl Into<String>) -> Self {
55 Self { key: key.into() }
56 }
57}
58
59impl Default for QueryParamTokenExtractor {
60 fn default() -> Self {
61 Self::extracting_key("token")
62 }
63}
64
65impl TokenExtractor for QueryParamTokenExtractor {
66 fn extract<'a>(&self, request: &'a Request) -> Result<ExtractedToken<'a>, AuthError> {
67 let query = request.uri().query().ok_or(AuthError::MissingQueryParams)?;
68
69 let mut tokens = serde_querystring::DuplicateQS::parse(query.as_bytes())
70 .values(self.key.as_bytes())
71 .unwrap_or_default()
72 .into_iter();
73
74 let first_token = tokens
75 .next()
76 .ok_or(AuthError::MissingTokenQueryParam)?
77 .ok_or(AuthError::EmptyTokenQueryParam)?;
78
79 let first_token = std::str::from_utf8(first_token.as_ref()).expect("Valid UTF-8");
80
81 Ok(ExtractedToken::Owned(first_token.to_owned()))
82 }
83}
84
85pub(crate) fn extract_jwt<'a>(
86 request: &'a Request<axum::body::Body>,
87 extractors: &NonEmpty<Arc<dyn TokenExtractor>>,
88) -> Option<ExtractedToken<'a>> {
89 for extractor in extractors {
90 match extractor.extract(request) {
91 Ok(jwt) => return Some(jwt),
92 Err(err) => {
93 tracing::debug!(?extractor, ?err, "Extractor failed");
94 }
95 }
96 }
97 None
98}