use crate::sampling::{SamplingRequest, SamplingResponse};
use crate::McpServerError;
use async_trait::async_trait;
use klieo_core::ServerOutbound;
use std::time::Duration;
const SAMPLING_TIMEOUT: Duration = Duration::from_secs(60);
#[async_trait]
pub trait McpOutboundExt {
async fn sample(&self, req: SamplingRequest) -> Result<SamplingResponse, McpServerError>;
}
#[async_trait]
impl McpOutboundExt for dyn ServerOutbound {
async fn sample(&self, req: SamplingRequest) -> Result<SamplingResponse, McpServerError> {
let params = serde_json::to_value(&req).map_err(McpServerError::SamplingSerialise)?;
let response = self
.outbound_request("sampling/createMessage", params, SAMPLING_TIMEOUT)
.await
.map_err(McpServerError::from)?;
let typed: SamplingResponse =
serde_json::from_value(response).map_err(McpServerError::SamplingDeserialise)?;
Ok(typed)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sampling::SamplingContent;
use klieo_core::ServerOutboundError;
use std::sync::Arc;
struct MockOutbound {
response: Result<serde_json::Value, ServerOutboundError>,
}
#[async_trait]
impl ServerOutbound for MockOutbound {
async fn outbound_request(
&self,
_method: &str,
_params: serde_json::Value,
_timeout: Duration,
) -> Result<serde_json::Value, ServerOutboundError> {
match &self.response {
Ok(v) => Ok(v.clone()),
Err(ServerOutboundError::Timeout) => Err(ServerOutboundError::Timeout),
Err(ServerOutboundError::TransportClosed) => {
Err(ServerOutboundError::TransportClosed)
}
Err(ServerOutboundError::Unsupported) => Err(ServerOutboundError::Unsupported),
Err(ServerOutboundError::PeerError { code, message }) => {
Err(ServerOutboundError::PeerError {
code: *code,
message: message.clone(),
})
}
Err(_) => Err(ServerOutboundError::Unsupported),
}
}
}
fn ok_outbound(payload: serde_json::Value) -> Arc<dyn ServerOutbound> {
Arc::new(MockOutbound {
response: Ok(payload),
})
}
fn err_outbound(err: ServerOutboundError) -> Arc<dyn ServerOutbound> {
Arc::new(MockOutbound { response: Err(err) })
}
fn sample_request() -> SamplingRequest {
SamplingRequest {
messages: vec![],
model_preferences: None,
system_prompt: None,
max_tokens: 32,
temperature: None,
stop_sequences: None,
}
}
#[tokio::test]
async fn sample_returns_typed_response() {
let payload = serde_json::json!({
"role": "assistant",
"content": {"type": "text", "text": "42"},
"model": "test-model",
"stopReason": "endTurn"
});
let outbound = ok_outbound(payload);
let resp = outbound
.sample(sample_request())
.await
.expect("typed response");
assert_eq!(resp.role, "assistant");
assert_eq!(resp.model, "test-model");
assert_eq!(resp.stop_reason.as_deref(), Some("endTurn"));
match resp.content {
SamplingContent::Text { text } => assert_eq!(text, "42"),
}
}
#[tokio::test]
async fn sample_maps_outbound_unsupported_to_mcp_error() {
let outbound = err_outbound(ServerOutboundError::Unsupported);
let outcome = outbound.sample(sample_request()).await;
assert!(
matches!(outcome, Err(McpServerError::OutboundUnsupported)),
"Unsupported must map to OutboundUnsupported; got {outcome:?}"
);
}
#[tokio::test]
async fn sample_maps_timeout_to_mcp_error() {
let outbound = err_outbound(ServerOutboundError::Timeout);
let outcome = outbound.sample(sample_request()).await;
assert!(
matches!(outcome, Err(McpServerError::OutboundTimeout)),
"Timeout must map to OutboundTimeout; got {outcome:?}"
);
}
#[tokio::test]
async fn sample_maps_peer_error_to_client_returned_error() {
let outbound = err_outbound(ServerOutboundError::PeerError {
code: -32601,
message: "method not found".into(),
});
let outcome = outbound.sample(sample_request()).await;
match outcome {
Err(McpServerError::ClientReturnedError { code, message }) => {
assert_eq!(code, -32601);
assert_eq!(message, "method not found");
}
other => panic!("expected ClientReturnedError; got {other:?}"),
}
}
}