#[doc(hidden)]
extern crate alloc;
#[cfg(feature = "std")]
extern crate std;
use crate::error::Error;
use alloc::sync::Arc;
use alloc::{boxed::Box, string::String, vec::Vec};
use dimas_core::{
Result,
enums::{OperationState, TaskSignal},
message_types::Message,
traits::{Capability, Context},
};
use futures::future::BoxFuture;
#[cfg(feature = "std")]
use tokio::{sync::Mutex, task::JoinHandle};
use tracing::{Level, error, info, instrument, warn};
use zenoh::Session;
#[cfg(feature = "unstable")]
use zenoh::sample::Locality;
use zenoh::sample::SampleKind;
pub type PutCallback<P> =
Box<dyn FnMut(Context<P>, Message) -> BoxFuture<'static, Result<()>> + Send + Sync>;
pub type ArcPutCallback<P> = Arc<Mutex<PutCallback<P>>>;
pub type DeleteCallback<P> =
Box<dyn FnMut(Context<P>) -> BoxFuture<'static, Result<()>> + Send + Sync>;
pub type ArcDeleteCallback<P> = Arc<Mutex<DeleteCallback<P>>>;
pub struct Subscriber<P>
where
P: Send + Sync + 'static,
{
session: Arc<Session>,
selector: String,
context: Context<P>,
activation_state: OperationState,
#[cfg(feature = "unstable")]
allowed_origin: Locality,
put_callback: ArcPutCallback<P>,
delete_callback: Option<ArcDeleteCallback<P>>,
handle: std::sync::Mutex<Option<JoinHandle<()>>>,
}
impl<P> core::fmt::Debug for Subscriber<P>
where
P: Send + Sync + 'static,
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Subscriber")
.field("selector", &self.selector)
.finish_non_exhaustive()
}
}
impl<P> crate::traits::Responder for Subscriber<P>
where
P: Send + Sync + 'static,
{
fn selector(&self) -> &str {
&self.selector
}
}
impl<P> Capability for Subscriber<P>
where
P: Send + Sync + 'static,
{
fn manage_operation_state(&self, state: &OperationState) -> Result<()> {
if state >= &self.activation_state {
self.start()
} else if state < &self.activation_state {
self.stop()
} else {
Ok(())
}
}
}
impl<P> Subscriber<P>
where
P: Send + Sync + 'static,
{
#[must_use]
pub fn new(
session: Arc<Session>,
selector: String,
context: Context<P>,
activation_state: OperationState,
#[cfg(feature = "unstable")] allowed_origin: Locality,
put_callback: ArcPutCallback<P>,
delete_callback: Option<ArcDeleteCallback<P>>,
) -> Self {
Self {
session,
selector,
context,
activation_state,
#[cfg(feature = "unstable")]
allowed_origin,
put_callback,
delete_callback,
handle: std::sync::Mutex::new(None),
}
}
#[instrument(level = Level::TRACE, skip_all)]
fn start(&self) -> Result<()> {
self.stop()?;
let selector = self.selector.clone();
let p_cb = self.put_callback.clone();
let d_cb = self.delete_callback.clone();
let ctx1 = self.context.clone();
let ctx2 = self.context.clone();
let session = self.session.clone();
#[cfg(feature = "unstable")]
let allowed_origin = self.allowed_origin;
self.handle.lock().map_or_else(
|_| todo!(),
|mut handle| {
handle.replace(tokio::task::spawn(async move {
let key = selector.clone();
std::panic::set_hook(Box::new(move |reason| {
error!("subscriber panic: {}", reason);
if let Err(reason) = ctx1
.sender()
.blocking_send(TaskSignal::RestartSubscriber(key.clone()))
{
error!("could not restart subscriber: {}", reason);
} else {
info!("restarting subscriber!");
}
}));
if let Err(error) = run_subscriber(
session,
selector,
#[cfg(feature = "unstable")]
allowed_origin,
p_cb,
d_cb,
ctx2.clone(),
)
.await
{
error!("spawning subscriber failed with {error}");
}
}));
Ok(())
},
)
}
#[instrument(level = Level::TRACE, skip_all)]
fn stop(&self) -> Result<()> {
self.handle.lock().map_or_else(
|_| todo!(),
|mut handle| {
handle.take();
Ok(())
},
)
}
}
#[instrument(name="subscriber", level = Level::ERROR, skip_all)]
async fn run_subscriber<P>(
session: Arc<Session>,
selector: String,
#[cfg(feature = "unstable")] allowed_origin: Locality,
p_cb: ArcPutCallback<P>,
d_cb: Option<ArcDeleteCallback<P>>,
ctx: Context<P>,
) -> Result<()>
where
P: Send + Sync + 'static,
{
let builder = session.declare_subscriber(&selector);
#[cfg(feature = "unstable")]
let builder = builder.allowed_origin(allowed_origin);
let subscriber = builder.await?;
loop {
let sample = subscriber
.recv_async()
.await
.map_err(|source| Error::SubscriberCreation { source })?;
match sample.kind() {
SampleKind::Put => {
let content: Vec<u8> = sample.payload().to_bytes().into_owned();
let msg = Message::new(content);
let mut lock = p_cb.lock().await;
let ctx = ctx.clone();
if let Err(error) = lock(ctx, msg).await {
error!("subscriber put callback failed with {error}");
}
}
SampleKind::Delete => {
if let Some(cb) = d_cb.clone() {
let ctx = ctx.clone();
let mut lock = cb.lock().await;
if let Err(error) = lock(ctx).await {
error!("subscriber delete callback failed with {error}");
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug)]
struct Props {}
const fn is_normal<T: Sized + Send + Sync>() {}
#[test]
const fn normal_types() {
is_normal::<Subscriber<Props>>();
}
}