#[macro_use]
extern crate anyhow;
#[macro_use]
extern crate log;
use async_std::net::{TcpStream, ToSocketAddrs};
use async_std::stream;
use async_std::task;
use futures::channel::{mpsc, oneshot};
use futures::lock::Mutex;
use futures::prelude::*;
use futures::select;
use futures::{SinkExt, StreamExt};
use potatonet_common::bus_message::Message;
use potatonet_common::{
bus_message, Context, Error, LocalServiceId, NodeId, Request, Response, ResponseBytes, Result,
ServiceId, Topic,
};
use serde::de::DeserializeOwned;
use serde::export::PhantomData;
use serde::Serialize;
use slab::Slab;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
#[async_trait::async_trait]
trait SubCallback: Send {
async fn notify(&mut self, data: &[u8]);
}
struct SubCallbackAdapter<
T: Topic,
F: FnMut(T) -> R + Send + 'static,
R: Future<Output = ()> + Send + 'static,
> {
_market: PhantomData<T>,
f: F,
}
#[async_trait::async_trait]
impl<T, F, R> SubCallback for SubCallbackAdapter<T, F, R>
where
T: Topic,
F: FnMut(T) -> R + Send + 'static,
R: Future<Output = ()> + Send + 'static,
{
async fn notify(&mut self, data: &[u8]) {
if let Ok(msg) = T::decode(&data) {
(self.f)(msg).await;
}
}
}
#[derive(Hash, Eq, PartialEq, Copy, Clone)]
pub struct SubscribeId(usize);
#[derive(Default)]
struct Inner {
pending: HashMap<u32, oneshot::Sender<std::result::Result<ResponseBytes, String>>>,
seq: u32,
subscribes_set: HashMap<String, usize>,
subscribes: Slab<(String, Box<dyn SubCallback>)>,
}
pub struct Client {
node_id: NodeId,
tx: mpsc::Sender<bus_message::Message>,
tx_abort: mpsc::Sender<()>,
inner: Arc<Mutex<Inner>>,
}
impl Drop for Client {
fn drop(&mut self) {
self.tx_abort.try_send(()).ok();
info!("client closed");
}
}
impl Client {
async fn process_incoming_msg(inner: Arc<Mutex<Inner>>, msg: &bus_message::Message) {
match msg {
bus_message::Message::Rep { seq, result } => {
let mut inner = inner.lock().await;
if let Some(tx) = inner.pending.remove(&seq) {
tx.send(result.clone().map(|data| Response::new(data))).ok();
}
}
bus_message::Message::XPublish { topic, data } => {
let mut inner = inner.lock().await;
for (_, (sub_topic, f)) in &mut inner.subscribes {
if topic == sub_topic.as_str() {
f.notify(&data).await;
}
}
}
_ => {}
}
}
#[doc(hidden)]
pub async fn connect_with_notify<A, F, R>(addr: A, mut handle_msg: F) -> Result<Client>
where
A: ToSocketAddrs,
F: FnMut(bus_message::Message) -> R + Send + Sync + 'static,
R: Future<Output = ()> + Send + 'static,
{
let addr = addr.to_socket_addrs().await?.next();
let stream = match addr {
Some(addr) => {
info!("connect to bus. addr={}", addr);
Arc::new(TcpStream::connect(addr).await?)
}
None => bail!("could not resolve to any addresses"),
};
let node_id = match potatonet_common::bus_message::read_one_message(&*stream).await {
Ok(bus_message::Message::Hello(node_id)) => node_id,
Ok(msg) => {
println!("{:?}", msg);
bail!("invalid response")
}
Err(err) => return Err(err),
};
let inner: Arc<Mutex<Inner>> = Default::default();
let (tx_incoming_msg, mut rx_incoming_msg) = mpsc::channel(16);
let (tx_outgoing_msg, rx_outgoing_msg) = mpsc::channel(16);
let (reader_task, abort_reader) =
future::abortable(bus_message::read_messages(stream.clone(), tx_incoming_msg));
let reader_handle = task::spawn(reader_task);
let (writer_task, abort_writer) =
future::abortable(bus_message::write_messages(stream.clone(), rx_outgoing_msg));
let writer_handle = task::spawn(writer_task);
let (tx_abort, mut rx_abort) = mpsc::channel::<()>(1);
let fut = {
let inner = inner.clone();
let mut tx_outgoing_msg = tx_outgoing_msg.clone();
async move {
let mut hb_timer = stream::interval(Duration::from_secs(1)).fuse();
loop {
select! {
_ = rx_abort.next() => {
break;
}
_ = hb_timer.next() => {
if let Err(_) = tx_outgoing_msg.send(bus_message::Message::Ping).await {
break;
}
}
msg = rx_incoming_msg.next() => {
if let Some(msg) = msg {
Self::process_incoming_msg(inner.clone(), &msg).await;
handle_msg(msg).await;
} else {
break;
}
}
}
}
tx_outgoing_msg.send(bus_message::Message::Bye).await.ok();
abort_reader.abort();
abort_writer.abort();
reader_handle.await.ok();
writer_handle.await.ok();
info!("client closed");
}
};
task::spawn(fut);
Ok(Client {
node_id,
tx: tx_outgoing_msg,
tx_abort,
inner,
})
}
pub async fn connect<A: ToSocketAddrs>(addr: A) -> Result<Client> {
Self::connect_with_notify(addr, |_| async move {}).await
}
pub fn node_id(&self) -> NodeId {
self.node_id
}
pub async fn register_service<N: Into<String>>(&self, name: N, id: LocalServiceId) {
self.tx
.clone()
.send(bus_message::Message::RegisterService {
name: name.into(),
id,
})
.await
.ok();
}
pub async fn unregister_service(&self, id: LocalServiceId) {
self.tx
.clone()
.send(bus_message::Message::UnregisterService { id })
.await
.ok();
}
#[doc(hidden)]
pub async fn send_msg(&self, msg: Message) {
self.tx.clone().send(msg).await.ok();
}
pub async fn subscribe<T, F, R>(&self, handler: F) -> SubscribeId
where
T: Topic,
F: FnMut(T) -> R + Send + 'static,
R: Future<Output = ()> + Send + 'static,
{
self.subscribe_with_topic(T::name(), handler).await
}
pub async fn subscribe_with_topic<T, F, R>(&self, topic: &str, handler: F) -> SubscribeId
where
T: Topic,
F: FnMut(T) -> R + Send + 'static,
R: Future<Output = ()> + Send + 'static,
{
let mut inner = self.inner.lock().await;
let id = inner.subscribes.insert((
topic.to_string(),
Box::new(SubCallbackAdapter {
_market: PhantomData,
f: handler,
}),
));
if let Some(count) = inner.subscribes_set.get_mut(topic) {
*count += 1;
} else {
inner.subscribes_set.insert(topic.to_string(), 1);
self.send_msg(bus_message::Message::Subscribe {
topic: topic.to_string(),
})
.await;
}
SubscribeId(id)
}
pub async fn unsubscribe(&self, id: SubscribeId) {
let mut inner = self.inner.lock().await;
if inner.subscribes.contains(id.0 as usize) {
let (topic, _) = inner.subscribes.remove(id.0 as usize);
self.send_msg(bus_message::Message::Unsubscribe {
topic: topic.to_string(),
})
.await;
if let Some(count) = inner.subscribes_set.get_mut(&topic) {
*count -= 1;
if *count == 0 {
inner.subscribes_set.remove(&topic);
}
}
}
}
}
#[async_trait::async_trait]
impl Context for Client {
async fn call<T, R>(&self, service_name: &str, request: Request<T>) -> Result<Response<R>>
where
T: Serialize + Send + 'static,
R: DeserializeOwned + Send + 'static,
{
let (seq, rx) = {
let mut inner = self.inner.lock().await;
let (tx, rx) = oneshot::channel();
inner.seq += 1;
let seq = inner.seq;
inner.pending.insert(seq, tx);
(seq, rx)
};
let request = request.to_bytes();
self.tx
.clone()
.send(bus_message::Message::Req {
seq,
from: LocalServiceId(0),
to_service: service_name.to_string(),
method: request.method,
data: request.data,
})
.await
.ok();
match async_std::future::timeout(Duration::from_secs(5), rx).await {
Ok(Ok(Ok(resp))) => Ok(Response::<R>::from_bytes(resp)),
Ok(Ok(Err(err))) => Err(anyhow!(err)),
Ok(Err(_)) => {
let mut inner = self.inner.lock().await;
inner.pending.remove(&seq);
Err(Error::Internal.into())
}
Err(_) => {
let mut inner = self.inner.lock().await;
inner.pending.remove(&seq);
Err(Error::Timeout.into())
}
}
}
async fn notify<T: Serialize + Send + 'static>(&self, service_name: &str, request: Request<T>) {
let request = request.to_bytes();
self.tx
.clone()
.send(bus_message::Message::Notify {
from: LocalServiceId(0),
to_service: service_name.to_string(),
method: request.method,
data: request.data,
})
.await
.ok();
}
async fn notify_to<T: Serialize + Send + 'static>(&self, to: ServiceId, request: Request<T>) {
let request = request.to_bytes();
self.tx
.clone()
.send(bus_message::Message::NotifyTo {
from: LocalServiceId(0),
to,
method: request.method,
data: request.data,
})
.await
.ok();
}
async fn publish_with_topic<T: Topic>(&self, topic: &str, msg: T) {
if let Ok(data) = msg.encode() {
self.tx
.clone()
.send(bus_message::Message::Publish {
topic: topic.to_string(),
data: data.into(),
})
.await
.ok();
}
}
}