use std::collections::HashMap;
use std::convert::TryInto;
#[cfg(unix)]
use std::os::unix::io::RawFd;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use tokio::{self, sync::mpsc, task};
use crate::error::{get_rpc_status, Error, Result};
use crate::proto::{
Code, Codec, GenMessage, Message, MessageHeader, Request, Response, FLAG_NO_DATA,
FLAG_REMOTE_CLOSED, FLAG_REMOTE_OPEN, MESSAGE_TYPE_DATA, MESSAGE_TYPE_RESPONSE,
};
use crate::r#async::connection::*;
use crate::r#async::shutdown;
use crate::r#async::stream::{
Kind, MessageReceiver, MessageSender, ResultReceiver, ResultSender, StreamInner,
};
use super::stream::SendingMessage;
use super::transport::Socket;
#[derive(Clone)]
pub struct Client {
req_tx: MessageSender,
next_stream_id: Arc<AtomicU32>,
streams: Arc<Mutex<HashMap<u32, ResultSender>>>,
}
impl Client {
pub async fn connect(sockaddr: &str) -> Result<Client> {
let socket = Socket::connect(sockaddr)
.await
.map_err(err_to_others_err!(e, "Socket::connect error "))?;
Ok(Self::new(socket))
}
#[cfg(unix)]
pub unsafe fn from_raw_unix_socket_fd(fd: RawFd) -> Client {
let stream = unsafe { Socket::from_raw_unix_socket_fd(fd) }.unwrap();
Self::new(stream)
}
pub fn new(stream: Socket) -> Client {
let (req_tx, rx): (MessageSender, MessageReceiver) = mpsc::channel(100);
let req_map = Arc::new(Mutex::new(HashMap::new()));
let delegate = ClientBuilder {
rx: Some(rx),
streams: req_map.clone(),
};
let conn = Connection::new(stream, delegate);
tokio::spawn(async move { conn.run().await });
Client {
req_tx,
next_stream_id: Arc::new(AtomicU32::new(1)),
streams: req_map,
}
}
pub async fn request(&self, req: Request) -> Result<Response> {
let timeout_nano = req.timeout_nano;
let stream_id = self.next_stream_id.fetch_add(2, Ordering::Relaxed);
let msg: GenMessage = Message::new_request(stream_id, req)?
.try_into()
.map_err(|e: protobuf::Error| Error::Others(e.to_string()))?;
let (tx, mut rx): (ResultSender, ResultReceiver) = mpsc::channel(100);
self.streams
.lock()
.map_err(|_| Error::Others("Failed to acquire lock on streams".to_string()))?
.insert(stream_id, tx);
self.req_tx
.send(SendingMessage::new(msg))
.await
.map_err(|_| Error::LocalClosed)?;
let result = if timeout_nano == 0 {
rx.recv().await.ok_or_else(|| Error::RemoteClosed)?
} else {
tokio::time::timeout(
std::time::Duration::from_nanos(timeout_nano as u64),
rx.recv(),
)
.await
.map_err(|e| Error::Others(format!("Receive packet timeout {e:?}")))?
.ok_or_else(|| Error::RemoteClosed)?
};
let msg = result?;
let res = Response::decode(msg.payload)
.map_err(err_to_others_err!(e, "Unpack response error "))?;
let status = res.status();
if status.code() != Code::OK {
return Err(Error::RpcStatus((*status).clone()));
}
Ok(res)
}
pub async fn new_stream(
&self,
req: Request,
streaming_client: bool,
streaming_server: bool,
) -> Result<StreamInner> {
let stream_id = self.next_stream_id.fetch_add(2, Ordering::Relaxed);
let is_req_payload_empty = req.payload.is_empty();
let mut msg: GenMessage = Message::new_request(stream_id, req)?
.try_into()
.map_err(|e: protobuf::Error| Error::Others(e.to_string()))?;
if streaming_client {
if !is_req_payload_empty {
return Err(get_rpc_status(
Code::INVALID_ARGUMENT,
"Creating a ClientStream and sending payload at the same time is not allowed",
));
}
msg.header.add_flags(FLAG_REMOTE_OPEN | FLAG_NO_DATA);
} else {
msg.header.add_flags(FLAG_REMOTE_CLOSED);
}
let (tx, rx): (ResultSender, ResultReceiver) = mpsc::channel(100);
self.streams
.lock()
.map_err(|_| Error::Others("Failed to acquire lock on streams".to_string()))?
.insert(stream_id, tx);
self.req_tx
.send(SendingMessage::new(msg))
.await
.map_err(|e| Error::Others(format!("Send packet to sender error {e:?}")))?;
Ok(StreamInner::new(
stream_id,
self.req_tx.clone(),
rx,
streaming_client,
streaming_server,
Kind::Client,
self.streams.clone(),
))
}
}
#[derive(Debug)]
struct ClientBuilder {
rx: Option<MessageReceiver>,
streams: Arc<Mutex<HashMap<u32, ResultSender>>>,
}
impl Builder for ClientBuilder {
type Reader = ClientReader;
type Writer = ClientWriter;
fn build(&mut self) -> (Self::Reader, Self::Writer) {
let (notifier, waiter) = shutdown::new();
(
ClientReader {
shutdown_waiter: waiter,
streams: self.streams.clone(),
},
ClientWriter {
rx: self.rx.take().unwrap(),
shutdown_notifier: notifier,
streams: self.streams.clone(),
},
)
}
}
struct ClientWriter {
rx: MessageReceiver,
shutdown_notifier: shutdown::Notifier,
streams: Arc<Mutex<HashMap<u32, ResultSender>>>,
}
#[async_trait]
impl WriterDelegate for ClientWriter {
async fn recv(&mut self) -> Option<SendingMessage> {
self.rx.recv().await
}
async fn disconnect(&self, msg: &GenMessage, e: Error) {
let resp_tx = {
let mut map = self.streams.lock().unwrap();
map.remove(&msg.header.stream_id)
};
if let Some(resp_tx) = resp_tx {
let e = Error::Socket(format!("{e:?}"));
resp_tx
.send(Err(e))
.await
.unwrap_or_else(|_e| error!("The request has returned"));
}
}
async fn exit(&self) {
self.shutdown_notifier.shutdown();
}
}
async fn get_resp_tx(
req_map: Arc<Mutex<HashMap<u32, ResultSender>>>,
header: &MessageHeader,
) -> Option<ResultSender> {
let resp_tx = match header.type_ {
MESSAGE_TYPE_RESPONSE => match req_map.lock().unwrap().remove(&header.stream_id) {
Some(tx) => tx,
None => {
debug!("Receiver got unknown response packet {:?}", header);
return None;
}
},
MESSAGE_TYPE_DATA => {
if (header.flags & FLAG_REMOTE_CLOSED) == FLAG_REMOTE_CLOSED {
match req_map.lock().unwrap().remove(&header.stream_id) {
Some(tx) => tx,
None => {
debug!("Receiver got unknown data packet {:?}", header);
return None;
}
}
} else {
match req_map.lock().unwrap().get(&header.stream_id) {
Some(tx) => tx.clone(),
None => {
debug!("Receiver got unknown data packet {:?}", header);
return None;
}
}
}
}
_ => {
let resp_tx = match req_map.lock().unwrap().remove(&header.stream_id) {
Some(tx) => tx,
None => {
debug!("Receiver got unknown packet {:?}", header);
return None;
}
};
resp_tx
.send(Err(Error::Others(format!(
"Receiver got malformed packet {header:?}"
))))
.await
.unwrap_or_else(|_e| error!("The request has returned"));
return None;
}
};
Some(resp_tx)
}
struct ClientReader {
streams: Arc<Mutex<HashMap<u32, ResultSender>>>,
shutdown_waiter: shutdown::Waiter,
}
#[async_trait]
impl ReaderDelegate for ClientReader {
async fn wait_shutdown(&self) {
self.shutdown_waiter.wait_shutdown().await
}
async fn disconnect(&self, e: Error, sender: &mut task::JoinHandle<()>) {
sender.abort();
let _ = sender.await;
let mut map = std::mem::take(&mut *self.streams.lock().unwrap());
for (_stream_id, resp_tx) in map.drain() {
if let Err(_e) = resp_tx.send(Err(e.clone())).await {
warn!("Failed to terminate pending RPC: the request has returned");
}
}
}
async fn exit(&self) {}
async fn handle_err(&self, header: MessageHeader, e: Error) {
let req_map = self.streams.clone();
tokio::spawn(async move {
if let Some(resp_tx) = get_resp_tx(req_map, &header).await {
resp_tx
.send(Err(e))
.await
.unwrap_or_else(|_e| error!("The request has returned"));
}
});
}
async fn handle_msg(&self, msg: GenMessage) {
let req_map = self.streams.clone();
tokio::spawn(async move {
if let Some(resp_tx) = get_resp_tx(req_map, &msg.header).await {
resp_tx
.send(Ok(msg))
.await
.unwrap_or_else(|_e| error!("The request has returned"));
}
});
}
}