1#[doc(hidden)]
6extern crate alloc;
7
8#[cfg(feature = "std")]
9extern crate std;
10
11use crate::error::Error;
13use alloc::sync::Arc;
14use core::{fmt::Debug, time::Duration};
15use dimas_core::{
16 Result,
17 enums::OperationState,
18 message_types::{Message, QueryableMsg},
19 traits::{Capability, Context},
20};
21use futures::future::BoxFuture;
22#[cfg(feature = "std")]
23use std::{
24 boxed::Box,
25 string::{String, ToString},
26 vec::Vec,
27};
28#[cfg(feature = "std")]
29use tokio::sync::Mutex;
30use tracing::{Level, error, instrument, warn};
31#[cfg(feature = "unstable")]
32use zenoh::sample::Locality;
33use zenoh::{
34 Session, Wait,
35 query::{ConsolidationMode, QueryTarget},
36 sample::SampleKind,
37};
38pub type GetCallback<P> =
43 Box<dyn FnMut(Context<P>, QueryableMsg) -> BoxFuture<'static, Result<()>> + Send + Sync>;
44pub type ArcGetCallback<P> = Arc<Mutex<GetCallback<P>>>;
46pub struct Querier<P>
51where
52 P: Send + Sync + 'static,
53{
54 session: Arc<Session>,
56 selector: String,
57 context: Context<P>,
59 activation_state: OperationState,
60 callback: ArcGetCallback<P>,
61 mode: ConsolidationMode,
62 #[cfg(feature = "unstable")]
63 allowed_destination: Locality,
64 encoding: String,
65 target: QueryTarget,
66 timeout: Duration,
67 key_expr: std::sync::Mutex<Option<zenoh::key_expr::KeyExpr<'static>>>,
68}
69
70impl<P> Debug for Querier<P>
71where
72 P: Send + Sync + 'static,
73{
74 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
75 #[cfg(feature = "unstable")]
76 let res = f
77 .debug_struct("Querier")
78 .field("selector", &self.selector)
79 .field("mode", &self.mode)
80 .field("allowed_destination", &self.allowed_destination)
81 .finish_non_exhaustive();
82 #[cfg(not(feature = "unstable"))]
83 let res = f
84 .debug_struct("Querier")
85 .field("selector", &self.selector)
86 .field("mode", &self.mode)
87 .finish_non_exhaustive();
88 res
89 }
90}
91
92impl<P> crate::traits::Querier for Querier<P>
93where
94 P: Send + Sync + 'static,
95{
96 fn selector(&self) -> &str {
98 &self.selector
99 }
100
101 #[instrument(name="Querier", level = Level::ERROR, skip_all)]
103 fn get(
104 &self,
105 message: Option<Message>,
106 mut callback: Option<&mut dyn FnMut(QueryableMsg) -> Result<()>>,
107 ) -> Result<()> {
108 let cb = self.callback.clone();
109 self.key_expr.lock().map_or_else(
110 |_| todo!(),
111 |key_expr| {
112 let key_expr = key_expr
113 .clone()
114 .ok_or_else(|| Error::InvalidSelector("querier".into()))?;
115
116 let builder = message
117 .map_or_else(
118 || self.session.get(&key_expr),
119 |msg| {
120 self.session
121 .get(&self.selector)
122 .payload(msg.value())
123 },
124 )
125 .encoding(self.encoding.as_str())
126 .target(self.target)
127 .consolidation(self.mode)
128 .timeout(self.timeout);
129
130 #[cfg(feature = "unstable")]
131 let builder = builder.allowed_destination(self.allowed_destination);
132
133 let query = builder
134 .wait()
135 .map_err(|source| Error::QueryCreation { source })?;
136
137 let mut unreached = true;
138 let mut retry_count = 0u8;
139
140 while unreached && retry_count <= 5 {
141 retry_count += 1;
142 while let Ok(reply) = query.recv() {
143 match reply.result() {
144 Ok(sample) => match sample.kind() {
145 SampleKind::Put => {
146 let content: Vec<u8> = sample.payload().to_bytes().into_owned();
147 let msg = QueryableMsg(content);
148 if callback.is_none() {
149 let cb = cb.clone();
150 let ctx = self.context.clone();
151 tokio::task::spawn(async move {
152 let mut lock = cb.lock().await;
153 if let Err(error) = lock(ctx, msg).await {
154 error!("querier callback failed with {error}");
155 }
156 });
157 } else {
158 let callback = callback.as_mut().ok_or_else(|| {
159 Error::AccessingQuerier {
160 selector: key_expr.to_string(),
161 }
162 })?;
163 callback(msg)
164 .map_err(|source| Error::QueryCallback { source })?;
165 }
166 }
167 SampleKind::Delete => {
168 error!("Delete in Querier");
169 }
170 },
171 Err(err) => error!("receive error: {:?})", err),
172 }
173 unreached = false;
174 }
175 if unreached {
176 if retry_count < 5 {
177 std::thread::sleep(self.timeout);
178 } else {
179 return Err(Error::AccessingQueryable {
180 selector: key_expr.to_string(),
181 }
182 .into());
183 }
184 }
185 }
186
187 Ok(())
188 },
189 )
190 }
191}
192
193impl<P> Capability for Querier<P>
194where
195 P: Send + Sync + 'static,
196{
197 fn manage_operation_state(&self, state: &OperationState) -> Result<()> {
198 if state >= &self.activation_state {
199 return self.init();
200 } else if state < &self.activation_state {
201 return self.de_init();
202 }
203 Ok(())
204 }
205}
206
207impl<P> Querier<P>
208where
209 P: Send + Sync + 'static,
210{
211 #[must_use]
213 #[allow(clippy::too_many_arguments)]
214 pub fn new(
215 session: Arc<Session>,
216 selector: String,
217 context: Context<P>,
218 activation_state: OperationState,
219 response_callback: ArcGetCallback<P>,
220 mode: ConsolidationMode,
221 #[cfg(feature = "unstable")] allowed_destination: Locality,
222 encoding: String,
223 target: QueryTarget,
224 timeout: Duration,
225 ) -> Self {
226 Self {
227 session,
228 selector,
229 context,
230 activation_state,
231 callback: response_callback,
232 mode,
233 #[cfg(feature = "unstable")]
234 allowed_destination,
235 encoding,
236 target,
237 timeout,
238 key_expr: std::sync::Mutex::new(None),
239 }
240 }
241
242 fn init(&self) -> Result<()>
245 where
246 P: Send + Sync + 'static,
247 {
248 self.de_init()?;
249
250 self.key_expr.lock().map_or_else(
251 |_| todo!(),
252 |mut key_expr| {
253 let new_key_expr = self
254 .session
255 .declare_keyexpr(self.selector.clone())
256 .wait()?;
257 key_expr.replace(new_key_expr);
258 Ok(())
259 },
260 )
261 }
262
263 #[allow(clippy::unnecessary_wraps)]
266 fn de_init(&self) -> Result<()>
267 where
268 P: Send + Sync + 'static,
269 {
270 self.key_expr.lock().map_or_else(
271 |_| todo!(),
272 |mut key_expr| {
273 key_expr.take();
274 Ok(())
275 },
276 )
277 }
278}
279#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[derive(Debug)]
286 struct Props {}
287
288 const fn is_normal<T: Sized + Send + Sync>() {}
290
291 #[test]
292 const fn normal_types() {
293 is_normal::<Querier<Props>>();
294 }
295}