use super::*;
use ::zenoh::query::{ConsolidationMode, Parameters, Querier, QueryConsolidation, QueryTarget};
use bevy_ecs::prelude::{In, Res};
use futures_lite::future::race;
use std::time::Duration;
use thiserror::Error as ThisError;
use tokio::sync::mpsc::unbounded_channel;
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct ZenohQuerierConfig {
pub key: Arc<str>,
pub encoder: ZenohEncodingConfig,
pub decoder: ZenohEncodingConfig,
#[serde(default, skip_serializing_if = "is_default")]
pub parameters: HashMap<String, String>,
#[serde(default, skip_serializing_if = "is_default")]
pub congestion_control: ZenohCongestionControlConfig,
#[serde(default, skip_serializing_if = "is_default")]
pub priority: ZenohPriorityConfig,
#[serde(default, skip_serializing_if = "is_default")]
pub express: bool,
#[serde(default, skip_serializing_if = "is_default")]
pub target: ZenohQueryTargetConfig,
#[serde(default, skip_serializing_if = "is_default")]
pub consolidation: ZenohQueryConsolidationModeConfig,
#[serde(default, skip_serializing_if = "is_default")]
pub locality: ZenohLocalityConfig,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub timeout: Option<Duration>,
#[serde(default, skip_serializing_if = "is_default")]
pub wait_for_matching: WaitForMatching,
}
#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum WaitForMatching {
Never,
Once,
#[default]
Always,
}
impl WaitForMatching {
pub fn once(&self) -> bool {
matches!(self, Self::Once)
}
pub fn always(&self) -> bool {
matches!(self, Self::Always)
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ZenohQueryConsolidationModeConfig {
#[default]
Auto,
None,
Monotonic,
Latest,
}
impl From<ZenohQueryConsolidationModeConfig> for QueryConsolidation {
fn from(value: ZenohQueryConsolidationModeConfig) -> Self {
let mode = match value {
ZenohQueryConsolidationModeConfig::Auto => ConsolidationMode::Auto,
ZenohQueryConsolidationModeConfig::None => ConsolidationMode::None,
ZenohQueryConsolidationModeConfig::Monotonic => ConsolidationMode::Monotonic,
ZenohQueryConsolidationModeConfig::Latest => ConsolidationMode::Latest,
};
mode.into()
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum ZenohQueryTargetConfig {
#[default]
BestMatching,
All,
AllComplete,
}
impl From<ZenohQueryTargetConfig> for QueryTarget {
fn from(value: ZenohQueryTargetConfig) -> Self {
match value {
ZenohQueryTargetConfig::BestMatching => QueryTarget::BestMatching,
ZenohQueryTargetConfig::All => QueryTarget::All,
ZenohQueryTargetConfig::AllComplete => QueryTarget::AllComplete,
}
}
}
#[derive(ThisError, Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
pub enum ZenohQuerierError {
#[error("the zenoh session was removed from its resource")]
SessionRemoved,
#[error("error while encoding message: {}", .0)]
EncodingError(String),
#[error("Failed to create the querier: {}", .0)]
CreationFailed(ArcError),
#[error("{}", .0)]
ZenohError(#[from] ArcError),
}
impl DiagramElementRegistry {
pub(super) fn register_zenoh_querier(&mut self, ensure_session: EnsureZenohSession) {
let create_querier = |In(config): In<ZenohQuerierConfig>, session: Res<ZenohSession>| {
let session_outcome = session.outcome.clone();
async move {
let session = session_outcome
.await
.map_err(|_| ZenohQuerierError::SessionRemoved)?
.map_err(ZenohQuerierError::ZenohError)?;
let querier = session
.declare_querier(config.key.to_string())
.congestion_control(config.congestion_control.into())
.priority(config.priority.into())
.express(config.express)
.target(config.target.into())
.consolidation(config.consolidation)
.allowed_destination(config.locality.into());
let querier = if let Some(timeout) = config.timeout {
querier.timeout(timeout)
} else {
querier
};
let querier = querier.await.map_err(ArcError::new)?;
if config.wait_for_matching.once() {
wait_for_matching(&querier).await?;
}
Ok::<_, ZenohQuerierError>(Arc::new(querier))
}
};
let create_querier = create_querier.into_async_callback();
self.register_node_builder_fallible(
NodeBuilderOptions::new("zenoh_querier").with_default_display_text("Zenoh Querier"),
move |builder, mut config: ZenohQuerierConfig| {
builder.commands().queue(ensure_session.clone());
let encoder: Codec = (&config.encoder).try_into()?;
let decoder: Codec = (&config.decoder).try_into()?;
let parameters = std::mem::replace(&mut config.parameters, Default::default());
let parameters: Arc<Parameters> = Arc::new(parameters.into());
let wait_choice = config.wait_for_matching;
let querier = builder
.commands()
.request(config, create_querier.clone())
.outcome()
.shared();
let node =
builder.create_map(move |input: AsyncMap<JsonMessage, ZenohNodeStreams>| {
let querier = querier.clone();
let parameters = Arc::clone(¶meters);
let encoder = encoder.clone();
let decoder = decoder.clone();
let (sender, mut cancellation_receiver) = unbounded_channel();
input.streams.canceller.send(sender);
async move {
let querying = async move {
let payload = encoder
.encode(&input.request)
.map_err(ZenohQuerierError::EncodingError)?;
let querier = querier.await.map_err(|err| {
ZenohQuerierError::CreationFailed(ArcError(err.into()))
})??;
if wait_choice.always() {
wait_for_matching(&querier).await?;
}
let replies = querier
.get()
.parameters(parameters.as_ref().clone())
.encoding(encoder.encoding())
.payload(payload)
.await
.map_err(ArcError::new)?;
while let Ok(reply) = replies.recv_async().await {
let next_sample = match reply.result() {
Ok(sample) => sample,
Err(err) => {
input.streams.out_error.send(format!("{err}"));
continue;
}
};
match decoder.decode(next_sample) {
Ok(msg) => {
input.streams.out.send(msg);
}
Err(msg) => {
input.streams.out_error.send(msg);
}
}
}
Ok::<_, ZenohQuerierError>(JsonMessage::default())
};
let cancel = cancellation_receiver.recv();
race(querying, receive_cancel(cancel)).await
}
});
Ok(node)
},
)
.with_result();
}
}
async fn wait_for_matching(querier: &Querier<'_>) -> Result<(), ZenohQuerierError> {
let listener = querier.matching_listener().await.map_err(ArcError::new)?;
loop {
let matching = listener
.recv_async()
.await
.map_err(ArcError::new)?
.matching();
if matching {
return Ok(());
}
}
}