use std::collections::HashSet;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::Mutex;
use zenoh::config::ZenohId;
use zenoh::key_expr::KeyExpr;
use zenoh::liveliness::LivelinessToken;
use zenoh::Session;
use crate::serialize::serialize;
use crate::service::Service;
use crate::status::{Code, Status};
use crate::types::{Message, ServerMetadata, ServerTaskFuture, WireMessage};
pub struct ServerBuilder {
pub(crate) session: Session,
pub(crate) services: HashMap<String, Arc<dyn Service + Send + Sync>>,
pub(crate) labels: HashSet<String>,
}
impl ServerBuilder {
pub fn session(mut self, session: Session) -> Self {
self.session = session;
self
}
pub fn add_service(mut self, svc: Arc<dyn Service + Send + Sync>) -> Self {
self.services.insert(svc.name(), svc);
self
}
pub fn services(mut self, services: HashMap<String, Arc<dyn Service + Send + Sync>>) -> Self {
self.services = services;
self
}
pub fn add_label<IntoString>(mut self, label: IntoString) -> Self
where
IntoString: Into<String>,
{
self.labels.insert(label.into());
self
}
pub fn labels<IterIntoString, IntoString>(mut self, labels: IterIntoString) -> Self
where
IntoString: Into<String>,
IterIntoString: Iterator<Item = IntoString>,
{
self.labels.extend(labels.map(|e| e.into()));
self
}
pub fn build(self) -> Server {
Server {
session: self.session,
services: self.services,
tokens: Arc::new(Mutex::new(Vec::new())),
labels: self.labels,
}
}
}
pub struct Server {
pub(crate) session: Session,
pub(crate) services: HashMap<String, Arc<dyn Service + Send + Sync>>,
pub(crate) tokens: Arc<Mutex<Vec<LivelinessToken>>>,
pub(crate) labels: HashSet<String>,
}
impl Server {
pub fn builder(session: Session) -> ServerBuilder {
ServerBuilder {
session,
services: HashMap::new(),
labels: HashSet::new(),
}
}
pub fn instance_uuid(&self) -> ZenohId {
self.session.zid()
}
pub async fn serve(&self) -> Result<(), Status> {
let mut tokens = vec![];
let ke = format!("@rpc/{}/**", self.instance_uuid());
let queryable = self.session.declare_queryable(&ke).await.map_err(|e| {
Status::new(
Code::InternalError,
format!("Cannot declare queryable: {e:?}"),
)
})?;
tracing::debug!("[Server] declared queryabled on: {ke}");
for k in self.services.keys() {
let ke = format!("@rpc/{}/service/{k}", self.instance_uuid());
let lt = self
.session
.liveliness()
.declare_token(ke)
.await
.map_err(|e| {
Status::new(
Code::InternalError,
format!("Cannot declare liveliness token: {e:?}"),
)
})?;
tokens.push(lt)
}
self.tokens.lock().await.extend(tokens);
loop {
let query = queryable
.recv_async()
.await
.map_err(|e| Status::internal_error(format!("Cannot receive query: {e:?}")))?;
let ke = query.key_expr().clone();
tracing::debug!("[Server] received query on: {ke}");
let fut: ServerTaskFuture = match Self::get_token(&ke, 2) {
Some("service") => {
tracing::debug!("[Server] call to service");
let service_name = Self::get_service_name(&ke)?;
let svc = self
.services
.get(service_name)
.ok_or_else(|| {
Status::internal_error(format!("Service not found: {service_name}"))
})?
.clone();
let payload = query
.payload()
.ok_or_else(|| {
Status::internal_error("Query has empty value cannot proceed")
})?
.to_bytes()
.to_vec();
Box::pin(Self::service_call(svc, ke.clone(), payload))
}
Some("metadata") => {
tracing::debug!("[Server] call to metadata");
Box::pin(Self::server_metadata(
self.labels.clone(),
self.instance_uuid(),
))
}
Some(_) | None => {
tracing::error!("[Server] unknown call");
Box::pin(Self::create_error())
}
};
tokio::task::spawn(async move {
let res = fut.await;
let sample = match res {
Ok(data) => data,
Err(e) => {
let wmgs = WireMessage {
payload: None,
status: e,
};
serialize(&wmgs).unwrap_or_default()
}
};
let res = query.reply(ke, sample).await;
tracing::debug!("Query Result is: {res:?}");
});
}
}
async fn service_call(
svc: Arc<dyn Service + Send + Sync>,
ke: KeyExpr<'_>,
payload: Vec<u8>,
) -> Result<Vec<u8>, Status> {
let method_name = Self::get_method_name(&ke)?;
let msg = Message {
method: method_name.into(),
body: payload,
metadata: HashMap::new(),
status: Status::new(Code::Accepted, ""),
};
match svc.call(msg).await {
Ok(msg) => {
tracing::debug!("Service response: {msg:?}");
let wmsg = WireMessage {
payload: Some(msg.body),
status: Status::ok(""),
};
serialize(&wmsg)
.map_err(|e| Status::internal_error(format!("Serialization error: {e:?}")))
}
Err(e) => {
tracing::error!("Service error is : {e:?}");
let wmsg = WireMessage {
payload: None,
status: e,
};
serialize(&wmsg)
.map_err(|e| Status::internal_error(format!("Serialization error: {e:?}")))
}
}
}
async fn server_metadata(labels: HashSet<String>, id: ZenohId) -> Result<Vec<u8>, Status> {
let metadata = ServerMetadata { labels, id };
let serialized_metadata = serialize(&metadata)
.map_err(|e| Status::internal_error(format!("Serialization error: {e:?}")))?;
let wmsg = WireMessage {
payload: Some(serialized_metadata),
status: Status::ok(""),
};
serialize(&wmsg).map_err(|e| Status::internal_error(format!("Serialization error: {e:?}")))
}
async fn create_error() -> Result<Vec<u8>, Status> {
Err(Status::unavailable("Unavailable"))
}
fn get_service_name<'a>(ke: &'a KeyExpr) -> Result<&'a str, Status> {
Self::get_token(ke, 3).ok_or(Status::internal_error("Cannot get service name"))
}
fn get_method_name<'a>(ke: &'a KeyExpr) -> Result<&'a str, Status> {
Self::get_token(ke, 4).ok_or(Status::internal_error("Cannot get method name"))
}
fn get_token<'a>(ke: &'a KeyExpr, index: usize) -> Option<&'a str> {
let tokens: Vec<_> = ke.split('/').collect();
tokens.get(index).copied()
}
}