tokio_ipc 0.1.0

Multi-protocol RPC framework built on top of tokio
Documentation
// tokio_ipc/tests/test_multi_protocol.rs - Tests for multi-protocol support

use anyhow::Result;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use tokio::sync::oneshot;
use tokio_ipc::{
    RpcServerHandler, RpcClient, RpcServer, protocol, protocol_handler, protocol_sender,
};
use tokio_socket::SocketAddr;

// Import Sender traits for all protocols
use counter_protocol::Sender as _;
use math_protocol::Sender as _;
use string_protocol::Sender as _;

// ============================================================================
// PROTOCOL 1: Math Operations
// ============================================================================

protocol! {
    pub mod math_protocol {
        #[derive(Debug)]
        add {
            a: i32,
            b: i32,
        } -> #[derive(Debug)] {
            result: i32,
        };

        #[derive(Debug)]
        multiply {
            a: i32,
            b: i32,
        } -> #[derive(Debug)] {
            result: i32,
        };
    }
}

// ============================================================================
// PROTOCOL 2: String Operations
// ============================================================================

protocol! {
    pub mod string_protocol {
        #[derive(Debug)]
        concat {
            a: String,
            b: String,
        } -> #[derive(Debug)] {
            result: String,
        };

        #[derive(Debug)]
        to_uppercase {
            text: String,
        } -> #[derive(Debug)] {
            result: String,
        };
    }
}

// ============================================================================
// PROTOCOL 3: Counter Operations
// ============================================================================

protocol! {
    pub mod counter_protocol {
        #[derive(Debug)]
        increment -> #[derive(Debug)] {
            value: u64,
        };

        #[derive(Debug)]
        get_value -> #[derive(Debug)] {
            value: u64,
        };

        #[derive(Debug)]
        reset;
    }
}

// ============================================================================
// HANDLER IMPLEMENTATIONS
// ============================================================================

#[derive(Clone)]
struct MathHandler;

impl math_protocol::Receive for MathHandler {
    async fn add(&self, a: i32, b: i32) -> Result<math_protocol::add::Response> {
        Ok(math_protocol::add::Response { result: a + b })
    }

    async fn multiply(&self, a: i32, b: i32) -> Result<math_protocol::multiply::Response> {
        Ok(math_protocol::multiply::Response { result: a * b })
    }
}

#[derive(Clone)]
struct StringHandler;

impl string_protocol::Receive for StringHandler {
    async fn concat(&self, a: String, b: String) -> Result<string_protocol::concat::Response> {
        Ok(string_protocol::concat::Response {
            result: format!("{}{}", a, b),
        })
    }

    async fn to_uppercase(&self, text: String) -> Result<string_protocol::to_uppercase::Response> {
        Ok(string_protocol::to_uppercase::Response {
            result: text.to_uppercase(),
        })
    }
}

#[derive(Clone)]
struct CounterHandler {
    counter: Arc<AtomicU64>,
}

impl counter_protocol::Receive for CounterHandler {
    async fn increment(&self) -> Result<counter_protocol::increment::Response> {
        let value = self.counter.fetch_add(1, Ordering::SeqCst) + 1;
        Ok(counter_protocol::increment::Response { value })
    }

    async fn get_value(&self) -> Result<counter_protocol::get_value::Response> {
        let value = self.counter.load(Ordering::SeqCst);
        Ok(counter_protocol::get_value::Response { value })
    }

    async fn reset(&self) -> Result<()> {
        self.counter.store(0, Ordering::SeqCst);
        Ok(())
    }
}

// ============================================================================
// MULTI-PROTOCOL SERVER
// ============================================================================

// Create a handler that supports all three protocols
#[derive(Clone)]
struct MultiProtocolHandler {
    math: MathHandler,
    string: StringHandler,
    counter: CounterHandler,
}

// Implement each protocol's Receive trait
impl math_protocol::Receive for MultiProtocolHandler {
    async fn add(&self, a: i32, b: i32) -> Result<math_protocol::add::Response> {
        self.math.add(a, b).await
    }

    async fn multiply(&self, a: i32, b: i32) -> Result<math_protocol::multiply::Response> {
        self.math.multiply(a, b).await
    }
}

impl string_protocol::Receive for MultiProtocolHandler {
    async fn concat(&self, a: String, b: String) -> Result<string_protocol::concat::Response> {
        self.string.concat(a, b).await
    }

    async fn to_uppercase(&self, text: String) -> Result<string_protocol::to_uppercase::Response> {
        self.string.to_uppercase(text).await
    }
}

impl counter_protocol::Receive for MultiProtocolHandler {
    async fn increment(&self) -> Result<counter_protocol::increment::Response> {
        self.counter.increment().await
    }

    async fn get_value(&self) -> Result<counter_protocol::get_value::Response> {
        self.counter.get_value().await
    }

    async fn reset(&self) -> Result<()> {
        self.counter.reset().await
    }
}

// Use the protocol_handler macro to create the receiver
protocol_handler!(
    MultiProtocolReceiver impl [math_protocol, string_protocol, counter_protocol] with MultiProtocolHandler
);

// Use the multi_protocol_sender macro to create the sender
protocol_sender! {
    MultiProtocolSender impl [math_protocol, string_protocol, counter_protocol]
}

// ============================================================================
// SERVER HANDLER
// ============================================================================

struct TestMultiProtocolServer;

impl RpcServerHandler for TestMultiProtocolServer {
    type ReceiveRpc = MultiProtocolReceiver;
    type SendRpc = MultiProtocolSender;

    async fn on_rpc_connect(&self, _rpc_sender: &Self::SendRpc) -> Self::ReceiveRpc {
        let handler = MultiProtocolHandler {
            math: MathHandler,
            string: StringHandler,
            counter: CounterHandler {
                counter: Arc::new(AtomicU64::new(0)),
            },
        };

        MultiProtocolReceiver::new(handler)
    }
}

// ============================================================================
// TESTS
// ============================================================================

#[tokio::test]
async fn test_multi_protocol_math_operations() {
    let socket_name = format!("test-math-{}", uuid::Uuid::new_v4());
    let (ready_tx, ready_rx) = oneshot::channel();

    // Start server
    let server_socket = socket_name.clone();
    let server_task = tokio::spawn(async move {
        let addr = SocketAddr::abstract_uds(&server_socket);
        let _server = RpcServer::bind_unix(&addr, TestMultiProtocolServer).unwrap();
        let _ = ready_tx.send(());
        tokio::time::sleep(Duration::from_secs(5)).await;
    });

    ready_rx.await.unwrap();
    tokio::time::sleep(Duration::from_millis(100)).await;

    // Create client with empty receiver (client doesn't need to receive anything)
    let empty_handler = tokio_ipc::ProtocolRegistry::builder().build();

    let addr = SocketAddr::abstract_uds(&socket_name);
    let client: RpcClient<MultiProtocolSender> =
        RpcClient::connect(addr, empty_handler).await.unwrap();

    // Test math protocol
    let add_result = client.sender.add(10, 5).await.unwrap();
    assert_eq!(add_result.result, 15);

    let multiply_result = client.sender.multiply(7, 6).await.unwrap();
    assert_eq!(multiply_result.result, 42);

    server_task.abort();
}

#[tokio::test]
async fn test_multi_protocol_string_operations() {
    let socket_name = format!("test-string-{}", uuid::Uuid::new_v4());
    let (ready_tx, ready_rx) = oneshot::channel();

    let server_socket = socket_name.clone();
    let server_task = tokio::spawn(async move {
        let addr = SocketAddr::abstract_uds(&server_socket);
        let _server = RpcServer::bind_unix(&addr, TestMultiProtocolServer).unwrap();
        let _ = ready_tx.send(());
        tokio::time::sleep(Duration::from_secs(5)).await;
    });

    ready_rx.await.unwrap();
    tokio::time::sleep(Duration::from_millis(100)).await;

    let empty_handler = tokio_ipc::ProtocolRegistry::builder().build();

    let addr = SocketAddr::abstract_uds(&socket_name);
    let client: RpcClient<MultiProtocolSender> =
        RpcClient::connect(addr, empty_handler).await.unwrap();

    // Test string protocol
    let concat_result = client
        .sender
        .concat("Hello, ".to_string(), "World!".to_string())
        .await
        .unwrap();
    assert_eq!(concat_result.result, "Hello, World!");

    let uppercase_result = client
        .sender
        .to_uppercase("test string".to_string())
        .await
        .unwrap();
    assert_eq!(uppercase_result.result, "TEST STRING");

    server_task.abort();
}

#[tokio::test]
async fn test_multi_protocol_counter_operations() {
    let socket_name = format!("test-counter-{}", uuid::Uuid::new_v4());
    let (ready_tx, ready_rx) = oneshot::channel();

    let server_socket = socket_name.clone();
    let server_task = tokio::spawn(async move {
        let addr = SocketAddr::abstract_uds(&server_socket);
        let _server = RpcServer::bind_unix(&addr, TestMultiProtocolServer).unwrap();
        let _ = ready_tx.send(());
        tokio::time::sleep(Duration::from_secs(5)).await;
    });

    ready_rx.await.unwrap();
    tokio::time::sleep(Duration::from_millis(100)).await;

    let empty_handler = tokio_ipc::ProtocolRegistry::builder().build();

    let addr = SocketAddr::abstract_uds(&socket_name);
    let client: RpcClient<MultiProtocolSender> =
        RpcClient::connect(addr, empty_handler).await.unwrap();

    // Test counter protocol
    let value1 = client.sender.get_value().await.unwrap();
    assert_eq!(value1.value, 0);

    let inc1 = client.sender.increment().await.unwrap();
    assert_eq!(inc1.value, 1);

    let inc2 = client.sender.increment().await.unwrap();
    assert_eq!(inc2.value, 2);

    let value2 = client.sender.get_value().await.unwrap();
    assert_eq!(value2.value, 2);

    client.sender.reset().await.unwrap();

    let value3 = client.sender.get_value().await.unwrap();
    assert_eq!(value3.value, 0);

    server_task.abort();
}

#[tokio::test]
async fn test_multi_protocol_mixed_operations() {
    let socket_name = format!("test-mixed-{}", uuid::Uuid::new_v4());
    let (ready_tx, ready_rx) = oneshot::channel();

    let server_socket = socket_name.clone();
    let server_task = tokio::spawn(async move {
        let addr = SocketAddr::abstract_uds(&server_socket);
        let _server = RpcServer::bind_unix(&addr, TestMultiProtocolServer).unwrap();
        let _ = ready_tx.send(());
        tokio::time::sleep(Duration::from_secs(5)).await;
    });

    ready_rx.await.unwrap();
    tokio::time::sleep(Duration::from_millis(100)).await;

    let empty_handler = tokio_ipc::ProtocolRegistry::builder().build();

    let addr = SocketAddr::abstract_uds(&socket_name);
    let client: RpcClient<MultiProtocolSender> =
        RpcClient::connect(addr, empty_handler).await.unwrap();

    // Mix operations from all protocols
    let add_result = client.sender.add(5, 3).await.unwrap();
    assert_eq!(add_result.result, 8);

    let concat_result = client
        .sender
        .concat("A".to_string(), "B".to_string())
        .await
        .unwrap();
    assert_eq!(concat_result.result, "AB");

    let inc1 = client.sender.increment().await.unwrap();
    assert_eq!(inc1.value, 1);

    let multiply_result = client.sender.multiply(4, 5).await.unwrap();
    assert_eq!(multiply_result.result, 20);

    let uppercase_result = client.sender.to_uppercase("xyz".to_string()).await.unwrap();
    assert_eq!(uppercase_result.result, "XYZ");

    let inc2 = client.sender.increment().await.unwrap();
    assert_eq!(inc2.value, 2);

    server_task.abort();
}