avl_console/middleware/
auth.rs1use crate::{error::ConsoleError, state::AppState};
4use axum::{
5 extract::{Request, State},
6 http::StatusCode,
7 middleware::Next,
8 response::{IntoResponse, Response},
9};
10use std::sync::Arc;
11use tower::{Layer, Service};
12
13#[derive(Clone)]
15pub struct AuthLayer {
16 state: Arc<AppState>,
17}
18
19impl AuthLayer {
20 pub fn new(state: Arc<AppState>) -> Self {
21 Self { state }
22 }
23}
24
25impl<S> Layer<S> for AuthLayer {
26 type Service = AuthMiddleware<S>;
27
28 fn layer(&self, inner: S) -> Self::Service {
29 AuthMiddleware {
30 inner,
31 state: self.state.clone(),
32 }
33 }
34}
35
36#[derive(Clone)]
37pub struct AuthMiddleware<S> {
38 inner: S,
39 state: Arc<AppState>,
40}
41
42impl<S> Service<Request> for AuthMiddleware<S>
43where
44 S: Service<Request, Response = Response> + Send + 'static,
45 S::Future: Send + 'static,
46{
47 type Response = S::Response;
48 type Error = S::Error;
49 type Future = std::pin::Pin<
50 Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
51 >;
52
53 fn poll_ready(
54 &mut self,
55 cx: &mut std::task::Context<'_>,
56 ) -> std::task::Poll<Result<(), Self::Error>> {
57 self.inner.poll_ready(cx)
58 }
59
60 fn call(&mut self, req: Request) -> Self::Future {
61 let state = self.state.clone();
62
63 let path = req.uri().path().to_string(); let cookies = req
66 .headers()
67 .get("cookie")
68 .and_then(|v| v.to_str().ok())
69 .unwrap_or("")
70 .to_string();
71
72 let future = self.inner.call(req);
73
74 Box::pin(async move {
75 if path.starts_with("/static") || path == "/login" || path == "/health" {
76 return future.await;
77 }
78
79 let session_id = extract_session_id(&cookies);
80
81 if let Some(sid) = session_id {
82 if let Some(_user_id) = state.get_session(&sid).await {
83 return future.await;
85 }
86 }
87
88 Ok(ConsoleError::Authentication("Session expired or invalid".to_string())
90 .into_response())
91 })
92 }
93}
94
95fn extract_session_id(cookies: &str) -> Option<String> {
96 cookies
97 .split(';')
98 .find_map(|cookie| {
99 let parts: Vec<&str> = cookie.trim().splitn(2, '=').collect();
100 if parts.len() == 2 && parts[0] == "avl_session" {
101 Some(parts[1].to_string())
102 } else {
103 None
104 }
105 })
106}
107
108pub async fn auth_middleware(
110 State(state): State<Arc<AppState>>,
111 req: Request,
112 next: Next,
113) -> Result<Response, StatusCode> {
114 let path = req.uri().path();
116 if path.starts_with("/static") || path == "/login" || path == "/health" {
117 return Ok(next.run(req).await);
118 }
119
120 let cookies = req
122 .headers()
123 .get("cookie")
124 .and_then(|v| v.to_str().ok())
125 .unwrap_or("");
126
127 let session_id = extract_session_id(cookies);
128
129 if let Some(sid) = session_id {
130 if let Some(_user_id) = state.get_session(&sid).await {
131 return Ok(next.run(req).await);
132 }
133 }
134
135 Err(StatusCode::UNAUTHORIZED)
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn test_extract_session_id() {
144 let cookies = "avl_session=abc123; other=value";
145 assert_eq!(extract_session_id(cookies), Some("abc123".to_string()));
146
147 let cookies = "other=value";
148 assert_eq!(extract_session_id(cookies), None);
149 }
150}