tmcp 0.4.0

Complete, ergonomic implementation of the Model Context Protocol (MCP)
Documentation
//! Client/server ping integration tests.

#[cfg(test)]
mod tests {
    use std::sync::{Arc, Mutex};

    use async_trait::async_trait;
    use tmcp::{
        ClientCtx, ClientHandler, Result, ServerCtx, ServerHandler,
        schema::*,
        testutils::{
            connected_client_and_server_with_conn, shutdown_client_and_server, test_client_ctx,
        },
    };
    use tokio::sync::mpsc;
    use tracing_subscriber::fmt;

    #[derive(Default, Clone)]
    struct TestClientHandler {
        calls: Arc<Mutex<Vec<String>>>,
    }

    #[async_trait]
    impl ClientHandler for TestClientHandler {
        async fn on_connect(&self, _ctx: &ClientCtx) -> Result<()> {
            self.calls.lock().unwrap().push("on_connect".into());
            Ok(())
        }

        async fn on_shutdown(&self, _ctx: &ClientCtx) -> Result<()> {
            self.calls.lock().unwrap().push("on_shutdown".into());
            Ok(())
        }

        async fn pong(&self, _ctx: &ClientCtx) -> Result<()> {
            self.calls.lock().unwrap().push("ping".into());
            Ok(())
        }

        async fn create_message(
            &self,
            _ctx: &ClientCtx,
            _method: &str,
            _params: CreateMessageParams,
        ) -> Result<CreateMessageResult> {
            self.calls.lock().unwrap().push("create_message".into());
            Ok(CreateMessageResult {
                message: SamplingMessage::assistant_text("Test response"),
                model: "test-model".into(),
                stop_reason: None,
            })
        }

        async fn list_roots(&self, _ctx: &ClientCtx) -> Result<ListRootsResult> {
            self.calls.lock().unwrap().push("list_roots".into());
            Ok(ListRootsResult {
                roots: vec![Root {
                    uri: "test://root".into(),
                    name: Some("Test Root".into()),
                    _meta: None,
                }],
                _meta: None,
            })
        }
    }

    struct TestServerHandler;

    #[async_trait]
    impl ServerHandler for TestServerHandler {
        async fn initialize(
            &self,
            _ctx: &ServerCtx,
            _protocol_version: String,
            _capabilities: ClientCapabilities,
            _client_info: Implementation,
        ) -> Result<InitializeResult> {
            Ok(InitializeResult::new("test-server").with_version("1.0.0"))
        }

        async fn pong(&self, _ctx: &ServerCtx) -> Result<()> {
            Ok(())
        }
    }

    #[tokio::test]
    async fn client_connection_trait_methods() {
        let connection = TestClientHandler::default();

        let (tx, _) = mpsc::unbounded_channel();
        let ctx = test_client_ctx(tx);

        connection.pong(&ctx).await.expect("Ping failed");

        let params = CreateMessageParams::user_message("Hello").with_max_tokens(1000);

        let result = connection
            .create_message(&ctx, "test", params)
            .await
            .expect("Create message failed");
        assert_eq!(result.model, "test-model");

        let roots = connection.list_roots(&ctx).await.unwrap();
        assert_eq!(roots.roots.len(), 1);

        let calls = connection.calls.lock().unwrap();
        assert!(calls.contains(&"ping".to_string()));
        assert!(calls.contains(&"create_message".to_string()));
        assert!(calls.contains(&"list_roots".to_string()));
    }

    #[tokio::test]
    async fn client_server_ping() {
        fmt::try_init().ok();

        let calls = Arc::new(Mutex::new(Vec::new()));

        let (mut client, handle) = connected_client_and_server_with_conn(
            || Box::new(TestServerHandler),
            TestClientHandler {
                calls: calls.clone(),
            },
        )
        .await
        .expect("setup");

        client.init().await.expect("client init");
        client.ping().await.expect("client ping");

        {
            let list = calls.lock().unwrap();
            assert!(list.contains(&"on_connect".to_string()));
        }

        shutdown_client_and_server(client, handle).await;
    }
}