dimas_com/zenoh/
queryable.rs

1// Copyright © 2023 Stephan Kunz
2
3//! Module `queryable` provides an information/compute provider `Queryable` which can be created using the `QueryableBuilder`.
4
5#[doc(hidden)]
6extern crate alloc;
7
8#[cfg(feature = "std")]
9extern crate std;
10
11// region:		--- modules
12use alloc::sync::Arc;
13use alloc::{boxed::Box, string::String};
14use core::fmt::Debug;
15use dimas_core::{
16	Result,
17	enums::{OperationState, TaskSignal},
18	message_types::QueryMsg,
19	traits::{Capability, Context},
20};
21use futures::future::BoxFuture;
22#[cfg(feature = "std")]
23use tokio::{sync::Mutex, task::JoinHandle};
24use tracing::{Level, error, info, instrument, warn};
25use zenoh::Session;
26#[cfg(feature = "unstable")]
27use zenoh::sample::Locality;
28// endregion:	--- modules
29
30// region:    	--- types
31/// type defnition for a queryables `request` callback
32pub type GetCallback<P> =
33	Box<dyn FnMut(Context<P>, QueryMsg) -> BoxFuture<'static, Result<()>> + Send + Sync>;
34/// type defnition for a queryables atomic reference counted `request` callback
35pub type ArcGetCallback<P> = Arc<Mutex<GetCallback<P>>>;
36// endregion: 	--- types
37
38// region:		--- Queryable
39/// Queryable
40pub struct Queryable<P>
41where
42	P: Send + Sync + 'static,
43{
44	/// the zenoh session this queryable belongs to
45	session: Arc<Session>,
46	selector: String,
47	/// Context for the Subscriber
48	context: Context<P>,
49	activation_state: OperationState,
50	callback: ArcGetCallback<P>,
51	completeness: bool,
52	#[cfg(feature = "unstable")]
53	allowed_origin: Locality,
54	handle: std::sync::Mutex<Option<JoinHandle<()>>>,
55}
56
57impl<P> Debug for Queryable<P>
58where
59	P: Send + Sync + 'static,
60{
61	fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
62		f.debug_struct("Queryable")
63			.field("selector", &self.selector)
64			.field("complete", &self.completeness)
65			.finish_non_exhaustive()
66	}
67}
68
69impl<P> crate::traits::Responder for Queryable<P>
70where
71	P: Send + Sync + 'static,
72{
73	/// Get `selector`
74	fn selector(&self) -> &str {
75		&self.selector
76	}
77}
78
79impl<P> Capability for Queryable<P>
80where
81	P: Send + Sync + 'static,
82{
83	fn manage_operation_state(&self, state: &OperationState) -> Result<()> {
84		if state >= &self.activation_state {
85			self.start()
86		} else if state < &self.activation_state {
87			self.stop()
88		} else {
89			Ok(())
90		}
91	}
92}
93
94impl<P> Queryable<P>
95where
96	P: Send + Sync + 'static,
97{
98	/// Constructor for a [`Queryable`]
99	#[must_use]
100	pub fn new(
101		session: Arc<Session>,
102		selector: String,
103		context: Context<P>,
104		activation_state: OperationState,
105		request_callback: ArcGetCallback<P>,
106		completeness: bool,
107		#[cfg(feature = "unstable")] allowed_origin: Locality,
108	) -> Self {
109		Self {
110			session,
111			selector,
112			context,
113			activation_state,
114			callback: request_callback,
115			completeness,
116			#[cfg(feature = "unstable")]
117			allowed_origin,
118			handle: std::sync::Mutex::new(None),
119		}
120	}
121
122	/// Start or restart the queryable.
123	/// An already running queryable will be stopped.
124	#[instrument(level = Level::TRACE, skip_all)]
125	fn start(&self) -> Result<()> {
126		self.stop()?;
127
128		let completeness = self.completeness;
129		#[cfg(feature = "unstable")]
130		let allowed_origin = self.allowed_origin;
131		let selector = self.selector.clone();
132		let cb = self.callback.clone();
133		let ctx1 = self.context.clone();
134		let ctx2 = self.context.clone();
135		let session = self.session.clone();
136
137		self.handle.lock().map_or_else(
138			|_| todo!(),
139			|mut handle| {
140				handle.replace(tokio::task::spawn(async move {
141					let key = selector.clone();
142					std::panic::set_hook(Box::new(move |reason| {
143						error!("queryable panic: {}", reason);
144						if let Err(reason) = ctx1
145							.sender()
146							.blocking_send(TaskSignal::RestartQueryable(key.clone()))
147						{
148							error!("could not restart queryable: {}", reason);
149						} else {
150							info!("restarting queryable!");
151						}
152					}));
153					if let Err(error) = run_queryable(
154						session,
155						selector,
156						cb,
157						completeness,
158						#[cfg(feature = "unstable")]
159						allowed_origin,
160						ctx2,
161					)
162					.await
163					{
164						error!("queryable failed with {error}");
165					}
166				}));
167				Ok(())
168			},
169		)
170	}
171
172	/// Stop a running Queryable
173	#[instrument(level = Level::TRACE)]
174	fn stop(&self) -> Result<()> {
175		self.handle.lock().map_or_else(
176			|_| todo!(),
177			|mut handle| {
178				handle.take();
179				Ok(())
180			},
181		)
182	}
183}
184
185#[instrument(name="queryable", level = Level::ERROR, skip_all)]
186async fn run_queryable<P>(
187	session: Arc<Session>,
188	selector: String,
189	callback: ArcGetCallback<P>,
190	completeness: bool,
191	#[cfg(feature = "unstable")] allowed_origin: Locality,
192	ctx: Context<P>,
193) -> Result<()>
194where
195	P: Send + Sync + 'static,
196{
197	let builder = session
198		.declare_queryable(&selector)
199		.complete(completeness);
200	#[cfg(feature = "unstable")]
201	let builder = builder.allowed_origin(allowed_origin);
202
203	let queryable = builder.await?;
204
205	loop {
206		let query = queryable.recv_async().await?;
207		let request = QueryMsg(query);
208
209		let ctx = ctx.clone();
210		let mut lock = callback.lock().await;
211		if let Err(error) = lock(ctx, request).await {
212			error!("queryable callback failed with {error}");
213		}
214	}
215}
216// endregion:	--- Queryable
217
218#[cfg(test)]
219mod tests {
220	use super::*;
221
222	#[derive(Debug)]
223	struct Props {}
224
225	// check, that the auto traits are available
226	const fn is_normal<T: Sized + Send + Sync>() {}
227
228	#[test]
229	const fn normal_types() {
230		is_normal::<Queryable<Props>>();
231	}
232}