#[doc(hidden)]
extern crate alloc;
#[cfg(feature = "std")]
extern crate std;
use crate::error::Error;
use alloc::sync::Arc;
use core::{fmt::Debug, time::Duration};
use dimas_core::{
Result,
enums::OperationState,
message_types::{Message, QueryableMsg},
traits::{Capability, Context},
};
use futures::future::BoxFuture;
#[cfg(feature = "std")]
use std::{
boxed::Box,
string::{String, ToString},
vec::Vec,
};
#[cfg(feature = "std")]
use tokio::sync::Mutex;
use tracing::{Level, error, instrument, warn};
#[cfg(feature = "unstable")]
use zenoh::sample::Locality;
use zenoh::{
Session, Wait,
query::{ConsolidationMode, QueryTarget},
sample::SampleKind,
};
pub type GetCallback<P> =
Box<dyn FnMut(Context<P>, QueryableMsg) -> BoxFuture<'static, Result<()>> + Send + Sync>;
pub type ArcGetCallback<P> = Arc<Mutex<GetCallback<P>>>;
pub struct Querier<P>
where
P: Send + Sync + 'static,
{
session: Arc<Session>,
selector: String,
context: Context<P>,
activation_state: OperationState,
callback: ArcGetCallback<P>,
mode: ConsolidationMode,
#[cfg(feature = "unstable")]
allowed_destination: Locality,
encoding: String,
target: QueryTarget,
timeout: Duration,
key_expr: std::sync::Mutex<Option<zenoh::key_expr::KeyExpr<'static>>>,
}
impl<P> Debug for Querier<P>
where
P: Send + Sync + 'static,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
#[cfg(feature = "unstable")]
let res = f
.debug_struct("Querier")
.field("selector", &self.selector)
.field("mode", &self.mode)
.field("allowed_destination", &self.allowed_destination)
.finish_non_exhaustive();
#[cfg(not(feature = "unstable"))]
let res = f
.debug_struct("Querier")
.field("selector", &self.selector)
.field("mode", &self.mode)
.finish_non_exhaustive();
res
}
}
impl<P> crate::traits::Querier for Querier<P>
where
P: Send + Sync + 'static,
{
fn selector(&self) -> &str {
&self.selector
}
#[instrument(name="Querier", level = Level::ERROR, skip_all)]
fn get(
&self,
message: Option<Message>,
mut callback: Option<&mut dyn FnMut(QueryableMsg) -> Result<()>>,
) -> Result<()> {
let cb = self.callback.clone();
self.key_expr.lock().map_or_else(
|_| todo!(),
|key_expr| {
let key_expr = key_expr
.clone()
.ok_or_else(|| Error::InvalidSelector("querier".into()))?;
let builder = message
.map_or_else(
|| self.session.get(&key_expr),
|msg| {
self.session
.get(&self.selector)
.payload(msg.value())
},
)
.encoding(self.encoding.as_str())
.target(self.target)
.consolidation(self.mode)
.timeout(self.timeout);
#[cfg(feature = "unstable")]
let builder = builder.allowed_destination(self.allowed_destination);
let query = builder
.wait()
.map_err(|source| Error::QueryCreation { source })?;
let mut unreached = true;
let mut retry_count = 0u8;
while unreached && retry_count <= 5 {
retry_count += 1;
while let Ok(reply) = query.recv() {
match reply.result() {
Ok(sample) => match sample.kind() {
SampleKind::Put => {
let content: Vec<u8> = sample.payload().to_bytes().into_owned();
let msg = QueryableMsg(content);
if callback.is_none() {
let cb = cb.clone();
let ctx = self.context.clone();
tokio::task::spawn(async move {
let mut lock = cb.lock().await;
if let Err(error) = lock(ctx, msg).await {
error!("querier callback failed with {error}");
}
});
} else {
let callback = callback.as_mut().ok_or_else(|| {
Error::AccessingQuerier {
selector: key_expr.to_string(),
}
})?;
callback(msg)
.map_err(|source| Error::QueryCallback { source })?;
}
}
SampleKind::Delete => {
error!("Delete in Querier");
}
},
Err(err) => error!("receive error: {:?})", err),
}
unreached = false;
}
if unreached {
if retry_count < 5 {
std::thread::sleep(self.timeout);
} else {
return Err(Error::AccessingQueryable {
selector: key_expr.to_string(),
}
.into());
}
}
}
Ok(())
},
)
}
}
impl<P> Capability for Querier<P>
where
P: Send + Sync + 'static,
{
fn manage_operation_state(&self, state: &OperationState) -> Result<()> {
if state >= &self.activation_state {
return self.init();
} else if state < &self.activation_state {
return self.de_init();
}
Ok(())
}
}
impl<P> Querier<P>
where
P: Send + Sync + 'static,
{
#[must_use]
#[allow(clippy::too_many_arguments)]
pub fn new(
session: Arc<Session>,
selector: String,
context: Context<P>,
activation_state: OperationState,
response_callback: ArcGetCallback<P>,
mode: ConsolidationMode,
#[cfg(feature = "unstable")] allowed_destination: Locality,
encoding: String,
target: QueryTarget,
timeout: Duration,
) -> Self {
Self {
session,
selector,
context,
activation_state,
callback: response_callback,
mode,
#[cfg(feature = "unstable")]
allowed_destination,
encoding,
target,
timeout,
key_expr: std::sync::Mutex::new(None),
}
}
fn init(&self) -> Result<()>
where
P: Send + Sync + 'static,
{
self.de_init()?;
self.key_expr.lock().map_or_else(
|_| todo!(),
|mut key_expr| {
let new_key_expr = self
.session
.declare_keyexpr(self.selector.clone())
.wait()?;
key_expr.replace(new_key_expr);
Ok(())
},
)
}
#[allow(clippy::unnecessary_wraps)]
fn de_init(&self) -> Result<()>
where
P: Send + Sync + 'static,
{
self.key_expr.lock().map_or_else(
|_| todo!(),
|mut key_expr| {
key_expr.take();
Ok(())
},
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug)]
struct Props {}
const fn is_normal<T: Sized + Send + Sync>() {}
#[test]
const fn normal_types() {
is_normal::<Querier<Props>>();
}
}