Skip to main content

mcp_proxy/
token.rs

1//! Token passthrough middleware for forwarding client credentials to backends.
2//!
3//! When a backend has `forward_auth = true`, the client's inbound bearer token
4//! is extracted from `RouterRequest.extensions` and stored as a [`ClientToken`]
5//! for downstream middleware and backend services to consume.
6//!
7//! # Configuration
8//!
9//! ```toml
10//! [[backends]]
11//! name = "github"
12//! transport = "http"
13//! url = "http://github-mcp.internal:8080"
14//! forward_auth = true  # forward client's token to this backend
15//!
16//! [[backends]]
17//! name = "db"
18//! transport = "http"
19//! url = "http://db-mcp.internal:8080"
20//! bearer_token = "${DB_API_KEY}"  # static token for this backend
21//! ```
22//!
23//! # How it works
24//!
25//! 1. The proxy's auth layer (JWT/bearer) validates the inbound token and
26//!    stores [`TokenClaims`](tower_mcp::oauth::token::TokenClaims) in request extensions.
27//! 2. This middleware reads the `TokenClaims` and stores the subject (`sub` claim)
28//!    and any available identity info as a [`ClientToken`] in extensions.
29//! 3. Backend-specific middleware or future transport enhancements can read
30//!    `ClientToken` to forward credentials.
31
32use std::collections::HashSet;
33use std::convert::Infallible;
34use std::future::Future;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::task::{Context, Poll};
38
39use tower::Service;
40
41use tower_mcp::router::{RouterRequest, RouterResponse};
42use tower_mcp_types::protocol::McpRequest;
43
44/// A client's identity token extracted from inbound authentication.
45///
46/// Stored in `RouterRequest.extensions` by the [`TokenPassthroughService`]
47/// for downstream consumption.
48#[derive(Debug, Clone)]
49pub struct ClientToken {
50    /// The subject (user/client identifier) from the token.
51    pub subject: Option<String>,
52    /// Space-delimited scopes from the token.
53    pub scope: Option<String>,
54    /// The raw bearer token string, if available.
55    pub raw_token: Option<String>,
56}
57
58/// Middleware that extracts client identity from auth claims and makes it
59/// available to backends configured with `forward_auth = true`.
60#[derive(Clone)]
61pub struct TokenPassthroughService<S> {
62    inner: S,
63    forward_namespaces: Arc<HashSet<String>>,
64}
65
66impl<S> TokenPassthroughService<S> {
67    /// Create a new token passthrough service.
68    ///
69    /// `forward_namespaces` is the set of backend namespace prefixes (e.g. `"github/"`)
70    /// that should receive forwarded tokens.
71    pub fn new(inner: S, forward_namespaces: HashSet<String>) -> Self {
72        Self {
73            inner,
74            forward_namespaces: Arc::new(forward_namespaces),
75        }
76    }
77}
78
79/// Check if a request targets a namespace that wants token forwarding.
80fn request_targets_namespace(req: &McpRequest, namespaces: &HashSet<String>) -> bool {
81    let name = match req {
82        McpRequest::CallTool(params) => Some(params.name.as_str()),
83        McpRequest::ReadResource(params) => Some(params.uri.as_str()),
84        McpRequest::GetPrompt(params) => Some(params.name.as_str()),
85        _ => None,
86    };
87    if let Some(name) = name {
88        namespaces.iter().any(|ns| name.starts_with(ns))
89    } else {
90        false
91    }
92}
93
94impl<S> Service<RouterRequest> for TokenPassthroughService<S>
95where
96    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
97        + Clone
98        + Send
99        + 'static,
100    S::Future: Send,
101{
102    type Response = RouterResponse;
103    type Error = Infallible;
104    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
105
106    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
107        self.inner.poll_ready(cx)
108    }
109
110    fn call(&mut self, mut req: RouterRequest) -> Self::Future {
111        // Only inject ClientToken for requests targeting forward_auth backends
112        if !self.forward_namespaces.is_empty()
113            && request_targets_namespace(&req.inner, &self.forward_namespaces)
114        {
115            let client_token = req
116                .extensions
117                .get::<tower_mcp::oauth::token::TokenClaims>()
118                .map(|claims| ClientToken {
119                    subject: claims.sub.clone(),
120                    scope: claims.scope.clone(),
121                    raw_token: None, // Raw token not available from TokenClaims
122                });
123            if let Some(token) = client_token {
124                tracing::debug!(
125                    subject = ?token.subject,
126                    "Injected ClientToken for forward_auth backend"
127                );
128                req.extensions.insert(token);
129            }
130        }
131
132        let fut = self.inner.call(req);
133        Box::pin(fut)
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use std::collections::HashSet;
140
141    use tower::Service;
142    use tower_mcp::protocol::{CallToolParams, McpRequest, RequestId};
143    use tower_mcp::router::{Extensions, RouterRequest};
144
145    use super::{TokenPassthroughService, request_targets_namespace};
146    use crate::test_util::{MockService, call_service};
147
148    #[test]
149    fn test_request_targets_namespace_match() {
150        let namespaces: HashSet<String> = ["github/".to_string()].into();
151        let req = McpRequest::CallTool(CallToolParams {
152            name: "github/search".to_string(),
153            arguments: serde_json::json!({}),
154            meta: None,
155            task: None,
156        });
157        assert!(request_targets_namespace(&req, &namespaces));
158    }
159
160    #[test]
161    fn test_request_targets_namespace_no_match() {
162        let namespaces: HashSet<String> = ["github/".to_string()].into();
163        let req = McpRequest::CallTool(CallToolParams {
164            name: "db/query".to_string(),
165            arguments: serde_json::json!({}),
166            meta: None,
167            task: None,
168        });
169        assert!(!request_targets_namespace(&req, &namespaces));
170    }
171
172    #[test]
173    fn test_request_targets_namespace_list_tools() {
174        let namespaces: HashSet<String> = ["github/".to_string()].into();
175        let req = McpRequest::ListTools(Default::default());
176        assert!(!request_targets_namespace(&req, &namespaces));
177    }
178
179    #[tokio::test]
180    async fn test_passthrough_injects_client_token() {
181        let mock = MockService::with_tools(&["github/search"]);
182        let namespaces: HashSet<String> = ["github/".to_string()].into();
183        let mut svc = TokenPassthroughService::new(mock, namespaces);
184
185        // Create request with TokenClaims in extensions
186        let mut extensions = Extensions::new();
187        extensions.insert(tower_mcp::oauth::token::TokenClaims {
188            sub: Some("user-123".to_string()),
189            scope: Some("mcp:read".to_string()),
190            iss: None,
191            aud: None,
192            exp: None,
193            client_id: None,
194            extra: Default::default(),
195        });
196
197        let req = RouterRequest {
198            id: RequestId::Number(1),
199            inner: McpRequest::CallTool(CallToolParams {
200                name: "github/search".to_string(),
201                arguments: serde_json::json!({}),
202                meta: None,
203                task: None,
204            }),
205            extensions,
206        };
207
208        let resp = svc.call(req).await.unwrap();
209        assert!(resp.inner.is_ok());
210    }
211
212    #[tokio::test]
213    async fn test_passthrough_skips_non_forward_backends() {
214        let mock = MockService::with_tools(&["db/query"]);
215        let namespaces: HashSet<String> = ["github/".to_string()].into();
216        let mut svc = TokenPassthroughService::new(mock, namespaces);
217
218        let resp = call_service(
219            &mut svc,
220            McpRequest::CallTool(CallToolParams {
221                name: "db/query".to_string(),
222                arguments: serde_json::json!({}),
223                meta: None,
224                task: None,
225            }),
226        )
227        .await;
228
229        assert!(resp.inner.is_ok());
230    }
231
232    #[tokio::test]
233    async fn test_passthrough_no_claims_passes_through() {
234        let mock = MockService::with_tools(&["github/search"]);
235        let namespaces: HashSet<String> = ["github/".to_string()].into();
236        let mut svc = TokenPassthroughService::new(mock, namespaces);
237
238        // No TokenClaims in extensions -- should still pass through
239        let resp = call_service(
240            &mut svc,
241            McpRequest::CallTool(CallToolParams {
242                name: "github/search".to_string(),
243                arguments: serde_json::json!({}),
244                meta: None,
245                task: None,
246            }),
247        )
248        .await;
249
250        assert!(resp.inner.is_ok());
251    }
252}