avl_console/middleware/
auth.rs

1//! Authentication middleware
2
3use crate::{error::ConsoleError, state::AppState};
4use axum::{
5    extract::{Request, State},
6    http::StatusCode,
7    middleware::Next,
8    response::{IntoResponse, Response},
9};
10use std::sync::Arc;
11use tower::{Layer, Service};
12
13/// Authentication layer
14#[derive(Clone)]
15pub struct AuthLayer {
16    state: Arc<AppState>,
17}
18
19impl AuthLayer {
20    pub fn new(state: Arc<AppState>) -> Self {
21        Self { state }
22    }
23}
24
25impl<S> Layer<S> for AuthLayer {
26    type Service = AuthMiddleware<S>;
27
28    fn layer(&self, inner: S) -> Self::Service {
29        AuthMiddleware {
30            inner,
31            state: self.state.clone(),
32        }
33    }
34}
35
36#[derive(Clone)]
37pub struct AuthMiddleware<S> {
38    inner: S,
39    state: Arc<AppState>,
40}
41
42impl<S> Service<Request> for AuthMiddleware<S>
43where
44    S: Service<Request, Response = Response> + Send + 'static,
45    S::Future: Send + 'static,
46{
47    type Response = S::Response;
48    type Error = S::Error;
49    type Future = std::pin::Pin<
50        Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>,
51    >;
52
53    fn poll_ready(
54        &mut self,
55        cx: &mut std::task::Context<'_>,
56    ) -> std::task::Poll<Result<(), Self::Error>> {
57        self.inner.poll_ready(cx)
58    }
59
60    fn call(&mut self, req: Request) -> Self::Future {
61        let state = self.state.clone();
62
63        // Extract path before moving req
64        let path = req.uri().path().to_string();        // Extract session cookie before moving req
65        let cookies = req
66            .headers()
67            .get("cookie")
68            .and_then(|v| v.to_str().ok())
69            .unwrap_or("")
70            .to_string();
71
72        let future = self.inner.call(req);
73
74        Box::pin(async move {
75            if path.starts_with("/static") || path == "/login" || path == "/health" {
76                return future.await;
77            }
78
79            let session_id = extract_session_id(&cookies);
80
81            if let Some(sid) = session_id {
82                if let Some(_user_id) = state.get_session(&sid).await {
83                    // User is authenticated
84                    return future.await;
85                }
86            }
87
88            // Not authenticated - return 401
89            Ok(ConsoleError::Authentication("Session expired or invalid".to_string())
90                .into_response())
91        })
92    }
93}
94
95fn extract_session_id(cookies: &str) -> Option<String> {
96    cookies
97        .split(';')
98        .find_map(|cookie| {
99            let parts: Vec<&str> = cookie.trim().splitn(2, '=').collect();
100            if parts.len() == 2 && parts[0] == "avl_session" {
101                Some(parts[1].to_string())
102            } else {
103                None
104            }
105        })
106}
107
108/// Extract authenticated user from request
109pub async fn auth_middleware(
110    State(state): State<Arc<AppState>>,
111    req: Request,
112    next: Next,
113) -> Result<Response, StatusCode> {
114    // Skip auth for public routes
115    let path = req.uri().path();
116    if path.starts_with("/static") || path == "/login" || path == "/health" {
117        return Ok(next.run(req).await);
118    }
119
120    // Extract session cookie
121    let cookies = req
122        .headers()
123        .get("cookie")
124        .and_then(|v| v.to_str().ok())
125        .unwrap_or("");
126
127    let session_id = extract_session_id(cookies);
128
129    if let Some(sid) = session_id {
130        if let Some(_user_id) = state.get_session(&sid).await {
131            return Ok(next.run(req).await);
132        }
133    }
134
135    Err(StatusCode::UNAUTHORIZED)
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141
142    #[test]
143    fn test_extract_session_id() {
144        let cookies = "avl_session=abc123; other=value";
145        assert_eq!(extract_session_id(cookies), Some("abc123".to_string()));
146
147        let cookies = "other=value";
148        assert_eq!(extract_session_id(cookies), None);
149    }
150}