Skip to main content

pmcp_server_toolkit/
auth.rs

1// Originated from pmcp-run/built-in/shared/mcp-server-common/src/auth.rs
2// (https://github.com/guyernest/pmcp-run)
3// Promoted to rust-mcp-sdk workspace for Phase 83 toolkit lift (P83-02).
4
5//! `AuthProvider` impls for the toolkit — bearer-token-based static auth suitable
6//! for dev/test environments. Production callers should use pmcp's OAuth/JWT
7//! providers instead.
8//!
9//! The headline type is [`StaticAuthProvider`], which validates inbound
10//! `Authorization: Bearer <token>` headers against a single expected token.
11//! Use it for tests, smoke deployments, and `cargo pmcp pentest`-style local
12//! servers. **Never put a static bearer token in a production server.**
13
14use async_trait::async_trait;
15use pmcp::error::ErrorCode;
16use pmcp::server::auth::{AuthContext, AuthProvider};
17use pmcp::Result;
18
19/// Static bearer-token auth provider, suitable for dev and tests.
20///
21/// Validates that incoming `Authorization` headers match exactly one configured
22/// bearer token. Returns `Some(AuthContext)` with `subject = "static-bearer"`
23/// on match, an `Err(ErrorCode::INVALID_REQUEST)` on token mismatch, and an
24/// `Err(ErrorCode::INVALID_REQUEST)` on missing header (because
25/// [`AuthProvider::is_required`] defaults to `true`).
26///
27/// # Example
28/// ```no_run
29/// use pmcp_server_toolkit::auth::StaticAuthProvider;
30/// let provider = StaticAuthProvider::new("dev-token-do-not-use-in-prod");
31/// # let _ = provider;
32/// ```
33///
34/// # Security note
35/// Token comparison uses [`constant_time_eq`] semantics via byte-wise XOR
36/// accumulation to avoid timing-side-channel leaks during dev/test use.
37/// Production callers should use pmcp's OAuth2 + JWT validator pipeline
38/// instead.
39pub struct StaticAuthProvider {
40    /// The single expected bearer token. Compared in constant time.
41    expected_token: String,
42}
43
44impl StaticAuthProvider {
45    /// Create a new `StaticAuthProvider` that accepts exactly one bearer token.
46    ///
47    /// # Example
48    /// ```no_run
49    /// use pmcp_server_toolkit::auth::StaticAuthProvider;
50    /// let provider = StaticAuthProvider::new("dev-token");
51    /// # let _ = provider;
52    /// ```
53    pub fn new(expected_token: impl Into<String>) -> Self {
54        Self {
55            expected_token: expected_token.into(),
56        }
57    }
58}
59
60/// Constant-time byte comparison.
61///
62/// Returns `true` iff `a` and `b` have the same length AND every byte matches.
63/// The function runs in time proportional to `max(a.len(), b.len())` and does
64/// NOT short-circuit on the first mismatch. This blocks timing-side-channel
65/// attacks against the bearer-token check.
66fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
67    if a.len() != b.len() {
68        return false;
69    }
70    let mut diff: u8 = 0;
71    for (x, y) in a.iter().zip(b.iter()) {
72        diff |= x ^ y;
73    }
74    diff == 0
75}
76
77#[async_trait]
78impl AuthProvider for StaticAuthProvider {
79    async fn validate_request(
80        &self,
81        authorization_header: Option<&str>,
82    ) -> Result<Option<AuthContext>> {
83        // Missing header → unauthenticated. is_required() defaults true, so
84        // the caller treats this as a 401.
85        let header = match authorization_header {
86            Some(h) => h,
87            None => {
88                return Err(pmcp::Error::protocol(
89                    ErrorCode::INVALID_REQUEST,
90                    "Missing Authorization header",
91                ));
92            },
93        };
94
95        // Strip the "Bearer " prefix (case-insensitive scheme name per RFC 6750).
96        let token = header
97            .strip_prefix("Bearer ")
98            .or_else(|| header.strip_prefix("bearer "))
99            .ok_or_else(|| {
100                pmcp::Error::protocol(
101                    ErrorCode::INVALID_REQUEST,
102                    "Authorization scheme must be Bearer",
103                )
104            })?;
105
106        if !constant_time_eq(token.as_bytes(), self.expected_token.as_bytes()) {
107            return Err(pmcp::Error::protocol(
108                ErrorCode::INVALID_REQUEST,
109                "Invalid bearer token",
110            ));
111        }
112
113        let mut ctx = AuthContext::new("static-bearer");
114        ctx.token = Some(token.to_string());
115        ctx.client_id = Some("static-bearer".to_string());
116        Ok(Some(ctx))
117    }
118
119    fn auth_scheme(&self) -> &'static str {
120        "Bearer"
121    }
122
123    fn is_required(&self) -> bool {
124        true
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    #[tokio::test]
133    async fn valid_bearer_token_returns_some_auth_context() {
134        let provider = StaticAuthProvider::new("secret-token");
135        let result = provider
136            .validate_request(Some("Bearer secret-token"))
137            .await
138            .expect("expected Ok");
139        let ctx = result.expect("expected Some(AuthContext)");
140        assert_eq!(ctx.user_id(), "static-bearer");
141        assert!(ctx.authenticated);
142    }
143
144    #[tokio::test]
145    async fn invalid_bearer_token_returns_err() {
146        let provider = StaticAuthProvider::new("secret-token");
147        let result = provider.validate_request(Some("Bearer wrong-token")).await;
148        assert!(result.is_err(), "expected Err for mismatched token");
149    }
150
151    #[tokio::test]
152    async fn missing_authorization_header_returns_err() {
153        let provider = StaticAuthProvider::new("secret-token");
154        let result = provider.validate_request(None).await;
155        assert!(result.is_err(), "expected Err for missing header");
156    }
157
158    #[tokio::test]
159    async fn non_bearer_scheme_returns_err() {
160        let provider = StaticAuthProvider::new("secret-token");
161        let result = provider.validate_request(Some("Basic dXNlcjpwYXNz")).await;
162        assert!(result.is_err(), "expected Err for non-Bearer scheme");
163    }
164
165    #[tokio::test]
166    async fn case_insensitive_bearer_prefix() {
167        let provider = StaticAuthProvider::new("secret-token");
168        let result = provider
169            .validate_request(Some("bearer secret-token"))
170            .await
171            .expect("expected Ok");
172        assert!(result.is_some());
173    }
174
175    #[test]
176    fn constant_time_eq_handles_mismatched_lengths() {
177        assert!(!constant_time_eq(b"abc", b"abcd"));
178        assert!(!constant_time_eq(b"", b"x"));
179    }
180
181    #[test]
182    fn constant_time_eq_handles_equal_inputs() {
183        assert!(constant_time_eq(b"hunter2", b"hunter2"));
184        assert!(constant_time_eq(b"", b""));
185    }
186
187    #[test]
188    fn constant_time_eq_detects_mismatch() {
189        assert!(!constant_time_eq(b"hunter2", b"hunter3"));
190    }
191}