dimas_com/zenoh/
subscriber.rs1#[doc(hidden)]
7extern crate alloc;
8
9#[cfg(feature = "std")]
10extern crate std;
11
12use crate::error::Error;
14use alloc::sync::Arc;
15use alloc::{boxed::Box, string::String, vec::Vec};
16use dimas_core::{
17 Result,
18 enums::{OperationState, TaskSignal},
19 message_types::Message,
20 traits::{Capability, Context},
21};
22use futures::future::BoxFuture;
23#[cfg(feature = "std")]
24use tokio::{sync::Mutex, task::JoinHandle};
25use tracing::{Level, error, info, instrument, warn};
26use zenoh::Session;
27#[cfg(feature = "unstable")]
28use zenoh::sample::Locality;
29use zenoh::sample::SampleKind;
30pub type PutCallback<P> =
35 Box<dyn FnMut(Context<P>, Message) -> BoxFuture<'static, Result<()>> + Send + Sync>;
36pub type ArcPutCallback<P> = Arc<Mutex<PutCallback<P>>>;
38pub type DeleteCallback<P> =
40 Box<dyn FnMut(Context<P>) -> BoxFuture<'static, Result<()>> + Send + Sync>;
41pub type ArcDeleteCallback<P> = Arc<Mutex<DeleteCallback<P>>>;
43pub struct Subscriber<P>
48where
49 P: Send + Sync + 'static,
50{
51 session: Arc<Session>,
53 selector: String,
55 context: Context<P>,
57 activation_state: OperationState,
59 #[cfg(feature = "unstable")]
60 allowed_origin: Locality,
61 put_callback: ArcPutCallback<P>,
62 delete_callback: Option<ArcDeleteCallback<P>>,
63 handle: std::sync::Mutex<Option<JoinHandle<()>>>,
64}
65
66impl<P> core::fmt::Debug for Subscriber<P>
67where
68 P: Send + Sync + 'static,
69{
70 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
71 f.debug_struct("Subscriber")
72 .field("selector", &self.selector)
73 .finish_non_exhaustive()
74 }
75}
76
77impl<P> crate::traits::Responder for Subscriber<P>
78where
79 P: Send + Sync + 'static,
80{
81 fn selector(&self) -> &str {
83 &self.selector
84 }
85}
86
87impl<P> Capability for Subscriber<P>
88where
89 P: Send + Sync + 'static,
90{
91 fn manage_operation_state(&self, state: &OperationState) -> Result<()> {
92 if state >= &self.activation_state {
93 self.start()
94 } else if state < &self.activation_state {
95 self.stop()
96 } else {
97 Ok(())
98 }
99 }
100}
101
102impl<P> Subscriber<P>
103where
104 P: Send + Sync + 'static,
105{
106 #[must_use]
108 pub fn new(
109 session: Arc<Session>,
110 selector: String,
111 context: Context<P>,
112 activation_state: OperationState,
113 #[cfg(feature = "unstable")] allowed_origin: Locality,
114 put_callback: ArcPutCallback<P>,
115 delete_callback: Option<ArcDeleteCallback<P>>,
116 ) -> Self {
117 Self {
118 session,
119 selector,
120 context,
121 activation_state,
122 #[cfg(feature = "unstable")]
123 allowed_origin,
124 put_callback,
125 delete_callback,
126 handle: std::sync::Mutex::new(None),
127 }
128 }
129 #[instrument(level = Level::TRACE, skip_all)]
132 fn start(&self) -> Result<()> {
133 self.stop()?;
134
135 let selector = self.selector.clone();
136 let p_cb = self.put_callback.clone();
137 let d_cb = self.delete_callback.clone();
138 let ctx1 = self.context.clone();
139 let ctx2 = self.context.clone();
140 let session = self.session.clone();
141 #[cfg(feature = "unstable")]
142 let allowed_origin = self.allowed_origin;
143
144 self.handle.lock().map_or_else(
145 |_| todo!(),
146 |mut handle| {
147 handle.replace(tokio::task::spawn(async move {
148 let key = selector.clone();
149 std::panic::set_hook(Box::new(move |reason| {
150 error!("subscriber panic: {}", reason);
151 if let Err(reason) = ctx1
152 .sender()
153 .blocking_send(TaskSignal::RestartSubscriber(key.clone()))
154 {
155 error!("could not restart subscriber: {}", reason);
156 } else {
157 info!("restarting subscriber!");
158 }
159 }));
160 if let Err(error) = run_subscriber(
161 session,
162 selector,
163 #[cfg(feature = "unstable")]
164 allowed_origin,
165 p_cb,
166 d_cb,
167 ctx2.clone(),
168 )
169 .await
170 {
171 error!("spawning subscriber failed with {error}");
172 }
173 }));
174 Ok(())
175 },
176 )
177 }
178
179 #[instrument(level = Level::TRACE, skip_all)]
181 fn stop(&self) -> Result<()> {
182 self.handle.lock().map_or_else(
183 |_| todo!(),
184 |mut handle| {
185 handle.take();
186 Ok(())
187 },
188 )
189 }
190}
191
192#[instrument(name="subscriber", level = Level::ERROR, skip_all)]
193async fn run_subscriber<P>(
194 session: Arc<Session>,
195 selector: String,
196 #[cfg(feature = "unstable")] allowed_origin: Locality,
197 p_cb: ArcPutCallback<P>,
198 d_cb: Option<ArcDeleteCallback<P>>,
199 ctx: Context<P>,
200) -> Result<()>
201where
202 P: Send + Sync + 'static,
203{
204 let builder = session.declare_subscriber(&selector);
205
206 #[cfg(feature = "unstable")]
207 let builder = builder.allowed_origin(allowed_origin);
208
209 let subscriber = builder.await?;
210
211 loop {
212 let sample = subscriber
213 .recv_async()
214 .await
215 .map_err(|source| Error::SubscriberCreation { source })?;
216
217 match sample.kind() {
218 SampleKind::Put => {
219 let content: Vec<u8> = sample.payload().to_bytes().into_owned();
220 let msg = Message::new(content);
221 let mut lock = p_cb.lock().await;
222 let ctx = ctx.clone();
223 if let Err(error) = lock(ctx, msg).await {
224 error!("subscriber put callback failed with {error}");
225 }
226 }
227 SampleKind::Delete => {
228 if let Some(cb) = d_cb.clone() {
229 let ctx = ctx.clone();
230 let mut lock = cb.lock().await;
231 if let Err(error) = lock(ctx).await {
232 error!("subscriber delete callback failed with {error}");
233 }
234 }
235 }
236 }
237 }
238}
239#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[derive(Debug)]
246 struct Props {}
247
248 const fn is_normal<T: Sized + Send + Sync>() {}
250
251 #[test]
252 const fn normal_types() {
253 is_normal::<Subscriber<Props>>();
254 }
255}