use async_trait::async_trait;
use std::sync::Arc;
use tempfile::tempdir;
use tokio::sync::Mutex;
use tokio_util::compat::TokioAsyncReadCompatExt;
use mrpc::{Connection, Result as MrpcResult, RpcError, RpcSender, Server, Value};
const PINGS: u32 = 3;
#[derive(Clone, Default)]
struct PingPongService {
pong_counter: Arc<Mutex<u32>>,
}
#[async_trait]
impl Connection for PingPongService {
async fn handle_request(
&self,
_client: RpcSender,
method: &str,
params: Vec<Value>,
) -> MrpcResult<Value> {
match method {
"ping" => {
let _ = params.first().and_then(|v| v.as_i64()).ok_or_else(|| {
RpcError::Protocol("Expected integer parameter for ping".into())
})?;
let mut count = self.pong_counter.lock().await;
*count += 1;
Ok(Value::Boolean(true))
}
_ => Err(RpcError::Protocol(format!(
"PingPongService: Unknown method: {}",
method
))),
}
}
}
#[tokio::test]
async fn test_mrpc_compatibility_with_msgpack_rpc() -> Result<(), Box<dyn std::error::Error>> {
let temp_dir = tempdir()?;
let socket_path = temp_dir.path().join("pingpong.sock");
let pong_counter = Arc::new(Mutex::new(0_u32));
let pong_counter_clone = pong_counter.clone();
let server: Server<PingPongService> = Server::from_fn(move || PingPongService {
pong_counter: pong_counter_clone.clone(),
})
.unix(&socket_path)
.await?;
let server_task = tokio::spawn(async move {
if let Err(e) = server.run().await {
panic!("Server error: {}", e);
}
});
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let socket = tokio::net::UnixStream::connect(&socket_path).await?;
let client = msgpack_rpc::Client::new(socket.compat());
for i in 0..PINGS {
client
.request("ping", &[msgpack_rpc::Value::Integer(i.into())])
.await
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
let final_count = *pong_counter.lock().await;
assert_eq!(
final_count, PINGS,
"Expected {} pongs, but got {}",
PINGS, final_count
);
server_task.abort();
Ok(())
}