1extern crate alloc;
4
5#[cfg(feature = "std")]
6extern crate std;
7
8use alloc::{
10 boxed::Box,
11 string::{String, ToString},
12 sync::Arc,
13 vec::Vec,
14};
15use bitcode::encode;
16use core::time::Duration;
17use dimas_core::{
18 enums::{OperationState, TaskSignal},
19 message_types::{ControlResponse, Message, ObservableResponse},
20 traits::{Capability, Context},
21 utils::feedback_selector_from,
22 Result,
23};
24use futures::future::BoxFuture;
25#[cfg(feature = "std")]
26use tokio::{sync::Mutex, task::JoinHandle};
27use tracing::{error, info, instrument, warn, Level};
28#[cfg(feature = "unstable")]
29use zenoh::sample::Locality;
30use zenoh::Wait;
31use zenoh::{
32 qos::{CongestionControl, Priority},
33 Session,
34};
35pub type ControlCallback<P> = Box<
40 dyn FnMut(Context<P>, Message) -> BoxFuture<'static, Result<ControlResponse>> + Send + Sync,
41>;
42pub type ArcControlCallback<P> = Arc<Mutex<ControlCallback<P>>>;
44pub type FeedbackCallback<P> =
46 Box<dyn FnMut(Context<P>) -> BoxFuture<'static, Result<Message>> + Send + Sync>;
47pub type ArcFeedbackCallback<P> = Arc<Mutex<FeedbackCallback<P>>>;
49pub type ExecutionCallback<P> =
51 Box<dyn FnMut(Context<P>) -> BoxFuture<'static, Result<Message>> + Send + Sync>;
52pub type ArcExecutionCallback<P> = Arc<Mutex<ExecutionCallback<P>>>;
54pub struct Observable<P>
59where
60 P: Send + Sync + 'static,
61{
62 session: Arc<Session>,
64 selector: String,
66 context: Context<P>,
68 activation_state: OperationState,
69 feedback_interval: Duration,
70 control_callback: ArcControlCallback<P>,
72 feedback_callback: ArcFeedbackCallback<P>,
74 feedback_publisher: Arc<Mutex<Option<zenoh::pubsub::Publisher<'static>>>>,
75 execution_function: ArcExecutionCallback<P>,
77 execution_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
78 handle: std::sync::Mutex<Option<JoinHandle<()>>>,
79}
80
81impl<P> core::fmt::Debug for Observable<P>
82where
83 P: Send + Sync + 'static,
84{
85 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
86 f.debug_struct("Observable")
87 .finish_non_exhaustive()
88 }
89}
90
91impl<P> crate::traits::Responder for Observable<P>
92where
93 P: Send + Sync + 'static,
94{
95 fn selector(&self) -> &str {
97 &self.selector
98 }
99}
100
101impl<P> Capability for Observable<P>
102where
103 P: Send + Sync + 'static,
104{
105 fn manage_operation_state(&self, state: &OperationState) -> Result<()> {
106 if state >= &self.activation_state {
107 self.start()
108 } else if state < &self.activation_state {
109 self.stop()
110 } else {
111 Ok(())
112 }
113 }
114}
115
116impl<P> Observable<P>
117where
118 P: Send + Sync + 'static,
119{
120 #[allow(clippy::too_many_arguments)]
122 #[must_use]
123 pub fn new(
124 session: Arc<Session>,
125 selector: String,
126 context: Context<P>,
127 activation_state: OperationState,
128 feedback_interval: Duration,
129 control_callback: ArcControlCallback<P>,
130 feedback_callback: ArcFeedbackCallback<P>,
131 execution_function: ArcExecutionCallback<P>,
132 ) -> Self {
133 Self {
134 session,
135 selector,
136 context,
137 activation_state,
138 feedback_interval,
139 control_callback,
140 feedback_callback,
141 feedback_publisher: Arc::new(Mutex::new(None)),
142 execution_function,
143 execution_handle: Arc::new(Mutex::new(None)),
144 handle: std::sync::Mutex::new(None),
145 }
146 }
147
148 #[instrument(level = Level::TRACE, skip_all)]
151 fn start(&self) -> Result<()> {
152 self.stop()?;
153
154 let selector = self.selector.clone();
155 let interval = self.feedback_interval;
156 let ccb = self.control_callback.clone();
157 let fcb = self.feedback_callback.clone();
158 let fcbp = self.feedback_publisher.clone();
159 let efc = self.execution_function.clone();
160 let efch = self.execution_handle.clone();
161 let ctx1 = self.context.clone();
162 let ctx2 = self.context.clone();
163 let session = self.session.clone();
164
165 self.handle.lock().map_or_else(
166 |_| todo!(),
167 |mut handle| {
168 handle.replace(tokio::task::spawn(async move {
169 let key = selector.clone();
170 std::panic::set_hook(Box::new(move |reason| {
171 error!("observable panic: {}", reason);
172 if let Err(reason) = ctx1
173 .sender()
174 .blocking_send(TaskSignal::RestartObservable(key.clone()))
175 {
176 error!("could not restart observable: {}", reason);
177 } else {
178 info!("restarting observable!");
179 };
180 }));
181 if let Err(error) =
182 run_observable(session, selector, interval, ccb, fcb, fcbp, efc, efch, ctx2)
183 .await
184 {
185 error!("observable failed with {error}");
186 };
187 }));
188
189 Ok(())
190 },
191 )
192 }
193
194 #[instrument(level = Level::TRACE, skip_all)]
196 #[allow(clippy::significant_drop_in_scrutinee)]
197 fn stop(&self) -> Result<()> {
198 self.handle.lock().map_or_else(
199 |_| todo!(),
200 |mut handle| {
201 if let Some(handle) = handle.take() {
202 let feedback_publisher = self.feedback_publisher.clone();
203 let feedback_callback = self.feedback_callback.clone();
204 let execution_handle = self.execution_handle.clone();
205 let ctx = self.context.clone();
206 tokio::spawn(async move {
207 if let Some(execution_handle) = execution_handle.lock().await.take() {
209 execution_handle.abort();
210 if let Some(publisher) = feedback_publisher.lock().await.take() {
212 let Ok(msg) = feedback_callback.lock().await(ctx).await else {
213 todo!()
214 };
215 let response = ObservableResponse::Canceled(msg.value().clone());
216 match publisher
217 .put(Message::encode(&response).value().clone())
218 .wait()
219 {
220 Ok(()) => {}
221 Err(err) => error!("could not send cancel state due to {err}"),
222 };
223 };
224 };
225 handle.abort();
226 });
227 }
228 Ok(())
229 },
230 )
231 }
232}
233#[allow(clippy::significant_drop_tightening)]
237#[allow(clippy::too_many_arguments)]
238#[instrument(name="observable", level = Level::ERROR, skip_all)]
239async fn run_observable<P>(
240 session: Arc<Session>,
241 selector: String,
242 feedback_interval: Duration,
243 control_callback: ArcControlCallback<P>,
244 feedback_callback: ArcFeedbackCallback<P>,
245 feedback_publisher: Arc<Mutex<Option<zenoh::pubsub::Publisher<'static>>>>,
246 execution_function: ArcExecutionCallback<P>,
247 execution_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
248 ctx: Context<P>,
249) -> Result<()>
250where
251 P: Send + Sync + 'static,
252{
253 let builder = session
255 .declare_queryable(&selector)
256 .complete(true);
257
258 #[cfg(feature = "unstable")]
259 let builder = builder.allowed_origin(Locality::Any);
260
261 let queryable = builder.await?;
262
263 let feedback_timer = tokio::time::sleep(feedback_interval);
266 tokio::pin!(feedback_timer);
267
268 let key = selector.clone();
270 let publisher_selector = feedback_selector_from(&key, &session.zid().to_string());
271
272 let mut is_running = false;
274 let (tx, mut rx) = tokio::sync::mpsc::channel(8);
275
276 loop {
280 let ctx = ctx.clone();
281 tokio::select! {
283 Ok(query) = queryable.recv_async() => {
285 let p = query.parameters().as_str();
287 if p == "request" {
288 if is_running {
290 let key = query.selector().key_expr().to_string();
292 let encoded: Vec<u8> = encode(&ControlResponse::Occupied);
293 match query.reply(&key, encoded).wait() {
294 Ok(()) => {},
295 Err(err) => error!("failed to reply with {err}"),
296 };
297 } else {
298 let content = query.payload().map_or_else(
301 || {
302 let content: Vec<u8> = Vec::new();
303 content
304 },
305 |value| {
306 let content: Vec<u8> = value.to_bytes().into_owned();
307 content
308 },
309 );
310 let msg = Message::new(content);
311 let ctx_clone = ctx.clone();
312 let res = control_callback.lock().await(ctx_clone, msg).await;
313 match res {
314 Ok(response) => {
315 if matches!(response, ControlResponse::Accepted ) {
316 let mut fp = feedback_publisher.lock().await;
318 session
319 .declare_publisher(publisher_selector.clone())
320 .congestion_control(CongestionControl::Block)
321 .priority(Priority::RealTime)
322 .wait()
323 .map_or_else(
324 |err| error!("could not create feedback publisher due to {err}"),
325 |publ| { fp.replace(publ); }
326 );
327
328
329 let tx_clone = tx.clone();
331 let execution_function_clone = execution_function.clone();
332 let ctx_clone = ctx.clone();
333 execution_handle.lock().await.replace(tokio::spawn( async move {
334 let res = execution_function_clone.lock().await(ctx_clone).await.unwrap_or_else(|_| { todo!() });
335 if !matches!(tx_clone.send(res).await, Ok(())) { error!("failed to send back execution result") };
336 }));
337
338 feedback_timer.set(tokio::time::sleep(feedback_interval));
340 is_running = true;
341 }
342 let encoded: Vec<u8> = encode(&response);
344 match query.reply(&key, encoded).wait() {
345 Ok(()) => {},
346 Err(err) => error!("failed to reply with {err}"),
347 };
348 }
349 Err(error) => error!("control callback failed with {error}"),
350 }
351 }
352 } else if p == "cancel" {
353 if is_running {
355 is_running = false;
356 let publisher = feedback_publisher.lock().await.take();
357 let handle = execution_handle.lock().await.take();
358 if let Some(h) = handle {
359 h.abort();
360 let _ = h.await;
362 let Ok(msg) = feedback_callback.lock().await(ctx).await else { todo!() };
363 let response =
364 ObservableResponse::Canceled(msg.value().clone());
365 if let Some(p) = publisher {
366 match p.put(Message::encode(&response).value().clone()).wait() {
367 Ok(()) => {},
368 Err(err) => error!("could not send cancel state due to {err}"),
369 };
370 } else {
371 error!("missing publisher");
372 };
373 } else {
374 error!("unexpected absence of execution handle");
375 };
376 }
377 let encoded: Vec<u8> = encode(&ControlResponse::Canceled);
379 match query.reply(&key, encoded).wait() {
380 Ok(()) => {},
381 Err(err) => error!("failed to reply with {err}"),
382 };
383 } else {
384 error!("observable got unknown parameter: {p}");
385 }
386 }
387
388 Some(result) = rx.recv() => {
390 if is_running {
391 is_running = false;
392 execution_handle.lock().await.take();
393 let response = ObservableResponse::Finished(result.value().clone());
394 feedback_publisher.lock().await.take().map_or_else(
395 || error!("could not publish result"),
396 |p| {
397 match p.put(Message::encode(&response).value()).wait() {
398 Ok(()) => {},
399 Err(err) => error!("publishing result failed due to {err}"),
400 };
401 }
402 );
403 }
404 }
405
406 () = &mut feedback_timer, if is_running => {
408 let Ok(msg) = feedback_callback.lock().await(ctx).await else { todo!() };
409 let response =
410 ObservableResponse::Feedback(msg.value().clone());
411
412 let lock = feedback_publisher.lock().await;
413 let publisher = lock.as_ref().map_or_else(
414 || { todo!() },
415 |p| p
416 );
417 match publisher.put(Message::encode(&response).value().clone()).wait() {
418 Ok(()) => {},
419 Err(err) => error!("publishing feedback failed due to {err}"),
420 };
421
422 feedback_timer.set(tokio::time::sleep(feedback_interval));
424 }
425 }
426 }
427}
428#[cfg(test)]
431mod tests {
432 use super::*;
433
434 #[derive(Debug)]
435 struct Props {}
436
437 const fn is_normal<T: Sized + Send + Sync>() {}
439
440 #[test]
441 const fn normal_types() {
442 is_normal::<Observable<Props>>();
443 }
444}