mrpc 0.0.6

MessagePack-RPC for Rust
Documentation
use async_trait::async_trait;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tokio::task;
use tokio::time::timeout;

use mrpc::{Client, Connection, Result, RpcError, RpcSender, Server, ServiceError, Value};

#[derive(Clone)]
struct TestServer;

#[async_trait]
impl Connection for TestServer {
    async fn handle_request(
        &self,
        _: RpcSender,
        method: &str,
        params: Vec<Value>,
    ) -> Result<Value> {
        match method {
            "add" => {
                if let [a, b] = params.as_slice() {
                    let a = a.as_i64().ok_or_else(|| {
                        RpcError::Protocol("First parameter must be an integer".into())
                    })?;
                    let b = b.as_i64().ok_or_else(|| {
                        RpcError::Protocol("Second parameter must be an integer".into())
                    })?;
                    Ok(Value::from(a + b))
                } else {
                    Err(RpcError::Protocol("Expected two parameters".into()))
                }
            }
            _ => Err(RpcError::Service(ServiceError {
                name: "MethodNotFound".into(),
                value: Value::String(format!("Method '{}' not found", method).into()),
            })),
        }
    }
}

#[derive(Clone)]
struct TestClient;

impl Default for TestClient {
    fn default() -> Self {
        TestClient
    }
}

#[async_trait]
impl Connection for TestClient {}

#[derive(Clone)]
struct TestClientConnect {
    connected_success: Arc<Mutex<bool>>,
}

impl TestClientConnect {
    fn new() -> Self {
        TestClientConnect {
            connected_success: Arc::new(Mutex::new(false)),
        }
    }
}

impl Default for TestClientConnect {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl Connection for TestClientConnect {
    async fn connected(&self, client: RpcSender) -> Result<()> {
        // Send a message during connection
        let result = client
            .send_request("add", &[Value::from(10), Value::from(20)])
            .await?;
        assert_eq!(result, Value::from(30), "Connected method request failed");

        // Set the flag to indicate successful completion
        let mut success = self.connected_success.lock().await;
        *success = true;

        Ok(())
    }
}

async fn setup_server_and_client<T: Connection + Default>(
) -> Result<(Client<T>, Server<TestServer>)> {
    let server = Server::from_fn(|| TestServer).tcp("127.0.0.1:0").await?;
    let addr = server.local_addr()?;

    let _server_handle = tokio::spawn(async move {
        server.run().await.unwrap();
    });

    let client = Client::connect_tcp(&addr.to_string(), T::default()).await?;

    Ok((client, Server::from_fn(|| TestServer)))
}

async fn setup_server_and_client_with_connect() -> Result<(
    Client<TestClientConnect>,
    Server<TestServer>,
    Arc<Mutex<bool>>,
)> {
    let test_client = TestClientConnect::new();
    let connected_success = test_client.connected_success.clone();

    let server = Server::from_fn(|| TestServer).tcp("127.0.0.1:0").await?;
    let addr = server.local_addr()?;

    let _server_handle = tokio::spawn(async move {
        server.run().await.unwrap();
    });

    let client = Client::connect_tcp(&addr.to_string(), test_client).await?;

    Ok((client, Server::from_fn(|| TestServer), connected_success))
}

#[tokio::test]
async fn test_basic_request_response() -> Result<()> {
    let (client, _) = setup_server_and_client::<TestClient>().await?;

    let result = client
        .send_request("add", &[Value::from(5), Value::from(3)])
        .await?;
    assert_eq!(result, Value::from(8));

    Ok(())
}

#[tokio::test]
async fn test_method_not_found() -> Result<()> {
    let (client, _) = setup_server_and_client::<TestClient>().await?;

    let result = client
        .send_request("non_existent_method", &[Value::from(1)])
        .await;

    match result {
        Err(RpcError::Service(ServiceError { name, value })) => {
            assert_eq!(name, "MethodNotFound");
            assert_eq!(
                value,
                Value::String("Method 'non_existent_method' not found".into())
            );
        }
        _ => panic!("Expected Service error, got {:?}", result),
    }

    Ok(())
}

#[tokio::test]
async fn test_concurrent_requests() -> Result<()> {
    let (client, _) = setup_server_and_client::<TestClient>().await?;
    let client = std::sync::Arc::new(client);

    let num_requests = 100;
    let mut handles = vec![];

    for i in 0..num_requests {
        let client_clone = client.clone();
        let handle = task::spawn(async move {
            tokio::time::sleep(Duration::from_millis(i % 10)).await;
            let result = client_clone
                .send_request("add", &[Value::from(i), Value::from(i)])
                .await?;
            assert_eq!(result, Value::from(i * 2));
            Ok::<_, RpcError>(())
        });
        handles.push(handle);
    }

    for handle in handles {
        handle.await.unwrap()?;
    }

    Ok(())
}

#[tokio::test]
async fn test_client_request_from_connected() -> Result<()> {
    let timeout_duration = Duration::from_secs(5); // 5 second timeout

    let result = timeout(timeout_duration, async {
        let (_client, _, connected_success) = setup_server_and_client_with_connect().await?;

        // Wait for the connected method to complete or timeout
        for _ in 0..50 {
            // Check every 100ms for 5 seconds
            if *connected_success.lock().await {
                return Ok(());
            }
            tokio::time::sleep(Duration::from_millis(100)).await;
        }

        Err(RpcError::Protocol(
            "Connected method did not complete in time".into(),
        ))
    })
    .await;

    match result {
        Ok(Ok(())) => Ok(()),
        Ok(Err(e)) => Err(e),
        Err(_) => Err(RpcError::Protocol("Test timed out".into())),
    }
}