use std::sync::{Arc, Mutex};
use std::time::Duration;
use async_trait::async_trait;
use rmcp::model::{CreateMessageRequestParams, CreateMessageResult, Role, SamplingMessage};
#[derive(Debug, Clone)]
pub enum FakeSamplingError {
Refused { reason: String },
Transport { message: String },
MalformedResponse { message: String },
}
impl std::fmt::Display for FakeSamplingError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Refused { reason } => write!(f, "client refused: {reason}"),
Self::Transport { message } => write!(f, "transport: {message}"),
Self::MalformedResponse { message } => {
write!(f, "malformed response: {message}")
}
}
}
}
impl std::error::Error for FakeSamplingError {}
#[derive(Debug, Clone)]
pub enum FakeResponse {
Text { text: String, model: String },
Slow {
text: String,
model: String,
duration: Duration,
},
EmptyContent,
Error(FakeSamplingError),
}
impl FakeResponse {
pub fn text(text: impl Into<String>) -> Self {
Self::Text {
text: text.into(),
model: "fake-claude".to_string(),
}
}
pub fn refused(reason: impl Into<String>) -> Self {
Self::Error(FakeSamplingError::Refused {
reason: reason.into(),
})
}
pub fn slow(text: impl Into<String>, duration: Duration) -> Self {
Self::Slow {
text: text.into(),
model: "fake-claude".to_string(),
duration,
}
}
}
#[derive(Clone, Default)]
pub struct FakeMcpClient {
responses: Arc<Mutex<Vec<FakeResponse>>>,
next_idx: Arc<Mutex<usize>>,
requests: Arc<Mutex<Vec<CreateMessageRequestParams>>>,
}
impl FakeMcpClient {
pub fn new(response: FakeResponse) -> Self {
Self {
responses: Arc::new(Mutex::new(vec![response])),
next_idx: Arc::new(Mutex::new(0)),
requests: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn respond_with(&self, response: FakeResponse) {
*self.responses.lock().expect("FakeMcpClient mutex poisoned") = vec![response];
*self.next_idx.lock().expect("FakeMcpClient mutex poisoned") = 0;
}
pub fn respond_each(&self, responses: Vec<FakeResponse>) {
assert!(
!responses.is_empty(),
"FakeMcpClient::respond_each: pass at least one response"
);
*self.responses.lock().expect("FakeMcpClient mutex poisoned") = responses;
*self.next_idx.lock().expect("FakeMcpClient mutex poisoned") = 0;
}
pub fn reject_with(&self, reason: impl Into<String>) {
self.respond_with(FakeResponse::refused(reason));
}
pub fn record_requests(&self) -> Vec<CreateMessageRequestParams> {
self.requests
.lock()
.expect("FakeMcpClient mutex poisoned")
.clone()
}
fn next_response(&self) -> FakeResponse {
let responses = self.responses.lock().expect("FakeMcpClient mutex poisoned");
if responses.is_empty() {
return FakeResponse::Error(FakeSamplingError::Transport {
message: "FakeMcpClient: no response configured".to_string(),
});
}
let mut idx = self.next_idx.lock().expect("FakeMcpClient mutex poisoned");
let r = responses[(*idx).min(responses.len() - 1)].clone();
if *idx < responses.len() - 1 {
*idx += 1;
}
r
}
}
#[async_trait]
impl crate::llm::sampling::SamplingClient for FakeMcpClient {
async fn create_message(
&self,
params: CreateMessageRequestParams,
) -> Result<CreateMessageResult, crate::llm::sampling::SamplingError> {
self.requests
.lock()
.expect("FakeMcpClient mutex poisoned")
.push(params.clone());
match self.next_response() {
FakeResponse::Text { text, model } => Ok(CreateMessageResult::new(
SamplingMessage::assistant_text(text),
model,
)
.with_stop_reason(CreateMessageResult::STOP_REASON_END_TURN)),
FakeResponse::Slow {
text,
model,
duration,
} => {
tokio::time::sleep(duration).await;
Ok(
CreateMessageResult::new(SamplingMessage::assistant_text(text), model)
.with_stop_reason(CreateMessageResult::STOP_REASON_END_TURN),
)
}
FakeResponse::EmptyContent => {
Ok(CreateMessageResult::new(
SamplingMessage::new_multiple(Role::Assistant, Vec::new()),
"fake-claude".to_string(),
)
.with_stop_reason(CreateMessageResult::STOP_REASON_END_TURN))
}
FakeResponse::Error(err) => Err(crate::llm::sampling::SamplingError::Fake(err)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::sampling::SamplingClient;
fn req() -> CreateMessageRequestParams {
CreateMessageRequestParams::new(vec![SamplingMessage::user_text("hi")], 512)
}
#[tokio::test]
async fn happy_path_returns_canned_text() {
let fake = FakeMcpClient::new(FakeResponse::text("hello world"));
let result = fake.create_message(req()).await.expect("ok");
let content = result.message.content.into_vec();
let text = content[0].as_text().expect("text content").text.clone();
assert_eq!(text, "hello world");
assert_eq!(result.model, "fake-claude");
}
#[tokio::test]
async fn respond_with_replaces_response() {
let fake = FakeMcpClient::new(FakeResponse::text("first"));
fake.respond_with(FakeResponse::text("second"));
let result = fake.create_message(req()).await.expect("ok");
let content = result.message.content.into_vec();
assert_eq!(content[0].as_text().unwrap().text, "second");
}
#[tokio::test]
async fn respond_each_sequences_and_wraps_to_last() {
let fake = FakeMcpClient::default();
fake.respond_each(vec![FakeResponse::text("a"), FakeResponse::text("b")]);
let r1 = fake.create_message(req()).await.expect("ok");
let r2 = fake.create_message(req()).await.expect("ok");
let r3 = fake.create_message(req()).await.expect("ok"); assert_eq!(
r1.message.content.into_vec()[0].as_text().unwrap().text,
"a"
);
assert_eq!(
r2.message.content.into_vec()[0].as_text().unwrap().text,
"b"
);
assert_eq!(
r3.message.content.into_vec()[0].as_text().unwrap().text,
"b"
);
}
#[tokio::test]
async fn reject_with_returns_refused_error() {
let fake = FakeMcpClient::new(FakeResponse::text("won't see this"));
fake.reject_with("user dismissed");
let err = fake.create_message(req()).await.unwrap_err();
match err {
crate::llm::sampling::SamplingError::Fake(FakeSamplingError::Refused { reason }) => {
assert_eq!(reason, "user dismissed");
}
other => panic!("expected Refused, got {other:?}"),
}
}
#[tokio::test]
async fn empty_content_returns_zero_content_blocks() {
let fake = FakeMcpClient::new(FakeResponse::EmptyContent);
let result = fake.create_message(req()).await.expect("ok");
let content = result.message.content.into_vec();
assert!(content.is_empty(), "EmptyContent must produce zero blocks");
}
#[tokio::test]
async fn slow_response_actually_sleeps() {
let fake = FakeMcpClient::new(FakeResponse::slow("late", Duration::from_millis(40)));
let start = std::time::Instant::now();
let _ = fake.create_message(req()).await.expect("ok");
let elapsed = start.elapsed();
assert!(
elapsed >= Duration::from_millis(35),
"slow response should sleep at least ~40ms; observed {:?}",
elapsed
);
}
#[tokio::test]
async fn record_requests_captures_each_call() {
let fake = FakeMcpClient::new(FakeResponse::text("ok"));
let _ = fake.create_message(req()).await;
let mut p2 = req();
p2.max_tokens = 1024;
let _ = fake.create_message(p2.clone()).await;
let recorded = fake.record_requests();
assert_eq!(recorded.len(), 2);
assert_eq!(recorded[0].max_tokens, 512);
assert_eq!(recorded[1].max_tokens, 1024);
}
#[tokio::test]
async fn default_with_no_response_returns_transport_error() {
let fake = FakeMcpClient::default();
*fake.responses.lock().expect("FakeMcpClient mutex poisoned") = Vec::new();
let err = fake.create_message(req()).await.unwrap_err();
match err {
crate::llm::sampling::SamplingError::Fake(FakeSamplingError::Transport { .. }) => {}
other => panic!("expected Transport error, got {other:?}"),
}
}
}