use crate::{
error::{RpcError, RpcResult},
io::{self, Message},
};
use async_std::{
channel::{bounded, Receiver, Sender},
future,
net::{TcpListener, TcpStream},
stream::StreamExt,
sync::{Arc, Mutex},
task,
};
use identity::Identity;
use std::{
collections::BTreeMap,
net::Shutdown,
sync::atomic::{AtomicBool, Ordering},
time::Duration,
};
type Lock<T> = Arc<Mutex<T>>;
pub fn default_socket_path() -> (&'static str, u16) {
("localhost", 10222)
}
pub struct RpcSocket {
stream: Option<TcpStream>,
listen: Option<Arc<TcpListener>>,
running: AtomicBool,
listening: AtomicBool,
wfm: Lock<BTreeMap<Identity, Sender<Message>>>,
inc_io: (Sender<Message>, Receiver<Message>),
timeout: Duration,
}
impl RpcSocket {
pub async fn connect(addr: &str, port: u16) -> RpcResult<Arc<Self>> {
Self::connect_timeout(addr, port, Duration::from_secs(5)).await
}
pub async fn connect_timeout(addr: &str, port: u16, timeout: Duration) -> RpcResult<Arc<Self>> {
let stream = TcpStream::connect(&format!("{}:{}", addr, port)).await?;
let _self = Arc::new(Self {
stream: Some(stream),
listen: None,
running: true.into(),
listening: false.into(),
wfm: Default::default(),
inc_io: bounded(4),
timeout,
});
_self.spawn_incoming();
Ok(_self)
}
pub async fn listen<F: Fn(Message) + Send + 'static>(self: &Arc<Self>, cb: F) {
let _self = Arc::clone(self);
_self.listening.swap(true, Ordering::Relaxed);
task::spawn(async move {
while let Ok(msg) = _self.inc_io.1.recv().await {
cb(msg);
}
});
}
pub async fn server<F, D>(addr: &str, port: u16, cb: F, data: D) -> RpcResult<Arc<Self>>
where
F: Fn(TcpStream, D) + Send + Copy + 'static,
D: Send + Sync + Clone + 'static,
{
info!("Opening qrpc socket on {}:{}", addr, port);
let listen = Arc::new(TcpListener::bind(format!("{}:{}", addr, port)).await?);
let _self = Arc::new(Self {
stream: None,
listen: Some(listen),
running: true.into(),
listening: true.into(),
wfm: Default::default(),
inc_io: bounded(4),
timeout: Duration::from_secs(5),
});
let s = Arc::clone(&_self);
task::spawn(async move {
let mut inc = s.listen.as_ref().unwrap().incoming();
while let Some(Ok(stream)) = inc.next().await {
if !s.running() {
break;
}
debug!("New incoming qrpc connection! ({:?})", stream.peer_addr());
let d = data.clone();
task::spawn(async move { cb(stream, d) });
}
info!("Terminating rpc accept loop...");
});
Ok(_self)
}
fn spawn_incoming(self: &Arc<Self>) {
let _self = Arc::clone(self);
task::spawn(async move {
let mut sock = _self.stream.clone().unwrap();
while _self.running.load(Ordering::Relaxed) {
let msg = match io::recv(&mut sock).await {
Ok(msg) => msg,
Err(e) => {
task::sleep(std::time::Duration::from_millis(10)).await;
error!("Failed reading message: {}", e.to_string());
continue;
}
};
let id = msg.id;
let mut wfm = _self.wfm.lock().await;
match wfm.get(&id) {
Some(sender) => sender.send(msg).await.unwrap(),
None => _self.inc_io.0.send(msg).await.unwrap(),
}
wfm.remove(&id);
}
});
}
pub async fn reply(self: &Arc<Self>, msg: Message) -> RpcResult<()> {
let mut s = self.stream.clone().unwrap();
io::send(&mut s, msg).await
}
pub async fn send<T, F>(self: &Arc<Self>, msg: Message, convert: F) -> RpcResult<T>
where
F: Fn(Message) -> RpcResult<T>,
{
let id = msg.id;
let (tx, rx) = bounded(1);
self.wfm.lock().await.insert(id, tx);
let mut s = self.stream.clone().unwrap();
io::send(&mut s, msg).await?;
future::timeout(self.timeout, async move {
match rx.recv().await {
Ok(msg) => convert(msg),
Err(_) => Err(RpcError::ConnectionFault(
"No message with matching ID received!".into(),
)),
}
})
.await?
}
pub fn shutdown(self: &Arc<Self>) {
self.running.swap(false, Ordering::Relaxed);
if let Some(ref s) = self.stream {
s.shutdown(Shutdown::Both).unwrap();
}
}
pub fn running(&self) -> bool {
self.running.load(Ordering::Relaxed)
}
pub fn listening(&self) -> bool {
self.listening.load(Ordering::Relaxed)
}
}