dimas_com/zenoh/
observable.rs

1// Copyright © 2024 Stephan Kunz
2
3extern crate alloc;
4
5#[cfg(feature = "std")]
6extern crate std;
7
8// region:		--- modules
9use alloc::{
10	boxed::Box,
11	string::{String, ToString},
12	sync::Arc,
13	vec::Vec,
14};
15use bitcode::encode;
16use core::time::Duration;
17use dimas_core::{
18	Result,
19	enums::{OperationState, TaskSignal},
20	message_types::{ControlResponse, Message, ObservableResponse},
21	traits::{Capability, Context},
22	utils::feedback_selector_from,
23};
24use futures::future::BoxFuture;
25#[cfg(feature = "std")]
26use tokio::{sync::Mutex, task::JoinHandle};
27use tracing::{Level, error, info, instrument, warn};
28use zenoh::Wait;
29#[cfg(feature = "unstable")]
30use zenoh::sample::Locality;
31use zenoh::{
32	Session,
33	qos::{CongestionControl, Priority},
34};
35// endregion:	--- modules
36
37// region:    	--- types
38/// Type definition for an observables `control` callback
39pub type ControlCallback<P> = Box<
40	dyn FnMut(Context<P>, Message) -> BoxFuture<'static, Result<ControlResponse>> + Send + Sync,
41>;
42/// Type definition for an observables atomic reference counted `control` callback
43pub type ArcControlCallback<P> = Arc<Mutex<ControlCallback<P>>>;
44/// Type definition for an observables `feedback` callback
45pub type FeedbackCallback<P> =
46	Box<dyn FnMut(Context<P>) -> BoxFuture<'static, Result<Message>> + Send + Sync>;
47/// Type definition for an observables atomic reference counted `feedback` callback
48pub type ArcFeedbackCallback<P> = Arc<Mutex<FeedbackCallback<P>>>;
49/// Type definition for an observables atomic reference counted `execution` callback
50pub type ExecutionCallback<P> =
51	Box<dyn FnMut(Context<P>) -> BoxFuture<'static, Result<Message>> + Send + Sync>;
52/// Type definition for an observables atomic reference counted `execution` callback
53pub type ArcExecutionCallback<P> = Arc<Mutex<ExecutionCallback<P>>>;
54// endregion: 	--- types
55
56// region:		--- Observable
57/// Observable
58pub struct Observable<P>
59where
60	P: Send + Sync + 'static,
61{
62	/// the zenoh session this observable belongs to
63	session: Arc<Session>,
64	/// The observables key expression
65	selector: String,
66	/// Context for the Observable
67	context: Context<P>,
68	activation_state: OperationState,
69	feedback_interval: Duration,
70	/// callback for observation request and cancelation
71	control_callback: ArcControlCallback<P>,
72	/// callback for observation feedback
73	feedback_callback: ArcFeedbackCallback<P>,
74	feedback_publisher: Arc<Mutex<Option<zenoh::pubsub::Publisher<'static>>>>,
75	/// function for observation execution
76	execution_function: ArcExecutionCallback<P>,
77	execution_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
78	handle: std::sync::Mutex<Option<JoinHandle<()>>>,
79}
80
81impl<P> core::fmt::Debug for Observable<P>
82where
83	P: Send + Sync + 'static,
84{
85	fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
86		f.debug_struct("Observable")
87			.finish_non_exhaustive()
88	}
89}
90
91impl<P> crate::traits::Responder for Observable<P>
92where
93	P: Send + Sync + 'static,
94{
95	/// Get `selector`
96	fn selector(&self) -> &str {
97		&self.selector
98	}
99}
100
101impl<P> Capability for Observable<P>
102where
103	P: Send + Sync + 'static,
104{
105	fn manage_operation_state(&self, state: &OperationState) -> Result<()> {
106		if state >= &self.activation_state {
107			self.start()
108		} else if state < &self.activation_state {
109			self.stop()
110		} else {
111			Ok(())
112		}
113	}
114}
115
116impl<P> Observable<P>
117where
118	P: Send + Sync + 'static,
119{
120	/// Constructor for an [`Observable`]
121	#[allow(clippy::too_many_arguments)]
122	#[must_use]
123	pub fn new(
124		session: Arc<Session>,
125		selector: String,
126		context: Context<P>,
127		activation_state: OperationState,
128		feedback_interval: Duration,
129		control_callback: ArcControlCallback<P>,
130		feedback_callback: ArcFeedbackCallback<P>,
131		execution_function: ArcExecutionCallback<P>,
132	) -> Self {
133		Self {
134			session,
135			selector,
136			context,
137			activation_state,
138			feedback_interval,
139			control_callback,
140			feedback_callback,
141			feedback_publisher: Arc::new(Mutex::new(None)),
142			execution_function,
143			execution_handle: Arc::new(Mutex::new(None)),
144			handle: std::sync::Mutex::new(None),
145		}
146	}
147
148	/// Start or restart the Observable.
149	/// An already running Observable will be stopped.
150	#[instrument(level = Level::TRACE, skip_all)]
151	fn start(&self) -> Result<()> {
152		self.stop()?;
153
154		let selector = self.selector.clone();
155		let interval = self.feedback_interval;
156		let ccb = self.control_callback.clone();
157		let fcb = self.feedback_callback.clone();
158		let fcbp = self.feedback_publisher.clone();
159		let efc = self.execution_function.clone();
160		let efch = self.execution_handle.clone();
161		let ctx1 = self.context.clone();
162		let ctx2 = self.context.clone();
163		let session = self.session.clone();
164
165		self.handle.lock().map_or_else(
166			|_| todo!(),
167			|mut handle| {
168				handle.replace(tokio::task::spawn(async move {
169					let key = selector.clone();
170					std::panic::set_hook(Box::new(move |reason| {
171						error!("observable panic: {}", reason);
172						if let Err(reason) = ctx1
173							.sender()
174							.blocking_send(TaskSignal::RestartObservable(key.clone()))
175						{
176							error!("could not restart observable: {}", reason);
177						} else {
178							info!("restarting observable!");
179						}
180					}));
181					if let Err(error) =
182						run_observable(session, selector, interval, ccb, fcb, fcbp, efc, efch, ctx2)
183							.await
184					{
185						error!("observable failed with {error}");
186					}
187				}));
188
189				Ok(())
190			},
191		)
192	}
193
194	/// Stop a running Observable
195	#[instrument(level = Level::TRACE, skip_all)]
196	#[allow(clippy::significant_drop_in_scrutinee)]
197	fn stop(&self) -> Result<()> {
198		self.handle.lock().map_or_else(
199			|_| todo!(),
200			|mut handle| {
201				if let Some(handle) = handle.take() {
202					let feedback_publisher = self.feedback_publisher.clone();
203					let feedback_callback = self.feedback_callback.clone();
204					let execution_handle = self.execution_handle.clone();
205					let ctx = self.context.clone();
206					tokio::spawn(async move {
207						// stop execution if running
208						if let Some(execution_handle) = execution_handle.lock().await.take() {
209							execution_handle.abort();
210							// send back cancelation message
211							if let Some(publisher) = feedback_publisher.lock().await.take() {
212								let Ok(msg) = feedback_callback.lock().await(ctx).await else {
213									todo!()
214								};
215								let response = ObservableResponse::Canceled(msg.value().clone());
216								match publisher
217									.put(Message::encode(&response).value().clone())
218									.wait()
219								{
220									Ok(()) => {}
221									Err(err) => error!("could not send cancel state due to {err}"),
222								}
223							}
224						}
225						handle.abort();
226					});
227				}
228				Ok(())
229			},
230		)
231	}
232}
233// endregion:	--- Observable
234
235// region:		--- functions
236#[allow(clippy::significant_drop_tightening)]
237#[allow(clippy::too_many_arguments)]
238#[allow(clippy::too_many_lines)]
239#[instrument(name="observable", level = Level::ERROR, skip_all)]
240async fn run_observable<P>(
241	session: Arc<Session>,
242	selector: String,
243	feedback_interval: Duration,
244	control_callback: ArcControlCallback<P>,
245	feedback_callback: ArcFeedbackCallback<P>,
246	feedback_publisher: Arc<Mutex<Option<zenoh::pubsub::Publisher<'static>>>>,
247	execution_function: ArcExecutionCallback<P>,
248	execution_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
249	ctx: Context<P>,
250) -> Result<()>
251where
252	P: Send + Sync + 'static,
253{
254	// create the control queryable
255	let builder = session
256		.declare_queryable(&selector)
257		.complete(true);
258
259	#[cfg(feature = "unstable")]
260	let builder = builder.allowed_origin(Locality::Any);
261
262	let queryable = builder.await?;
263
264	// initialize a pinned feedback timer
265	// TODO: init here leads to on unnecessary timer-cycle without doing something
266	let feedback_timer = tokio::time::sleep(feedback_interval);
267	tokio::pin!(feedback_timer);
268
269	// base communication key & selector for feedback publisher
270	let key = selector.clone();
271	let publisher_selector = feedback_selector_from(&key, &session.zid().to_string());
272
273	// variables to manage control loop
274	let mut is_running = false;
275	let (tx, mut rx) = tokio::sync::mpsc::channel(8);
276
277	// main control loop of the observable
278	// started and terminated by state management
279	// do not terminate loop in case of errors during execution
280	loop {
281		let ctx = ctx.clone();
282		// different cases that may happen
283		tokio::select! {
284			// got query from an observer
285			Ok(query) = queryable.recv_async() => {
286				// TODO: make a proper "key: value" implementation
287				let p = query.parameters().as_str();
288				if p == "request" {
289					// received request => if no execution is running: spawn execution with channel for result else: return already running message
290					if is_running {
291						// send occupied response
292						let key = query.selector().key_expr().to_string();
293						let encoded: Vec<u8> = encode(&ControlResponse::Occupied);
294						match query.reply(&key, encoded).wait() {
295							Ok(()) => {},
296							Err(err) => error!("failed to reply with {err}"),
297						}
298					} else {
299						// start a computation
300						// create Message from payload
301						let content = query.payload().map_or_else(
302							|| {
303								let content: Vec<u8> = Vec::new();
304								content
305							},
306							|value| {
307								let content: Vec<u8> = value.to_bytes().into_owned();
308								content
309							},
310						);
311						let msg = Message::new(content);
312						let ctx_clone = ctx.clone();
313						let res = control_callback.lock().await(ctx_clone, msg).await;
314						match res {
315							Ok(response) => {
316								if matches!(response, ControlResponse::Accepted ) {
317									// create feedback publisher
318									let mut fp = feedback_publisher.lock().await;
319									session
320										.declare_publisher(publisher_selector.clone())
321										.congestion_control(CongestionControl::Block)
322										.priority(Priority::RealTime)
323										.wait()
324										.map_or_else(
325											|err| error!("could not create feedback publisher due to {err}"),
326											|publ| { fp.replace(publ); }
327										);
328
329
330									// spawn execution
331									let tx_clone = tx.clone();
332									let execution_function_clone = execution_function.clone();
333									let ctx_clone = ctx.clone();
334									execution_handle.lock().await.replace(tokio::spawn( async move {
335										let res = execution_function_clone.lock().await(ctx_clone).await.unwrap_or_else(|_| { todo!() });
336										if !matches!(tx_clone.send(res).await, Ok(())) { error!("failed to send back execution result") }
337									}));
338
339									// start feedback timer
340									feedback_timer.set(tokio::time::sleep(feedback_interval));
341									is_running = true;
342								}
343								// send  response back to requestor
344								let encoded: Vec<u8> = encode(&response);
345								match query.reply(&key, encoded).wait() {
346									Ok(()) => {},
347									Err(err) => error!("failed to reply with {err}"),
348								}
349							}
350							Err(error) => error!("control callback failed with {error}"),
351						}
352					}
353				} else if p == "cancel" {
354					// received cancel => abort a running execution
355					if is_running {
356						is_running = false;
357						let publisher = feedback_publisher.lock().await.take();
358						let handle = execution_handle.lock().await.take();
359						if let Some(h) = handle {
360							h.abort();
361							// wait for abortion
362							let _ = h.await;
363							let Ok(msg) = feedback_callback.lock().await(ctx).await else { todo!() };
364							let response =
365								ObservableResponse::Canceled(msg.value().clone());
366							if let Some(p) = publisher {
367								match p.put(Message::encode(&response).value().clone()).wait() {
368									Ok(()) => {},
369									Err(err) => error!("could not send cancel state due to {err}"),
370								}
371							} else {
372								error!("missing publisher");
373							}
374						} else {
375							error!("unexpected absence of execution handle");
376						}
377					}
378					// acknowledge cancel request
379					let encoded: Vec<u8> = encode(&ControlResponse::Canceled);
380					match query.reply(&key, encoded).wait() {
381						Ok(()) => {},
382						Err(err) => error!("failed to reply with {err}"),
383					}
384				} else {
385					error!("observable got unknown parameter: {p}");
386				}
387			}
388
389			// request finished => send back result of request (which may be a failure)
390			Some(result) = rx.recv() => {
391				if is_running {
392					is_running = false;
393					execution_handle.lock().await.take();
394					let response = ObservableResponse::Finished(result.value().clone());
395					feedback_publisher.lock().await.take().map_or_else(
396						|| error!("could not publish result"),
397						|p| {
398							match p.put(Message::encode(&response).value()).wait() {
399								Ok(()) => {},
400								Err(err) => error!("publishing result failed due to {err}"),
401							}
402						}
403					);
404				}
405			}
406
407			// feedback timer expired and observable still is executing
408			() = &mut feedback_timer, if is_running => {
409				let Ok(msg) = feedback_callback.lock().await(ctx).await else { todo!() };
410				let response =
411					ObservableResponse::Feedback(msg.value().clone());
412
413				let lock = feedback_publisher.lock().await;
414				let publisher = lock.as_ref().map_or_else(
415					|| { todo!() },
416					|p| p
417				);
418				match publisher.put(Message::encode(&response).value().clone()).wait() {
419					Ok(()) => {},
420					Err(err) => error!("publishing feedback failed due to {err}"),
421				}
422
423				// restart timer
424				feedback_timer.set(tokio::time::sleep(feedback_interval));
425			}
426		}
427	}
428}
429// endregion:	--- functions
430
431#[cfg(test)]
432mod tests {
433	use super::*;
434
435	#[derive(Debug)]
436	struct Props {}
437
438	// check, that the auto traits are available
439	const fn is_normal<T: Sized + Send + Sync>() {}
440
441	#[test]
442	const fn normal_types() {
443		is_normal::<Observable<Props>>();
444	}
445}