#[doc(hidden)]
extern crate alloc;
#[cfg(feature = "std")]
extern crate std;
use alloc::sync::Arc;
use alloc::{boxed::Box, string::String};
use core::fmt::Debug;
use dimas_core::{
Result,
enums::{OperationState, TaskSignal},
message_types::QueryMsg,
traits::{Capability, Context},
};
use futures::future::BoxFuture;
#[cfg(feature = "std")]
use tokio::{sync::Mutex, task::JoinHandle};
use tracing::{Level, error, info, instrument, warn};
use zenoh::Session;
#[cfg(feature = "unstable")]
use zenoh::sample::Locality;
pub type GetCallback<P> =
Box<dyn FnMut(Context<P>, QueryMsg) -> BoxFuture<'static, Result<()>> + Send + Sync>;
pub type ArcGetCallback<P> = Arc<Mutex<GetCallback<P>>>;
pub struct Queryable<P>
where
P: Send + Sync + 'static,
{
session: Arc<Session>,
selector: String,
context: Context<P>,
activation_state: OperationState,
callback: ArcGetCallback<P>,
completeness: bool,
#[cfg(feature = "unstable")]
allowed_origin: Locality,
handle: std::sync::Mutex<Option<JoinHandle<()>>>,
}
impl<P> Debug for Queryable<P>
where
P: Send + Sync + 'static,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Queryable")
.field("selector", &self.selector)
.field("complete", &self.completeness)
.finish_non_exhaustive()
}
}
impl<P> crate::traits::Responder for Queryable<P>
where
P: Send + Sync + 'static,
{
fn selector(&self) -> &str {
&self.selector
}
}
impl<P> Capability for Queryable<P>
where
P: Send + Sync + 'static,
{
fn manage_operation_state(&self, state: &OperationState) -> Result<()> {
if state >= &self.activation_state {
self.start()
} else if state < &self.activation_state {
self.stop()
} else {
Ok(())
}
}
}
impl<P> Queryable<P>
where
P: Send + Sync + 'static,
{
#[must_use]
pub fn new(
session: Arc<Session>,
selector: String,
context: Context<P>,
activation_state: OperationState,
request_callback: ArcGetCallback<P>,
completeness: bool,
#[cfg(feature = "unstable")] allowed_origin: Locality,
) -> Self {
Self {
session,
selector,
context,
activation_state,
callback: request_callback,
completeness,
#[cfg(feature = "unstable")]
allowed_origin,
handle: std::sync::Mutex::new(None),
}
}
#[instrument(level = Level::TRACE, skip_all)]
fn start(&self) -> Result<()> {
self.stop()?;
let completeness = self.completeness;
#[cfg(feature = "unstable")]
let allowed_origin = self.allowed_origin;
let selector = self.selector.clone();
let cb = self.callback.clone();
let ctx1 = self.context.clone();
let ctx2 = self.context.clone();
let session = self.session.clone();
self.handle.lock().map_or_else(
|_| todo!(),
|mut handle| {
handle.replace(tokio::task::spawn(async move {
let key = selector.clone();
std::panic::set_hook(Box::new(move |reason| {
error!("queryable panic: {}", reason);
if let Err(reason) = ctx1
.sender()
.blocking_send(TaskSignal::RestartQueryable(key.clone()))
{
error!("could not restart queryable: {}", reason);
} else {
info!("restarting queryable!");
}
}));
if let Err(error) = run_queryable(
session,
selector,
cb,
completeness,
#[cfg(feature = "unstable")]
allowed_origin,
ctx2,
)
.await
{
error!("queryable failed with {error}");
}
}));
Ok(())
},
)
}
#[instrument(level = Level::TRACE)]
fn stop(&self) -> Result<()> {
self.handle.lock().map_or_else(
|_| todo!(),
|mut handle| {
handle.take();
Ok(())
},
)
}
}
#[instrument(name="queryable", level = Level::ERROR, skip_all)]
async fn run_queryable<P>(
session: Arc<Session>,
selector: String,
callback: ArcGetCallback<P>,
completeness: bool,
#[cfg(feature = "unstable")] allowed_origin: Locality,
ctx: Context<P>,
) -> Result<()>
where
P: Send + Sync + 'static,
{
let builder = session
.declare_queryable(&selector)
.complete(completeness);
#[cfg(feature = "unstable")]
let builder = builder.allowed_origin(allowed_origin);
let queryable = builder.await?;
loop {
let query = queryable.recv_async().await?;
let request = QueryMsg(query);
let ctx = ctx.clone();
let mut lock = callback.lock().await;
if let Err(error) = lock(ctx, request).await {
error!("queryable callback failed with {error}");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug)]
struct Props {}
const fn is_normal<T: Sized + Send + Sync>() {}
#[test]
const fn normal_types() {
is_normal::<Queryable<Props>>();
}
}