klieo-mcp-server 2.2.0

Expose any klieo ToolInvoker or Agent as an MCP server over stdio or HTTP. The inverse of klieo-tools-mcp.
Documentation
//! Typed extension methods over `klieo_core::ServerOutbound`.
//!
//! Provides MCP-specific helpers (currently `sample()` for the
//! `sampling/createMessage` method) layered as a trait so tool
//! handlers obtain the outbound channel from `ToolCtx.server_outbound`
//! and call the typed methods without knowing transport details.

use crate::sampling::{SamplingRequest, SamplingResponse};
use crate::McpServerError;
use async_trait::async_trait;
use klieo_core::ServerOutbound;
use std::time::Duration;

/// Upper bound on how long a single `sampling/createMessage`
/// outbound request waits for a peer response before the helper
/// returns [`McpServerError::OutboundTimeout`]. Sized to cover an
/// interactive client LLM completion while keeping the call site
/// bounded.
const SAMPLING_TIMEOUT: Duration = Duration::from_secs(60);

/// Typed outbound MCP helpers on top of `ServerOutbound`.
///
/// Implemented for the bare `dyn ServerOutbound` trait object so
/// callers reach the helpers through any concrete outbound
/// implementation (today: `klieo_mcp_server::outbound::OutboundRequests`)
/// without widening the trait-object marker bounds at the call site.
/// The supertrait clause on `ServerOutbound: Send + Sync` already
/// guarantees every implementor carries those marker traits.
#[async_trait]
pub trait McpOutboundExt {
    /// Issue `sampling/createMessage` and parse the typed response.
    /// Returns [`McpServerError::OutboundUnsupported`] when the
    /// transport carrying this `ServerOutbound` advertises no
    /// server-initiated request support.
    async fn sample(&self, req: SamplingRequest) -> Result<SamplingResponse, McpServerError>;
}

#[async_trait]
impl McpOutboundExt for dyn ServerOutbound {
    async fn sample(&self, req: SamplingRequest) -> Result<SamplingResponse, McpServerError> {
        let params = serde_json::to_value(&req).map_err(McpServerError::SamplingSerialise)?;
        let response = self
            .outbound_request("sampling/createMessage", params, SAMPLING_TIMEOUT)
            .await
            .map_err(McpServerError::from)?;
        let typed: SamplingResponse =
            serde_json::from_value(response).map_err(McpServerError::SamplingDeserialise)?;
        Ok(typed)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::sampling::SamplingContent;
    use klieo_core::ServerOutboundError;
    use std::sync::Arc;

    /// In-memory [`ServerOutbound`] used by the extension-trait tests.
    /// Returns a single canned response (or a typed error) on every
    /// `outbound_request` call.
    struct MockOutbound {
        response: Result<serde_json::Value, ServerOutboundError>,
    }

    #[async_trait]
    impl ServerOutbound for MockOutbound {
        async fn outbound_request(
            &self,
            _method: &str,
            _params: serde_json::Value,
            _timeout: Duration,
        ) -> Result<serde_json::Value, ServerOutboundError> {
            match &self.response {
                Ok(v) => Ok(v.clone()),
                Err(ServerOutboundError::Timeout) => Err(ServerOutboundError::Timeout),
                Err(ServerOutboundError::TransportClosed) => {
                    Err(ServerOutboundError::TransportClosed)
                }
                Err(ServerOutboundError::Unsupported) => Err(ServerOutboundError::Unsupported),
                Err(ServerOutboundError::PeerError { code, message }) => {
                    Err(ServerOutboundError::PeerError {
                        code: *code,
                        message: message.clone(),
                    })
                }
                Err(_) => Err(ServerOutboundError::Unsupported),
            }
        }
    }

    fn ok_outbound(payload: serde_json::Value) -> Arc<dyn ServerOutbound> {
        Arc::new(MockOutbound {
            response: Ok(payload),
        })
    }

    fn err_outbound(err: ServerOutboundError) -> Arc<dyn ServerOutbound> {
        Arc::new(MockOutbound { response: Err(err) })
    }

    fn sample_request() -> SamplingRequest {
        SamplingRequest {
            messages: vec![],
            model_preferences: None,
            system_prompt: None,
            max_tokens: 32,
            temperature: None,
            stop_sequences: None,
        }
    }

    #[tokio::test]
    async fn sample_returns_typed_response() {
        let payload = serde_json::json!({
            "role": "assistant",
            "content": {"type": "text", "text": "42"},
            "model": "test-model",
            "stopReason": "endTurn"
        });
        let outbound = ok_outbound(payload);
        let resp = outbound
            .sample(sample_request())
            .await
            .expect("typed response");
        assert_eq!(resp.role, "assistant");
        assert_eq!(resp.model, "test-model");
        assert_eq!(resp.stop_reason.as_deref(), Some("endTurn"));
        match resp.content {
            SamplingContent::Text { text } => assert_eq!(text, "42"),
        }
    }

    #[tokio::test]
    async fn sample_maps_outbound_unsupported_to_mcp_error() {
        let outbound = err_outbound(ServerOutboundError::Unsupported);
        let outcome = outbound.sample(sample_request()).await;
        assert!(
            matches!(outcome, Err(McpServerError::OutboundUnsupported)),
            "Unsupported must map to OutboundUnsupported; got {outcome:?}"
        );
    }

    #[tokio::test]
    async fn sample_maps_timeout_to_mcp_error() {
        let outbound = err_outbound(ServerOutboundError::Timeout);
        let outcome = outbound.sample(sample_request()).await;
        assert!(
            matches!(outcome, Err(McpServerError::OutboundTimeout)),
            "Timeout must map to OutboundTimeout; got {outcome:?}"
        );
    }

    #[tokio::test]
    async fn sample_maps_peer_error_to_client_returned_error() {
        let outbound = err_outbound(ServerOutboundError::PeerError {
            code: -32601,
            message: "method not found".into(),
        });
        let outcome = outbound.sample(sample_request()).await;
        match outcome {
            Err(McpServerError::ClientReturnedError { code, message }) => {
                assert_eq!(code, -32601);
                assert_eq!(message, "method not found");
            }
            other => panic!("expected ClientReturnedError; got {other:?}"),
        }
    }
}