Skip to main content

mcp_proxy/
access_log.rs

1//! Structured access logging middleware.
2//!
3//! Logs each MCP request with structured fields including method, tool/resource
4//! name, backend name, duration, and status. Uses the `mcp::access` tracing
5//! target so operators can filter access logs independently.
6//!
7//! The backend name is derived from the namespace prefix of the tool name. For
8//! example, with separator `/`, a tool named `math/add` belongs to backend
9//! `math`.
10
11use std::convert::Infallible;
12use std::future::Future;
13use std::pin::Pin;
14use std::task::{Context, Poll};
15use std::time::Instant;
16
17use tower::{Layer, Service};
18use tower_mcp::protocol::McpRequest;
19use tower_mcp::{RouterRequest, RouterResponse};
20
21/// Tower layer that produces an [`AccessLogService`].
22#[derive(Clone)]
23pub struct AccessLogLayer {
24    separator: String,
25}
26
27impl Default for AccessLogLayer {
28    fn default() -> Self {
29        Self {
30            separator: "/".to_string(),
31        }
32    }
33}
34
35impl AccessLogLayer {
36    /// Create a new access log layer with the given namespace separator.
37    pub fn new(separator: impl Into<String>) -> Self {
38        Self {
39            separator: separator.into(),
40        }
41    }
42}
43
44impl<S> Layer<S> for AccessLogLayer {
45    type Service = AccessLogService<S>;
46
47    fn layer(&self, inner: S) -> Self::Service {
48        AccessLogService::new(inner, self.separator.clone())
49    }
50}
51
52/// Tower service that emits structured access log entries.
53///
54/// Includes the backend name derived from the namespace prefix of tool names.
55#[derive(Clone)]
56pub struct AccessLogService<S> {
57    inner: S,
58    separator: String,
59}
60
61impl<S> AccessLogService<S> {
62    /// Create a new access log service wrapping `inner`.
63    ///
64    /// The `separator` is used to split namespaced tool names into backend and
65    /// tool components (e.g. with separator `/`, `math/add` yields backend
66    /// `math`).
67    pub fn new(inner: S, separator: impl Into<String>) -> Self {
68        Self {
69            inner,
70            separator: separator.into(),
71        }
72    }
73}
74
75/// Extract the tool, resource, or prompt name from an MCP request.
76fn request_target(req: &McpRequest) -> Option<&str> {
77    match req {
78        McpRequest::CallTool(params) => Some(&params.name),
79        McpRequest::ReadResource(params) => Some(&params.uri),
80        McpRequest::GetPrompt(params) => Some(&params.name),
81        _ => None,
82    }
83}
84
85/// Extract the backend name from a namespaced target string.
86///
87/// Given a target like `"math/add"` and separator `"/"`, returns `Some("math")`.
88/// Returns `None` if the target does not contain the separator.
89fn extract_backend<'a>(target: &'a str, separator: &str) -> Option<&'a str> {
90    target.find(separator).map(|idx| &target[..idx])
91}
92
93impl<S> Service<RouterRequest> for AccessLogService<S>
94where
95    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
96        + Clone
97        + Send
98        + 'static,
99    S::Future: Send,
100{
101    type Response = RouterResponse;
102    type Error = Infallible;
103    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
104
105    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
106        self.inner.poll_ready(cx)
107    }
108
109    fn call(&mut self, req: RouterRequest) -> Self::Future {
110        let method = req.inner.method_name().to_string();
111        let target = request_target(&req.inner).map(|s| s.to_string());
112        let backend = target
113            .as_deref()
114            .and_then(|t| extract_backend(t, &self.separator))
115            .map(|s| s.to_string());
116        let start = Instant::now();
117        let fut = self.inner.call(req);
118
119        Box::pin(async move {
120            let result = fut.await;
121            let duration_ms = start.elapsed().as_millis() as u64;
122
123            let status = match &result {
124                Ok(resp) => {
125                    if resp.inner.is_ok() {
126                        "ok"
127                    } else {
128                        "error"
129                    }
130                }
131                Err(_) => "error",
132            };
133
134            match (target, backend) {
135                (Some(name), Some(be)) => {
136                    tracing::info!(
137                        target: "mcp::access",
138                        method = %method,
139                        target = %name,
140                        backend = %be,
141                        duration_ms = duration_ms,
142                        status = %status,
143                    );
144                }
145                (Some(name), None) => {
146                    tracing::info!(
147                        target: "mcp::access",
148                        method = %method,
149                        target = %name,
150                        duration_ms = duration_ms,
151                        status = %status,
152                    );
153                }
154                _ => {
155                    tracing::info!(
156                        target: "mcp::access",
157                        method = %method,
158                        duration_ms = duration_ms,
159                        status = %status,
160                    );
161                }
162            }
163
164            result
165        })
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use tower_mcp::protocol::McpRequest;
172
173    use super::{AccessLogService, extract_backend};
174    use crate::test_util::{ErrorMockService, MockService, call_service};
175
176    #[test]
177    fn test_extract_backend_with_separator() {
178        assert_eq!(extract_backend("math/add", "/"), Some("math"));
179        assert_eq!(extract_backend("db/query", "/"), Some("db"));
180        assert_eq!(extract_backend("math::add", "::"), Some("math"));
181    }
182
183    #[test]
184    fn test_extract_backend_no_separator() {
185        assert_eq!(extract_backend("add", "/"), None);
186        assert_eq!(extract_backend("tool", "::"), None);
187    }
188
189    #[test]
190    fn test_extract_backend_multiple_separators() {
191        // Should return the first segment before the first separator.
192        assert_eq!(extract_backend("a/b/c", "/"), Some("a"));
193    }
194
195    #[tokio::test]
196    async fn test_access_log_passes_through_list() {
197        let mock = MockService::with_tools(&["tool"]);
198        let mut svc = AccessLogService::new(mock, "/");
199
200        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
201        assert!(resp.inner.is_ok());
202    }
203
204    #[tokio::test]
205    async fn test_access_log_passes_through_tool_call() {
206        let mock = MockService::with_tools(&["tool"]);
207        let mut svc = AccessLogService::new(mock, "/");
208
209        let resp = call_service(
210            &mut svc,
211            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
212                name: "tool".to_string(),
213                arguments: serde_json::json!({}),
214                meta: None,
215                task: None,
216            }),
217        )
218        .await;
219
220        assert!(resp.inner.is_ok());
221    }
222
223    #[tokio::test]
224    async fn test_access_log_passes_through_namespaced_tool_call() {
225        let mock = MockService::with_tools(&["math/add"]);
226        let mut svc = AccessLogService::new(mock, "/");
227
228        let resp = call_service(
229            &mut svc,
230            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
231                name: "math/add".to_string(),
232                arguments: serde_json::json!({}),
233                meta: None,
234                task: None,
235            }),
236        )
237        .await;
238
239        assert!(resp.inner.is_ok());
240    }
241
242    #[tokio::test]
243    async fn test_access_log_handles_errors() {
244        let mock = ErrorMockService;
245        let mut svc = AccessLogService::new(mock, "/");
246
247        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
248        assert!(resp.inner.is_err());
249    }
250
251    #[tokio::test]
252    async fn test_access_log_handles_ping() {
253        let mock = MockService::with_tools(&[]);
254        let mut svc = AccessLogService::new(mock, "/");
255
256        let resp = call_service(&mut svc, McpRequest::Ping).await;
257        assert!(resp.inner.is_ok());
258    }
259}