use async_trait::async_trait;
use pmcp::shared::TransportMessage;
use pmcp::types::{JSONRPCRequest, JSONRPCResponse};
use pmcp::{
Client, ClientCapabilities, LoggingMiddleware, Middleware, MiddlewareChain, StdioTransport,
};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Instant;
use tracing::{info, Level};
struct TimingMiddleware {
request_count: AtomicU64,
start_times: dashmap::DashMap<String, Instant>,
}
impl TimingMiddleware {
fn new() -> Self {
Self {
request_count: AtomicU64::new(0),
start_times: dashmap::DashMap::new(),
}
}
}
#[async_trait]
impl Middleware for TimingMiddleware {
async fn on_request(&self, request: &mut JSONRPCRequest) -> pmcp::Result<()> {
let count = self.request_count.fetch_add(1, Ordering::SeqCst);
info!("Request #{}: {}", count + 1, request.method);
self.start_times
.insert(request.id.to_string(), Instant::now());
Ok(())
}
async fn on_response(&self, response: &mut JSONRPCResponse) -> pmcp::Result<()> {
if let Some((_, start)) = self.start_times.remove(&response.id.to_string()) {
let elapsed = start.elapsed();
info!("Response for {} took {:?}", response.id, elapsed);
}
Ok(())
}
}
struct MetadataMiddleware {
client_id: String,
}
#[async_trait]
impl Middleware for MetadataMiddleware {
async fn on_send(&self, _message: &TransportMessage) -> pmcp::Result<()> {
info!("Client {} sending message", self.client_id);
Ok(())
}
async fn on_receive(&self, _message: &TransportMessage) -> pmcp::Result<()> {
info!("Client {} received message", self.client_id);
Ok(())
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt::init();
let mut middleware = MiddlewareChain::new();
middleware.add(Arc::new(LoggingMiddleware::new(Level::DEBUG)));
middleware.add(Arc::new(TimingMiddleware::new()));
middleware.add(Arc::new(MetadataMiddleware {
client_id: "example-client".to_string(),
}));
info!("Creating client with middleware");
let transport = StdioTransport::new();
let mut client = Client::new(transport);
let capabilities = ClientCapabilities::default();
info!("Initializing connection (middleware will track this)");
match client.initialize(capabilities).await {
Ok(server_info) => {
info!(
"Connected to: {} v{}",
server_info.server_info.name, server_info.server_info.version
);
},
Err(e) => {
info!("Connection failed (expected in example): {}", e);
},
}
Ok(())
}