use async_trait::async_trait;
use pmcp::error::{Error, ErrorCode};
use pmcp::types::capabilities::ClientCapabilities;
use pmcp::{Client, StdioTransport, ToolHandler};
use serde_json::{json, Value};
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use tokio::time::{sleep, Duration};
#[allow(dead_code)]
struct ErrorDemoTool {
call_count: Arc<AtomicU32>,
}
#[allow(dead_code)]
impl ErrorDemoTool {
fn new() -> Self {
Self {
call_count: Arc::new(AtomicU32::new(0)),
}
}
}
#[async_trait]
impl ToolHandler for ErrorDemoTool {
async fn handle(
&self,
arguments: Value,
_extra: pmcp::RequestHandlerExtra,
) -> pmcp::Result<Value> {
let scenario = arguments
.get("scenario")
.and_then(|v| v.as_str())
.unwrap_or("success");
match scenario {
"parse_error" => {
Err(Error::parse("Invalid JSON: expected object, got array"))
},
"invalid_request" => {
Err(Error::validation("Missing required parameter 'input'"))
},
"method_not_found" => {
Err(Error::method_not_found("tools/unknown"))
},
"invalid_params" => {
Err(Error::invalid_params("Parameter 'count' must be positive"))
},
"internal_error" => {
Err(Error::internal("Database connection failed"))
},
"timeout" => {
sleep(Duration::from_secs(30)).await;
Ok(json!({"status": "should_timeout"}))
},
"rate_limit" => {
Err(Error::protocol(
ErrorCode::other(-32001),
format!(
"Rate limit exceeded: {}",
json!({
"retry_after": 60,
"limit": 100,
"window": "1h"
})
),
))
},
"transient" => {
let count = self.call_count.fetch_add(1, Ordering::SeqCst);
if count < 2 {
Err(Error::internal("Temporary failure, please retry"))
} else {
Ok(json!({
"status": "success",
"attempts": count + 1
}))
}
},
"validation" => {
let input = arguments
.get("input")
.and_then(|v| v.as_str())
.unwrap_or("");
if input.len() < 5 {
Err(Error::validation(
"Input must be at least 5 characters long",
))
} else if !input.chars().all(|c| c.is_alphanumeric()) {
Err(Error::validation(
"Input must contain only alphanumeric characters",
))
} else {
Ok(json!({
"status": "validated",
"input": input
}))
}
},
_ => {
Ok(json!({
"status": "success",
"scenario": scenario
}))
},
}
}
}
async fn retry_with_backoff<F, T>(
mut operation: F,
max_retries: u32,
initial_delay: Duration,
) -> Result<T, Error>
where
F: FnMut() -> futures::future::BoxFuture<'static, Result<T, Error>>,
{
let mut delay = initial_delay;
let mut last_error = None;
for attempt in 0..=max_retries {
match operation().await {
Ok(result) => return Ok(result),
Err(e) => {
let msg = e.to_string().to_lowercase();
let is_retryable = match &e {
Error::Transport(_) => true,
_ => {
msg.contains("timeout")
|| msg.contains("unavailable")
|| msg.contains("temporary")
|| msg.contains("retry")
},
};
if !is_retryable {
return Err(e);
}
last_error = Some(e);
if attempt < max_retries {
println!(
"⏳ Attempt {} failed. Retrying in {:?}...",
attempt + 1,
delay
);
sleep(delay).await;
delay *= 2; }
},
}
}
Err(last_error.unwrap_or_else(|| Error::internal("All retry attempts failed")))
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt()
.with_env_filter("pmcp=info")
.init();
println!("=== MCP Error Handling Example ===\n");
let transport = StdioTransport::new();
let mut client = Client::new(transport);
let capabilities = ClientCapabilities::minimal();
println!("Connecting to server...");
let _server_info = client.initialize(capabilities).await?;
println!("✅ Connected!\n");
println!("📋 Example 1: Different error types\n");
let error_scenarios = vec![
("parse_error", "Parse Error"),
("invalid_request", "Invalid Request"),
("method_not_found", "Method Not Found"),
("invalid_params", "Invalid Parameters"),
("internal_error", "Internal Server Error"),
("rate_limit", "Rate Limiting"),
];
for (scenario, description) in error_scenarios {
print!("Testing {}: ", description);
match client
.call_tool("error_demo".to_string(), json!({"scenario": scenario}))
.await
{
Ok(_) => println!("✅ Unexpected success"),
Err(e) => {
println!("❌ {}", e);
if let Some(code) = e.error_code() {
println!(" Error code: {:?}", code);
}
if let Error::Protocol {
data: Some(data), ..
} = &e
{
println!(
" Additional data: {}",
serde_json::to_string_pretty(data)?
);
}
},
}
println!();
}
println!("\n📋 Example 2: Input validation\n");
let test_inputs = vec![
("abc", "Too short"),
("hello123", "Valid input"),
("hello world!", "Invalid characters"),
("valid1234", "Another valid input"),
];
for (input, description) in test_inputs {
print!("Testing '{}' ({}): ", input, description);
match client
.call_tool(
"error_demo".to_string(),
json!({
"scenario": "validation",
"input": input
}),
)
.await
{
Ok(result) => {
println!(
"✅ Result: {}",
serde_json::to_string_pretty(&result.content)?
);
},
Err(e) => {
println!("❌ {}", e);
},
}
}
println!("\n\n📋 Example 3: Retry with exponential backoff\n");
let client_clone = client.clone();
let result = retry_with_backoff(
|| {
let client = client_clone.clone();
Box::pin(async move {
client
.call_tool("error_demo".to_string(), json!({"scenario": "transient"}))
.await
})
},
3,
Duration::from_millis(500),
)
.await;
match result {
Ok(res) => {
println!(
"\n✅ Success after retries: {}",
serde_json::to_string_pretty(&res.content)?
);
},
Err(e) => {
println!("\n❌ Failed after all retries: {}", e);
},
}
println!("\n\n📋 Example 4: Timeout handling\n");
print!("Testing timeout (5s limit): ");
let start = std::time::Instant::now();
match tokio::time::timeout(
Duration::from_secs(5),
client.call_tool("error_demo".to_string(), json!({"scenario": "timeout"})),
)
.await
{
Ok(Ok(_)) => println!("✅ Unexpected success"),
Ok(Err(e)) => println!("❌ Error: {}", e),
Err(_) => {
let elapsed = start.elapsed();
println!("⏱️ Timed out after {:?}", elapsed);
},
}
println!("\n\n📋 Example 5: Error recovery strategies\n");
println!("Strategy 1 - Fallback:");
let primary_result = client
.call_tool(
"error_demo".to_string(),
json!({
"scenario": "internal_error"
}),
)
.await;
match primary_result {
Ok(res) => {
println!(" Primary succeeded: {:?}", res.content);
},
Err(e) => {
println!(" Primary failed: {}", e);
println!(" Trying fallback...");
match client
.call_tool("error_demo".to_string(), json!({"scenario": "success"}))
.await
{
Ok(_) => println!(" ✅ Fallback succeeded"),
Err(e) => println!(" ❌ Fallback also failed: {}", e),
}
},
}
println!("\nStrategy 2 - Circuit breaker:");
let mut failures = 0;
let failure_threshold = 3;
let mut circuit_open = false;
for i in 0..5 {
if circuit_open {
println!(" Attempt {}: Circuit open, skipping", i + 1);
continue;
}
match client
.call_tool(
"error_demo".to_string(),
json!({
"scenario": if i < 3 { "internal_error" } else { "success" }
}),
)
.await
{
Ok(_) => {
println!(" Attempt {}: ✅ Success", i + 1);
failures = 0; },
Err(e) => {
failures += 1;
println!(" Attempt {}: ❌ Failed ({})", i + 1, e);
if failures >= failure_threshold {
circuit_open = true;
println!(" 🚫 Circuit breaker opened!");
}
},
}
}
println!("\n\n📋 Example 6: Batch operations with error aggregation\n");
let operations = vec![
("success", "Operation 1"),
("invalid_params", "Operation 2"),
("success", "Operation 3"),
("internal_error", "Operation 4"),
("success", "Operation 5"),
];
let mut results = Vec::new();
let mut errors = Vec::new();
for (scenario, name) in operations {
match client
.call_tool("error_demo".to_string(), json!({"scenario": scenario}))
.await
{
Ok(res) => results.push((name, res)),
Err(e) => errors.push((name, e)),
}
}
println!("Batch results:");
println!(" ✅ Successful: {} operations", results.len());
for (name, _) in &results {
println!(" - {}", name);
}
println!(" ❌ Failed: {} operations", errors.len());
for (name, error) in &errors {
println!(" - {}: {}", name, error);
}
#[allow(clippy::cast_precision_loss)]
let success_rate = (results.len() as f64 / (results.len() + errors.len()) as f64) * 100.0;
println!("\n Success rate: {:.1}%", success_rate);
Ok(())
}