use anyhow::Result;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use tokio::sync::oneshot;
use tokio_ipc::{
RpcServerHandler, RpcClient, RpcServer, protocol, protocol_handler, protocol_sender,
};
use tokio_socket::SocketAddr;
use counter_protocol::Sender as _;
use math_protocol::Sender as _;
use string_protocol::Sender as _;
protocol! {
pub mod math_protocol {
#[derive(Debug)]
add {
a: i32,
b: i32,
} -> #[derive(Debug)] {
result: i32,
};
#[derive(Debug)]
multiply {
a: i32,
b: i32,
} -> #[derive(Debug)] {
result: i32,
};
}
}
protocol! {
pub mod string_protocol {
#[derive(Debug)]
concat {
a: String,
b: String,
} -> #[derive(Debug)] {
result: String,
};
#[derive(Debug)]
to_uppercase {
text: String,
} -> #[derive(Debug)] {
result: String,
};
}
}
protocol! {
pub mod counter_protocol {
#[derive(Debug)]
increment -> #[derive(Debug)] {
value: u64,
};
#[derive(Debug)]
get_value -> #[derive(Debug)] {
value: u64,
};
#[derive(Debug)]
reset;
}
}
#[derive(Clone)]
struct MathHandler;
impl math_protocol::Receive for MathHandler {
async fn add(&self, a: i32, b: i32) -> Result<math_protocol::add::Response> {
Ok(math_protocol::add::Response { result: a + b })
}
async fn multiply(&self, a: i32, b: i32) -> Result<math_protocol::multiply::Response> {
Ok(math_protocol::multiply::Response { result: a * b })
}
}
#[derive(Clone)]
struct StringHandler;
impl string_protocol::Receive for StringHandler {
async fn concat(&self, a: String, b: String) -> Result<string_protocol::concat::Response> {
Ok(string_protocol::concat::Response {
result: format!("{}{}", a, b),
})
}
async fn to_uppercase(&self, text: String) -> Result<string_protocol::to_uppercase::Response> {
Ok(string_protocol::to_uppercase::Response {
result: text.to_uppercase(),
})
}
}
#[derive(Clone)]
struct CounterHandler {
counter: Arc<AtomicU64>,
}
impl counter_protocol::Receive for CounterHandler {
async fn increment(&self) -> Result<counter_protocol::increment::Response> {
let value = self.counter.fetch_add(1, Ordering::SeqCst) + 1;
Ok(counter_protocol::increment::Response { value })
}
async fn get_value(&self) -> Result<counter_protocol::get_value::Response> {
let value = self.counter.load(Ordering::SeqCst);
Ok(counter_protocol::get_value::Response { value })
}
async fn reset(&self) -> Result<()> {
self.counter.store(0, Ordering::SeqCst);
Ok(())
}
}
#[derive(Clone)]
struct MultiProtocolHandler {
math: MathHandler,
string: StringHandler,
counter: CounterHandler,
}
impl math_protocol::Receive for MultiProtocolHandler {
async fn add(&self, a: i32, b: i32) -> Result<math_protocol::add::Response> {
self.math.add(a, b).await
}
async fn multiply(&self, a: i32, b: i32) -> Result<math_protocol::multiply::Response> {
self.math.multiply(a, b).await
}
}
impl string_protocol::Receive for MultiProtocolHandler {
async fn concat(&self, a: String, b: String) -> Result<string_protocol::concat::Response> {
self.string.concat(a, b).await
}
async fn to_uppercase(&self, text: String) -> Result<string_protocol::to_uppercase::Response> {
self.string.to_uppercase(text).await
}
}
impl counter_protocol::Receive for MultiProtocolHandler {
async fn increment(&self) -> Result<counter_protocol::increment::Response> {
self.counter.increment().await
}
async fn get_value(&self) -> Result<counter_protocol::get_value::Response> {
self.counter.get_value().await
}
async fn reset(&self) -> Result<()> {
self.counter.reset().await
}
}
protocol_handler!(
MultiProtocolReceiver impl [math_protocol, string_protocol, counter_protocol] with MultiProtocolHandler
);
protocol_sender! {
MultiProtocolSender impl [math_protocol, string_protocol, counter_protocol]
}
struct TestMultiProtocolServer;
impl RpcServerHandler for TestMultiProtocolServer {
type ReceiveRpc = MultiProtocolReceiver;
type SendRpc = MultiProtocolSender;
async fn on_rpc_connect(&self, _rpc_sender: &Self::SendRpc) -> Self::ReceiveRpc {
let handler = MultiProtocolHandler {
math: MathHandler,
string: StringHandler,
counter: CounterHandler {
counter: Arc::new(AtomicU64::new(0)),
},
};
MultiProtocolReceiver::new(handler)
}
}
#[tokio::test]
async fn test_multi_protocol_math_operations() {
let socket_name = format!("test-math-{}", uuid::Uuid::new_v4());
let (ready_tx, ready_rx) = oneshot::channel();
let server_socket = socket_name.clone();
let server_task = tokio::spawn(async move {
let addr = SocketAddr::abstract_uds(&server_socket);
let _server = RpcServer::bind_unix(&addr, TestMultiProtocolServer).unwrap();
let _ = ready_tx.send(());
tokio::time::sleep(Duration::from_secs(5)).await;
});
ready_rx.await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let empty_handler = tokio_ipc::ProtocolRegistry::builder().build();
let addr = SocketAddr::abstract_uds(&socket_name);
let client: RpcClient<MultiProtocolSender> =
RpcClient::connect(addr, empty_handler).await.unwrap();
let add_result = client.sender.add(10, 5).await.unwrap();
assert_eq!(add_result.result, 15);
let multiply_result = client.sender.multiply(7, 6).await.unwrap();
assert_eq!(multiply_result.result, 42);
server_task.abort();
}
#[tokio::test]
async fn test_multi_protocol_string_operations() {
let socket_name = format!("test-string-{}", uuid::Uuid::new_v4());
let (ready_tx, ready_rx) = oneshot::channel();
let server_socket = socket_name.clone();
let server_task = tokio::spawn(async move {
let addr = SocketAddr::abstract_uds(&server_socket);
let _server = RpcServer::bind_unix(&addr, TestMultiProtocolServer).unwrap();
let _ = ready_tx.send(());
tokio::time::sleep(Duration::from_secs(5)).await;
});
ready_rx.await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let empty_handler = tokio_ipc::ProtocolRegistry::builder().build();
let addr = SocketAddr::abstract_uds(&socket_name);
let client: RpcClient<MultiProtocolSender> =
RpcClient::connect(addr, empty_handler).await.unwrap();
let concat_result = client
.sender
.concat("Hello, ".to_string(), "World!".to_string())
.await
.unwrap();
assert_eq!(concat_result.result, "Hello, World!");
let uppercase_result = client
.sender
.to_uppercase("test string".to_string())
.await
.unwrap();
assert_eq!(uppercase_result.result, "TEST STRING");
server_task.abort();
}
#[tokio::test]
async fn test_multi_protocol_counter_operations() {
let socket_name = format!("test-counter-{}", uuid::Uuid::new_v4());
let (ready_tx, ready_rx) = oneshot::channel();
let server_socket = socket_name.clone();
let server_task = tokio::spawn(async move {
let addr = SocketAddr::abstract_uds(&server_socket);
let _server = RpcServer::bind_unix(&addr, TestMultiProtocolServer).unwrap();
let _ = ready_tx.send(());
tokio::time::sleep(Duration::from_secs(5)).await;
});
ready_rx.await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let empty_handler = tokio_ipc::ProtocolRegistry::builder().build();
let addr = SocketAddr::abstract_uds(&socket_name);
let client: RpcClient<MultiProtocolSender> =
RpcClient::connect(addr, empty_handler).await.unwrap();
let value1 = client.sender.get_value().await.unwrap();
assert_eq!(value1.value, 0);
let inc1 = client.sender.increment().await.unwrap();
assert_eq!(inc1.value, 1);
let inc2 = client.sender.increment().await.unwrap();
assert_eq!(inc2.value, 2);
let value2 = client.sender.get_value().await.unwrap();
assert_eq!(value2.value, 2);
client.sender.reset().await.unwrap();
let value3 = client.sender.get_value().await.unwrap();
assert_eq!(value3.value, 0);
server_task.abort();
}
#[tokio::test]
async fn test_multi_protocol_mixed_operations() {
let socket_name = format!("test-mixed-{}", uuid::Uuid::new_v4());
let (ready_tx, ready_rx) = oneshot::channel();
let server_socket = socket_name.clone();
let server_task = tokio::spawn(async move {
let addr = SocketAddr::abstract_uds(&server_socket);
let _server = RpcServer::bind_unix(&addr, TestMultiProtocolServer).unwrap();
let _ = ready_tx.send(());
tokio::time::sleep(Duration::from_secs(5)).await;
});
ready_rx.await.unwrap();
tokio::time::sleep(Duration::from_millis(100)).await;
let empty_handler = tokio_ipc::ProtocolRegistry::builder().build();
let addr = SocketAddr::abstract_uds(&socket_name);
let client: RpcClient<MultiProtocolSender> =
RpcClient::connect(addr, empty_handler).await.unwrap();
let add_result = client.sender.add(5, 3).await.unwrap();
assert_eq!(add_result.result, 8);
let concat_result = client
.sender
.concat("A".to_string(), "B".to_string())
.await
.unwrap();
assert_eq!(concat_result.result, "AB");
let inc1 = client.sender.increment().await.unwrap();
assert_eq!(inc1.value, 1);
let multiply_result = client.sender.multiply(4, 5).await.unwrap();
assert_eq!(multiply_result.result, 20);
let uppercase_result = client.sender.to_uppercase("xyz".to_string()).await.unwrap();
assert_eq!(uppercase_result.result, "XYZ");
let inc2 = client.sender.increment().await.unwrap();
assert_eq!(inc2.value, 2);
server_task.abort();
}