1use std::convert::Infallible;
12use std::future::Future;
13use std::pin::Pin;
14use std::task::{Context, Poll};
15use std::time::Instant;
16
17use tower::{Layer, Service};
18use tower_mcp::protocol::McpRequest;
19use tower_mcp::{RouterRequest, RouterResponse};
20
21#[derive(Clone)]
23pub struct AccessLogLayer {
24 separator: String,
25}
26
27impl Default for AccessLogLayer {
28 fn default() -> Self {
29 Self {
30 separator: "/".to_string(),
31 }
32 }
33}
34
35impl AccessLogLayer {
36 pub fn new(separator: impl Into<String>) -> Self {
38 Self {
39 separator: separator.into(),
40 }
41 }
42}
43
44impl<S> Layer<S> for AccessLogLayer {
45 type Service = AccessLogService<S>;
46
47 fn layer(&self, inner: S) -> Self::Service {
48 AccessLogService::new(inner, self.separator.clone())
49 }
50}
51
52#[derive(Clone)]
56pub struct AccessLogService<S> {
57 inner: S,
58 separator: String,
59}
60
61impl<S> AccessLogService<S> {
62 pub fn new(inner: S, separator: impl Into<String>) -> Self {
68 Self {
69 inner,
70 separator: separator.into(),
71 }
72 }
73}
74
75fn request_target(req: &McpRequest) -> Option<&str> {
77 match req {
78 McpRequest::CallTool(params) => Some(¶ms.name),
79 McpRequest::ReadResource(params) => Some(¶ms.uri),
80 McpRequest::GetPrompt(params) => Some(¶ms.name),
81 _ => None,
82 }
83}
84
85fn extract_backend<'a>(target: &'a str, separator: &str) -> Option<&'a str> {
90 target.find(separator).map(|idx| &target[..idx])
91}
92
93impl<S> Service<RouterRequest> for AccessLogService<S>
94where
95 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
96 + Clone
97 + Send
98 + 'static,
99 S::Future: Send,
100{
101 type Response = RouterResponse;
102 type Error = Infallible;
103 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
104
105 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
106 self.inner.poll_ready(cx)
107 }
108
109 fn call(&mut self, req: RouterRequest) -> Self::Future {
110 let method = req.inner.method_name().to_string();
111 let target = request_target(&req.inner).map(|s| s.to_string());
112 let backend = target
113 .as_deref()
114 .and_then(|t| extract_backend(t, &self.separator))
115 .map(|s| s.to_string());
116 let start = Instant::now();
117 let fut = self.inner.call(req);
118
119 Box::pin(async move {
120 let result = fut.await;
121 let duration_ms = start.elapsed().as_millis() as u64;
122
123 let status = match &result {
124 Ok(resp) => {
125 if resp.inner.is_ok() {
126 "ok"
127 } else {
128 "error"
129 }
130 }
131 Err(_) => "error",
132 };
133
134 match (target, backend) {
135 (Some(name), Some(be)) => {
136 tracing::info!(
137 target: "mcp::access",
138 method = %method,
139 target = %name,
140 backend = %be,
141 duration_ms = duration_ms,
142 status = %status,
143 );
144 }
145 (Some(name), None) => {
146 tracing::info!(
147 target: "mcp::access",
148 method = %method,
149 target = %name,
150 duration_ms = duration_ms,
151 status = %status,
152 );
153 }
154 _ => {
155 tracing::info!(
156 target: "mcp::access",
157 method = %method,
158 duration_ms = duration_ms,
159 status = %status,
160 );
161 }
162 }
163
164 result
165 })
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use tower_mcp::protocol::McpRequest;
172
173 use super::{AccessLogService, extract_backend};
174 use crate::test_util::{ErrorMockService, MockService, call_service};
175
176 #[test]
177 fn test_extract_backend_with_separator() {
178 assert_eq!(extract_backend("math/add", "/"), Some("math"));
179 assert_eq!(extract_backend("db/query", "/"), Some("db"));
180 assert_eq!(extract_backend("math::add", "::"), Some("math"));
181 }
182
183 #[test]
184 fn test_extract_backend_no_separator() {
185 assert_eq!(extract_backend("add", "/"), None);
186 assert_eq!(extract_backend("tool", "::"), None);
187 }
188
189 #[test]
190 fn test_extract_backend_multiple_separators() {
191 assert_eq!(extract_backend("a/b/c", "/"), Some("a"));
193 }
194
195 #[tokio::test]
196 async fn test_access_log_passes_through_list() {
197 let mock = MockService::with_tools(&["tool"]);
198 let mut svc = AccessLogService::new(mock, "/");
199
200 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
201 assert!(resp.inner.is_ok());
202 }
203
204 #[tokio::test]
205 async fn test_access_log_passes_through_tool_call() {
206 let mock = MockService::with_tools(&["tool"]);
207 let mut svc = AccessLogService::new(mock, "/");
208
209 let resp = call_service(
210 &mut svc,
211 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
212 name: "tool".to_string(),
213 arguments: serde_json::json!({}),
214 meta: None,
215 task: None,
216 }),
217 )
218 .await;
219
220 assert!(resp.inner.is_ok());
221 }
222
223 #[tokio::test]
224 async fn test_access_log_passes_through_namespaced_tool_call() {
225 let mock = MockService::with_tools(&["math/add"]);
226 let mut svc = AccessLogService::new(mock, "/");
227
228 let resp = call_service(
229 &mut svc,
230 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
231 name: "math/add".to_string(),
232 arguments: serde_json::json!({}),
233 meta: None,
234 task: None,
235 }),
236 )
237 .await;
238
239 assert!(resp.inner.is_ok());
240 }
241
242 #[tokio::test]
243 async fn test_access_log_handles_errors() {
244 let mock = ErrorMockService;
245 let mut svc = AccessLogService::new(mock, "/");
246
247 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
248 assert!(resp.inner.is_err());
249 }
250
251 #[tokio::test]
252 async fn test_access_log_handles_ping() {
253 let mock = MockService::with_tools(&[]);
254 let mut svc = AccessLogService::new(mock, "/");
255
256 let resp = call_service(&mut svc, McpRequest::Ping).await;
257 assert!(resp.inner.is_ok());
258 }
259}