async_mcp/sse/
middleware.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
use actix_web::{
    body::EitherBody,
    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
    Error, HttpResponse,
};
use futures::future::LocalBoxFuture;
use jsonwebtoken::{decode, DecodingKey, Validation};
use serde::{Deserialize, Serialize};
use std::future::{ready, Ready};

#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
    pub exp: usize,
    pub iat: usize,
}

#[derive(Clone)]
pub struct AuthConfig {
    pub jwt_secret: String,
}

pub struct JwtAuth(Option<AuthConfig>);

impl JwtAuth {
    pub fn new(config: Option<AuthConfig>) -> Self {
        JwtAuth(config)
    }
}

impl<S, B> Transform<S, ServiceRequest> for JwtAuth
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
    S::Future: 'static,
    B: 'static,
{
    type Response = ServiceResponse<EitherBody<B>>;
    type Error = Error;
    type InitError = ();
    type Transform = JwtAuthMiddleware<S>;
    type Future = Ready<Result<Self::Transform, Self::InitError>>;

    fn new_transform(&self, service: S) -> Self::Future {
        ready(Ok(JwtAuthMiddleware {
            service,
            auth_config: self.0.clone(),
        }))
    }
}

pub struct JwtAuthMiddleware<S> {
    service: S,
    auth_config: Option<AuthConfig>,
}

impl<S, B> Service<ServiceRequest> for JwtAuthMiddleware<S>
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
    S::Future: 'static,
    B: 'static,
{
    type Response = ServiceResponse<EitherBody<B>>;
    type Error = Error;
    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;

    forward_ready!(service);

    fn call(&self, req: ServiceRequest) -> Self::Future {
        if let Some(config) = &self.auth_config {
            let auth_header = req
                .headers()
                .get("Authorization")
                .and_then(|h| h.to_str().ok());

            match auth_header {
                Some(auth) if auth.starts_with("Bearer ") => {
                    let token = &auth[7..];
                    match decode::<Claims>(
                        token,
                        &DecodingKey::from_secret(config.jwt_secret.as_bytes()),
                        &Validation::default(),
                    ) {
                        Ok(_) => {
                            let fut = self.service.call(req);
                            Box::pin(
                                async move { fut.await.map(ServiceResponse::map_into_left_body) },
                            )
                        }
                        Err(_) => {
                            let (req, _) = req.into_parts();
                            Box::pin(async move {
                                Ok(
                                    ServiceResponse::new(
                                        req,
                                        HttpResponse::Unauthorized().finish(),
                                    )
                                    .map_into_right_body(),
                                )
                            })
                        }
                    }
                }
                _ => {
                    let (req, _) = req.into_parts();
                    Box::pin(async move {
                        Ok(
                            ServiceResponse::new(req, HttpResponse::Unauthorized().finish())
                                .map_into_right_body(),
                        )
                    })
                }
            }
        } else {
            let fut = self.service.call(req);
            Box::pin(async move { fut.await.map(ServiceResponse::map_into_left_body) })
        }
    }
}