dimas_com/zenoh/
queryable.rs1#[doc(hidden)]
6extern crate alloc;
7
8#[cfg(feature = "std")]
9extern crate std;
10
11use alloc::sync::Arc;
13use alloc::{boxed::Box, string::String};
14use core::fmt::Debug;
15use dimas_core::{
16 Result,
17 enums::{OperationState, TaskSignal},
18 message_types::QueryMsg,
19 traits::{Capability, Context},
20};
21use futures::future::BoxFuture;
22#[cfg(feature = "std")]
23use tokio::{sync::Mutex, task::JoinHandle};
24use tracing::{Level, error, info, instrument, warn};
25use zenoh::Session;
26#[cfg(feature = "unstable")]
27use zenoh::sample::Locality;
28pub type GetCallback<P> =
33 Box<dyn FnMut(Context<P>, QueryMsg) -> BoxFuture<'static, Result<()>> + Send + Sync>;
34pub type ArcGetCallback<P> = Arc<Mutex<GetCallback<P>>>;
36pub struct Queryable<P>
41where
42 P: Send + Sync + 'static,
43{
44 session: Arc<Session>,
46 selector: String,
47 context: Context<P>,
49 activation_state: OperationState,
50 callback: ArcGetCallback<P>,
51 completeness: bool,
52 #[cfg(feature = "unstable")]
53 allowed_origin: Locality,
54 handle: std::sync::Mutex<Option<JoinHandle<()>>>,
55}
56
57impl<P> Debug for Queryable<P>
58where
59 P: Send + Sync + 'static,
60{
61 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
62 f.debug_struct("Queryable")
63 .field("selector", &self.selector)
64 .field("complete", &self.completeness)
65 .finish_non_exhaustive()
66 }
67}
68
69impl<P> crate::traits::Responder for Queryable<P>
70where
71 P: Send + Sync + 'static,
72{
73 fn selector(&self) -> &str {
75 &self.selector
76 }
77}
78
79impl<P> Capability for Queryable<P>
80where
81 P: Send + Sync + 'static,
82{
83 fn manage_operation_state(&self, state: &OperationState) -> Result<()> {
84 if state >= &self.activation_state {
85 self.start()
86 } else if state < &self.activation_state {
87 self.stop()
88 } else {
89 Ok(())
90 }
91 }
92}
93
94impl<P> Queryable<P>
95where
96 P: Send + Sync + 'static,
97{
98 #[must_use]
100 pub fn new(
101 session: Arc<Session>,
102 selector: String,
103 context: Context<P>,
104 activation_state: OperationState,
105 request_callback: ArcGetCallback<P>,
106 completeness: bool,
107 #[cfg(feature = "unstable")] allowed_origin: Locality,
108 ) -> Self {
109 Self {
110 session,
111 selector,
112 context,
113 activation_state,
114 callback: request_callback,
115 completeness,
116 #[cfg(feature = "unstable")]
117 allowed_origin,
118 handle: std::sync::Mutex::new(None),
119 }
120 }
121
122 #[instrument(level = Level::TRACE, skip_all)]
125 fn start(&self) -> Result<()> {
126 self.stop()?;
127
128 let completeness = self.completeness;
129 #[cfg(feature = "unstable")]
130 let allowed_origin = self.allowed_origin;
131 let selector = self.selector.clone();
132 let cb = self.callback.clone();
133 let ctx1 = self.context.clone();
134 let ctx2 = self.context.clone();
135 let session = self.session.clone();
136
137 self.handle.lock().map_or_else(
138 |_| todo!(),
139 |mut handle| {
140 handle.replace(tokio::task::spawn(async move {
141 let key = selector.clone();
142 std::panic::set_hook(Box::new(move |reason| {
143 error!("queryable panic: {}", reason);
144 if let Err(reason) = ctx1
145 .sender()
146 .blocking_send(TaskSignal::RestartQueryable(key.clone()))
147 {
148 error!("could not restart queryable: {}", reason);
149 } else {
150 info!("restarting queryable!");
151 }
152 }));
153 if let Err(error) = run_queryable(
154 session,
155 selector,
156 cb,
157 completeness,
158 #[cfg(feature = "unstable")]
159 allowed_origin,
160 ctx2,
161 )
162 .await
163 {
164 error!("queryable failed with {error}");
165 }
166 }));
167 Ok(())
168 },
169 )
170 }
171
172 #[instrument(level = Level::TRACE)]
174 fn stop(&self) -> Result<()> {
175 self.handle.lock().map_or_else(
176 |_| todo!(),
177 |mut handle| {
178 handle.take();
179 Ok(())
180 },
181 )
182 }
183}
184
185#[instrument(name="queryable", level = Level::ERROR, skip_all)]
186async fn run_queryable<P>(
187 session: Arc<Session>,
188 selector: String,
189 callback: ArcGetCallback<P>,
190 completeness: bool,
191 #[cfg(feature = "unstable")] allowed_origin: Locality,
192 ctx: Context<P>,
193) -> Result<()>
194where
195 P: Send + Sync + 'static,
196{
197 let builder = session
198 .declare_queryable(&selector)
199 .complete(completeness);
200 #[cfg(feature = "unstable")]
201 let builder = builder.allowed_origin(allowed_origin);
202
203 let queryable = builder.await?;
204
205 loop {
206 let query = queryable.recv_async().await?;
207 let request = QueryMsg(query);
208
209 let ctx = ctx.clone();
210 let mut lock = callback.lock().await;
211 if let Err(error) = lock(ctx, request).await {
212 error!("queryable callback failed with {error}");
213 }
214 }
215}
216#[cfg(test)]
219mod tests {
220 use super::*;
221
222 #[derive(Debug)]
223 struct Props {}
224
225 const fn is_normal<T: Sized + Send + Sync>() {}
227
228 #[test]
229 const fn normal_types() {
230 is_normal::<Queryable<Props>>();
231 }
232}