mcp_runner/sse_proxy/
auth.rs

1//! Authentication middleware for the SSE proxy.
2//!
3//! This module provides authentication handling for the Actix Web-based SSE proxy,
4//! implementing bearer token authentication and access control.
5
6use crate::config::SSEProxyConfig;
7use crate::error::Error;
8use crate::sse_proxy::actix_error::ApiError; // Add this import
9
10use actix_web::{
11    Error as ActixError,
12    dev::{Service, ServiceRequest, ServiceResponse, Transform, forward_ready},
13};
14use futures::future::{LocalBoxFuture, Ready, ready};
15use std::sync::Arc;
16use tracing;
17
18/// Authentication middleware factory
19pub struct Authentication {
20    config: Arc<SSEProxyConfig>,
21}
22
23impl Authentication {
24    /// Create a new Authentication middleware
25    pub fn new(config: Arc<SSEProxyConfig>) -> Self {
26        Self { config }
27    }
28}
29
30impl<S, B> Transform<S, ServiceRequest> for Authentication
31where
32    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = ActixError> + 'static,
33    B: 'static,
34{
35    type Response = ServiceResponse<B>;
36    type Error = ActixError;
37    type Transform = AuthenticationMiddleware<S>;
38    type InitError = ();
39    type Future = Ready<Result<Self::Transform, Self::InitError>>;
40
41    fn new_transform(&self, service: S) -> Self::Future {
42        ready(Ok(AuthenticationMiddleware {
43            service,
44            config: self.config.clone(),
45        }))
46    }
47}
48
49/// Authentication middleware implementation
50pub struct AuthenticationMiddleware<S> {
51    service: S,
52    config: Arc<SSEProxyConfig>,
53}
54
55impl<S, B> Service<ServiceRequest> for AuthenticationMiddleware<S>
56where
57    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = ActixError> + 'static,
58    B: 'static,
59{
60    type Response = ServiceResponse<B>;
61    type Error = ActixError;
62    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
63
64    forward_ready!(service);
65
66    fn call(&self, req: ServiceRequest) -> Self::Future {
67        // Skip authentication for OPTIONS requests (CORS preflight)
68        if req.method() == "OPTIONS" {
69            let fut = self.service.call(req);
70            return Box::pin(async move {
71                let res = fut.await?;
72                Ok(res)
73            });
74        }
75
76        // Check if authentication is required
77        if let Some(auth_config) = &self.config.authenticate {
78            if let Some(bearer_config) = &auth_config.bearer {
79                let expected_token = &bearer_config.token;
80
81                // Extract the Authorization header
82                if let Some(auth_header) = req.headers().get("Authorization") {
83                    if let Ok(auth_str) = auth_header.to_str() {
84                        if let Some(token) = auth_str.strip_prefix("Bearer ") {
85                            if token == expected_token {
86                                // Token is valid, proceed with the request
87                                let fut = self.service.call(req);
88                                return Box::pin(async move {
89                                    let res = fut.await?;
90                                    Ok(res)
91                                });
92                            }
93                        }
94                    }
95                }
96
97                // Invalid or missing token
98                tracing::warn!("Authentication failed: Invalid or missing bearer token");
99                return Box::pin(async move {
100                    // Convert to ApiError first, then into ActixError
101                    Err(ApiError::from(Error::Unauthorized(
102                        "Invalid or missing bearer token".to_string(),
103                    ))
104                    .into())
105                });
106            }
107        }
108
109        // No authentication required, proceed with the request
110        let fut = self.service.call(req);
111        Box::pin(async move {
112            let res = fut.await?;
113            Ok(res)
114        })
115    }
116}