1extern crate alloc;
4
5#[cfg(feature = "std")]
6extern crate std;
7
8use alloc::{
10 boxed::Box,
11 string::{String, ToString},
12 sync::Arc,
13 vec::Vec,
14};
15use bitcode::encode;
16use core::time::Duration;
17use dimas_core::{
18 Result,
19 enums::{OperationState, TaskSignal},
20 message_types::{ControlResponse, Message, ObservableResponse},
21 traits::{Capability, Context},
22 utils::feedback_selector_from,
23};
24use futures::future::BoxFuture;
25#[cfg(feature = "std")]
26use tokio::{sync::Mutex, task::JoinHandle};
27use tracing::{Level, error, info, instrument, warn};
28use zenoh::Wait;
29#[cfg(feature = "unstable")]
30use zenoh::sample::Locality;
31use zenoh::{
32 Session,
33 qos::{CongestionControl, Priority},
34};
35pub type ControlCallback<P> = Box<
40 dyn FnMut(Context<P>, Message) -> BoxFuture<'static, Result<ControlResponse>> + Send + Sync,
41>;
42pub type ArcControlCallback<P> = Arc<Mutex<ControlCallback<P>>>;
44pub type FeedbackCallback<P> =
46 Box<dyn FnMut(Context<P>) -> BoxFuture<'static, Result<Message>> + Send + Sync>;
47pub type ArcFeedbackCallback<P> = Arc<Mutex<FeedbackCallback<P>>>;
49pub type ExecutionCallback<P> =
51 Box<dyn FnMut(Context<P>) -> BoxFuture<'static, Result<Message>> + Send + Sync>;
52pub type ArcExecutionCallback<P> = Arc<Mutex<ExecutionCallback<P>>>;
54pub struct Observable<P>
59where
60 P: Send + Sync + 'static,
61{
62 session: Arc<Session>,
64 selector: String,
66 context: Context<P>,
68 activation_state: OperationState,
69 feedback_interval: Duration,
70 control_callback: ArcControlCallback<P>,
72 feedback_callback: ArcFeedbackCallback<P>,
74 feedback_publisher: Arc<Mutex<Option<zenoh::pubsub::Publisher<'static>>>>,
75 execution_function: ArcExecutionCallback<P>,
77 execution_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
78 handle: std::sync::Mutex<Option<JoinHandle<()>>>,
79}
80
81impl<P> core::fmt::Debug for Observable<P>
82where
83 P: Send + Sync + 'static,
84{
85 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
86 f.debug_struct("Observable")
87 .finish_non_exhaustive()
88 }
89}
90
91impl<P> crate::traits::Responder for Observable<P>
92where
93 P: Send + Sync + 'static,
94{
95 fn selector(&self) -> &str {
97 &self.selector
98 }
99}
100
101impl<P> Capability for Observable<P>
102where
103 P: Send + Sync + 'static,
104{
105 fn manage_operation_state(&self, state: &OperationState) -> Result<()> {
106 if state >= &self.activation_state {
107 self.start()
108 } else if state < &self.activation_state {
109 self.stop()
110 } else {
111 Ok(())
112 }
113 }
114}
115
116impl<P> Observable<P>
117where
118 P: Send + Sync + 'static,
119{
120 #[allow(clippy::too_many_arguments)]
122 #[must_use]
123 pub fn new(
124 session: Arc<Session>,
125 selector: String,
126 context: Context<P>,
127 activation_state: OperationState,
128 feedback_interval: Duration,
129 control_callback: ArcControlCallback<P>,
130 feedback_callback: ArcFeedbackCallback<P>,
131 execution_function: ArcExecutionCallback<P>,
132 ) -> Self {
133 Self {
134 session,
135 selector,
136 context,
137 activation_state,
138 feedback_interval,
139 control_callback,
140 feedback_callback,
141 feedback_publisher: Arc::new(Mutex::new(None)),
142 execution_function,
143 execution_handle: Arc::new(Mutex::new(None)),
144 handle: std::sync::Mutex::new(None),
145 }
146 }
147
148 #[instrument(level = Level::TRACE, skip_all)]
151 fn start(&self) -> Result<()> {
152 self.stop()?;
153
154 let selector = self.selector.clone();
155 let interval = self.feedback_interval;
156 let ccb = self.control_callback.clone();
157 let fcb = self.feedback_callback.clone();
158 let fcbp = self.feedback_publisher.clone();
159 let efc = self.execution_function.clone();
160 let efch = self.execution_handle.clone();
161 let ctx1 = self.context.clone();
162 let ctx2 = self.context.clone();
163 let session = self.session.clone();
164
165 self.handle.lock().map_or_else(
166 |_| todo!(),
167 |mut handle| {
168 handle.replace(tokio::task::spawn(async move {
169 let key = selector.clone();
170 std::panic::set_hook(Box::new(move |reason| {
171 error!("observable panic: {}", reason);
172 if let Err(reason) = ctx1
173 .sender()
174 .blocking_send(TaskSignal::RestartObservable(key.clone()))
175 {
176 error!("could not restart observable: {}", reason);
177 } else {
178 info!("restarting observable!");
179 }
180 }));
181 if let Err(error) =
182 run_observable(session, selector, interval, ccb, fcb, fcbp, efc, efch, ctx2)
183 .await
184 {
185 error!("observable failed with {error}");
186 }
187 }));
188
189 Ok(())
190 },
191 )
192 }
193
194 #[instrument(level = Level::TRACE, skip_all)]
196 #[allow(clippy::significant_drop_in_scrutinee)]
197 fn stop(&self) -> Result<()> {
198 self.handle.lock().map_or_else(
199 |_| todo!(),
200 |mut handle| {
201 if let Some(handle) = handle.take() {
202 let feedback_publisher = self.feedback_publisher.clone();
203 let feedback_callback = self.feedback_callback.clone();
204 let execution_handle = self.execution_handle.clone();
205 let ctx = self.context.clone();
206 tokio::spawn(async move {
207 if let Some(execution_handle) = execution_handle.lock().await.take() {
209 execution_handle.abort();
210 if let Some(publisher) = feedback_publisher.lock().await.take() {
212 let Ok(msg) = feedback_callback.lock().await(ctx).await else {
213 todo!()
214 };
215 let response = ObservableResponse::Canceled(msg.value().clone());
216 match publisher
217 .put(Message::encode(&response).value().clone())
218 .wait()
219 {
220 Ok(()) => {}
221 Err(err) => error!("could not send cancel state due to {err}"),
222 }
223 }
224 }
225 handle.abort();
226 });
227 }
228 Ok(())
229 },
230 )
231 }
232}
233#[allow(clippy::significant_drop_tightening)]
237#[allow(clippy::too_many_arguments)]
238#[allow(clippy::too_many_lines)]
239#[instrument(name="observable", level = Level::ERROR, skip_all)]
240async fn run_observable<P>(
241 session: Arc<Session>,
242 selector: String,
243 feedback_interval: Duration,
244 control_callback: ArcControlCallback<P>,
245 feedback_callback: ArcFeedbackCallback<P>,
246 feedback_publisher: Arc<Mutex<Option<zenoh::pubsub::Publisher<'static>>>>,
247 execution_function: ArcExecutionCallback<P>,
248 execution_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
249 ctx: Context<P>,
250) -> Result<()>
251where
252 P: Send + Sync + 'static,
253{
254 let builder = session
256 .declare_queryable(&selector)
257 .complete(true);
258
259 #[cfg(feature = "unstable")]
260 let builder = builder.allowed_origin(Locality::Any);
261
262 let queryable = builder.await?;
263
264 let feedback_timer = tokio::time::sleep(feedback_interval);
267 tokio::pin!(feedback_timer);
268
269 let key = selector.clone();
271 let publisher_selector = feedback_selector_from(&key, &session.zid().to_string());
272
273 let mut is_running = false;
275 let (tx, mut rx) = tokio::sync::mpsc::channel(8);
276
277 loop {
281 let ctx = ctx.clone();
282 tokio::select! {
284 Ok(query) = queryable.recv_async() => {
286 let p = query.parameters().as_str();
288 if p == "request" {
289 if is_running {
291 let key = query.selector().key_expr().to_string();
293 let encoded: Vec<u8> = encode(&ControlResponse::Occupied);
294 match query.reply(&key, encoded).wait() {
295 Ok(()) => {},
296 Err(err) => error!("failed to reply with {err}"),
297 }
298 } else {
299 let content = query.payload().map_or_else(
302 || {
303 let content: Vec<u8> = Vec::new();
304 content
305 },
306 |value| {
307 let content: Vec<u8> = value.to_bytes().into_owned();
308 content
309 },
310 );
311 let msg = Message::new(content);
312 let ctx_clone = ctx.clone();
313 let res = control_callback.lock().await(ctx_clone, msg).await;
314 match res {
315 Ok(response) => {
316 if matches!(response, ControlResponse::Accepted ) {
317 let mut fp = feedback_publisher.lock().await;
319 session
320 .declare_publisher(publisher_selector.clone())
321 .congestion_control(CongestionControl::Block)
322 .priority(Priority::RealTime)
323 .wait()
324 .map_or_else(
325 |err| error!("could not create feedback publisher due to {err}"),
326 |publ| { fp.replace(publ); }
327 );
328
329
330 let tx_clone = tx.clone();
332 let execution_function_clone = execution_function.clone();
333 let ctx_clone = ctx.clone();
334 execution_handle.lock().await.replace(tokio::spawn( async move {
335 let res = execution_function_clone.lock().await(ctx_clone).await.unwrap_or_else(|_| { todo!() });
336 if !matches!(tx_clone.send(res).await, Ok(())) { error!("failed to send back execution result") }
337 }));
338
339 feedback_timer.set(tokio::time::sleep(feedback_interval));
341 is_running = true;
342 }
343 let encoded: Vec<u8> = encode(&response);
345 match query.reply(&key, encoded).wait() {
346 Ok(()) => {},
347 Err(err) => error!("failed to reply with {err}"),
348 }
349 }
350 Err(error) => error!("control callback failed with {error}"),
351 }
352 }
353 } else if p == "cancel" {
354 if is_running {
356 is_running = false;
357 let publisher = feedback_publisher.lock().await.take();
358 let handle = execution_handle.lock().await.take();
359 if let Some(h) = handle {
360 h.abort();
361 let _ = h.await;
363 let Ok(msg) = feedback_callback.lock().await(ctx).await else { todo!() };
364 let response =
365 ObservableResponse::Canceled(msg.value().clone());
366 if let Some(p) = publisher {
367 match p.put(Message::encode(&response).value().clone()).wait() {
368 Ok(()) => {},
369 Err(err) => error!("could not send cancel state due to {err}"),
370 }
371 } else {
372 error!("missing publisher");
373 }
374 } else {
375 error!("unexpected absence of execution handle");
376 }
377 }
378 let encoded: Vec<u8> = encode(&ControlResponse::Canceled);
380 match query.reply(&key, encoded).wait() {
381 Ok(()) => {},
382 Err(err) => error!("failed to reply with {err}"),
383 }
384 } else {
385 error!("observable got unknown parameter: {p}");
386 }
387 }
388
389 Some(result) = rx.recv() => {
391 if is_running {
392 is_running = false;
393 execution_handle.lock().await.take();
394 let response = ObservableResponse::Finished(result.value().clone());
395 feedback_publisher.lock().await.take().map_or_else(
396 || error!("could not publish result"),
397 |p| {
398 match p.put(Message::encode(&response).value()).wait() {
399 Ok(()) => {},
400 Err(err) => error!("publishing result failed due to {err}"),
401 }
402 }
403 );
404 }
405 }
406
407 () = &mut feedback_timer, if is_running => {
409 let Ok(msg) = feedback_callback.lock().await(ctx).await else { todo!() };
410 let response =
411 ObservableResponse::Feedback(msg.value().clone());
412
413 let lock = feedback_publisher.lock().await;
414 let publisher = lock.as_ref().map_or_else(
415 || { todo!() },
416 |p| p
417 );
418 match publisher.put(Message::encode(&response).value().clone()).wait() {
419 Ok(()) => {},
420 Err(err) => error!("publishing feedback failed due to {err}"),
421 }
422
423 feedback_timer.set(tokio::time::sleep(feedback_interval));
425 }
426 }
427 }
428}
429#[cfg(test)]
432mod tests {
433 use super::*;
434
435 #[derive(Debug)]
436 struct Props {}
437
438 const fn is_normal<T: Sized + Send + Sync>() {}
440
441 #[test]
442 const fn normal_types() {
443 is_normal::<Observable<Props>>();
444 }
445}