use async_trait::async_trait;
use bytes::Bytes;
use std::collections::VecDeque;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::{Mutex, Notify};
use crate::error::{Error, Result};
use crate::packet::{new_call_cancel, new_call_data_full, new_call_start, Validate};
use crate::proto::{packet::Body, CallData, CallStart, Packet};
use crate::stream::{Context, Stream};
use crate::transport::decode_optional_data;
#[async_trait]
pub trait PacketWriter: Send + Sync {
async fn write_packet(&self, packet: Packet) -> Result<()>;
async fn close(&self) -> Result<()>;
}
pub struct CommonRpc {
ctx: Context,
service: String,
method: String,
local_completed: AtomicBool,
writer: Arc<dyn PacketWriter>,
notify: Notify,
state: Mutex<RpcState>,
}
struct RpcState {
data_queue: VecDeque<Bytes>,
data_closed: bool,
remote_err: Option<String>,
}
impl CommonRpc {
pub fn new(
ctx: Context,
service: String,
method: String,
writer: Arc<dyn PacketWriter>,
) -> Self {
Self {
ctx,
service,
method,
local_completed: AtomicBool::new(false),
writer,
notify: Notify::new(),
state: Mutex::new(RpcState {
data_queue: VecDeque::new(),
data_closed: false,
remote_err: None,
}),
}
}
pub fn context(&self) -> &Context {
&self.ctx
}
pub fn service(&self) -> &str {
&self.service
}
pub fn method(&self) -> &str {
&self.method
}
pub fn is_local_completed(&self) -> bool {
self.local_completed.load(Ordering::SeqCst)
}
pub async fn wait(&self) -> Result<()> {
loop {
{
let state = self.state.lock().await;
if self.ctx.is_cancelled() {
return Err(Error::Cancelled);
}
if let Some(ref err) = state.remote_err {
return Err(Error::Remote(err.clone()));
}
if state.data_closed {
return Ok(());
}
}
tokio::select! {
_ = self.notify.notified() => continue,
_ = self.ctx.cancelled() => return Err(Error::Cancelled),
}
}
}
pub async fn read_one(&self) -> Result<Bytes> {
loop {
{
let mut state = self.state.lock().await;
if let Some(data) = state.data_queue.pop_front() {
return Ok(data);
}
if state.data_closed {
if let Some(ref err) = state.remote_err {
return Err(Error::Remote(err.clone()));
}
return Err(Error::StreamClosed);
}
}
if self.ctx.is_cancelled() {
let mut state = self.state.lock().await;
if !state.data_closed {
state.data_closed = true;
drop(state);
let _ = self.writer.close().await;
self.ctx.cancel();
self.notify.notify_waiters();
}
return Err(Error::Cancelled);
}
tokio::select! {
_ = self.notify.notified() => continue,
_ = self.ctx.cancelled() => {
continue;
}
}
}
}
pub async fn write_call_data(
&self,
data: Option<Bytes>,
complete: bool,
error: Option<String>,
) -> Result<()> {
let should_complete = complete || error.is_some();
if should_complete {
if self
.local_completed
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_err()
{
if complete && data.is_none() && error.is_none() {
return Ok(());
}
return Err(Error::Completed);
}
} else if self.local_completed.load(Ordering::SeqCst) {
return Err(Error::Completed);
}
let packet = new_call_data_full(data, complete, error);
self.writer.write_packet(packet).await
}
pub async fn write_call_cancel(&self) -> Result<()> {
if self.local_completed.swap(true, Ordering::SeqCst) {
return Err(Error::Completed);
}
self.writer.write_packet(new_call_cancel()).await
}
pub async fn handle_call_data(&self, call_data: CallData) -> Result<()> {
let mut state = self.state.lock().await;
if state.data_closed {
if call_data.complete {
return Ok(());
}
return Err(Error::Completed);
}
if let Some(data) = decode_optional_data(call_data.data, call_data.data_is_zero) {
state.data_queue.push_back(data);
}
if !call_data.error.is_empty() {
state.remote_err = Some(call_data.error);
state.data_closed = true;
} else if call_data.complete {
state.data_closed = true;
}
drop(state);
self.notify.notify_waiters();
Ok(())
}
pub async fn handle_call_cancel(&self) -> Result<()> {
self.handle_stream_close(Some("cancelled".to_string())).await
}
pub async fn handle_stream_close(&self, err: Option<String>) -> Result<()> {
let mut state = self.state.lock().await;
if let Some(e) = err {
if state.remote_err.is_none() {
state.remote_err = Some(e);
}
}
state.data_closed = true;
drop(state);
let _ = self.writer.close().await;
self.ctx.cancel();
self.notify.notify_waiters();
Ok(())
}
async fn close_locked(&self) {
let mut state = self.state.lock().await;
state.data_closed = true;
self.local_completed.store(true, Ordering::SeqCst);
drop(state);
let _ = self.writer.close().await;
self.notify.notify_waiters();
self.ctx.cancel();
}
}
pub struct ClientRpc {
common: CommonRpc,
start_sent: AtomicBool,
}
impl ClientRpc {
pub fn new(
ctx: Context,
service: String,
method: String,
writer: Arc<dyn PacketWriter>,
) -> Self {
Self {
common: CommonRpc::new(ctx, service, method, writer),
start_sent: AtomicBool::new(false),
}
}
pub fn context(&self) -> &Context {
self.common.context()
}
pub fn service(&self) -> &str {
self.common.service()
}
pub fn method(&self) -> &str {
self.common.method()
}
pub async fn wait(&self) -> Result<()> {
self.common.wait().await
}
pub async fn start(&self, data: Option<Bytes>) -> Result<()> {
if self
.start_sent
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_err()
{
return Err(Error::Completed);
}
if self.common.ctx.is_cancelled() {
self.common.ctx.cancel();
let _ = self.common.writer.close().await;
return Err(Error::Cancelled);
}
let packet = new_call_start(
self.common.service.clone(),
self.common.method.clone(),
data,
);
if let Err(e) = self.common.writer.write_packet(packet).await {
self.common.ctx.cancel();
let _ = self.common.writer.close().await;
return Err(e);
}
Ok(())
}
pub async fn handle_packet(&self, packet: Packet) -> Result<()> {
packet.validate()?;
match packet.body {
Some(Body::CallData(call_data)) => self.common.handle_call_data(call_data).await,
Some(Body::CallCancel(true)) => self.common.handle_call_cancel().await,
Some(Body::CallCancel(false)) => Ok(()),
Some(Body::CallStart(_)) => {
Err(Error::UnrecognizedPacket)
}
None => Err(Error::EmptyPacket),
}
}
pub async fn handle_stream_close(&self, err: Option<String>) -> Result<()> {
self.common.handle_stream_close(err).await
}
pub async fn close(&self) {
if !self.start_sent.load(Ordering::SeqCst) {
return;
}
let _ = self.common.write_call_cancel().await;
self.common.close_locked().await;
}
}
#[async_trait]
impl Stream for ClientRpc {
fn context(&self) -> &Context {
&self.common.ctx
}
async fn send_bytes(&self, data: Bytes) -> Result<()> {
self.common
.write_call_data(Some(data), false, None)
.await
}
async fn recv_bytes(&self) -> Result<Bytes> {
self.common.read_one().await
}
async fn close_send(&self) -> Result<()> {
self.common.write_call_data(None, true, None).await
}
async fn close(&self) -> Result<()> {
ClientRpc::close(self).await;
Ok(())
}
}
pub struct ServerRpc {
common: CommonRpc,
initial_data: Mutex<Option<Bytes>>,
}
impl ServerRpc {
pub fn from_call_start(
ctx: Context,
call_start: CallStart,
writer: Arc<dyn PacketWriter>,
) -> Self {
let initial_data = decode_optional_data(call_start.data, call_start.data_is_zero);
Self {
common: CommonRpc::new(ctx, call_start.rpc_service, call_start.rpc_method, writer),
initial_data: Mutex::new(initial_data),
}
}
pub fn context(&self) -> &Context {
self.common.context()
}
pub fn service(&self) -> &str {
self.common.service()
}
pub fn method(&self) -> &str {
self.common.method()
}
pub async fn wait(&self) -> Result<()> {
self.common.wait().await
}
pub async fn handle_packet(&self, packet: Packet) -> Result<()> {
packet.validate()?;
match packet.body {
Some(Body::CallData(call_data)) => self.common.handle_call_data(call_data).await,
Some(Body::CallCancel(true)) => self.common.handle_call_cancel().await,
Some(Body::CallCancel(false)) => Ok(()),
Some(Body::CallStart(_)) => {
Err(Error::DuplicateCallStart)
}
None => Err(Error::EmptyPacket),
}
}
pub async fn handle_stream_close(&self, err: Option<String>) -> Result<()> {
self.common.handle_stream_close(err).await
}
pub async fn send_error(&self, error: String) -> Result<()> {
self.common.write_call_data(None, true, Some(error)).await
}
}
#[async_trait]
impl Stream for ServerRpc {
fn context(&self) -> &Context {
&self.common.ctx
}
async fn send_bytes(&self, data: Bytes) -> Result<()> {
self.common
.write_call_data(Some(data), false, None)
.await
}
async fn recv_bytes(&self) -> Result<Bytes> {
{
let mut initial = self.initial_data.lock().await;
if let Some(data) = initial.take() {
return Ok(data);
}
}
self.common.read_one().await
}
async fn close_send(&self) -> Result<()> {
self.common.write_call_data(None, true, None).await
}
async fn close(&self) -> Result<()> {
self.common.local_completed.store(true, Ordering::SeqCst);
self.common.writer.close().await?;
self.common.ctx.cancel();
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex as StdMutex;
struct MockWriter {
packets: StdMutex<Vec<Packet>>,
closed: AtomicBool,
}
impl MockWriter {
fn new() -> Self {
Self {
packets: StdMutex::new(Vec::new()),
closed: AtomicBool::new(false),
}
}
fn packets(&self) -> Vec<Packet> {
self.packets.lock().unwrap().clone()
}
fn is_closed(&self) -> bool {
self.closed.load(Ordering::SeqCst)
}
}
#[async_trait]
impl PacketWriter for MockWriter {
async fn write_packet(&self, packet: Packet) -> Result<()> {
self.packets.lock().unwrap().push(packet);
Ok(())
}
async fn close(&self) -> Result<()> {
self.closed.store(true, Ordering::SeqCst);
Ok(())
}
}
#[tokio::test]
async fn test_client_rpc_start() {
let writer = Arc::new(MockWriter::new());
let ctx = Context::new();
let rpc = ClientRpc::new(ctx, "test.Service".into(), "TestMethod".into(), writer.clone());
rpc.start(Some(Bytes::from(vec![1, 2, 3]))).await.unwrap();
let packets = writer.packets();
assert_eq!(packets.len(), 1);
match &packets[0].body {
Some(Body::CallStart(cs)) => {
assert_eq!(cs.rpc_service, "test.Service");
assert_eq!(cs.rpc_method, "TestMethod");
assert_eq!(cs.data, vec![1, 2, 3]);
assert!(!cs.data_is_zero);
}
_ => panic!("expected CallStart"),
}
}
#[tokio::test]
async fn test_client_rpc_double_start_fails() {
let writer = Arc::new(MockWriter::new());
let ctx = Context::new();
let rpc = ClientRpc::new(ctx, "test.Service".into(), "TestMethod".into(), writer);
rpc.start(None).await.unwrap();
let result = rpc.start(None).await;
assert!(matches!(result, Err(Error::Completed)));
}
#[tokio::test]
async fn test_client_rpc_close_sends_cancel() {
let writer = Arc::new(MockWriter::new());
let ctx = Context::new();
let rpc = ClientRpc::new(ctx, "test.Service".into(), "TestMethod".into(), writer.clone());
rpc.start(None).await.unwrap();
rpc.close().await;
let packets = writer.packets();
assert_eq!(packets.len(), 2);
assert!(packets[1].is_call_cancel());
assert!(writer.is_closed());
}
#[tokio::test]
async fn test_server_rpc_from_call_start() {
let call_start = CallStart {
rpc_service: "test.Service".into(),
rpc_method: "TestMethod".into(),
data: vec![1, 2, 3],
data_is_zero: false,
};
let writer = Arc::new(MockWriter::new());
let ctx = Context::new();
let rpc = ServerRpc::from_call_start(ctx, call_start, writer);
assert_eq!(rpc.service(), "test.Service");
assert_eq!(rpc.method(), "TestMethod");
let data = rpc.recv_bytes().await.unwrap();
assert_eq!(&data[..], &[1, 2, 3]);
}
#[tokio::test]
async fn test_common_rpc_read_one_with_data() {
let writer = Arc::new(MockWriter::new());
let ctx = Context::new();
let rpc = CommonRpc::new(ctx, "svc".into(), "method".into(), writer);
let call_data = CallData {
data: vec![1, 2, 3],
data_is_zero: false,
complete: false,
error: String::new(),
};
rpc.handle_call_data(call_data).await.unwrap();
let data = rpc.read_one().await.unwrap();
assert_eq!(&data[..], &[1, 2, 3]);
}
#[tokio::test]
async fn test_common_rpc_read_one_stream_closed() {
let writer = Arc::new(MockWriter::new());
let ctx = Context::new();
let rpc = CommonRpc::new(ctx, "svc".into(), "method".into(), writer);
let call_data = CallData {
data: vec![],
data_is_zero: false,
complete: true,
error: String::new(),
};
rpc.handle_call_data(call_data).await.unwrap();
let result = rpc.read_one().await;
assert!(matches!(result, Err(Error::StreamClosed)));
}
#[tokio::test]
async fn test_common_rpc_read_one_with_error() {
let writer = Arc::new(MockWriter::new());
let ctx = Context::new();
let rpc = CommonRpc::new(ctx, "svc".into(), "method".into(), writer);
let call_data = CallData {
data: vec![],
data_is_zero: false,
complete: true,
error: "test error".into(),
};
rpc.handle_call_data(call_data).await.unwrap();
let result = rpc.read_one().await;
match result {
Err(Error::Remote(msg)) => assert_eq!(msg, "test error"),
_ => panic!("expected Remote error"),
}
}
#[tokio::test]
async fn test_write_call_data_after_complete() {
let writer = Arc::new(MockWriter::new());
let ctx = Context::new();
let rpc = CommonRpc::new(ctx, "svc".into(), "method".into(), writer);
rpc.write_call_data(None, true, None).await.unwrap();
let result = rpc.write_call_data(Some(Bytes::from(vec![1])), false, None).await;
assert!(matches!(result, Err(Error::Completed)));
}
#[tokio::test]
async fn test_write_call_cancel() {
let writer = Arc::new(MockWriter::new());
let ctx = Context::new();
let rpc = CommonRpc::new(ctx, "svc".into(), "method".into(), writer.clone());
rpc.write_call_cancel().await.unwrap();
let packets = writer.packets();
assert_eq!(packets.len(), 1);
assert!(packets[0].is_call_cancel());
let result = rpc.write_call_cancel().await;
assert!(matches!(result, Err(Error::Completed)));
}
}