pmcp_server_toolkit/
auth.rs1use async_trait::async_trait;
15use pmcp::error::ErrorCode;
16use pmcp::server::auth::{AuthContext, AuthProvider};
17use pmcp::Result;
18
19pub struct StaticAuthProvider {
40 expected_token: String,
42}
43
44impl StaticAuthProvider {
45 pub fn new(expected_token: impl Into<String>) -> Self {
54 Self {
55 expected_token: expected_token.into(),
56 }
57 }
58}
59
60fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
67 if a.len() != b.len() {
68 return false;
69 }
70 let mut diff: u8 = 0;
71 for (x, y) in a.iter().zip(b.iter()) {
72 diff |= x ^ y;
73 }
74 diff == 0
75}
76
77#[async_trait]
78impl AuthProvider for StaticAuthProvider {
79 async fn validate_request(
80 &self,
81 authorization_header: Option<&str>,
82 ) -> Result<Option<AuthContext>> {
83 let header = match authorization_header {
86 Some(h) => h,
87 None => {
88 return Err(pmcp::Error::protocol(
89 ErrorCode::INVALID_REQUEST,
90 "Missing Authorization header",
91 ));
92 },
93 };
94
95 let token = header
97 .strip_prefix("Bearer ")
98 .or_else(|| header.strip_prefix("bearer "))
99 .ok_or_else(|| {
100 pmcp::Error::protocol(
101 ErrorCode::INVALID_REQUEST,
102 "Authorization scheme must be Bearer",
103 )
104 })?;
105
106 if !constant_time_eq(token.as_bytes(), self.expected_token.as_bytes()) {
107 return Err(pmcp::Error::protocol(
108 ErrorCode::INVALID_REQUEST,
109 "Invalid bearer token",
110 ));
111 }
112
113 let mut ctx = AuthContext::new("static-bearer");
114 ctx.token = Some(token.to_string());
115 ctx.client_id = Some("static-bearer".to_string());
116 Ok(Some(ctx))
117 }
118
119 fn auth_scheme(&self) -> &'static str {
120 "Bearer"
121 }
122
123 fn is_required(&self) -> bool {
124 true
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131
132 #[tokio::test]
133 async fn valid_bearer_token_returns_some_auth_context() {
134 let provider = StaticAuthProvider::new("secret-token");
135 let result = provider
136 .validate_request(Some("Bearer secret-token"))
137 .await
138 .expect("expected Ok");
139 let ctx = result.expect("expected Some(AuthContext)");
140 assert_eq!(ctx.user_id(), "static-bearer");
141 assert!(ctx.authenticated);
142 }
143
144 #[tokio::test]
145 async fn invalid_bearer_token_returns_err() {
146 let provider = StaticAuthProvider::new("secret-token");
147 let result = provider.validate_request(Some("Bearer wrong-token")).await;
148 assert!(result.is_err(), "expected Err for mismatched token");
149 }
150
151 #[tokio::test]
152 async fn missing_authorization_header_returns_err() {
153 let provider = StaticAuthProvider::new("secret-token");
154 let result = provider.validate_request(None).await;
155 assert!(result.is_err(), "expected Err for missing header");
156 }
157
158 #[tokio::test]
159 async fn non_bearer_scheme_returns_err() {
160 let provider = StaticAuthProvider::new("secret-token");
161 let result = provider.validate_request(Some("Basic dXNlcjpwYXNz")).await;
162 assert!(result.is_err(), "expected Err for non-Bearer scheme");
163 }
164
165 #[tokio::test]
166 async fn case_insensitive_bearer_prefix() {
167 let provider = StaticAuthProvider::new("secret-token");
168 let result = provider
169 .validate_request(Some("bearer secret-token"))
170 .await
171 .expect("expected Ok");
172 assert!(result.is_some());
173 }
174
175 #[test]
176 fn constant_time_eq_handles_mismatched_lengths() {
177 assert!(!constant_time_eq(b"abc", b"abcd"));
178 assert!(!constant_time_eq(b"", b"x"));
179 }
180
181 #[test]
182 fn constant_time_eq_handles_equal_inputs() {
183 assert!(constant_time_eq(b"hunter2", b"hunter2"));
184 assert!(constant_time_eq(b"", b""));
185 }
186
187 #[test]
188 fn constant_time_eq_detects_mismatch() {
189 assert!(!constant_time_eq(b"hunter2", b"hunter3"));
190 }
191}