use std::{
any::Any,
ffi::{CStr, CString},
sync::{Arc, Mutex, MutexGuard},
};
use rosidl_runtime_rs::{Message, RmwMessage};
use crate::{
error::ToResult, log_error, qos::QoSProfile, rcl_bindings::*, IntoPrimitiveOptions, Node,
NodeHandle, RclPrimitive, RclPrimitiveHandle, RclPrimitiveKind, RclrsError, ReadyKind,
Waitable, WaitableLifecycle, WorkScope, Worker, WorkerCommands, ENTITY_LIFECYCLE_MUTEX,
};
mod any_subscription_callback;
pub use any_subscription_callback::*;
mod node_subscription_callback;
pub use node_subscription_callback::*;
mod into_async_subscription_callback;
pub use into_async_subscription_callback::*;
mod into_node_subscription_callback;
pub use into_node_subscription_callback::*;
mod into_worker_subscription_callback;
pub use into_worker_subscription_callback::*;
mod message_info;
pub use message_info::*;
mod readonly_loaned_message;
pub use readonly_loaned_message::*;
mod worker_subscription_callback;
pub use worker_subscription_callback::*;
pub type Subscription<T> = Arc<SubscriptionState<T, Node>>;
pub type WorkerSubscription<T, Payload> = Arc<SubscriptionState<T, Worker<Payload>>>;
pub struct SubscriptionState<T, Scope>
where
T: Message,
Scope: WorkScope,
{
handle: Arc<SubscriptionHandle>,
callback: Arc<Mutex<AnySubscriptionCallback<T, Scope::Payload>>>,
#[allow(unused)]
lifecycle: WaitableLifecycle,
}
impl<T, Scope> SubscriptionState<T, Scope>
where
T: Message,
Scope: WorkScope,
{
pub fn topic_name(&self) -> String {
self.handle.topic_name()
}
pub fn qos(&self) -> QoSProfile {
let options = unsafe {
let handle = self.handle.lock();
let options = rcl_subscription_get_options(&*handle);
if options.is_null() {
None
} else {
Some((&(*options).qos).into())
}
};
if options.is_none() {
log_error!("Subscroption.qos", "Options returned null");
}
options.unwrap_or_default()
}
pub(crate) fn create<'a>(
options: impl Into<SubscriptionOptions<'a>>,
callback: AnySubscriptionCallback<T, Scope::Payload>,
node_handle: &Arc<NodeHandle>,
commands: &Arc<WorkerCommands>,
) -> Result<Arc<Self>, RclrsError> {
let SubscriptionOptions { topic, qos } = options.into();
let callback = Arc::new(Mutex::new(callback));
let mut rcl_subscription = unsafe { rcl_get_zero_initialized_subscription() };
let type_support =
<T as Message>::RmwMsg::get_type_support() as *const rosidl_message_type_support_t;
let topic_c_string = CString::new(topic).map_err(|err| RclrsError::StringContainsNul {
err,
s: topic.into(),
})?;
let mut rcl_subscription_options = unsafe { rcl_subscription_get_default_options() };
rcl_subscription_options.qos = qos.into();
{
let rcl_node = node_handle.rcl_node.lock().unwrap();
let _lifecycle_lock = ENTITY_LIFECYCLE_MUTEX.lock().unwrap();
unsafe {
rcl_subscription_init(
&mut rcl_subscription,
&*rcl_node,
type_support,
topic_c_string.as_ptr(),
&rcl_subscription_options,
)
.ok()?;
}
}
let handle = Arc::new(SubscriptionHandle {
rcl_subscription: Mutex::new(rcl_subscription),
node_handle: Arc::clone(node_handle),
});
let (waitable, lifecycle) = Waitable::new(
Box::new(SubscriptionExecutable {
handle: Arc::clone(&handle),
callback: Arc::clone(&callback),
commands: Arc::clone(commands),
}),
Some(Arc::clone(commands.get_guard_condition())),
);
commands.add_to_wait_set(waitable);
Ok(Arc::new(Self {
handle,
callback,
lifecycle,
}))
}
}
impl<T: Message> SubscriptionState<T, Node> {
pub fn set_callback<Args>(&self, callback: impl IntoNodeSubscriptionCallback<T, Args>) {
let callback = callback.into_node_subscription_callback();
*self.callback.lock().unwrap() = callback;
}
pub fn set_async_callback<Args>(&self, callback: impl IntoAsyncSubscriptionCallback<T, Args>) {
let callback = callback.into_async_subscription_callback();
*self.callback.lock().unwrap() = callback;
}
}
impl<T: Message, Payload: 'static + Send + Sync> SubscriptionState<T, Worker<Payload>> {
pub fn set_worker_callback<Args>(
&self,
callback: impl IntoWorkerSubscriptionCallback<T, Payload, Args>,
) {
let callback = callback.into_worker_subscription_callback();
*self.callback.lock().unwrap() = callback;
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct SubscriptionOptions<'a> {
pub topic: &'a str,
pub qos: QoSProfile,
}
impl<'a> SubscriptionOptions<'a> {
pub fn new(topic: &'a str) -> Self {
Self {
topic,
qos: QoSProfile::topics_default(),
}
}
}
impl<'a, T: IntoPrimitiveOptions<'a>> From<T> for SubscriptionOptions<'a> {
fn from(value: T) -> Self {
let primitive = value.into_primitive_options();
let mut options = Self::new(primitive.name);
primitive.apply_to(&mut options.qos);
options
}
}
struct SubscriptionExecutable<T: Message, Payload> {
handle: Arc<SubscriptionHandle>,
callback: Arc<Mutex<AnySubscriptionCallback<T, Payload>>>,
commands: Arc<WorkerCommands>,
}
impl<T, Payload: 'static> RclPrimitive for SubscriptionExecutable<T, Payload>
where
T: Message,
{
unsafe fn execute(
&mut self,
ready: ReadyKind,
payload: &mut dyn Any,
) -> Result<(), RclrsError> {
ready.for_basic()?;
self.callback
.lock()
.unwrap()
.execute(&self.handle, payload, &self.commands)
}
fn kind(&self) -> RclPrimitiveKind {
RclPrimitiveKind::Subscription
}
fn handle(&self) -> RclPrimitiveHandle<'_> {
RclPrimitiveHandle::Subscription(self.handle.lock())
}
}
unsafe impl Send for rcl_subscription_t {}
pub(crate) struct SubscriptionHandle {
pub(crate) rcl_subscription: Mutex<rcl_subscription_t>,
pub(crate) node_handle: Arc<NodeHandle>,
}
impl SubscriptionHandle {
pub(crate) fn lock(&self) -> MutexGuard<'_, rcl_subscription_t> {
self.rcl_subscription.lock().unwrap()
}
pub(crate) fn topic_name(&self) -> String {
unsafe {
let raw_topic_pointer = rcl_subscription_get_topic_name(&*self.lock());
CStr::from_ptr(raw_topic_pointer)
}
.to_string_lossy()
.into_owned()
}
fn take<T: Message>(&self) -> Result<(T, MessageInfo), RclrsError> {
let mut rmw_message = <T as Message>::RmwMsg::default();
let message_info = Self::take_inner::<T>(self, &mut rmw_message)?;
Ok((T::from_rmw_message(rmw_message), message_info))
}
fn take_boxed<T: Message>(&self) -> Result<(Box<T>, MessageInfo), RclrsError> {
let mut rmw_message = Box::<<T as Message>::RmwMsg>::default();
let message_info = Self::take_inner::<T>(self, &mut *rmw_message)?;
let message = Box::new(T::from_rmw_message(*rmw_message));
Ok((message, message_info))
}
fn take_inner<T: Message>(
&self,
rmw_message: &mut <T as Message>::RmwMsg,
) -> Result<MessageInfo, RclrsError> {
let mut message_info = unsafe { rmw_get_zero_initialized_message_info() };
let rcl_subscription = &mut *self.lock();
unsafe {
rcl_take(
rcl_subscription,
rmw_message as *mut <T as Message>::RmwMsg as *mut _,
&mut message_info,
std::ptr::null_mut(),
)
.ok()?
};
Ok(MessageInfo::from_rmw_message_info(&message_info))
}
fn take_loaned<T: Message>(
self: &Arc<Self>,
) -> Result<(ReadOnlyLoanedMessage<T>, MessageInfo), RclrsError> {
let mut msg_ptr = std::ptr::null_mut();
let mut message_info = unsafe { rmw_get_zero_initialized_message_info() };
unsafe {
rcl_take_loaned_message(
&*self.lock(),
&mut msg_ptr,
&mut message_info,
std::ptr::null_mut(),
)
.ok()?;
}
let read_only_loaned_msg = ReadOnlyLoanedMessage {
msg_ptr: msg_ptr as *const T::RmwMsg,
handle: Arc::clone(self),
};
Ok((
read_only_loaned_msg,
MessageInfo::from_rmw_message_info(&message_info),
))
}
}
impl Drop for SubscriptionHandle {
fn drop(&mut self) {
let rcl_subscription = self.rcl_subscription.get_mut().unwrap();
let mut rcl_node = self.node_handle.rcl_node.lock().unwrap();
let _lifecycle_lock = ENTITY_LIFECYCLE_MUTEX.lock().unwrap();
unsafe {
rcl_subscription_fini(rcl_subscription, &mut *rcl_node);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{test_helpers::*, vendor::test_msgs::msg};
#[test]
fn traits() {
assert_send::<SubscriptionState<msg::BoundedSequences, Node>>();
assert_sync::<SubscriptionState<msg::BoundedSequences, Node>>();
}
#[test]
fn test_subscriptions() -> Result<(), RclrsError> {
use crate::TopicEndpointInfo;
let namespace = "/test_subscriptions_graph";
let graph = construct_test_graph(namespace)?;
let node_2_empty_subscription = graph
.node2
.create_subscription::<msg::Empty, _>("graph_test_topic_1", |_msg: msg::Empty| {})?;
let topic1 = node_2_empty_subscription.topic_name();
let node_2_basic_types_subscription =
graph.node2.create_subscription::<msg::BasicTypes, _>(
"graph_test_topic_2",
|_msg: msg::BasicTypes| {},
)?;
let topic2 = node_2_basic_types_subscription.topic_name();
let node_1_defaults_subscription = graph.node1.create_subscription::<msg::Defaults, _>(
"graph_test_topic_3",
|_msg: msg::Defaults| {},
)?;
let topic3 = node_1_defaults_subscription.topic_name();
std::thread::sleep(std::time::Duration::from_millis(100));
assert_eq!(graph.node2.count_subscriptions(&topic1)?, 1);
assert_eq!(graph.node2.count_subscriptions(&topic2)?, 1);
let node_1_subscription_names_and_types = graph
.node1
.get_subscription_names_and_types_by_node(&graph.node1.name(), namespace)?;
let types = node_1_subscription_names_and_types.get(&topic3).unwrap();
assert!(types.contains(&"test_msgs/msg/Defaults".to_string()));
let node_2_subscription_names_and_types = graph
.node2
.get_subscription_names_and_types_by_node(&graph.node2.name(), namespace)?;
let types = node_2_subscription_names_and_types.get(&topic1).unwrap();
assert!(types.contains(&"test_msgs/msg/Empty".to_string()));
let types = node_2_subscription_names_and_types.get(&topic2).unwrap();
assert!(types.contains(&"test_msgs/msg/BasicTypes".to_string()));
let expected_subscriptions_info = vec![TopicEndpointInfo {
node_name: String::from("graph_test_node_2"),
node_namespace: String::from(namespace),
topic_type: String::from("test_msgs/msg/Empty"),
}];
assert_eq!(
graph.node1.get_subscriptions_info_by_topic(&topic1)?,
expected_subscriptions_info
);
assert_eq!(
graph.node2.get_subscriptions_info_by_topic(&topic1)?,
expected_subscriptions_info
);
Ok(())
}
#[test]
fn test_node_subscription_raii() {
use crate::*;
use std::sync::atomic::{AtomicBool, Ordering};
let mut executor = Context::default().create_basic_executor();
let triggered = Arc::new(AtomicBool::new(false));
let inner_triggered = Arc::clone(&triggered);
let callback = move |_: msg::Empty| {
inner_triggered.store(true, Ordering::Release);
};
let (_subscription, publisher) = {
let node = executor
.create_node(&format!("test_node_subscription_raii_{}", line!()))
.unwrap();
let qos = QoSProfile::default().keep_all().reliable();
let subscription = node
.create_subscription::<msg::Empty, _>("test_topic".qos(qos), callback)
.unwrap();
let publisher = node
.create_publisher::<msg::Empty>("test_topic".qos(qos))
.unwrap();
(subscription, publisher)
};
publisher.publish(msg::Empty::default()).unwrap();
let start_time = std::time::Instant::now();
while !triggered.load(Ordering::Acquire) {
assert!(executor.spin(SpinOptions::spin_once()).is_empty());
assert!(start_time.elapsed() < std::time::Duration::from_secs(10));
}
}
#[test]
fn test_delayed_subscription() {
use crate::{vendor::example_interfaces::msg::Empty, *};
use futures::{
channel::{mpsc, oneshot},
StreamExt,
};
use std::sync::atomic::{AtomicBool, Ordering};
let mut executor = Context::default().create_basic_executor();
let node = executor
.create_node(
format!("test_delayed_subscription_{}", line!())
.start_parameter_services(false),
)
.unwrap();
let (promise, receiver) = oneshot::channel();
let promise = Arc::new(Mutex::new(Some(promise)));
let success = Arc::new(AtomicBool::new(false));
let send_success = Arc::clone(&success);
let publisher = node.create_publisher("test_delayed_subscription").unwrap();
let commands = Arc::clone(executor.commands());
std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(10));
let _ = commands.run(async move {
let (sender, mut receiver) = mpsc::unbounded();
let _subscription = node
.create_subscription("test_delayed_subscription", move |_: Empty| {
let _ = sender.unbounded_send(());
})
.unwrap();
let _ = publisher.notify_on_subscriber_ready().await;
publisher.publish(Empty::default()).unwrap();
if let Some(_) = receiver.next().await {
send_success.store(true, Ordering::Release);
if let Some(promise) = promise.lock().unwrap().take() {
promise.send(()).unwrap();
}
}
});
});
let r = executor.spin(
SpinOptions::default()
.until_promise_resolved(receiver)
.timeout(std::time::Duration::from_secs(10)),
);
assert!(r.is_empty(), "{r:?}");
let message_was_received = success.load(Ordering::Acquire);
assert!(message_was_received);
}
#[test]
fn test_subscription_qos_settings() {
use crate::vendor::example_interfaces::msg::Empty;
use crate::*;
let executor = Context::default().create_basic_executor();
let node = executor
.create_node(&format!("test_subscription_qos_settings_{}", line!()))
.unwrap();
let subscription = node
.create_subscription("test_subscription_qos_topic".best_effort(), |_: Empty| {
})
.unwrap();
let qos = subscription.qos();
assert_eq!(qos.reliability, QoSReliabilityPolicy::BestEffort);
let expected_qos = QoSProfile::topics_default().best_effort();
assert_eq!(expected_qos.reliability, QoSReliabilityPolicy::BestEffort);
let subscription = node
.create_subscription(
"test_subscription_qos_topic_2".qos(expected_qos),
|_: Empty| {
},
)
.unwrap();
let qos = subscription.qos();
assert_eq!(expected_qos.reliability, qos.reliability);
assert_eq!(qos.reliability, QoSReliabilityPolicy::BestEffort);
let subscription = node
.create_subscription(
SubscriptionOptions {
topic: "test_subscription_qos_topic_3",
qos: expected_qos,
},
|_: Empty| {
},
)
.unwrap();
let qos = subscription.qos();
assert_eq!(expected_qos.reliability, qos.reliability);
assert_eq!(qos.reliability, QoSReliabilityPolicy::BestEffort);
}
#[test]
fn test_setting_qos_from_parameters() {
use crate::vendor::example_interfaces::msg::Empty;
use crate::*;
let args = ["--ros-args", "-p", "qos_reliability:=best_effort"].map(ToString::to_string);
let context = Context::new(args, InitOptions::default()).unwrap();
let executor = context.create_basic_executor();
let node = executor
.create_node(&format!("test_setting_qos_from_parameters_{}", line!()))
.unwrap();
let qos_reliability_str = node
.declare_parameter::<Arc<str>>("qos_reliability")
.default("best_effort".into())
.mandatory()
.unwrap()
.get();
let mut expected_qos = QOS_PROFILE_DEFAULT;
expected_qos.reliability = match &*qos_reliability_str {
"reliable" => QoSReliabilityPolicy::Reliable,
"best_effort" => QoSReliabilityPolicy::BestEffort,
#[cfg(not(ros_distro = "humble"))]
"best_available" => QoSReliabilityPolicy::BestAvailable,
x => panic!("unknown reliability string: {x}"),
};
assert_eq!(expected_qos.reliability, QoSReliabilityPolicy::BestEffort);
let subscription = node
.create_subscription(
"test_setting_qos_from_parameters_topic".qos(expected_qos),
|_: Empty| {
},
)
.unwrap();
let qos = subscription.qos();
assert_eq!(expected_qos.reliability, qos.reliability);
assert_eq!(qos.reliability, QoSReliabilityPolicy::BestEffort);
}
}