solo_api/auth/
middleware.rs1use super::{
15 AuthConfig, AuthError, bearer::BearerValidator, oidc::OidcConfig, oidc::OidcValidator,
16};
17use axum::extract::{Request, State};
18use axum::http::{HeaderValue, StatusCode, header};
19use axum::middleware::Next;
20use axum::response::{IntoResponse, Response};
21use solo_core::TenantId;
22use std::sync::Arc;
23
24#[derive(Debug, Clone)]
28pub enum AuthValidator {
29 Bearer(BearerValidator),
30 Oidc(OidcValidator),
31}
32
33impl AuthValidator {
34 pub fn from_config(config: &AuthConfig, default_tenant: TenantId) -> Self {
48 match config {
49 AuthConfig::Bearer { token } => {
50 if token.is_empty() {
51 panic!(
52 "auth: bearer mode requires a non-empty token in [auth].token. \
53 Set a real token or remove the [auth] block to use \
54 --bearer-token-file instead."
55 );
56 }
57 Self::Bearer(BearerValidator::new(token.clone(), default_tenant))
58 }
59 AuthConfig::Oidc {
60 discovery_url,
61 audience,
62 tenant_claim_name,
63 } => Self::Oidc(OidcValidator::new(OidcConfig {
64 discovery_url: discovery_url.clone(),
65 audience: audience.clone(),
66 tenant_claim_name: tenant_claim_name.clone(),
67 })),
68 }
69 }
70}
71
72pub async fn auth_middleware(
76 State(validator): State<Arc<AuthValidator>>,
77 mut req: Request,
78 next: Next,
79) -> Response {
80 let auth_header = req
81 .headers()
82 .get(header::AUTHORIZATION)
83 .and_then(|h| h.to_str().ok())
84 .map(|s| s.to_string());
85
86 let principal_result = match validator.as_ref() {
87 AuthValidator::Bearer(v) => v.validate(auth_header.as_deref()),
88 AuthValidator::Oidc(v) => v.validate(auth_header.as_deref()).await,
89 };
90
91 let principal = match principal_result {
92 Ok(p) => p,
93 Err(e) => return error_response(&e),
94 };
95
96 req.extensions_mut().insert(principal);
97 next.run(req).await
98}
99
100fn error_response(err: &AuthError) -> Response {
104 let status = match err {
105 AuthError::MissingAuthHeader
106 | AuthError::MalformedAuthHeader
107 | AuthError::InvalidBearer
108 | AuthError::InvalidOidcToken { .. } => StatusCode::UNAUTHORIZED,
109 AuthError::MissingTenantClaim { .. } | AuthError::InvalidTenantClaim(_) => {
110 StatusCode::FORBIDDEN
111 }
112 AuthError::Discovery(_) | AuthError::Jwks(_) => StatusCode::INTERNAL_SERVER_ERROR,
113 };
114 let body = axum::Json(serde_json::json!({
115 "error": err.to_string(),
116 "status": status.as_u16(),
117 }));
118 let mut resp = (status, body).into_response();
119 if status == StatusCode::UNAUTHORIZED {
120 resp.headers_mut().insert(
121 axum::http::header::WWW_AUTHENTICATE,
122 HeaderValue::from_static(r#"Bearer realm="solo""#),
123 );
124 }
125 resp
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use crate::auth::{AuthConfig, AuthenticatedPrincipal};
132 use axum::Extension;
133 use axum::Router;
134 use axum::body::Body;
135 use axum::http::{Request, StatusCode};
136 use axum::routing::get;
137 use http_body_util::BodyExt;
138 use tower::ServiceExt;
139
140 async fn echo_principal(Extension(p): Extension<AuthenticatedPrincipal>) -> String {
141 format!("subject={};tenant={:?}", p.subject, p.tenant_claim)
142 }
143
144 fn router_with_validator(validator: Arc<AuthValidator>) -> Router {
145 Router::new().route("/echo", get(echo_principal)).layer(
146 axum::middleware::from_fn_with_state(validator, auth_middleware),
147 )
148 }
149
150 #[tokio::test]
151 async fn bearer_inserts_principal_into_extension() {
152 let cfg = AuthConfig::Bearer {
153 token: "abc".to_string(),
154 };
155 let v = Arc::new(AuthValidator::from_config(&cfg, TenantId::default_tenant()));
156 let router = router_with_validator(v);
157
158 let req = Request::builder()
159 .uri("/echo")
160 .header("authorization", "Bearer abc")
161 .body(Body::empty())
162 .unwrap();
163 let resp = router.oneshot(req).await.unwrap();
164 assert_eq!(resp.status(), StatusCode::OK);
165 let body = resp.into_body().collect().await.unwrap().to_bytes();
166 let s = String::from_utf8_lossy(&body);
167 assert!(s.starts_with("subject=bearer;"), "got {s}");
168 }
169
170 #[tokio::test]
171 async fn bearer_missing_returns_401_with_www_authenticate() {
172 let cfg = AuthConfig::Bearer {
173 token: "abc".to_string(),
174 };
175 let v = Arc::new(AuthValidator::from_config(&cfg, TenantId::default_tenant()));
176 let router = router_with_validator(v);
177
178 let req = Request::builder().uri("/echo").body(Body::empty()).unwrap();
179 let resp = router.oneshot(req).await.unwrap();
180 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
181 let www = resp
182 .headers()
183 .get("www-authenticate")
184 .and_then(|v| v.to_str().ok())
185 .unwrap_or("");
186 assert!(www.starts_with("Bearer"), "got {www}");
187 }
188
189 #[tokio::test]
190 async fn bearer_wrong_token_returns_401() {
191 let cfg = AuthConfig::Bearer {
192 token: "abc".to_string(),
193 };
194 let v = Arc::new(AuthValidator::from_config(&cfg, TenantId::default_tenant()));
195 let router = router_with_validator(v);
196
197 let req = Request::builder()
198 .uri("/echo")
199 .header("authorization", "Bearer wrong")
200 .body(Body::empty())
201 .unwrap();
202 let resp = router.oneshot(req).await.unwrap();
203 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
204 }
205}