use crate::App;
use anyhow::Result;
use async_std::task;
use futures::channel::mpsc;
use futures::{Future, FutureExt};
use potatonet_client::{Client, SubscribeId};
use potatonet_common::{Context, LocalServiceId, Request, Response, ServiceId, Topic};
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::sync::Arc;
pub struct NodeContext<'a> {
pub(crate) client: Arc<Client>,
pub(crate) app: Arc<App>,
pub(crate) from: Option<ServiceId>,
pub(crate) service_name: &'a str,
pub(crate) local_service_id: LocalServiceId,
pub(crate) tx_abort: mpsc::Sender<()>,
}
impl<'a> NodeContext<'a> {
pub fn from(&self) -> Option<ServiceId> {
self.from
}
pub fn service_name(&self) -> &str {
self.service_name
}
pub fn service_id(&self) -> ServiceId {
self.local_service_id.to_global(self.client.node_id())
}
pub fn shutdown_node(&self) {
self.tx_abort.clone().try_send(()).ok();
}
}
impl<'a> NodeContext<'a> {
pub async fn subscribe_with_topic<T, F, R>(&self, topic: &str, mut handler: F) -> SubscribeId
where
T: Topic,
F: FnMut(&NodeContext<'_>, T) -> R + Send + 'static,
R: Future<Output = ()> + Send + 'static,
{
let client = Arc::downgrade(&self.client);
let app = Arc::downgrade(&self.app);
let service_name = self.service_name.to_string();
let local_service_id = self.local_service_id;
let tx_abort = self.tx_abort.clone();
self.client
.subscribe_with_topic(topic, move |msg| {
let client = client.clone();
let app = app.clone();
let service_name = service_name.clone();
let tx_abort = tx_abort.clone();
if let (Some(client), Some(app)) = (client.upgrade(), app.upgrade()) {
let ctx = NodeContext {
client: client.clone(),
app: app.clone(),
from: None,
service_name: &service_name,
local_service_id,
tx_abort,
};
handler(&ctx, msg).boxed()
} else {
Box::pin(futures::future::ready(()))
}
})
.await
}
pub async fn subscribe<T, F, R>(&self, handler: F) -> SubscribeId
where
T: Topic,
F: FnMut(&NodeContext<'_>, T) -> R + Send + 'static,
R: Future<Output = ()> + Send + 'static,
{
self.subscribe_with_topic(T::name(), handler).await
}
pub async fn unsubscribe(&self, id: SubscribeId) {
self.client.unsubscribe(id).await;
}
}
#[async_trait::async_trait]
impl<'a> Context for NodeContext<'a> {
async fn call<T, R>(&self, service_name: &str, request: Request<T>) -> Result<Response<R>>
where
T: Serialize + Send + 'static,
R: DeserializeOwned + Send + 'static,
{
trace!("call. service={} method={}", service_name, request.method);
if let Some(lid) = self.app.services_map.get(service_name).copied() {
let request_bytes = request.to_bytes();
if let Some((_, service)) = self.app.services.get(lid.0 as usize) {
let resp = service.call(self, request_bytes).await?;
return Ok(Response::from_bytes(resp));
}
}
self.client.call(service_name, request).await
}
async fn notify<T: Serialize + Send + 'static>(&self, service_name: &str, request: Request<T>) {
trace!("notify. service={} method={}", service_name, request.method);
let request_bytes = request.to_bytes();
if let Some(lid) = self.app.services_map.get(service_name).copied() {
let client = self.client.clone();
let app = self.app.clone();
let service_name = service_name.to_string();
let from = self.local_service_id.to_global(client.node_id());
let tx_abort = self.tx_abort.clone();
task::spawn(async move {
if let Some((_, service)) = app.services.get(lid.0 as usize) {
service
.notify(
&NodeContext {
client,
app: app.clone(),
from: Some(from),
service_name: &service_name,
local_service_id: lid,
tx_abort,
},
request_bytes,
)
.await;
}
});
}
self.client.notify(service_name, request).await;
}
async fn notify_to<T: Serialize + Send + 'static>(&self, to: ServiceId, request: Request<T>) {
trace!("notify to. to={} method={}", to, request.method);
let request_bytes = request.to_bytes();
if to.node_id == self.client.node_id() {
let client = self.client.clone();
let app = self.app.clone();
let from = self.local_service_id.to_global(self.client.node_id());
let service_name = self.service_name.to_string();
let tx_abort = self.tx_abort.clone();
task::spawn(async move {
if let Some((_, service)) = app.services.get(to.local_service_id.0 as usize) {
service.notify(
&NodeContext {
client,
app: app.clone(),
from: Some(from),
service_name: &service_name,
local_service_id: to.local_service_id,
tx_abort,
},
request_bytes,
);
}
});
return;
}
self.client.notify_to(to, request).await;
}
async fn publish_with_topic<T: Topic>(&self, topic: &str, msg: T) {
trace!("publish. topic={}", topic);
self.client.publish_with_topic(topic, msg).await;
}
}