mcp_core/sse/
middleware.rs

1use actix_web::{
2    body::EitherBody,
3    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
4    Error, HttpResponse,
5};
6use futures::future::LocalBoxFuture;
7use jsonwebtoken::{decode, DecodingKey, Validation};
8use serde::{Deserialize, Serialize};
9use std::future::{ready, Ready};
10
11#[derive(Debug, Serialize, Deserialize)]
12pub struct Claims {
13    pub exp: usize,
14    pub iat: usize,
15}
16
17#[derive(Clone)]
18pub struct AuthConfig {
19    pub jwt_secret: String,
20}
21
22pub struct JwtAuth(Option<AuthConfig>);
23
24impl JwtAuth {
25    pub fn new(config: Option<AuthConfig>) -> Self {
26        JwtAuth(config)
27    }
28}
29
30impl<S, B> Transform<S, ServiceRequest> for JwtAuth
31where
32    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
33    S::Future: 'static,
34    B: 'static,
35{
36    type Response = ServiceResponse<EitherBody<B>>;
37    type Error = Error;
38    type InitError = ();
39    type Transform = JwtAuthMiddleware<S>;
40    type Future = Ready<Result<Self::Transform, Self::InitError>>;
41
42    fn new_transform(&self, service: S) -> Self::Future {
43        ready(Ok(JwtAuthMiddleware {
44            service,
45            auth_config: self.0.clone(),
46        }))
47    }
48}
49
50pub struct JwtAuthMiddleware<S> {
51    service: S,
52    auth_config: Option<AuthConfig>,
53}
54
55impl<S, B> Service<ServiceRequest> for JwtAuthMiddleware<S>
56where
57    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
58    S::Future: 'static,
59    B: 'static,
60{
61    type Response = ServiceResponse<EitherBody<B>>;
62    type Error = Error;
63    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
64
65    forward_ready!(service);
66
67    fn call(&self, req: ServiceRequest) -> Self::Future {
68        if let Some(config) = &self.auth_config {
69            let auth_header = req
70                .headers()
71                .get("Authorization")
72                .and_then(|h| h.to_str().ok());
73
74            match auth_header {
75                Some(auth) if auth.starts_with("Bearer ") => {
76                    let token = &auth[7..];
77                    match decode::<Claims>(
78                        token,
79                        &DecodingKey::from_secret(config.jwt_secret.as_bytes()),
80                        &Validation::default(),
81                    ) {
82                        Ok(_) => {
83                            let fut = self.service.call(req);
84                            Box::pin(
85                                async move { fut.await.map(ServiceResponse::map_into_left_body) },
86                            )
87                        }
88                        Err(_) => {
89                            let (req, _) = req.into_parts();
90                            Box::pin(async move {
91                                Ok(
92                                    ServiceResponse::new(
93                                        req,
94                                        HttpResponse::Unauthorized().finish(),
95                                    )
96                                    .map_into_right_body(),
97                                )
98                            })
99                        }
100                    }
101                }
102                _ => {
103                    let (req, _) = req.into_parts();
104                    Box::pin(async move {
105                        Ok(
106                            ServiceResponse::new(req, HttpResponse::Unauthorized().finish())
107                                .map_into_right_body(),
108                        )
109                    })
110                }
111            }
112        } else {
113            let fut = self.service.call(req);
114            Box::pin(async move { fut.await.map(ServiceResponse::map_into_left_body) })
115        }
116    }
117}