1#[cfg(feature = "oauth2")]
57use std::collections::HashMap;
58use std::future::Future;
59use std::pin::Pin;
60use std::sync::Arc;
61use std::task::{Context, Poll};
62#[cfg(feature = "oauth2")]
63use std::time::Duration;
64
65use axum::extract::FromRequestParts;
66use axum::response::{IntoResponse, Response};
67use http::StatusCode;
68use http::request::Parts;
69#[cfg(feature = "oauth2")]
70use jsonwebtoken::jwk::JwkSet;
71#[cfg(feature = "oauth2")]
72use serde::Deserialize;
73#[cfg(feature = "oauth2")]
74use url::Url;
75
76const DEFAULT_BCRYPT_COST: u32 = 12;
80
81pub async fn hash_password(password: &str) -> crate::AutumnResult<String> {
100 let password = password.to_string();
101 tokio::task::spawn_blocking(move || {
102 bcrypt::hash(password, DEFAULT_BCRYPT_COST)
103 .map_err(|e| crate::AutumnError::from(std::io::Error::other(e.to_string())))
104 })
105 .await
106 .map_err(|e| crate::AutumnError::from(std::io::Error::other(e.to_string())))?
107}
108
109pub async fn verify_password(password: &str, hash: &str) -> crate::AutumnResult<bool> {
129 let password = password.to_string();
130
131 let is_valid_format = hash.len() == 60 && hash.starts_with('$');
134
135 let hash_to_verify = if is_valid_format {
136 hash.to_string()
137 } else {
138 "$2b$12$KIXe8K4j1sH6/xH.x9d71uJ5Jk8t6O4m6Q110g4H8y1r6J6O6O6O6".to_string()
140 };
141
142 let result = tokio::task::spawn_blocking(move || bcrypt::verify(&password, &hash_to_verify))
143 .await
144 .map_err(|e| crate::AutumnError::from(std::io::Error::other(e.to_string())))?;
145
146 if !is_valid_format {
147 return Ok(false);
148 }
149
150 result.map_err(|e| crate::AutumnError::from(std::io::Error::other(e.to_string())))
151}
152
153#[doc(hidden)]
166pub async fn __check_secured(
167 session: &crate::session::Session,
168 roles: &[&str],
169) -> crate::AutumnResult<()> {
170 __check_secured_with_key(session, "user_id", roles).await
171}
172
173#[doc(hidden)]
179pub async fn __check_secured_with_key(
180 session: &crate::session::Session,
181 auth_session_key: &str,
182 roles: &[&str],
183) -> crate::AutumnResult<()> {
184 if session.get(auth_session_key).await.is_none() {
186 return Err(crate::AutumnError::unauthorized_msg(
187 "authentication required",
188 ));
189 }
190
191 if !roles.is_empty() {
194 let user_role = session.get("role").await.unwrap_or_default();
195 if !roles.iter().any(|&r| r == user_role) {
196 return Err(crate::AutumnError::forbidden_msg(
197 "insufficient permissions",
198 ));
199 }
200 }
201
202 Ok(())
203}
204
205pub struct Auth<T>(pub T);
234
235impl<T, S> FromRequestParts<S> for Auth<T>
236where
237 T: Clone + Send + Sync + 'static,
238 S: Send + Sync,
239{
240 type Rejection = AuthRejection;
241
242 fn from_request_parts(
243 parts: &mut Parts,
244 _state: &S,
245 ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
246 let user = parts.extensions.get::<T>().cloned();
247 async move { user.map_or_else(|| Err(AuthRejection), |user| Ok(Self(user))) }
248 }
249}
250
251#[derive(Debug)]
253pub struct AuthRejection;
254
255impl IntoResponse for AuthRejection {
256 fn into_response(self) -> Response {
257 crate::AutumnError::unauthorized_msg("authentication required").into_response()
258 }
259}
260
261impl std::fmt::Display for AuthRejection {
262 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
263 f.write_str("authentication required")
264 }
265}
266
267#[derive(Clone)]
288pub struct RequireAuth {
289 session_key: Arc<str>,
290}
291
292impl RequireAuth {
293 pub fn new(session_key: impl Into<String>) -> Self {
295 Self {
296 session_key: Arc::from(session_key.into()),
297 }
298 }
299}
300
301impl<S> tower::Layer<S> for RequireAuth {
302 type Service = RequireAuthService<S>;
303
304 fn layer(&self, inner: S) -> Self::Service {
305 RequireAuthService {
306 inner,
307 session_key: Arc::clone(&self.session_key),
308 }
309 }
310}
311
312#[derive(Clone)]
314pub struct RequireAuthService<S> {
315 inner: S,
316 session_key: Arc<str>,
317}
318
319impl<S, ResBody> tower::Service<axum::extract::Request> for RequireAuthService<S>
320where
321 S: tower::Service<axum::extract::Request, Response = Response<ResBody>>
322 + Clone
323 + Send
324 + 'static,
325 S::Future: Send + 'static,
326 S::Error: Send + 'static,
327 ResBody: From<String> + Default + Send + 'static,
328{
329 type Response = Response<ResBody>;
330 type Error = S::Error;
331 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
332
333 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
334 self.inner.poll_ready(cx)
335 }
336
337 fn call(&mut self, req: axum::extract::Request) -> Self::Future {
338 let session_key = Arc::clone(&self.session_key);
339 let mut inner = self.inner.clone();
340 std::mem::swap(&mut self.inner, &mut inner);
341
342 Box::pin(async move {
343 let session = req.extensions().get::<crate::session::Session>().cloned();
345
346 let is_authenticated = if let Some(ref session) = session {
347 session.contains_key(&session_key).await
348 } else {
349 false
350 };
351
352 if is_authenticated {
353 inner.call(req).await
354 } else {
355 let body = crate::error::problem_details_json_string(
356 StatusCode::UNAUTHORIZED,
357 "authentication required",
358 None,
359 None,
360 req.extensions()
361 .get::<crate::middleware::RequestId>()
362 .map(std::string::ToString::to_string),
363 Some(req.uri().path().to_owned()),
364 true,
365 );
366 let response = Response::builder()
367 .status(StatusCode::UNAUTHORIZED)
368 .header(http::header::CONTENT_TYPE, "application/problem+json")
369 .body(ResBody::from(body))
370 .unwrap_or_default();
371 Ok(response)
372 }
373 })
374 }
375}
376
377#[derive(Debug, Clone, serde::Deserialize)]
388pub struct AuthConfig {
389 #[serde(default = "default_bcrypt_cost")]
391 pub bcrypt_cost: u32,
392
393 #[serde(default = "default_session_key")]
395 pub session_key: String,
396
397 #[cfg(feature = "oauth2")]
400 #[serde(default)]
401 pub oauth2: OAuth2Config,
402}
403
404const fn default_bcrypt_cost() -> u32 {
405 DEFAULT_BCRYPT_COST
406}
407
408fn default_session_key() -> String {
409 "user_id".to_owned()
410}
411
412#[cfg(feature = "oauth2")]
413const fn default_provider_scope() -> String {
414 String::new()
415}
416
417#[cfg(feature = "oauth2")]
418const OAUTH_HTTP_TIMEOUT_SECS: u64 = 15;
419
420#[cfg(feature = "oauth2")]
421#[derive(Debug, Clone, Default, serde::Deserialize)]
436pub struct OAuth2Config {
437 #[serde(flatten)]
439 pub providers: HashMap<String, OAuth2ProviderConfig>,
440}
441
442#[cfg(feature = "oauth2")]
443#[derive(Debug, Clone, serde::Deserialize)]
445pub struct OAuth2ProviderConfig {
446 pub client_id: String,
448 pub client_secret: String,
450 pub authorize_url: String,
452 pub token_url: String,
454 #[serde(default)]
456 pub userinfo_url: Option<String>,
457 pub redirect_uri: String,
459 #[serde(default = "default_provider_scope")]
461 pub scope: String,
462 #[serde(default)]
464 pub issuer: Option<String>,
465 #[serde(default)]
467 pub jwks_url: Option<String>,
468}
469
470#[cfg(feature = "oauth2")]
471#[derive(Debug, Clone, Deserialize)]
473pub struct OAuth2Callback {
474 pub code: String,
476 pub state: String,
478}
479
480#[cfg(feature = "oauth2")]
481#[derive(Debug, Clone)]
483pub struct OidcIdentity {
484 pub subject: String,
486 pub email: Option<String>,
488 pub name: Option<String>,
490 pub preferred_username: Option<String>,
492 pub raw_claims: serde_json::Value,
494}
495
496#[cfg(feature = "oauth2")]
497#[derive(Debug, Deserialize)]
498struct OAuth2TokenResponse {
499 access_token: String,
500 #[allow(dead_code)]
501 token_type: Option<String>,
502 id_token: Option<String>,
503}
504
505#[cfg(feature = "oauth2")]
506pub async fn oauth2_authorize_url(
512 session: &crate::session::Session,
513 provider_name: &str,
514 provider: &OAuth2ProviderConfig,
515) -> crate::AutumnResult<String> {
516 let state = uuid::Uuid::new_v4().to_string();
517 let nonce = uuid::Uuid::new_v4().to_string();
518 session
519 .insert(format!("oauth2:{provider_name}:state"), state.clone())
520 .await;
521 session
522 .insert(format!("oauth2:{provider_name}:nonce"), nonce.clone())
523 .await;
524
525 let mut url = Url::parse(&provider.authorize_url)
526 .map_err(|e| crate::AutumnError::bad_request_msg(format!("invalid authorize_url: {e}")))?;
527 {
528 let mut q = url.query_pairs_mut();
529 q.append_pair("response_type", "code");
530 q.append_pair("client_id", &provider.client_id);
531 q.append_pair("redirect_uri", &provider.redirect_uri);
532 if !provider.scope.trim().is_empty() {
533 q.append_pair("scope", &provider.scope);
534 }
535 q.append_pair("state", &state);
536 q.append_pair("nonce", &nonce);
537 }
538 Ok(url.into())
539}
540
541#[cfg(feature = "oauth2")]
542pub async fn oauth2_finish_login(
553 session: &crate::session::Session,
554 session_key: &str,
555 provider_name: &str,
556 provider: &OAuth2ProviderConfig,
557 callback: &OAuth2Callback,
558) -> crate::AutumnResult<OidcIdentity> {
559 validate_callback_state(session, provider_name, callback).await?;
560 let token = exchange_oauth2_token(provider, callback).await?;
561 let (claims, source) = load_identity_claims(provider, &token).await?;
562 validate_oidc_nonce(session, provider_name, &claims, source).await?;
563 let subject = extract_subject(&claims, source)?;
564 finalize_oauth2_session(session, session_key, provider_name, subject, claims).await
565}
566
567#[cfg(feature = "oauth2")]
568async fn validate_callback_state(
569 session: &crate::session::Session,
570 provider_name: &str,
571 callback: &OAuth2Callback,
572) -> crate::AutumnResult<()> {
573 let state_key = format!("oauth2:{provider_name}:state");
574 let expected_state = session.get(&state_key).await.ok_or_else(|| {
578 crate::AutumnError::unauthorized_msg("oauth2 state missing; restart login")
579 })?;
580 if subtle::ConstantTimeEq::ct_eq(expected_state.as_bytes(), callback.state.as_bytes())
581 .unwrap_u8()
582 != 1
583 {
584 return Err(crate::AutumnError::unauthorized_msg(
585 "oauth2 state mismatch",
586 ));
587 }
588 session.remove(&state_key).await;
590 Ok(())
591}
592
593#[cfg(feature = "oauth2")]
594async fn exchange_oauth2_token(
595 provider: &OAuth2ProviderConfig,
596 callback: &OAuth2Callback,
597) -> crate::AutumnResult<OAuth2TokenResponse> {
598 let token_response = oauth_http_client()?
599 .post(&provider.token_url)
600 .header(reqwest::header::ACCEPT, "application/json")
601 .form(&[
602 ("grant_type", "authorization_code"),
603 ("code", callback.code.as_str()),
604 ("redirect_uri", provider.redirect_uri.as_str()),
605 ("client_id", provider.client_id.as_str()),
606 ("client_secret", provider.client_secret.as_str()),
607 ])
608 .send()
609 .await
610 .map_err(|e| {
611 crate::AutumnError::service_unavailable_msg(format!("token request failed: {e}"))
612 })?
613 .error_for_status()
614 .map_err(|e| crate::AutumnError::unauthorized_msg(format!("token exchange failed: {e}")))?;
615
616 let token_content_type = token_response
617 .headers()
618 .get(reqwest::header::CONTENT_TYPE)
619 .and_then(|v| v.to_str().ok())
620 .map(str::to_owned);
621 let token_body = token_response.text().await.map_err(|e| {
622 crate::AutumnError::bad_request_msg(format!("invalid token response body: {e}"))
623 })?;
624 parse_oauth2_token_response(token_content_type.as_deref(), &token_body)
625}
626
627#[cfg(feature = "oauth2")]
628async fn load_identity_claims(
629 provider: &OAuth2ProviderConfig,
630 token: &OAuth2TokenResponse,
631) -> crate::AutumnResult<(serde_json::Value, IdentitySource)> {
632 if let Some(id_token) = token.id_token.as_deref() {
633 return Ok((
634 validate_and_decode_id_token(id_token, provider).await?,
635 IdentitySource::IdToken,
636 ));
637 }
638 if let Some(userinfo_url) = &provider.userinfo_url {
639 let claims = oauth_http_client()?
640 .get(userinfo_url)
641 .header(
642 reqwest::header::USER_AGENT,
643 concat!("autumn-web/", env!("CARGO_PKG_VERSION")),
644 )
645 .bearer_auth(&token.access_token)
646 .send()
647 .await
648 .map_err(|e| {
649 crate::AutumnError::service_unavailable_msg(format!("userinfo request failed: {e}"))
650 })?
651 .error_for_status()
652 .map_err(|e| crate::AutumnError::unauthorized_msg(format!("userinfo failed: {e}")))?
653 .json()
654 .await
655 .map_err(|e| {
656 crate::AutumnError::bad_request_msg(format!("invalid userinfo payload: {e}"))
657 })?;
658 return Ok((claims, IdentitySource::UserInfo));
659 }
660 Err(crate::AutumnError::bad_request_msg(
661 "provider must return id_token or configure userinfo_url",
662 ))
663}
664
665#[cfg(feature = "oauth2")]
666async fn validate_oidc_nonce(
667 session: &crate::session::Session,
668 provider_name: &str,
669 claims: &serde_json::Value,
670 source: IdentitySource,
671) -> crate::AutumnResult<()> {
672 let nonce_key = format!("oauth2:{provider_name}:nonce");
673 let stored_nonce = session.remove(&nonce_key).await;
674 if source == IdentitySource::IdToken {
675 let expected_nonce = stored_nonce.ok_or_else(|| {
679 crate::AutumnError::unauthorized_msg("oauth2 nonce missing from session")
680 })?;
681 let actual_nonce = claims
682 .get("nonce")
683 .and_then(serde_json::Value::as_str)
684 .ok_or_else(|| crate::AutumnError::unauthorized_msg("missing oidc nonce claim"))?;
685 if subtle::ConstantTimeEq::ct_eq(expected_nonce.as_bytes(), actual_nonce.as_bytes())
686 .unwrap_u8()
687 != 1
688 {
689 return Err(crate::AutumnError::unauthorized_msg("oidc nonce mismatch"));
690 }
691 }
692 Ok(())
693}
694
695#[cfg(feature = "oauth2")]
696async fn finalize_oauth2_session(
697 session: &crate::session::Session,
698 session_key: &str,
699 provider_name: &str,
700 subject: String,
701 claims: serde_json::Value,
702) -> crate::AutumnResult<OidcIdentity> {
703 session.insert(session_key, subject.clone()).await;
704 session.insert("auth_provider", provider_name).await;
705 session.rotate_id().await;
706 Ok(OidcIdentity {
707 subject,
708 email: claims
709 .get("email")
710 .and_then(serde_json::Value::as_str)
711 .map(str::to_owned),
712 name: claims
713 .get("name")
714 .and_then(serde_json::Value::as_str)
715 .map(str::to_owned),
716 preferred_username: claims
717 .get("preferred_username")
718 .and_then(serde_json::Value::as_str)
719 .map(str::to_owned),
720 raw_claims: claims,
721 })
722}
723
724#[cfg(feature = "oauth2")]
725fn parse_oauth2_token_response(
726 content_type: Option<&str>,
727 body: &str,
728) -> crate::AutumnResult<OAuth2TokenResponse> {
729 let looks_like_json = content_type.is_some_and(|v| v.contains("application/json"))
730 || body.trim_start().starts_with('{');
731 if looks_like_json {
732 return serde_json::from_str(body).map_err(|e| {
733 crate::AutumnError::bad_request_msg(format!("invalid json token response: {e}"))
734 });
735 }
736
737 let mut access_token = None;
738 let mut token_type = None;
739 let mut id_token = None;
740
741 for (k, v) in url::form_urlencoded::parse(body.as_bytes()) {
742 match k.as_ref() {
743 "access_token" => access_token = Some(v.into_owned()),
744 "token_type" => token_type = Some(v.into_owned()),
745 "id_token" => id_token = Some(v.into_owned()),
746 _ => {}
747 }
748 }
749
750 let access_token = access_token.ok_or_else(|| {
751 crate::AutumnError::bad_request_msg("token response missing access_token")
752 })?;
753
754 Ok(OAuth2TokenResponse {
755 access_token,
756 token_type,
757 id_token,
758 })
759}
760
761#[cfg(feature = "oauth2")]
762#[derive(Debug, Clone, Copy, PartialEq, Eq)]
763enum IdentitySource {
764 IdToken,
765 UserInfo,
766}
767
768#[cfg(feature = "oauth2")]
769fn extract_subject(
770 claims: &serde_json::Value,
771 source: IdentitySource,
772) -> crate::AutumnResult<String> {
773 if let Some(sub) = claims.get("sub").and_then(serde_json::Value::as_str) {
774 return Ok(sub.to_owned());
775 }
776
777 if source == IdentitySource::UserInfo {
778 if let Some(id) = claims.get("id").and_then(serde_json::Value::as_i64) {
779 return Ok(id.to_string());
780 }
781 if let Some(id) = claims.get("id").and_then(serde_json::Value::as_str) {
782 return Ok(id.to_owned());
783 }
784 return Err(crate::AutumnError::bad_request_msg(
785 "missing identity claim: expected sub or id from userinfo",
786 ));
787 }
788
789 Err(crate::AutumnError::bad_request_msg("missing sub claim"))
790}
791
792#[cfg(feature = "oauth2")]
793async fn validate_and_decode_id_token(
794 token: &str,
795 provider: &OAuth2ProviderConfig,
796) -> crate::AutumnResult<serde_json::Value> {
797 let issuer = provider
798 .issuer
799 .as_deref()
800 .ok_or_else(|| crate::AutumnError::bad_request_msg("provider.issuer required for oidc"))?;
801 let jwks_url = provider.jwks_url.as_deref().ok_or_else(|| {
802 crate::AutumnError::bad_request_msg("provider.jwks_url required for oidc")
803 })?;
804
805 let header = jsonwebtoken::decode_header(token).map_err(|e| {
806 crate::AutumnError::unauthorized_msg(format!("invalid id_token header: {e}"))
807 })?;
808 let kid = header
809 .kid
810 .as_deref()
811 .ok_or_else(|| crate::AutumnError::unauthorized_msg("id_token header missing kid"))?;
812 let alg = header.alg;
813
814 let jwks: JwkSet = oauth_http_client()?
815 .get(jwks_url)
816 .send()
817 .await
818 .map_err(|e| {
819 crate::AutumnError::service_unavailable_msg(format!("jwks request failed: {e}"))
820 })?
821 .error_for_status()
822 .map_err(|e| crate::AutumnError::unauthorized_msg(format!("jwks fetch failed: {e}")))?
823 .json()
824 .await
825 .map_err(|e| crate::AutumnError::bad_request_msg(format!("invalid jwks response: {e}")))?;
826
827 let jwk = jwks
828 .keys
829 .iter()
830 .find(|k| k.common.key_id.as_deref() == Some(kid))
831 .ok_or_else(|| crate::AutumnError::unauthorized_msg("no jwk matched id_token kid"))?;
832 let decoding_key = jsonwebtoken::DecodingKey::from_jwk(jwk)
833 .map_err(|e| crate::AutumnError::unauthorized_msg(format!("invalid jwk key: {e}")))?;
834
835 let mut validation = jsonwebtoken::Validation::new(alg);
836 validation.set_issuer(&[issuer]);
837 validation.set_audience(std::slice::from_ref(&provider.client_id));
838 validation.required_spec_claims = ["exp", "iss", "aud", "sub"]
839 .into_iter()
840 .map(str::to_owned)
841 .collect();
842 validation.validate_exp = true;
843 validation.validate_nbf = true;
844
845 let claims = jsonwebtoken::decode::<serde_json::Value>(token, &decoding_key, &validation)
846 .map_err(|e| crate::AutumnError::unauthorized_msg(format!("invalid id_token: {e}")))?;
847 Ok(claims.claims)
848}
849
850#[cfg(feature = "oauth2")]
851fn oauth_http_client() -> crate::AutumnResult<reqwest::Client> {
852 reqwest::Client::builder()
853 .timeout(Duration::from_secs(OAUTH_HTTP_TIMEOUT_SECS))
854 .build()
855 .map_err(|e| {
856 crate::AutumnError::service_unavailable_msg(format!(
857 "failed to build oauth http client: {e}"
858 ))
859 })
860}
861
862impl Default for AuthConfig {
863 fn default() -> Self {
864 Self {
865 bcrypt_cost: default_bcrypt_cost(),
866 session_key: default_session_key(),
867 #[cfg(feature = "oauth2")]
868 oauth2: OAuth2Config::default(),
869 }
870 }
871}
872
873pub trait ApiTokenStore: Send + Sync + 'static {
886 fn issue<'a>(
891 &'a self,
892 principal_id: &'a str,
893 ) -> Pin<Box<dyn Future<Output = crate::AutumnResult<String>> + Send + 'a>>;
894
895 fn verify<'a>(
898 &'a self,
899 raw_token: &'a str,
900 ) -> Pin<Box<dyn Future<Output = crate::AutumnResult<Option<String>>> + Send + 'a>>;
901
902 fn revoke<'a>(
904 &'a self,
905 raw_token: &'a str,
906 ) -> Pin<Box<dyn Future<Output = crate::AutumnResult<()>> + Send + 'a>>;
907}
908
909#[must_use]
924pub fn hash_api_token(raw: &str) -> String {
925 use sha2::Digest as _;
926 sha2::Sha256::digest(raw.as_bytes())
927 .iter()
928 .fold(String::with_capacity(64), |mut s, b| {
929 use std::fmt::Write as _;
930 let _ = write!(s, "{b:02x}");
931 s
932 })
933}
934
935fn generate_raw_token() -> String {
939 let u1 = uuid::Uuid::new_v4();
940 let u2 = uuid::Uuid::new_v4();
941 format!("{}{}", u1.simple(), u2.simple())
942}
943
944#[derive(Clone)]
965pub struct InMemoryApiTokenStore {
966 tokens: Arc<std::sync::RwLock<std::collections::HashMap<String, String>>>,
968}
969
970impl Default for InMemoryApiTokenStore {
971 fn default() -> Self {
972 Self {
973 tokens: Arc::new(std::sync::RwLock::new(std::collections::HashMap::new())),
974 }
975 }
976}
977
978impl ApiTokenStore for InMemoryApiTokenStore {
979 fn issue<'a>(
980 &'a self,
981 principal_id: &'a str,
982 ) -> Pin<Box<dyn Future<Output = crate::AutumnResult<String>> + Send + 'a>> {
983 Box::pin(async move {
984 let raw = generate_raw_token();
985 let hash = hash_api_token(&raw);
986 self.tokens
987 .write()
988 .expect("api token store lock poisoned")
989 .insert(hash, principal_id.to_owned());
990 Ok(raw)
991 })
992 }
993
994 fn verify<'a>(
995 &'a self,
996 raw_token: &'a str,
997 ) -> Pin<Box<dyn Future<Output = crate::AutumnResult<Option<String>>> + Send + 'a>> {
998 Box::pin(async move {
999 let hash = hash_api_token(raw_token);
1000 Ok(self
1001 .tokens
1002 .read()
1003 .expect("api token store lock poisoned")
1004 .get(&hash)
1005 .cloned())
1006 })
1007 }
1008
1009 fn revoke<'a>(
1010 &'a self,
1011 raw_token: &'a str,
1012 ) -> Pin<Box<dyn Future<Output = crate::AutumnResult<()>> + Send + 'a>> {
1013 Box::pin(async move {
1014 let hash = hash_api_token(raw_token);
1015 self.tokens
1016 .write()
1017 .expect("api token store lock poisoned")
1018 .remove(&hash);
1019 Ok(())
1020 })
1021 }
1022}
1023
1024pub async fn issue_api_token(
1032 store: &dyn ApiTokenStore,
1033 principal_id: &str,
1034) -> crate::AutumnResult<String> {
1035 store.issue(principal_id).await
1036}
1037
1038pub async fn revoke_api_token(
1046 store: &dyn ApiTokenStore,
1047 raw_token: &str,
1048) -> crate::AutumnResult<()> {
1049 store.revoke(raw_token).await
1050}
1051
1052#[derive(Clone)]
1055struct ApiTokenPrincipal(String);
1056
1057#[derive(Debug, Clone)]
1075pub struct ApiToken(pub String);
1076
1077impl<S> FromRequestParts<S> for ApiToken
1078where
1079 S: Send + Sync,
1080{
1081 type Rejection = AuthRejection;
1082
1083 fn from_request_parts(
1084 parts: &mut Parts,
1085 _state: &S,
1086 ) -> impl Future<Output = Result<Self, Self::Rejection>> + Send {
1087 let principal = parts.extensions.get::<ApiTokenPrincipal>().cloned();
1088 async move { principal.map(|p| Self(p.0)).ok_or(AuthRejection) }
1089 }
1090}
1091
1092#[derive(Clone)]
1117pub struct RequireApiToken {
1118 store: Arc<dyn ApiTokenStore>,
1119}
1120
1121impl RequireApiToken {
1122 #[must_use]
1127 pub fn new<S: ApiTokenStore + 'static>(store: Arc<S>) -> Self {
1128 Self { store }
1129 }
1130}
1131
1132impl<S> tower::Layer<S> for RequireApiToken {
1133 type Service = RequireApiTokenService<S>;
1134
1135 fn layer(&self, inner: S) -> Self::Service {
1136 RequireApiTokenService {
1137 inner,
1138 store: Arc::clone(&self.store),
1139 }
1140 }
1141}
1142
1143#[derive(Clone)]
1145pub struct RequireApiTokenService<S> {
1146 inner: S,
1147 store: Arc<dyn ApiTokenStore>,
1148}
1149
1150impl<S, ResBody> tower::Service<axum::extract::Request> for RequireApiTokenService<S>
1151where
1152 S: tower::Service<axum::extract::Request, Response = Response<ResBody>>
1153 + Clone
1154 + Send
1155 + 'static,
1156 S::Future: Send + 'static,
1157 S::Error: Send + 'static,
1158 ResBody: From<String> + Default + Send + 'static,
1159{
1160 type Response = Response<ResBody>;
1161 type Error = S::Error;
1162 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
1163
1164 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1165 self.inner.poll_ready(cx)
1166 }
1167
1168 fn call(&mut self, mut req: axum::extract::Request) -> Self::Future {
1169 let store = Arc::clone(&self.store);
1170 let mut inner = self.inner.clone();
1171 std::mem::swap(&mut self.inner, &mut inner);
1172
1173 Box::pin(async move {
1174 let raw_token = req
1176 .headers()
1177 .get(http::header::AUTHORIZATION)
1178 .and_then(|v| v.to_str().ok())
1179 .and_then(parse_bearer_token)
1180 .map(str::to_owned);
1181
1182 let Some(raw_token) = raw_token else {
1183 let (request_id, instance) = api_token_problem_context(&req);
1184 return Ok(api_token_unauthorized_response(request_id, instance));
1185 };
1186
1187 match store.verify(&raw_token).await {
1188 Ok(Some(principal_id)) => {
1189 req.extensions_mut().insert(ApiTokenPrincipal(principal_id));
1190 inner.call(req).await
1191 }
1192 Ok(None) => {
1193 let (request_id, instance) = api_token_problem_context(&req);
1194 Ok(api_token_unauthorized_response(request_id, instance))
1195 }
1196 Err(err) => {
1197 let (request_id, instance) = api_token_problem_context(&req);
1198 Ok(api_token_error_response(&err, request_id, instance))
1199 }
1200 }
1201 })
1202 }
1203}
1204
1205fn parse_bearer_token(header: &str) -> Option<&str> {
1206 let (scheme, token) = header.split_once(' ')?;
1207 scheme.eq_ignore_ascii_case("Bearer").then_some(token)
1208}
1209
1210fn api_token_unauthorized_response<ResBody: From<String> + Default>(
1212 request_id: Option<String>,
1213 instance: Option<String>,
1214) -> Response<ResBody> {
1215 let body = crate::error::problem_details_json_string(
1216 StatusCode::UNAUTHORIZED,
1217 "authentication required",
1218 None,
1219 None,
1220 request_id,
1221 instance,
1222 true,
1223 );
1224 Response::builder()
1225 .status(StatusCode::UNAUTHORIZED)
1226 .header(http::header::CONTENT_TYPE, "application/problem+json")
1227 .body(ResBody::from(body))
1228 .unwrap_or_default()
1229}
1230
1231fn api_token_error_response<ResBody: From<String> + Default>(
1233 err: &crate::AutumnError,
1234 request_id: Option<String>,
1235 instance: Option<String>,
1236) -> Response<ResBody> {
1237 let status = err.status();
1238 let message = err.to_string();
1239 let body = crate::error::problem_details_json_string(
1240 status,
1241 message.clone(),
1242 None,
1243 None,
1244 request_id,
1245 instance,
1246 true,
1247 );
1248 let mut response = Response::builder()
1249 .status(status)
1250 .header(http::header::CONTENT_TYPE, "application/problem+json")
1251 .body(ResBody::from(body))
1252 .unwrap_or_default();
1253 response
1254 .extensions_mut()
1255 .insert(crate::middleware::AutumnErrorInfo {
1256 status,
1257 message,
1258 details: None,
1259 problem_type: None,
1260 });
1261 response
1262}
1263
1264fn api_token_problem_context(req: &axum::extract::Request) -> (Option<String>, Option<String>) {
1265 (
1266 req.extensions()
1267 .get::<crate::middleware::RequestId>()
1268 .map(std::string::ToString::to_string),
1269 Some(req.uri().path().to_owned()),
1270 )
1271}
1272
1273#[cfg(feature = "db")]
1297pub const API_TOKEN_MIGRATIONS: diesel_migrations::EmbeddedMigrations =
1298 diesel_migrations::embed_migrations!("migrations");
1299
1300#[cfg(feature = "db")]
1301mod db_store {
1302 use std::future::Future;
1303 use std::pin::Pin;
1304
1305 use diesel::OptionalExtension as _;
1306 use diesel::prelude::*;
1307 use diesel_async::AsyncPgConnection;
1308 use diesel_async::RunQueryDsl;
1309 use diesel_async::pooled_connection::deadpool::Pool;
1310
1311 use super::{ApiTokenStore, generate_raw_token, hash_api_token};
1312 use crate::error::AutumnError;
1313
1314 diesel::table! {
1315 api_tokens (id) {
1316 id -> Int8,
1317 token_hash -> Text,
1318 principal_id -> Text,
1319 created_at -> Timestamp,
1320 revoked_at -> Nullable<Timestamp>,
1321 }
1322 }
1323
1324 #[derive(Insertable)]
1325 #[diesel(table_name = api_tokens)]
1326 struct NewApiToken<'a> {
1327 token_hash: &'a str,
1328 principal_id: &'a str,
1329 }
1330
1331 #[derive(Clone)]
1355 pub struct DbApiTokenStore {
1356 pool: Pool<AsyncPgConnection>,
1357 }
1358
1359 impl DbApiTokenStore {
1360 #[must_use]
1362 pub const fn new(pool: Pool<AsyncPgConnection>) -> Self {
1363 Self { pool }
1364 }
1365 }
1366
1367 impl ApiTokenStore for DbApiTokenStore {
1368 fn issue<'a>(
1369 &'a self,
1370 principal_id: &'a str,
1371 ) -> Pin<Box<dyn Future<Output = crate::AutumnResult<String>> + Send + 'a>> {
1372 Box::pin(async move {
1373 let raw = generate_raw_token();
1374 let hash = hash_api_token(&raw);
1375 let mut conn = self
1376 .pool
1377 .get()
1378 .await
1379 .map_err(|e| AutumnError::internal_server_error_msg(e.to_string()))?;
1380 diesel::insert_into(api_tokens::table)
1381 .values(NewApiToken {
1382 token_hash: &hash,
1383 principal_id,
1384 })
1385 .execute(&mut conn)
1386 .await
1387 .map_err(|e| AutumnError::internal_server_error_msg(e.to_string()))?;
1388 Ok(raw)
1389 })
1390 }
1391
1392 fn verify<'a>(
1393 &'a self,
1394 raw_token: &'a str,
1395 ) -> Pin<Box<dyn Future<Output = crate::AutumnResult<Option<String>>> + Send + 'a>>
1396 {
1397 Box::pin(async move {
1398 let hash = hash_api_token(raw_token);
1399 let mut conn = self
1400 .pool
1401 .get()
1402 .await
1403 .map_err(|e| AutumnError::internal_server_error_msg(e.to_string()))?;
1404 let principal: Option<String> = api_tokens::table
1405 .filter(api_tokens::token_hash.eq(&hash))
1406 .filter(api_tokens::revoked_at.is_null())
1407 .select(api_tokens::principal_id)
1408 .first(&mut conn)
1409 .await
1410 .optional()
1411 .map_err(|e| AutumnError::internal_server_error_msg(e.to_string()))?;
1412 Ok(principal)
1413 })
1414 }
1415
1416 fn revoke<'a>(
1417 &'a self,
1418 raw_token: &'a str,
1419 ) -> Pin<Box<dyn Future<Output = crate::AutumnResult<()>> + Send + 'a>> {
1420 Box::pin(async move {
1421 let hash = hash_api_token(raw_token);
1422 let mut conn = self
1423 .pool
1424 .get()
1425 .await
1426 .map_err(|e| AutumnError::internal_server_error_msg(e.to_string()))?;
1427 diesel::update(api_tokens::table)
1428 .filter(api_tokens::token_hash.eq(&hash))
1429 .set(api_tokens::revoked_at.eq(diesel::dsl::now.nullable()))
1430 .execute(&mut conn)
1431 .await
1432 .map_err(|e| AutumnError::internal_server_error_msg(e.to_string()))?;
1433 Ok(())
1434 })
1435 }
1436 }
1437}
1438
1439#[cfg(feature = "db")]
1440pub use db_store::DbApiTokenStore;
1441
1442#[cfg(test)]
1443mod tests {
1444 use super::*;
1445
1446 #[tokio::test]
1447 async fn hash_and_verify_password() {
1448 let hash = hash_password("test_password").await.unwrap();
1449 assert!(hash.starts_with("$2b$"));
1450 assert!(verify_password("test_password", &hash).await.unwrap());
1451 assert!(!verify_password("wrong_password", &hash).await.unwrap());
1452 }
1453
1454 #[tokio::test]
1455 async fn verify_invalid_hash_returns_false() {
1456 let result = verify_password("test", "not-a-valid-hash").await;
1457 assert!(result.is_ok());
1458 assert!(!result.unwrap());
1459 }
1460
1461 #[tokio::test]
1462 async fn verify_password_rejects_invalid_hash_format_safely() {
1463 let result = verify_password("test", "short").await;
1465 assert!(result.is_ok());
1466 assert!(!result.unwrap());
1467
1468 let bad_prefix = "a".repeat(60);
1470 let result = verify_password("test", &bad_prefix).await;
1471 assert!(result.is_ok());
1472 assert!(!result.unwrap());
1473
1474 let bad_length = "$2b$12$short";
1476 let result = verify_password("test", bad_length).await;
1477 assert!(result.is_ok());
1478 assert!(!result.unwrap());
1479 }
1480
1481 #[test]
1482 fn auth_config_defaults() {
1483 let config = AuthConfig::default();
1484 assert_eq!(config.bcrypt_cost, 12);
1485 assert_eq!(config.session_key, "user_id");
1486 #[cfg(feature = "oauth2")]
1487 assert!(config.oauth2.providers.is_empty());
1488 }
1489
1490 #[cfg(feature = "oauth2")]
1491 #[test]
1492 fn oauth2_config_deserializes_provider_tables() {
1493 let cfg: crate::config::AutumnConfig = toml::from_str(
1494 r#"
1495 [auth.oauth2.github]
1496 client_id = "cid"
1497 client_secret = "secret"
1498 authorize_url = "https://github.com/login/oauth/authorize"
1499 token_url = "https://github.com/login/oauth/access_token"
1500 redirect_uri = "http://localhost:3000/auth/github/callback"
1501 "#,
1502 )
1503 .unwrap();
1504 let provider = cfg.auth.oauth2.providers.get("github").unwrap();
1505 assert_eq!(provider.client_id, "cid");
1506 assert_eq!(provider.scope, "");
1507 assert!(provider.issuer.is_none());
1508 assert!(provider.jwks_url.is_none());
1509 }
1510
1511 #[cfg(feature = "oauth2")]
1512 #[tokio::test]
1513 async fn oauth2_authorize_url_sets_state_and_nonce() {
1514 let session = crate::session::Session::new_for_test("s1".into(), HashMap::new());
1515 let provider = OAuth2ProviderConfig {
1516 client_id: "cid".into(),
1517 client_secret: "secret".into(),
1518 authorize_url: "https://idp.example/authorize".into(),
1519 token_url: "https://idp.example/token".into(),
1520 userinfo_url: None,
1521 redirect_uri: "http://localhost:3000/callback".into(),
1522 scope: "openid profile".into(),
1523 issuer: None,
1524 jwks_url: None,
1525 };
1526 let url = oauth2_authorize_url(&session, "github", &provider)
1527 .await
1528 .unwrap();
1529 assert!(url.contains("response_type=code"));
1530 assert!(session.get("oauth2:github:state").await.is_some());
1531 assert!(session.get("oauth2:github:nonce").await.is_some());
1532 }
1533
1534 #[cfg(feature = "oauth2")]
1535 #[tokio::test]
1536 async fn oauth2_authorize_url_omits_scope_when_empty() {
1537 let session = crate::session::Session::new_for_test("s1".into(), HashMap::new());
1538 let provider = OAuth2ProviderConfig {
1539 client_id: "cid".into(),
1540 client_secret: "secret".into(),
1541 authorize_url: "https://idp.example/authorize".into(),
1542 token_url: "https://idp.example/token".into(),
1543 userinfo_url: None,
1544 redirect_uri: "http://localhost:3000/callback".into(),
1545 scope: String::new(),
1546 issuer: None,
1547 jwks_url: None,
1548 };
1549 let url = oauth2_authorize_url(&session, "github", &provider)
1550 .await
1551 .unwrap();
1552 assert!(!url.contains("scope="));
1553 }
1554
1555 #[cfg(feature = "oauth2")]
1556 #[tokio::test]
1557 async fn validate_id_token_requires_oidc_metadata() {
1558 let provider = OAuth2ProviderConfig {
1559 client_id: "cid".into(),
1560 client_secret: "secret".into(),
1561 authorize_url: "https://idp.example/authorize".into(),
1562 token_url: "https://idp.example/token".into(),
1563 userinfo_url: None,
1564 redirect_uri: "http://localhost:3000/callback".into(),
1565 scope: "openid profile".into(),
1566 issuer: None,
1567 jwks_url: None,
1568 };
1569 let err = validate_and_decode_id_token("bad.token.value", &provider)
1570 .await
1571 .unwrap_err();
1572 assert_eq!(err.to_string(), "provider.issuer required for oidc");
1573 }
1574
1575 #[cfg(feature = "oauth2")]
1576 #[test]
1577 fn parse_oauth2_token_response_supports_form_encoded_payload() {
1578 let token = parse_oauth2_token_response(
1579 Some("application/x-www-form-urlencoded"),
1580 "access_token=abc123&token_type=bearer&id_token=xyz789&extra_field=ignored",
1581 )
1582 .unwrap();
1583 assert_eq!(token.access_token, "abc123");
1584 assert_eq!(token.token_type.as_deref(), Some("bearer"));
1585 assert_eq!(token.id_token.as_deref(), Some("xyz789"));
1586 }
1587
1588 #[cfg(feature = "oauth2")]
1589 #[test]
1590 fn parse_oauth2_token_response_fails_without_access_token() {
1591 let err = parse_oauth2_token_response(
1592 Some("application/x-www-form-urlencoded"),
1593 "token_type=bearer&id_token=xyz789",
1594 )
1595 .unwrap_err();
1596 assert_eq!(err.to_string(), "token response missing access_token");
1597 }
1598
1599 #[cfg(feature = "oauth2")]
1600 #[test]
1601 fn extract_subject_allows_userinfo_id_fallback() {
1602 let claims = serde_json::json!({ "id": 42 });
1603 let subject = extract_subject(&claims, IdentitySource::UserInfo).unwrap();
1604 assert_eq!(subject, "42");
1605 }
1606
1607 #[cfg(feature = "oauth2")]
1608 #[tokio::test]
1609 async fn validate_callback_state_preserves_state_on_mismatch() {
1610 let session = crate::session::Session::new_for_test("s1".into(), HashMap::new());
1614 session
1615 .insert("oauth2:github:state".to_owned(), "real-state".to_owned())
1616 .await;
1617 let bad_callback = OAuth2Callback {
1618 code: "c".into(),
1619 state: "wrong-state".into(),
1620 };
1621 let err = validate_callback_state(&session, "github", &bad_callback)
1622 .await
1623 .unwrap_err();
1624 assert!(err.to_string().contains("state mismatch"));
1625 assert_eq!(
1627 session.get("oauth2:github:state").await.as_deref(),
1628 Some("real-state")
1629 );
1630 }
1631
1632 #[cfg(feature = "oauth2")]
1633 #[tokio::test]
1634 async fn validate_oidc_nonce_rejects_missing_nonce_for_id_token() {
1635 let session = crate::session::Session::new_for_test("s1".into(), HashMap::new());
1638 let claims = serde_json::json!({ "nonce": "any" });
1640 let err = validate_oidc_nonce(&session, "github", &claims, IdentitySource::IdToken)
1641 .await
1642 .unwrap_err();
1643 assert!(err.to_string().contains("nonce missing from session"));
1644 }
1645
1646 #[cfg(feature = "oauth2")]
1647 #[test]
1648 fn extract_subject_requires_sub_for_id_token() {
1649 let claims = serde_json::json!({ "id": "abc" });
1650 let err = extract_subject(&claims, IdentitySource::IdToken).unwrap_err();
1651 assert_eq!(err.to_string(), "missing sub claim");
1652 }
1653
1654 #[test]
1655 fn auth_rejection_is_401() {
1656 let rejection = AuthRejection;
1657 let response = rejection.into_response();
1658 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
1659 }
1660
1661 #[test]
1662 fn auth_rejection_display() {
1663 assert_eq!(AuthRejection.to_string(), "authentication required");
1664 }
1665
1666 #[tokio::test]
1667 async fn auth_extractor_returns_401_when_no_user() {
1668 use crate::state::AppState;
1669 use axum::Router;
1670 use axum::body::Body;
1671 use axum::routing::get;
1672 use tower::ServiceExt;
1673
1674 #[derive(Clone)]
1675 struct TestUser {
1676 name: String,
1677 }
1678
1679 async fn handler(Auth(user): Auth<TestUser>) -> String {
1680 user.name
1681 }
1682
1683 let state = AppState {
1684 extensions: std::sync::Arc::new(std::sync::RwLock::new(
1685 std::collections::HashMap::new(),
1686 )),
1687 #[cfg(feature = "db")]
1688 pool: None,
1689 #[cfg(feature = "db")]
1690 replica_pool: None,
1691 profile: None,
1692 started_at: std::time::Instant::now(),
1693 health_detailed: false,
1694 probes: crate::probe::ProbeState::ready_for_test(),
1695 metrics: crate::middleware::MetricsCollector::new(),
1696 log_levels: crate::actuator::LogLevels::new("info"),
1697 task_registry: crate::actuator::TaskRegistry::new(),
1698 job_registry: crate::actuator::JobRegistry::new(),
1699 config_props: crate::actuator::ConfigProperties::default(),
1700 #[cfg(feature = "ws")]
1701 channels: crate::channels::Channels::new(32),
1702 #[cfg(feature = "ws")]
1703 shutdown: tokio_util::sync::CancellationToken::new(),
1704 policy_registry: crate::authorization::PolicyRegistry::default(),
1705 forbidden_response: crate::authorization::ForbiddenResponse::default(),
1706 auth_session_key: "user_id".to_owned(),
1707 shared_cache: None,
1708 };
1709
1710 let app = Router::new().route("/", get(handler)).with_state(state);
1711
1712 let response = app
1713 .oneshot(
1714 http::Request::builder()
1715 .uri("/")
1716 .body(Body::empty())
1717 .unwrap(),
1718 )
1719 .await
1720 .unwrap();
1721
1722 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
1723 }
1724
1725 #[tokio::test]
1726 async fn auth_extractor_returns_user_when_present() {
1727 use crate::state::AppState;
1728 use axum::Router;
1729 use axum::body::Body;
1730 use axum::routing::get;
1731 use tower::ServiceExt;
1732
1733 #[derive(Clone)]
1734 struct TestUser {
1735 name: String,
1736 }
1737
1738 async fn handler(Auth(user): Auth<TestUser>) -> String {
1739 user.name
1740 }
1741
1742 let state = AppState {
1743 extensions: std::sync::Arc::new(std::sync::RwLock::new(
1744 std::collections::HashMap::new(),
1745 )),
1746 #[cfg(feature = "db")]
1747 pool: None,
1748 #[cfg(feature = "db")]
1749 replica_pool: None,
1750 profile: None,
1751 started_at: std::time::Instant::now(),
1752 health_detailed: false,
1753 probes: crate::probe::ProbeState::ready_for_test(),
1754 metrics: crate::middleware::MetricsCollector::new(),
1755 log_levels: crate::actuator::LogLevels::new("info"),
1756 task_registry: crate::actuator::TaskRegistry::new(),
1757 job_registry: crate::actuator::JobRegistry::new(),
1758 config_props: crate::actuator::ConfigProperties::default(),
1759 #[cfg(feature = "ws")]
1760 channels: crate::channels::Channels::new(32),
1761 #[cfg(feature = "ws")]
1762 shutdown: tokio_util::sync::CancellationToken::new(),
1763 policy_registry: crate::authorization::PolicyRegistry::default(),
1764 forbidden_response: crate::authorization::ForbiddenResponse::default(),
1765 auth_session_key: "user_id".to_owned(),
1766 shared_cache: None,
1767 };
1768
1769 let app = Router::new()
1771 .route("/", get(handler))
1772 .layer(axum::middleware::from_fn(
1773 |mut req: axum::extract::Request, next: axum::middleware::Next| async move {
1774 req.extensions_mut().insert(TestUser {
1775 name: "alice".into(),
1776 });
1777 next.run(req).await
1778 },
1779 ))
1780 .with_state(state);
1781
1782 let response = app
1783 .oneshot(
1784 http::Request::builder()
1785 .uri("/")
1786 .body(Body::empty())
1787 .unwrap(),
1788 )
1789 .await
1790 .unwrap();
1791
1792 assert_eq!(response.status(), StatusCode::OK);
1793 let body = axum::body::to_bytes(response.into_body(), usize::MAX)
1794 .await
1795 .unwrap();
1796 assert_eq!(std::str::from_utf8(&body).unwrap(), "alice");
1797 }
1798
1799 #[tokio::test]
1800 async fn require_auth_rejects_unauthenticated() {
1801 use axum::Router;
1802 use axum::body::Body;
1803 use axum::routing::get;
1804 use tower::ServiceExt;
1805
1806 use crate::session::{MemoryStore, SessionConfig, SessionLayer};
1807 use crate::state::AppState;
1808
1809 let state = AppState {
1810 extensions: std::sync::Arc::new(std::sync::RwLock::new(
1811 std::collections::HashMap::new(),
1812 )),
1813 #[cfg(feature = "db")]
1814 pool: None,
1815 #[cfg(feature = "db")]
1816 replica_pool: None,
1817 profile: None,
1818 started_at: std::time::Instant::now(),
1819 health_detailed: false,
1820 probes: crate::probe::ProbeState::ready_for_test(),
1821 metrics: crate::middleware::MetricsCollector::new(),
1822 log_levels: crate::actuator::LogLevels::new("info"),
1823 task_registry: crate::actuator::TaskRegistry::new(),
1824 job_registry: crate::actuator::JobRegistry::new(),
1825 config_props: crate::actuator::ConfigProperties::default(),
1826 #[cfg(feature = "ws")]
1827 channels: crate::channels::Channels::new(32),
1828 #[cfg(feature = "ws")]
1829 shutdown: tokio_util::sync::CancellationToken::new(),
1830 policy_registry: crate::authorization::PolicyRegistry::default(),
1831 forbidden_response: crate::authorization::ForbiddenResponse::default(),
1832 auth_session_key: "user_id".to_owned(),
1833 shared_cache: None,
1834 };
1835
1836 let app = Router::new()
1837 .route("/protected", get(|| async { "secret" }))
1838 .layer(RequireAuth::new("user_id"))
1839 .layer(SessionLayer::new(
1840 MemoryStore::new(),
1841 SessionConfig::default(),
1842 ))
1843 .with_state(state);
1844
1845 let response = app
1846 .oneshot(
1847 http::Request::builder()
1848 .uri("/protected")
1849 .body(Body::empty())
1850 .unwrap(),
1851 )
1852 .await
1853 .unwrap();
1854
1855 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
1856 }
1857
1858 #[tokio::test]
1861 async fn check_secured_rejects_unauthenticated() {
1862 let session =
1863 crate::session::Session::new_for_test(String::new(), std::collections::HashMap::new());
1864 let result = __check_secured(&session, &[]).await;
1865 assert!(result.is_err());
1866 let err = result.unwrap_err();
1867 assert_eq!(err.status(), StatusCode::UNAUTHORIZED);
1868 assert_eq!(err.to_string(), "authentication required");
1869 }
1870
1871 #[tokio::test]
1872 async fn check_secured_allows_authenticated() {
1873 let data = std::collections::HashMap::from([("user_id".into(), "42".into())]);
1874 let session = crate::session::Session::new_for_test("sess".into(), data);
1875 let result = __check_secured(&session, &[]).await;
1876 assert!(result.is_ok());
1877 }
1878
1879 #[tokio::test]
1880 async fn check_secured_rejects_wrong_role() {
1881 let data = std::collections::HashMap::from([
1882 ("user_id".into(), "42".into()),
1883 ("role".into(), "viewer".into()),
1884 ]);
1885 let session = crate::session::Session::new_for_test("sess".into(), data);
1886 let result = __check_secured(&session, &["admin"]).await;
1887 assert!(result.is_err());
1888 let err = result.unwrap_err();
1889 assert_eq!(err.status(), StatusCode::FORBIDDEN);
1890 assert_eq!(err.to_string(), "insufficient permissions");
1891 }
1892
1893 #[tokio::test]
1894 async fn check_secured_allows_matching_role() {
1895 let data = std::collections::HashMap::from([
1896 ("user_id".into(), "42".into()),
1897 ("role".into(), "admin".into()),
1898 ]);
1899 let session = crate::session::Session::new_for_test("sess".into(), data);
1900 let result = __check_secured(&session, &["admin"]).await;
1901 assert!(result.is_ok());
1902 }
1903
1904 #[tokio::test]
1905 async fn check_secured_allows_any_of_multiple_roles() {
1906 let data = std::collections::HashMap::from([
1907 ("user_id".into(), "42".into()),
1908 ("role".into(), "editor".into()),
1909 ]);
1910 let session = crate::session::Session::new_for_test("sess".into(), data);
1911 let result = __check_secured(&session, &["admin", "editor"]).await;
1912 assert!(result.is_ok());
1913 }
1914
1915 #[tokio::test]
1918 async fn secured_macro_rejects_unauthenticated() {
1919 use axum::Router;
1920 use axum::body::Body;
1921 use axum::routing::get;
1922 use tower::ServiceExt;
1923
1924 use crate::session::{MemoryStore, SessionConfig, SessionLayer};
1925 use crate::state::AppState;
1926
1927 #[autumn_macros::secured]
1928 async fn protected_handler() -> crate::AutumnResult<&'static str> {
1929 Ok("secret")
1930 }
1931
1932 let state = AppState {
1933 extensions: std::sync::Arc::new(std::sync::RwLock::new(
1934 std::collections::HashMap::new(),
1935 )),
1936 #[cfg(feature = "db")]
1937 pool: None,
1938 #[cfg(feature = "db")]
1939 replica_pool: None,
1940 profile: None,
1941 started_at: std::time::Instant::now(),
1942 health_detailed: false,
1943 probes: crate::probe::ProbeState::ready_for_test(),
1944 metrics: crate::middleware::MetricsCollector::new(),
1945 log_levels: crate::actuator::LogLevels::new("info"),
1946 task_registry: crate::actuator::TaskRegistry::new(),
1947 job_registry: crate::actuator::JobRegistry::new(),
1948 config_props: crate::actuator::ConfigProperties::default(),
1949 #[cfg(feature = "ws")]
1950 channels: crate::channels::Channels::new(32),
1951 #[cfg(feature = "ws")]
1952 shutdown: tokio_util::sync::CancellationToken::new(),
1953 policy_registry: crate::authorization::PolicyRegistry::default(),
1954 forbidden_response: crate::authorization::ForbiddenResponse::default(),
1955 auth_session_key: "user_id".to_owned(),
1956 shared_cache: None,
1957 };
1958
1959 let app = Router::new()
1960 .route("/", get(protected_handler))
1961 .layer(SessionLayer::new(
1962 MemoryStore::new(),
1963 SessionConfig::default(),
1964 ))
1965 .with_state(state);
1966
1967 let response = app
1968 .oneshot(
1969 http::Request::builder()
1970 .uri("/")
1971 .body(Body::empty())
1972 .unwrap(),
1973 )
1974 .await
1975 .unwrap();
1976
1977 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
1978 }
1979
1980 #[tokio::test]
1981 async fn secured_macro_allows_authenticated() {
1982 use axum::Router;
1983 use axum::body::Body;
1984 use axum::routing::get;
1985 use http::header::COOKIE;
1986 use tower::ServiceExt;
1987
1988 use crate::session::{MemoryStore, SessionConfig, SessionLayer, SessionStore};
1989 use crate::state::AppState;
1990
1991 #[autumn_macros::secured]
1992 async fn protected_handler() -> crate::AutumnResult<&'static str> {
1993 Ok("secret")
1994 }
1995
1996 let store = MemoryStore::new();
1997 store
1998 .save(
1999 "sess1",
2000 std::collections::HashMap::from([("user_id".into(), "42".into())]),
2001 )
2002 .await
2003 .unwrap();
2004
2005 let state = AppState {
2006 extensions: std::sync::Arc::new(std::sync::RwLock::new(
2007 std::collections::HashMap::new(),
2008 )),
2009 #[cfg(feature = "db")]
2010 pool: None,
2011 #[cfg(feature = "db")]
2012 replica_pool: None,
2013 profile: None,
2014 started_at: std::time::Instant::now(),
2015 health_detailed: false,
2016 probes: crate::probe::ProbeState::ready_for_test(),
2017 metrics: crate::middleware::MetricsCollector::new(),
2018 log_levels: crate::actuator::LogLevels::new("info"),
2019 task_registry: crate::actuator::TaskRegistry::new(),
2020 job_registry: crate::actuator::JobRegistry::new(),
2021 config_props: crate::actuator::ConfigProperties::default(),
2022 #[cfg(feature = "ws")]
2023 channels: crate::channels::Channels::new(32),
2024 #[cfg(feature = "ws")]
2025 shutdown: tokio_util::sync::CancellationToken::new(),
2026 policy_registry: crate::authorization::PolicyRegistry::default(),
2027 forbidden_response: crate::authorization::ForbiddenResponse::default(),
2028 auth_session_key: "user_id".to_owned(),
2029 shared_cache: None,
2030 };
2031
2032 let app = Router::new()
2033 .route("/", get(protected_handler))
2034 .layer(SessionLayer::new(store, SessionConfig::default()))
2035 .with_state(state);
2036
2037 let response = app
2038 .oneshot(
2039 http::Request::builder()
2040 .uri("/")
2041 .header(COOKIE, "autumn.sid=sess1")
2042 .body(Body::empty())
2043 .unwrap(),
2044 )
2045 .await
2046 .unwrap();
2047
2048 assert_eq!(response.status(), StatusCode::OK);
2049 let body = axum::body::to_bytes(response.into_body(), usize::MAX)
2050 .await
2051 .unwrap();
2052 assert_eq!(std::str::from_utf8(&body).unwrap(), "secret");
2053 }
2054
2055 #[tokio::test]
2056 async fn secured_macro_honors_configured_auth_session_key() {
2057 use axum::Router;
2058 use axum::body::Body;
2059 use axum::routing::get;
2060 use http::header::COOKIE;
2061 use tower::ServiceExt;
2062
2063 use crate::session::{MemoryStore, SessionConfig, SessionLayer, SessionStore};
2064 use crate::state::AppState;
2065
2066 #[autumn_macros::secured]
2067 async fn account_handler() -> crate::AutumnResult<&'static str> {
2068 Ok("account")
2069 }
2070
2071 let store = MemoryStore::new();
2072 store
2073 .save(
2074 "sess1",
2075 std::collections::HashMap::from([
2076 ("uid".into(), "42".into()),
2077 ("account_id".into(), "42".into()),
2078 ]),
2079 )
2080 .await
2081 .unwrap();
2082
2083 let state = AppState {
2084 extensions: std::sync::Arc::new(std::sync::RwLock::new(
2085 std::collections::HashMap::new(),
2086 )),
2087 #[cfg(feature = "db")]
2088 pool: None,
2089 #[cfg(feature = "db")]
2090 replica_pool: None,
2091 profile: None,
2092 started_at: std::time::Instant::now(),
2093 health_detailed: false,
2094 probes: crate::probe::ProbeState::ready_for_test(),
2095 metrics: crate::middleware::MetricsCollector::new(),
2096 log_levels: crate::actuator::LogLevels::new("info"),
2097 task_registry: crate::actuator::TaskRegistry::new(),
2098 job_registry: crate::actuator::JobRegistry::new(),
2099 config_props: crate::actuator::ConfigProperties::default(),
2100 #[cfg(feature = "ws")]
2101 channels: crate::channels::Channels::new(32),
2102 #[cfg(feature = "ws")]
2103 shutdown: tokio_util::sync::CancellationToken::new(),
2104 policy_registry: crate::authorization::PolicyRegistry::default(),
2105 forbidden_response: crate::authorization::ForbiddenResponse::default(),
2106 auth_session_key: "uid".to_owned(),
2107 shared_cache: None,
2108 };
2109
2110 let app = Router::new()
2111 .route("/account", get(account_handler))
2112 .layer(SessionLayer::new(store, SessionConfig::default()))
2113 .with_state(state);
2114
2115 let response = app
2116 .oneshot(
2117 http::Request::builder()
2118 .uri("/account")
2119 .header(COOKIE, "autumn.sid=sess1")
2120 .body(Body::empty())
2121 .unwrap(),
2122 )
2123 .await
2124 .unwrap();
2125
2126 assert_eq!(response.status(), StatusCode::OK);
2127 let body = axum::body::to_bytes(response.into_body(), usize::MAX)
2128 .await
2129 .unwrap();
2130 assert_eq!(std::str::from_utf8(&body).unwrap(), "account");
2131 }
2132
2133 #[tokio::test]
2134 async fn secured_macro_with_role_rejects_wrong_role() {
2135 use axum::Router;
2136 use axum::body::Body;
2137 use axum::routing::get;
2138 use http::header::COOKIE;
2139 use tower::ServiceExt;
2140
2141 use crate::session::{MemoryStore, SessionConfig, SessionLayer, SessionStore};
2142 use crate::state::AppState;
2143
2144 #[autumn_macros::secured("admin")]
2145 async fn admin_only() -> crate::AutumnResult<&'static str> {
2146 Ok("admin area")
2147 }
2148
2149 let store = MemoryStore::new();
2150 store
2151 .save(
2152 "sess1",
2153 std::collections::HashMap::from([
2154 ("user_id".into(), "42".into()),
2155 ("role".into(), "viewer".into()),
2156 ]),
2157 )
2158 .await
2159 .unwrap();
2160
2161 let state = AppState {
2162 extensions: std::sync::Arc::new(std::sync::RwLock::new(
2163 std::collections::HashMap::new(),
2164 )),
2165 #[cfg(feature = "db")]
2166 pool: None,
2167 #[cfg(feature = "db")]
2168 replica_pool: None,
2169 profile: None,
2170 started_at: std::time::Instant::now(),
2171 health_detailed: false,
2172 probes: crate::probe::ProbeState::ready_for_test(),
2173 metrics: crate::middleware::MetricsCollector::new(),
2174 log_levels: crate::actuator::LogLevels::new("info"),
2175 task_registry: crate::actuator::TaskRegistry::new(),
2176 job_registry: crate::actuator::JobRegistry::new(),
2177 config_props: crate::actuator::ConfigProperties::default(),
2178 #[cfg(feature = "ws")]
2179 channels: crate::channels::Channels::new(32),
2180 #[cfg(feature = "ws")]
2181 shutdown: tokio_util::sync::CancellationToken::new(),
2182 policy_registry: crate::authorization::PolicyRegistry::default(),
2183 forbidden_response: crate::authorization::ForbiddenResponse::default(),
2184 auth_session_key: "user_id".to_owned(),
2185 shared_cache: None,
2186 };
2187
2188 let app = Router::new()
2189 .route("/", get(admin_only))
2190 .layer(SessionLayer::new(store, SessionConfig::default()))
2191 .with_state(state);
2192
2193 let response = app
2194 .oneshot(
2195 http::Request::builder()
2196 .uri("/")
2197 .header(COOKIE, "autumn.sid=sess1")
2198 .body(Body::empty())
2199 .unwrap(),
2200 )
2201 .await
2202 .unwrap();
2203
2204 assert_eq!(response.status(), StatusCode::FORBIDDEN);
2205 }
2206
2207 #[tokio::test]
2208 async fn secured_macro_with_multiple_roles_allows_match() {
2209 use axum::Router;
2210 use axum::body::Body;
2211 use axum::routing::get;
2212 use http::header::COOKIE;
2213 use tower::ServiceExt;
2214
2215 use crate::session::{MemoryStore, SessionConfig, SessionLayer, SessionStore};
2216 use crate::state::AppState;
2217
2218 #[autumn_macros::secured("admin", "editor")]
2219 async fn content_handler() -> crate::AutumnResult<&'static str> {
2220 Ok("content")
2221 }
2222
2223 let store = MemoryStore::new();
2224 store
2225 .save(
2226 "sess1",
2227 std::collections::HashMap::from([
2228 ("user_id".into(), "42".into()),
2229 ("role".into(), "editor".into()),
2230 ]),
2231 )
2232 .await
2233 .unwrap();
2234
2235 let state = AppState {
2236 extensions: std::sync::Arc::new(std::sync::RwLock::new(
2237 std::collections::HashMap::new(),
2238 )),
2239 #[cfg(feature = "db")]
2240 pool: None,
2241 #[cfg(feature = "db")]
2242 replica_pool: None,
2243 profile: None,
2244 started_at: std::time::Instant::now(),
2245 health_detailed: false,
2246 probes: crate::probe::ProbeState::ready_for_test(),
2247 metrics: crate::middleware::MetricsCollector::new(),
2248 log_levels: crate::actuator::LogLevels::new("info"),
2249 task_registry: crate::actuator::TaskRegistry::new(),
2250 job_registry: crate::actuator::JobRegistry::new(),
2251 config_props: crate::actuator::ConfigProperties::default(),
2252 #[cfg(feature = "ws")]
2253 channels: crate::channels::Channels::new(32),
2254 #[cfg(feature = "ws")]
2255 shutdown: tokio_util::sync::CancellationToken::new(),
2256 policy_registry: crate::authorization::PolicyRegistry::default(),
2257 forbidden_response: crate::authorization::ForbiddenResponse::default(),
2258 auth_session_key: "user_id".to_owned(),
2259 shared_cache: None,
2260 };
2261
2262 let app = Router::new()
2263 .route("/", get(content_handler))
2264 .layer(SessionLayer::new(store, SessionConfig::default()))
2265 .with_state(state);
2266
2267 let response = app
2268 .oneshot(
2269 http::Request::builder()
2270 .uri("/")
2271 .header(COOKIE, "autumn.sid=sess1")
2272 .body(Body::empty())
2273 .unwrap(),
2274 )
2275 .await
2276 .unwrap();
2277
2278 assert_eq!(response.status(), StatusCode::OK);
2279 let body = axum::body::to_bytes(response.into_body(), usize::MAX)
2280 .await
2281 .unwrap();
2282 assert_eq!(std::str::from_utf8(&body).unwrap(), "content");
2283 }
2284
2285 #[tokio::test]
2286 async fn require_auth_allows_authenticated() {
2287 use axum::Router;
2288 use axum::body::Body;
2289 use axum::routing::get;
2290 use http::header::COOKIE;
2291 use tower::ServiceExt;
2292
2293 use crate::session::{MemoryStore, SessionConfig, SessionLayer, SessionStore};
2294 use crate::state::AppState;
2295
2296 let store = MemoryStore::new();
2297 let mut session_data = std::collections::HashMap::new();
2299 session_data.insert("user_id".into(), "42".into());
2300 store.save("valid-session", session_data).await.unwrap();
2301
2302 let state = AppState {
2303 extensions: std::sync::Arc::new(std::sync::RwLock::new(
2304 std::collections::HashMap::new(),
2305 )),
2306 #[cfg(feature = "db")]
2307 pool: None,
2308 #[cfg(feature = "db")]
2309 replica_pool: None,
2310 profile: None,
2311 started_at: std::time::Instant::now(),
2312 health_detailed: false,
2313 probes: crate::probe::ProbeState::ready_for_test(),
2314 metrics: crate::middleware::MetricsCollector::new(),
2315 log_levels: crate::actuator::LogLevels::new("info"),
2316 task_registry: crate::actuator::TaskRegistry::new(),
2317 job_registry: crate::actuator::JobRegistry::new(),
2318 config_props: crate::actuator::ConfigProperties::default(),
2319 #[cfg(feature = "ws")]
2320 channels: crate::channels::Channels::new(32),
2321 #[cfg(feature = "ws")]
2322 shutdown: tokio_util::sync::CancellationToken::new(),
2323 policy_registry: crate::authorization::PolicyRegistry::default(),
2324 forbidden_response: crate::authorization::ForbiddenResponse::default(),
2325 auth_session_key: "user_id".to_owned(),
2326 shared_cache: None,
2327 };
2328
2329 let app = Router::new()
2330 .route("/protected", get(|| async { "secret" }))
2331 .layer(RequireAuth::new("user_id"))
2332 .layer(SessionLayer::new(store, SessionConfig::default()))
2333 .with_state(state);
2334
2335 let response = app
2336 .oneshot(
2337 http::Request::builder()
2338 .uri("/protected")
2339 .header(COOKIE, "autumn.sid=valid-session")
2340 .body(Body::empty())
2341 .unwrap(),
2342 )
2343 .await
2344 .unwrap();
2345
2346 assert_eq!(response.status(), StatusCode::OK);
2347 let body = axum::body::to_bytes(response.into_body(), usize::MAX)
2348 .await
2349 .unwrap();
2350 assert_eq!(std::str::from_utf8(&body).unwrap(), "secret");
2351 }
2352
2353 #[tokio::test]
2354 async fn require_auth_poll_ready_propagates() {
2355 use std::task::{Context, Poll};
2356 use tower::{Layer, Service};
2357
2358 #[derive(Clone)]
2359 struct MockService {
2360 ready: bool,
2361 poll_count: std::sync::Arc<std::sync::atomic::AtomicUsize>,
2362 }
2363
2364 impl Service<axum::extract::Request> for MockService {
2365 type Response = axum::response::Response;
2366 type Error = std::convert::Infallible;
2367 type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
2368
2369 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
2370 self.poll_count
2371 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
2372 if self.ready {
2373 Poll::Ready(Ok(()))
2374 } else {
2375 Poll::Pending
2376 }
2377 }
2378
2379 fn call(&mut self, _req: axum::extract::Request) -> Self::Future {
2380 std::future::ready(Ok(axum::response::Response::new(axum::body::Body::empty())))
2381 }
2382 }
2383
2384 let layer = RequireAuth::new("user_id");
2385 let poll_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
2386 let mock_service = MockService {
2387 ready: false,
2388 poll_count: poll_count.clone(),
2389 };
2390 let mut service = layer.layer(mock_service);
2391
2392 let waker = futures::task::noop_waker();
2393 let mut cx = Context::from_waker(&waker);
2394
2395 let poll = service.poll_ready(&mut cx);
2397 assert!(poll.is_pending());
2398 assert_eq!(poll_count.load(std::sync::atomic::Ordering::SeqCst), 1);
2399
2400 let mock_service_ready = MockService {
2402 ready: true,
2403 poll_count: poll_count.clone(),
2404 };
2405 let mut service_ready = layer.layer(mock_service_ready);
2406 let poll_ready = service_ready.poll_ready(&mut cx);
2407 assert!(poll_ready.is_ready());
2408 assert_eq!(poll_count.load(std::sync::atomic::Ordering::SeqCst), 2);
2409 }
2410
2411 #[tokio::test]
2412 async fn auth_rejection_into_response() {
2413 let rejection = AuthRejection;
2414 let response = rejection.into_response();
2415 assert_eq!(response.status(), axum::http::StatusCode::UNAUTHORIZED);
2416 let body = axum::body::to_bytes(response.into_body(), usize::MAX)
2417 .await
2418 .unwrap();
2419 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2420 assert_eq!(json["status"], 401);
2421 assert_eq!(json["detail"], "authentication required");
2422 assert_eq!(json["code"], "autumn.unauthorized");
2423 }
2424
2425 #[test]
2426 fn test_auth_config_defaults() {
2427 let config = AuthConfig::default();
2428 assert_eq!(config.bcrypt_cost, DEFAULT_BCRYPT_COST);
2429 assert_eq!(config.session_key, "user_id");
2430 }
2431
2432 #[tokio::test]
2433 async fn test_hash_password() {
2434 let test_input = uuid::Uuid::new_v4().to_string();
2435
2436 let hash = super::hash_password(&test_input)
2438 .await
2439 .expect("Failed to hash password");
2440 assert!(hash.starts_with("$2b$"));
2441
2442 let is_valid = super::verify_password(&test_input, &hash)
2444 .await
2445 .expect("Failed to verify password");
2446 assert!(is_valid, "Password should be verified successfully");
2447
2448 let is_invalid = super::verify_password(&uuid::Uuid::new_v4().to_string(), &hash)
2450 .await
2451 .expect("Failed to verify wrong password");
2452 assert!(!is_invalid, "Wrong password should not be verified");
2453 }
2454
2455 #[tokio::test]
2456 async fn test_hash_password_empty() {
2457 let test_input = String::new();
2458 let hash = super::hash_password(&test_input)
2459 .await
2460 .expect("Failed to hash empty password");
2461 assert!(hash.starts_with("$2b$"));
2462
2463 let is_valid = super::verify_password(&test_input, &hash)
2464 .await
2465 .expect("Failed to verify empty password");
2466 assert!(is_valid, "Empty password should be verified successfully");
2467 }
2468
2469 #[tokio::test]
2470 async fn test_hash_password_long() {
2471 let test_input = "a".repeat(100);
2473 let hash = super::hash_password(&test_input)
2474 .await
2475 .expect("Failed to hash long password");
2476 assert!(hash.starts_with("$2b$"));
2477
2478 let is_valid = super::verify_password(&test_input, &hash)
2479 .await
2480 .expect("Failed to verify long password");
2481 assert!(is_valid, "Long password should be verified successfully");
2482 }
2483
2484 #[tokio::test]
2485 async fn test_hash_password_unicode() {
2486 let test_input = format!("{}🚀my_secrët_passwörd🔑", uuid::Uuid::new_v4());
2488 let hash = super::hash_password(&test_input)
2489 .await
2490 .expect("Failed to hash unicode password");
2491 assert!(hash.starts_with("$2b$"));
2492
2493 let is_valid = super::verify_password(&test_input, &hash)
2494 .await
2495 .expect("Failed to verify unicode password");
2496 assert!(is_valid, "Unicode password should be verified successfully");
2497 }
2498
2499 #[tokio::test]
2500 async fn test_verify_password_invalid_hash() {
2501 let test_input = uuid::Uuid::new_v4().to_string();
2503
2504 let result = super::verify_password(&test_input, "invalid_hash_string").await;
2506 assert!(result.is_err() || !result.unwrap());
2507
2508 let result2 = super::verify_password(&test_input, "$2b$04$").await;
2510 assert!(result2.is_err() || !result2.unwrap());
2511 }
2512}
2513
2514#[cfg(test)]
2517mod api_token_tests {
2518 use std::sync::Arc;
2519
2520 use http::StatusCode;
2521
2522 use super::{
2523 ApiToken, ApiTokenStore, InMemoryApiTokenStore, RequireApiToken, hash_api_token,
2524 issue_api_token, revoke_api_token,
2525 };
2526
2527 struct FailingApiTokenStore;
2528
2529 impl ApiTokenStore for FailingApiTokenStore {
2530 fn issue<'a>(
2531 &'a self,
2532 _principal_id: &'a str,
2533 ) -> std::pin::Pin<
2534 Box<dyn std::future::Future<Output = crate::AutumnResult<String>> + Send + 'a>,
2535 > {
2536 Box::pin(async {
2537 Err(crate::AutumnError::service_unavailable_msg(
2538 "api token store unavailable",
2539 ))
2540 })
2541 }
2542
2543 fn verify<'a>(
2544 &'a self,
2545 _raw_token: &'a str,
2546 ) -> std::pin::Pin<
2547 Box<dyn std::future::Future<Output = crate::AutumnResult<Option<String>>> + Send + 'a>,
2548 > {
2549 Box::pin(async {
2550 Err(crate::AutumnError::service_unavailable_msg(
2551 "api token store unavailable",
2552 ))
2553 })
2554 }
2555
2556 fn revoke<'a>(
2557 &'a self,
2558 _raw_token: &'a str,
2559 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = crate::AutumnResult<()>> + Send + 'a>>
2560 {
2561 Box::pin(async {
2562 Err(crate::AutumnError::service_unavailable_msg(
2563 "api token store unavailable",
2564 ))
2565 })
2566 }
2567 }
2568
2569 #[test]
2572 fn hash_api_token_is_deterministic() {
2573 let h1 = hash_api_token("abc123");
2574 let h2 = hash_api_token("abc123");
2575 assert_eq!(h1, h2);
2576 }
2577
2578 #[test]
2579 fn hash_api_token_produces_64_char_hex() {
2580 let hash = hash_api_token("any_raw_token");
2581 assert_eq!(hash.len(), 64, "SHA-256 hex must be 64 chars");
2582 assert!(
2583 hash.chars().all(|c| c.is_ascii_hexdigit()),
2584 "hash must be lowercase hex digits"
2585 );
2586 }
2587
2588 #[test]
2589 fn hash_api_token_differs_from_input() {
2590 let raw = "my_raw_token";
2591 assert_ne!(hash_api_token(raw), raw);
2592 }
2593
2594 #[test]
2595 fn hash_api_token_different_inputs_produce_different_hashes() {
2596 assert_ne!(hash_api_token("token_a"), hash_api_token("token_b"));
2597 }
2598
2599 #[tokio::test]
2602 async fn in_memory_store_issue_returns_unique_tokens() {
2603 let store = InMemoryApiTokenStore::default();
2604 let t1 = store.issue("user:1").await.unwrap();
2605 let t2 = store.issue("user:1").await.unwrap();
2606 assert_ne!(t1, t2, "each issued token must be unique");
2607 assert!(t1.len() >= 32, "token must have sufficient entropy");
2608 }
2609
2610 #[tokio::test]
2611 async fn in_memory_store_verify_returns_principal_for_valid_token() {
2612 let store = InMemoryApiTokenStore::default();
2613 let raw = store.issue("user:42").await.unwrap();
2614 let principal = store.verify(&raw).await.unwrap();
2615 assert_eq!(principal, Some("user:42".to_owned()));
2616 }
2617
2618 #[tokio::test]
2619 async fn in_memory_store_verify_returns_none_for_unknown_token() {
2620 let store = InMemoryApiTokenStore::default();
2621 let result = store.verify("not_a_real_token").await.unwrap();
2622 assert_eq!(result, None);
2623 }
2624
2625 #[tokio::test]
2626 async fn in_memory_store_revoke_invalidates_token() {
2627 let store = InMemoryApiTokenStore::default();
2628 let raw = store.issue("user:7").await.unwrap();
2629 assert_eq!(
2630 store.verify(&raw).await.unwrap(),
2631 Some("user:7".to_owned()),
2632 "token must be valid before revoking"
2633 );
2634 store.revoke(&raw).await.unwrap();
2635 assert_eq!(store.verify(&raw).await.unwrap(), None);
2636 }
2637
2638 #[tokio::test]
2639 async fn in_memory_store_raw_token_not_stored_verbatim() {
2640 let store = InMemoryApiTokenStore::default();
2641 let raw = store.issue("user:1").await.unwrap();
2642 let tampered = format!("{raw}x");
2644 assert_eq!(store.verify(&tampered).await.unwrap(), None);
2645 }
2646
2647 #[tokio::test]
2648 async fn issue_api_token_helper_issues_verifiable_token() {
2649 let store = InMemoryApiTokenStore::default();
2650 let raw = issue_api_token(&store, "user:5").await.unwrap();
2651 assert_eq!(store.verify(&raw).await.unwrap(), Some("user:5".to_owned()));
2652 }
2653
2654 #[tokio::test]
2655 async fn revoke_api_token_helper_revokes_token() {
2656 let store = InMemoryApiTokenStore::default();
2657 let raw = store.issue("user:6").await.unwrap();
2658 revoke_api_token(&store, &raw).await.unwrap();
2659 assert_eq!(store.verify(&raw).await.unwrap(), None);
2660 }
2661
2662 #[tokio::test]
2665 async fn require_api_token_rejects_missing_authorization_header() {
2666 use axum::body::Body;
2667 use tower::ServiceExt;
2668
2669 let store = Arc::new(InMemoryApiTokenStore::default());
2670 let app = axum::Router::new()
2671 .route("/", axum::routing::get(|| async { "ok" }))
2672 .layer(RequireApiToken::new(store));
2673
2674 let response = app
2675 .oneshot(
2676 http::Request::builder()
2677 .uri("/")
2678 .body(Body::empty())
2679 .unwrap(),
2680 )
2681 .await
2682 .unwrap();
2683
2684 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
2685 }
2686
2687 #[tokio::test]
2688 async fn require_api_token_rejects_non_bearer_scheme() {
2689 use axum::body::Body;
2690 use tower::ServiceExt;
2691
2692 let store = Arc::new(InMemoryApiTokenStore::default());
2693 let app = axum::Router::new()
2694 .route("/", axum::routing::get(|| async { "ok" }))
2695 .layer(RequireApiToken::new(store));
2696
2697 let response = app
2698 .oneshot(
2699 http::Request::builder()
2700 .uri("/")
2701 .header(http::header::AUTHORIZATION, "Basic dXNlcjpwYXNz")
2702 .body(Body::empty())
2703 .unwrap(),
2704 )
2705 .await
2706 .unwrap();
2707
2708 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
2709 }
2710
2711 #[tokio::test]
2712 async fn require_api_token_rejects_unknown_bearer_token() {
2713 use axum::body::Body;
2714 use tower::ServiceExt;
2715
2716 let store = Arc::new(InMemoryApiTokenStore::default());
2717 let app = axum::Router::new()
2718 .route("/", axum::routing::get(|| async { "ok" }))
2719 .layer(RequireApiToken::new(store));
2720
2721 let response = app
2722 .oneshot(
2723 http::Request::builder()
2724 .uri("/")
2725 .header(http::header::AUTHORIZATION, "Bearer unknown_token_xyz")
2726 .body(Body::empty())
2727 .unwrap(),
2728 )
2729 .await
2730 .unwrap();
2731
2732 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
2733 }
2734
2735 #[tokio::test]
2736 async fn require_api_token_propagates_store_verify_errors() {
2737 use axum::body::Body;
2738 use tower::ServiceExt;
2739
2740 let store = Arc::new(FailingApiTokenStore);
2741 let app = axum::Router::new()
2742 .route("/", axum::routing::get(|| async { "ok" }))
2743 .layer(RequireApiToken::new(store));
2744
2745 let response = app
2746 .oneshot(
2747 http::Request::builder()
2748 .uri("/")
2749 .header(http::header::AUTHORIZATION, "Bearer valid_client_token")
2750 .body(Body::empty())
2751 .unwrap(),
2752 )
2753 .await
2754 .unwrap();
2755
2756 assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
2757 assert_eq!(
2758 response
2759 .headers()
2760 .get(http::header::CONTENT_TYPE)
2761 .map(|value| value.to_str().unwrap_or_default()),
2762 Some("application/problem+json")
2763 );
2764
2765 let body = axum::body::to_bytes(response.into_body(), usize::MAX)
2766 .await
2767 .unwrap();
2768 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2769 assert_eq!(json["status"], 503);
2770 assert_eq!(json["code"], "autumn.service_unavailable");
2771 assert_eq!(json["detail"], "api token store unavailable");
2772 }
2773
2774 #[tokio::test]
2775 async fn require_api_token_allows_valid_bearer_token() {
2776 use axum::body::Body;
2777 use tower::ServiceExt;
2778
2779 let store = Arc::new(InMemoryApiTokenStore::default());
2780 let raw = store.issue("user:1").await.unwrap();
2781 let app = axum::Router::new()
2782 .route("/", axum::routing::get(|| async { "ok" }))
2783 .layer(RequireApiToken::new(Arc::clone(&store)));
2784
2785 let response = app
2786 .oneshot(
2787 http::Request::builder()
2788 .uri("/")
2789 .header(http::header::AUTHORIZATION, format!("Bearer {raw}"))
2790 .body(Body::empty())
2791 .unwrap(),
2792 )
2793 .await
2794 .unwrap();
2795
2796 assert_eq!(response.status(), StatusCode::OK);
2797 }
2798
2799 #[tokio::test]
2800 async fn require_api_token_accepts_case_insensitive_bearer_scheme() {
2801 use axum::body::Body;
2802 use tower::ServiceExt;
2803
2804 let store = Arc::new(InMemoryApiTokenStore::default());
2805 let raw = store.issue("user:1").await.unwrap();
2806
2807 for scheme in ["bearer", "bEaReR"] {
2808 let app = axum::Router::new()
2809 .route("/", axum::routing::get(|| async { "ok" }))
2810 .layer(RequireApiToken::new(Arc::clone(&store)));
2811
2812 let response = app
2813 .oneshot(
2814 http::Request::builder()
2815 .uri("/")
2816 .header(http::header::AUTHORIZATION, format!("{scheme} {raw}"))
2817 .body(Body::empty())
2818 .unwrap(),
2819 )
2820 .await
2821 .unwrap();
2822
2823 assert_eq!(response.status(), StatusCode::OK, "scheme {scheme}");
2824 }
2825 }
2826
2827 #[tokio::test]
2828 async fn require_api_token_rejects_revoked_token() {
2829 use axum::body::Body;
2830 use tower::ServiceExt;
2831
2832 let store = Arc::new(InMemoryApiTokenStore::default());
2833 let raw = store.issue("user:1").await.unwrap();
2834 store.revoke(&raw).await.unwrap();
2835 let app = axum::Router::new()
2836 .route("/", axum::routing::get(|| async { "ok" }))
2837 .layer(RequireApiToken::new(Arc::clone(&store)));
2838
2839 let response = app
2840 .oneshot(
2841 http::Request::builder()
2842 .uri("/")
2843 .header(http::header::AUTHORIZATION, format!("Bearer {raw}"))
2844 .body(Body::empty())
2845 .unwrap(),
2846 )
2847 .await
2848 .unwrap();
2849
2850 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
2851 }
2852
2853 #[tokio::test]
2854 async fn require_api_token_401_response_has_problem_details() {
2855 use axum::body::Body;
2856 use tower::ServiceExt;
2857
2858 let store = Arc::new(InMemoryApiTokenStore::default());
2859 let app = axum::Router::new()
2860 .route("/", axum::routing::get(|| async { "ok" }))
2861 .layer(RequireApiToken::new(store));
2862
2863 let response = app
2864 .oneshot(
2865 http::Request::builder()
2866 .uri("/")
2867 .body(Body::empty())
2868 .unwrap(),
2869 )
2870 .await
2871 .unwrap();
2872
2873 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
2874 assert_eq!(
2875 response
2876 .headers()
2877 .get(http::header::CONTENT_TYPE)
2878 .map(|v| v.to_str().unwrap_or_default()),
2879 Some("application/problem+json")
2880 );
2881 let body = axum::body::to_bytes(response.into_body(), usize::MAX)
2882 .await
2883 .unwrap();
2884 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2885 assert_eq!(json["status"], 401);
2886 assert_eq!(json["code"], "autumn.unauthorized");
2887 assert!(json["detail"].as_str().is_some());
2888 }
2889
2890 #[tokio::test]
2891 async fn require_api_token_401_problem_details_include_request_context() {
2892 use crate::middleware::RequestIdLayer;
2893 use axum::body::Body;
2894 use tower::ServiceExt;
2895
2896 let store = Arc::new(InMemoryApiTokenStore::default());
2897 let app = axum::Router::new()
2898 .route("/api/private", axum::routing::get(|| async { "ok" }))
2899 .layer(RequireApiToken::new(store))
2900 .layer(RequestIdLayer);
2901
2902 let response = app
2903 .oneshot(
2904 http::Request::builder()
2905 .uri("/api/private")
2906 .body(Body::empty())
2907 .unwrap(),
2908 )
2909 .await
2910 .unwrap();
2911
2912 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
2913 let request_id = response
2914 .headers()
2915 .get("x-request-id")
2916 .and_then(|value| value.to_str().ok())
2917 .expect("request id header should be present")
2918 .to_owned();
2919 let body = axum::body::to_bytes(response.into_body(), usize::MAX)
2920 .await
2921 .unwrap();
2922 let json: serde_json::Value = serde_json::from_slice(&body).unwrap();
2923 assert_eq!(json["request_id"], request_id);
2924 assert_eq!(json["instance"], "/api/private");
2925 }
2926
2927 #[tokio::test]
2930 async fn api_token_extractor_yields_principal_id_to_handler() {
2931 use axum::body::Body;
2932 use tower::ServiceExt;
2933
2934 async fn handler(ApiToken(principal): ApiToken) -> String {
2935 principal
2936 }
2937
2938 let store = Arc::new(InMemoryApiTokenStore::default());
2939 let raw = store.issue("user:99").await.unwrap();
2940 let app = axum::Router::new()
2941 .route("/", axum::routing::get(handler))
2942 .layer(RequireApiToken::new(Arc::clone(&store)));
2943
2944 let response = app
2945 .oneshot(
2946 http::Request::builder()
2947 .uri("/")
2948 .header(http::header::AUTHORIZATION, format!("Bearer {raw}"))
2949 .body(Body::empty())
2950 .unwrap(),
2951 )
2952 .await
2953 .unwrap();
2954
2955 assert_eq!(response.status(), StatusCode::OK);
2956 let body = axum::body::to_bytes(response.into_body(), usize::MAX)
2957 .await
2958 .unwrap();
2959 assert_eq!(std::str::from_utf8(&body).unwrap(), "user:99");
2960 }
2961
2962 #[tokio::test]
2963 async fn api_token_extractor_rejects_when_no_principal_in_extensions() {
2964 use axum::body::Body;
2965 use tower::ServiceExt;
2966
2967 async fn handler(ApiToken(principal): ApiToken) -> String {
2968 principal
2969 }
2970
2971 let app = axum::Router::new().route("/", axum::routing::get(handler));
2972
2973 let response = app
2974 .oneshot(
2975 http::Request::builder()
2976 .uri("/")
2977 .body(Body::empty())
2978 .unwrap(),
2979 )
2980 .await
2981 .unwrap();
2982
2983 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
2984 }
2985
2986 #[tokio::test]
2989 async fn api_token_and_session_auth_compose_without_conflict() {
2990 use axum::body::Body;
2991 use tower::ServiceExt;
2992
2993 use crate::session::{MemoryStore, SessionConfig, SessionLayer, SessionStore};
2994
2995 async fn api_handler(ApiToken(principal): ApiToken) -> String {
2996 principal
2997 }
2998
2999 let store = Arc::new(InMemoryApiTokenStore::default());
3000 let raw = store.issue("api_user").await.unwrap();
3001
3002 let session_store = MemoryStore::new();
3003 session_store
3004 .save(
3005 "sess1",
3006 std::collections::HashMap::from([("user_id".into(), "session_user".into())]),
3007 )
3008 .await
3009 .unwrap();
3010
3011 let app = axum::Router::new()
3012 .route(
3013 "/api",
3014 axum::routing::get(api_handler).layer(RequireApiToken::new(Arc::clone(&store))),
3015 )
3016 .layer(SessionLayer::new(session_store, SessionConfig::default()));
3017
3018 let response = app
3019 .oneshot(
3020 http::Request::builder()
3021 .uri("/api")
3022 .header(http::header::AUTHORIZATION, format!("Bearer {raw}"))
3023 .body(Body::empty())
3024 .unwrap(),
3025 )
3026 .await
3027 .unwrap();
3028
3029 assert_eq!(response.status(), StatusCode::OK);
3030 let body = axum::body::to_bytes(response.into_body(), usize::MAX)
3031 .await
3032 .unwrap();
3033 assert_eq!(std::str::from_utf8(&body).unwrap(), "api_user");
3034 }
3035
3036 #[tokio::test]
3039 async fn require_api_token_poll_ready_propagates_to_inner() {
3040 use std::task::{Context, Poll};
3041 use tower::{Layer, Service};
3042
3043 #[derive(Clone)]
3044 struct MockService {
3045 ready: bool,
3046 }
3047
3048 impl tower::Service<axum::extract::Request> for MockService {
3049 type Response = axum::response::Response;
3050 type Error = std::convert::Infallible;
3051 type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
3052
3053 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
3054 if self.ready {
3055 Poll::Ready(Ok(()))
3056 } else {
3057 Poll::Pending
3058 }
3059 }
3060
3061 fn call(&mut self, _req: axum::extract::Request) -> Self::Future {
3062 std::future::ready(Ok(axum::response::Response::new(axum::body::Body::empty())))
3063 }
3064 }
3065
3066 let waker = futures::task::noop_waker();
3067 let mut cx = Context::from_waker(&waker);
3068
3069 let store = Arc::new(InMemoryApiTokenStore::default());
3070 let layer = RequireApiToken::new(store);
3071 let mut svc = layer.layer(MockService { ready: false });
3072 assert!(svc.poll_ready(&mut cx).is_pending());
3073
3074 let store2 = Arc::new(InMemoryApiTokenStore::default());
3075 let layer2 = RequireApiToken::new(store2);
3076 let mut svc2 = layer2.layer(MockService { ready: true });
3077 assert!(svc2.poll_ready(&mut cx).is_ready());
3078 }
3079}