Skip to main content

mcp_protocol/
auth.rs

1//! # MCP Authentication
2//!
3//! **OAuth 2.1 Bearer token authentication** for MCP protocol
4//! - Standard Authorization header handling
5//! - Custom authentication strategies support
6//! - External MCP server authentication
7
8use crate::{BearerToken, McpError};
9use protocol_transport_core::{ProtocolError, UniversalRequest};
10
11/// **Authentication Handler Trait**
12pub trait AuthHandler: Send + Sync {
13    /// Validate request authentication
14    fn validate_request(&self, request: &UniversalRequest) -> Result<(), ProtocolError>;
15
16    /// Add authentication to outgoing request
17    fn add_auth_headers(&self, request: &mut UniversalRequest) -> Result<(), ProtocolError>;
18}
19
20/// **Bearer Token Authentication Handler**
21pub struct BearerAuthHandler {
22    /// Required bearer token (for server mode)
23    required_token: Option<String>,
24    /// Bearer token for client requests (for client mode)
25    client_token: Option<BearerToken>,
26}
27
28impl BearerAuthHandler {
29    /// Create new bearer auth handler
30    pub fn new() -> Self {
31        Self {
32            required_token: None,
33            client_token: None,
34        }
35    }
36
37    /// Configure required token for server mode
38    pub fn with_required_token(mut self, token: &str) -> Self {
39        self.required_token = Some(token.to_string());
40        self
41    }
42
43    /// Configure client token for outgoing requests
44    pub fn with_client_token(mut self, token: BearerToken) -> Self {
45        self.client_token = Some(token);
46        self
47    }
48
49    /// Extract bearer token from Authorization header
50    fn extract_bearer_token(&self, request: &UniversalRequest) -> Option<String> {
51        request
52            .headers
53            .get("authorization")
54            .or_else(|| request.headers.get("Authorization"))
55            .and_then(|auth_header| {
56                if auth_header.starts_with("Bearer ") {
57                    Some(auth_header[7..].to_string())
58                } else {
59                    None
60                }
61            })
62    }
63}
64
65impl AuthHandler for BearerAuthHandler {
66    fn validate_request(&self, request: &UniversalRequest) -> Result<(), ProtocolError> {
67        // If no token is required, allow all requests
68        let required_token = match &self.required_token {
69            Some(token) => token,
70            None => return Ok(()),
71        };
72
73        // Extract token from request
74        let provided_token = self.extract_bearer_token(request).ok_or_else(|| {
75            ProtocolError::Internal(
76                McpError::Authentication("Missing or invalid Authorization header".to_string())
77                    .to_string(),
78            )
79        })?;
80
81        // Validate token
82        if provided_token != *required_token {
83            return Err(ProtocolError::Internal(
84                McpError::Authentication("Invalid bearer token".to_string()).to_string(),
85            ));
86        }
87
88        Ok(())
89    }
90
91    fn add_auth_headers(&self, request: &mut UniversalRequest) -> Result<(), ProtocolError> {
92        if let Some(client_token) = &self.client_token {
93            request.headers.insert(
94                "Authorization".to_string(),
95                client_token.to_authorization_header(),
96            );
97        }
98        Ok(())
99    }
100}
101
102impl Default for BearerAuthHandler {
103    fn default() -> Self {
104        Self::new()
105    }
106}
107
108/// **No Authentication Handler** - Allows all requests
109pub struct NoAuthHandler;
110
111impl AuthHandler for NoAuthHandler {
112    fn validate_request(&self, _request: &UniversalRequest) -> Result<(), ProtocolError> {
113        Ok(())
114    }
115
116    fn add_auth_headers(&self, _request: &mut UniversalRequest) -> Result<(), ProtocolError> {
117        Ok(())
118    }
119}
120
121/// **Custom Authentication Handler** - User-defined validation
122pub struct CustomAuthHandler<F, G>
123where
124    F: Fn(&UniversalRequest) -> Result<(), ProtocolError> + Send + Sync,
125    G: Fn(&mut UniversalRequest) -> Result<(), ProtocolError> + Send + Sync,
126{
127    validate_fn: F,
128    add_auth_fn: G,
129}
130
131impl<F, G> CustomAuthHandler<F, G>
132where
133    F: Fn(&UniversalRequest) -> Result<(), ProtocolError> + Send + Sync,
134    G: Fn(&mut UniversalRequest) -> Result<(), ProtocolError> + Send + Sync,
135{
136    /// Create custom auth handler with validation and auth addition functions
137    pub fn new(validate_fn: F, add_auth_fn: G) -> Self {
138        Self {
139            validate_fn,
140            add_auth_fn,
141        }
142    }
143}
144
145impl<F, G> AuthHandler for CustomAuthHandler<F, G>
146where
147    F: Fn(&UniversalRequest) -> Result<(), ProtocolError> + Send + Sync,
148    G: Fn(&mut UniversalRequest) -> Result<(), ProtocolError> + Send + Sync,
149{
150    fn validate_request(&self, request: &UniversalRequest) -> Result<(), ProtocolError> {
151        (self.validate_fn)(request)
152    }
153
154    fn add_auth_headers(&self, request: &mut UniversalRequest) -> Result<(), ProtocolError> {
155        (self.add_auth_fn)(request)
156    }
157}
158
159/// **Authentication Builder** - Convenient auth handler creation
160pub struct AuthBuilder;
161
162impl AuthBuilder {
163    /// Create no authentication handler
164    pub fn none() -> NoAuthHandler {
165        NoAuthHandler
166    }
167
168    /// Create bearer token auth for server
169    pub fn bearer_server(required_token: &str) -> BearerAuthHandler {
170        BearerAuthHandler::new().with_required_token(required_token)
171    }
172
173    /// Create bearer token auth for client
174    pub fn bearer_client(token: &str) -> BearerAuthHandler {
175        BearerAuthHandler::new().with_client_token(BearerToken::new(token))
176    }
177
178    /// Create bearer token auth for both server and client
179    pub fn bearer_both(required_token: &str, client_token: &str) -> BearerAuthHandler {
180        BearerAuthHandler::new()
181            .with_required_token(required_token)
182            .with_client_token(BearerToken::new(client_token))
183    }
184}