mcp_core/sse/
middleware.rs1use actix_web::{
2 body::EitherBody,
3 dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
4 Error, HttpResponse,
5};
6use futures::future::LocalBoxFuture;
7use jsonwebtoken::{decode, DecodingKey, Validation};
8use serde::{Deserialize, Serialize};
9use std::future::{ready, Ready};
10
11#[derive(Debug, Serialize, Deserialize)]
12pub struct Claims {
13 pub exp: usize,
14 pub iat: usize,
15}
16
17#[derive(Clone)]
18pub struct AuthConfig {
19 pub jwt_secret: String,
20}
21
22pub struct JwtAuth(Option<AuthConfig>);
23
24impl JwtAuth {
25 pub fn new(config: Option<AuthConfig>) -> Self {
26 JwtAuth(config)
27 }
28}
29
30impl<S, B> Transform<S, ServiceRequest> for JwtAuth
31where
32 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
33 S::Future: 'static,
34 B: 'static,
35{
36 type Response = ServiceResponse<EitherBody<B>>;
37 type Error = Error;
38 type InitError = ();
39 type Transform = JwtAuthMiddleware<S>;
40 type Future = Ready<Result<Self::Transform, Self::InitError>>;
41
42 fn new_transform(&self, service: S) -> Self::Future {
43 ready(Ok(JwtAuthMiddleware {
44 service,
45 auth_config: self.0.clone(),
46 }))
47 }
48}
49
50pub struct JwtAuthMiddleware<S> {
51 service: S,
52 auth_config: Option<AuthConfig>,
53}
54
55impl<S, B> Service<ServiceRequest> for JwtAuthMiddleware<S>
56where
57 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
58 S::Future: 'static,
59 B: 'static,
60{
61 type Response = ServiceResponse<EitherBody<B>>;
62 type Error = Error;
63 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
64
65 forward_ready!(service);
66
67 fn call(&self, req: ServiceRequest) -> Self::Future {
68 if let Some(config) = &self.auth_config {
69 let auth_header = req
70 .headers()
71 .get("Authorization")
72 .and_then(|h| h.to_str().ok());
73
74 match auth_header {
75 Some(auth) if auth.starts_with("Bearer ") => {
76 let token = &auth[7..];
77 match decode::<Claims>(
78 token,
79 &DecodingKey::from_secret(config.jwt_secret.as_bytes()),
80 &Validation::default(),
81 ) {
82 Ok(_) => {
83 let fut = self.service.call(req);
84 Box::pin(
85 async move { fut.await.map(ServiceResponse::map_into_left_body) },
86 )
87 }
88 Err(_) => {
89 let (req, _) = req.into_parts();
90 Box::pin(async move {
91 Ok(
92 ServiceResponse::new(
93 req,
94 HttpResponse::Unauthorized().finish(),
95 )
96 .map_into_right_body(),
97 )
98 })
99 }
100 }
101 }
102 _ => {
103 let (req, _) = req.into_parts();
104 Box::pin(async move {
105 Ok(
106 ServiceResponse::new(req, HttpResponse::Unauthorized().finish())
107 .map_into_right_body(),
108 )
109 })
110 }
111 }
112 } else {
113 let fut = self.service.call(req);
114 Box::pin(async move { fut.await.map(ServiceResponse::map_into_left_body) })
115 }
116 }
117}