1use crate::{BearerToken, McpError};
9use protocol_transport_core::{ProtocolError, UniversalRequest};
10
11pub trait AuthHandler: Send + Sync {
13 fn validate_request(&self, request: &UniversalRequest) -> Result<(), ProtocolError>;
15
16 fn add_auth_headers(&self, request: &mut UniversalRequest) -> Result<(), ProtocolError>;
18}
19
20pub struct BearerAuthHandler {
22 required_token: Option<String>,
24 client_token: Option<BearerToken>,
26}
27
28impl BearerAuthHandler {
29 pub fn new() -> Self {
31 Self {
32 required_token: None,
33 client_token: None,
34 }
35 }
36
37 pub fn with_required_token(mut self, token: &str) -> Self {
39 self.required_token = Some(token.to_string());
40 self
41 }
42
43 pub fn with_client_token(mut self, token: BearerToken) -> Self {
45 self.client_token = Some(token);
46 self
47 }
48
49 fn extract_bearer_token(&self, request: &UniversalRequest) -> Option<String> {
51 request
52 .headers
53 .get("authorization")
54 .or_else(|| request.headers.get("Authorization"))
55 .and_then(|auth_header| {
56 if auth_header.starts_with("Bearer ") {
57 Some(auth_header[7..].to_string())
58 } else {
59 None
60 }
61 })
62 }
63}
64
65impl AuthHandler for BearerAuthHandler {
66 fn validate_request(&self, request: &UniversalRequest) -> Result<(), ProtocolError> {
67 let required_token = match &self.required_token {
69 Some(token) => token,
70 None => return Ok(()),
71 };
72
73 let provided_token = self.extract_bearer_token(request).ok_or_else(|| {
75 ProtocolError::Internal(
76 McpError::Authentication("Missing or invalid Authorization header".to_string())
77 .to_string(),
78 )
79 })?;
80
81 if provided_token != *required_token {
83 return Err(ProtocolError::Internal(
84 McpError::Authentication("Invalid bearer token".to_string()).to_string(),
85 ));
86 }
87
88 Ok(())
89 }
90
91 fn add_auth_headers(&self, request: &mut UniversalRequest) -> Result<(), ProtocolError> {
92 if let Some(client_token) = &self.client_token {
93 request.headers.insert(
94 "Authorization".to_string(),
95 client_token.to_authorization_header(),
96 );
97 }
98 Ok(())
99 }
100}
101
102impl Default for BearerAuthHandler {
103 fn default() -> Self {
104 Self::new()
105 }
106}
107
108pub struct NoAuthHandler;
110
111impl AuthHandler for NoAuthHandler {
112 fn validate_request(&self, _request: &UniversalRequest) -> Result<(), ProtocolError> {
113 Ok(())
114 }
115
116 fn add_auth_headers(&self, _request: &mut UniversalRequest) -> Result<(), ProtocolError> {
117 Ok(())
118 }
119}
120
121pub struct CustomAuthHandler<F, G>
123where
124 F: Fn(&UniversalRequest) -> Result<(), ProtocolError> + Send + Sync,
125 G: Fn(&mut UniversalRequest) -> Result<(), ProtocolError> + Send + Sync,
126{
127 validate_fn: F,
128 add_auth_fn: G,
129}
130
131impl<F, G> CustomAuthHandler<F, G>
132where
133 F: Fn(&UniversalRequest) -> Result<(), ProtocolError> + Send + Sync,
134 G: Fn(&mut UniversalRequest) -> Result<(), ProtocolError> + Send + Sync,
135{
136 pub fn new(validate_fn: F, add_auth_fn: G) -> Self {
138 Self {
139 validate_fn,
140 add_auth_fn,
141 }
142 }
143}
144
145impl<F, G> AuthHandler for CustomAuthHandler<F, G>
146where
147 F: Fn(&UniversalRequest) -> Result<(), ProtocolError> + Send + Sync,
148 G: Fn(&mut UniversalRequest) -> Result<(), ProtocolError> + Send + Sync,
149{
150 fn validate_request(&self, request: &UniversalRequest) -> Result<(), ProtocolError> {
151 (self.validate_fn)(request)
152 }
153
154 fn add_auth_headers(&self, request: &mut UniversalRequest) -> Result<(), ProtocolError> {
155 (self.add_auth_fn)(request)
156 }
157}
158
159pub struct AuthBuilder;
161
162impl AuthBuilder {
163 pub fn none() -> NoAuthHandler {
165 NoAuthHandler
166 }
167
168 pub fn bearer_server(required_token: &str) -> BearerAuthHandler {
170 BearerAuthHandler::new().with_required_token(required_token)
171 }
172
173 pub fn bearer_client(token: &str) -> BearerAuthHandler {
175 BearerAuthHandler::new().with_client_token(BearerToken::new(token))
176 }
177
178 pub fn bearer_both(required_token: &str, client_token: &str) -> BearerAuthHandler {
180 BearerAuthHandler::new()
181 .with_required_token(required_token)
182 .with_client_token(BearerToken::new(client_token))
183 }
184}