1extern crate alloc;
29
30use alloc::string::String;
31use alloc::sync::Arc;
32use alloc::vec::Vec;
33use core::marker::PhantomData;
34use core::time::Duration;
35use std::collections::HashMap;
36use std::sync::{Mutex, mpsc};
37
38use zerodds_dcps::dds_type::{DdsType, RawBytes};
39use zerodds_dcps::participant::DomainParticipant;
40use zerodds_dcps::publisher::DataWriter;
41use zerodds_dcps::qos::{PublisherQos, SubscriberQos, TopicQos};
42use zerodds_dcps::subscriber::DataReader;
43
44use crate::common_types::{RemoteExceptionCode, RequestHeader, SampleIdentity};
45use crate::error::{RpcError, RpcResult};
46use crate::qos_profile::RpcQos;
47use crate::topic_naming::ServiceTopicNames;
48use crate::wire_codec::{decode_reply_frame, encode_request_frame};
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
58pub(crate) enum InstanceRole {
59 Requester,
60 Replier,
61}
62
63type InstanceKey = (usize, InstanceRole, String, String);
65
66fn instance_registry() -> &'static Mutex<std::collections::HashSet<InstanceKey>> {
67 use std::sync::OnceLock;
68 static REGISTRY: OnceLock<Mutex<std::collections::HashSet<InstanceKey>>> = OnceLock::new();
69 REGISTRY.get_or_init(|| Mutex::new(std::collections::HashSet::new()))
70}
71
72fn participant_addr(p: &DomainParticipant) -> usize {
73 core::ptr::from_ref(p) as usize
76}
77
78pub(crate) fn try_claim_instance(
79 p: &DomainParticipant,
80 role: InstanceRole,
81 service_name: &str,
82 instance_name: &str,
83) -> RpcResult<InstanceClaim> {
84 if instance_name.is_empty() {
85 return Ok(InstanceClaim::anonymous());
88 }
89 let key: InstanceKey = (
90 participant_addr(p),
91 role,
92 service_name.into(),
93 instance_name.into(),
94 );
95 let mut reg = instance_registry()
96 .lock()
97 .map_err(|_| RpcError::Dcps("instance-registry poisoned".into()))?;
98 if !reg.insert(key.clone()) {
99 return Err(RpcError::DuplicateInstanceName(instance_name.into()));
100 }
101 Ok(InstanceClaim::owned(key))
102}
103
104#[derive(Debug)]
108pub(crate) struct InstanceClaim {
109 key: Option<InstanceKey>,
110}
111
112impl InstanceClaim {
113 fn anonymous() -> Self {
114 Self { key: None }
115 }
116 fn owned(key: InstanceKey) -> Self {
117 Self { key: Some(key) }
118 }
119}
120
121impl Drop for InstanceClaim {
122 fn drop(&mut self) {
123 if let Some(key) = self.key.take() {
124 if let Ok(mut reg) = instance_registry().lock() {
125 reg.remove(&key);
126 }
127 }
128 }
129}
130
131pub type ReplyOutcome = Result<Vec<u8>, RemoteExceptionCode>;
139
140struct PendingSlot {
142 sender: mpsc::Sender<ReplyOutcome>,
143}
144
145pub struct Requester<TIn: DdsType, TOut: DdsType> {
152 service_name: String,
153 instance_name: String,
154 request_writer: DataWriter<RawBytes>,
155 reply_reader: DataReader<RawBytes>,
156 writer_guid: [u8; 16],
162 next_seq: Mutex<u64>,
163 pending: Arc<Mutex<HashMap<SampleIdentity, PendingSlot>>>,
164 qos: RpcQos,
165 _claim: InstanceClaim,
166 _phantom: PhantomData<fn() -> (TIn, TOut)>,
167}
168
169impl<TIn: DdsType, TOut: DdsType> core::fmt::Debug for Requester<TIn, TOut> {
170 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
171 f.debug_struct("Requester")
172 .field("service", &self.service_name)
173 .field("instance", &self.instance_name)
174 .finish_non_exhaustive()
175 }
176}
177
178impl<TIn: DdsType + Send + 'static, TOut: DdsType + Send + 'static> Requester<TIn, TOut> {
179 pub fn new(
193 participant: &DomainParticipant,
194 service_name: &str,
195 qos: &RpcQos,
196 ) -> RpcResult<Self> {
197 Self::with_instance(participant, service_name, "", qos)
198 }
199
200 pub fn with_instance(
206 participant: &DomainParticipant,
207 service_name: &str,
208 instance_name: &str,
209 qos: &RpcQos,
210 ) -> RpcResult<Self> {
211 let topics = ServiceTopicNames::new(service_name)?;
212 let claim = try_claim_instance(
213 participant,
214 InstanceRole::Requester,
215 service_name,
216 instance_name,
217 )?;
218 let request_topic = participant
219 .create_topic::<RawBytes>(&topics.request, TopicQos::default())
220 .map_err(|e| RpcError::Dcps(alloc::format!("create_topic request: {e:?}")))?;
221 let reply_topic = participant
222 .create_topic::<RawBytes>(&topics.reply, TopicQos::default())
223 .map_err(|e| RpcError::Dcps(alloc::format!("create_topic reply: {e:?}")))?;
224 let publisher = participant.create_publisher(PublisherQos::default());
225 let subscriber = participant.create_subscriber(SubscriberQos::default());
226 let request_writer = publisher
227 .create_datawriter::<RawBytes>(&request_topic, qos.request_writer_qos())
228 .map_err(|e| RpcError::Dcps(alloc::format!("create_datawriter: {e:?}")))?;
229 let reply_reader = subscriber
230 .create_datareader::<RawBytes>(&reply_topic, qos.reply_reader_qos())
231 .map_err(|e| RpcError::Dcps(alloc::format!("create_datareader: {e:?}")))?;
232 let writer_guid = synthesize_writer_guid();
234 Ok(Self {
235 service_name: service_name.into(),
236 instance_name: instance_name.into(),
237 request_writer,
238 reply_reader,
239 writer_guid,
240 next_seq: Mutex::new(1),
241 pending: Arc::new(Mutex::new(HashMap::new())),
242 qos: qos.clone(),
243 _claim: claim,
244 _phantom: PhantomData,
245 })
246 }
247
248 #[must_use]
250 pub fn service_name(&self) -> &str {
251 &self.service_name
252 }
253
254 #[must_use]
256 pub fn instance_name(&self) -> &str {
257 &self.instance_name
258 }
259
260 #[must_use]
262 pub fn pending_count(&self) -> usize {
263 self.pending.lock().map(|m| m.len()).unwrap_or(0)
264 }
265
266 #[must_use]
268 pub fn default_timeout(&self) -> Duration {
269 self.qos.request_timeout
270 }
271
272 pub fn send_request_async(
280 &self,
281 payload: &TIn,
282 ) -> RpcResult<(SampleIdentity, mpsc::Receiver<ReplyOutcome>)> {
283 let id = self.next_request_id()?;
284 let header = RequestHeader::new(id, self.instance_name.clone());
285 let mut user_buf = Vec::new();
286 payload
287 .encode(&mut user_buf)
288 .map_err(|e| RpcError::Dcps(alloc::format!("encode TIn: {e}")))?;
289 let frame = encode_request_frame(&header, &user_buf);
290 let (tx, rx) = mpsc::channel();
291 {
295 let mut pend = self
296 .pending
297 .lock()
298 .map_err(|_| RpcError::Dcps("pending-table poisoned".into()))?;
299 pend.insert(id, PendingSlot { sender: tx });
300 }
301 if let Err(e) = self.request_writer.write(&RawBytes::new(frame)) {
302 if let Ok(mut pend) = self.pending.lock() {
304 pend.remove(&id);
305 }
306 return Err(RpcError::Dcps(alloc::format!("write request: {e:?}")));
307 }
308 Ok((id, rx))
309 }
310
311 pub fn send_oneway(&self, payload: &TIn) -> RpcResult<SampleIdentity> {
317 let id = self.next_request_id()?;
318 let header = RequestHeader::new(id, self.instance_name.clone());
319 let mut user_buf = Vec::new();
320 payload
321 .encode(&mut user_buf)
322 .map_err(|e| RpcError::Dcps(alloc::format!("encode TIn: {e}")))?;
323 let frame = encode_request_frame(&header, &user_buf);
324 self.request_writer
325 .write(&RawBytes::new(frame))
326 .map_err(|e| RpcError::Dcps(alloc::format!("write oneway: {e:?}")))?;
327 Ok(id)
328 }
329
330 pub fn send_request_blocking(
341 &self,
342 payload: &TIn,
343 timeout: Option<Duration>,
344 ) -> RpcResult<TOut> {
345 let timeout = timeout.unwrap_or(self.qos.request_timeout);
346 let (_id, rx) = self.send_request_async(payload)?;
347 let deadline = std::time::Instant::now() + timeout;
348 let poll = Duration::from_millis(2);
349 loop {
350 self.tick();
352 match rx.try_recv() {
353 Ok(Ok(bytes)) => {
354 let out = TOut::decode(&bytes)
355 .map_err(|e| RpcError::Dcps(alloc::format!("decode TOut: {e}")))?;
356 return Ok(out);
357 }
358 Ok(Err(code)) => return Err(RpcError::RemoteException(code.as_u32())),
359 Err(mpsc::TryRecvError::Empty) => {}
360 Err(mpsc::TryRecvError::Disconnected) => {
361 return Err(RpcError::Dcps("reply channel disconnected".into()));
362 }
363 }
364 if std::time::Instant::now() >= deadline {
365 return Err(RpcError::Timeout);
366 }
367 std::thread::sleep(poll);
368 }
369 }
370
371 pub fn tick(&self) {
375 let samples = match self.reply_reader.take() {
376 Ok(s) => s,
377 Err(_) => return,
378 };
379 if samples.is_empty() {
380 return;
381 }
382 let mut pend = match self.pending.lock() {
383 Ok(p) => p,
384 Err(_) => return,
385 };
386 for raw in samples {
387 let bytes = raw.data;
388 let (header, payload) = match decode_reply_frame(&bytes) {
389 Ok(t) => t,
390 Err(_) => continue, };
392 let Some(slot) = pend.remove(&header.related_request_id) else {
393 continue;
396 };
397 let payload_owned = payload.to_vec();
398 let result = if header.remote_ex == RemoteExceptionCode::Ok {
399 Ok(payload_owned)
400 } else {
401 Err(header.remote_ex)
402 };
403 let _ = slot.sender.send(result);
405 }
406 }
407
408 fn next_request_id(&self) -> RpcResult<SampleIdentity> {
409 let mut g = self
410 .next_seq
411 .lock()
412 .map_err(|_| RpcError::Dcps("seq counter poisoned".into()))?;
413 let sn = *g;
414 *g = sn.checked_add(1).ok_or_else(|| {
415 RpcError::Dcps("rpc sequence-number wrapped — ran out of u64 space".into())
416 })?;
417 Ok(SampleIdentity::new(self.writer_guid, sn))
418 }
419
420 #[doc(hidden)]
424 #[must_use]
425 pub fn __drain_request_writer(&self) -> Vec<Vec<u8>> {
426 self.request_writer.__drain_pending()
427 }
428
429 #[doc(hidden)]
432 pub fn __push_reply_raw(&self, bytes: Vec<u8>) -> RpcResult<()> {
433 self.reply_reader
434 .__push_raw(bytes)
435 .map_err(|e| RpcError::Dcps(alloc::format!("push raw: {e:?}")))
436 }
437
438 #[doc(hidden)]
441 #[must_use]
442 pub fn __writer_guid(&self) -> [u8; 16] {
443 self.writer_guid
444 }
445}
446
447fn synthesize_writer_guid() -> [u8; 16] {
452 use std::sync::atomic::{AtomicU64, Ordering};
453 static SALT: std::sync::OnceLock<[u8; 8]> = std::sync::OnceLock::new();
458 static CTR: AtomicU64 = AtomicU64::new(1);
459 let salt = *SALT.get_or_init(|| {
460 let probe: alloc::boxed::Box<u8> = alloc::boxed::Box::new(0u8);
464 let addr = (&*probe as *const u8) as u64;
465 drop(probe);
466 let now = std::time::SystemTime::now()
467 .duration_since(std::time::UNIX_EPOCH)
468 .map(|d| d.as_nanos() as u64)
469 .unwrap_or(0xCAFE_BABE_DEAD_BEEF);
470 let pid = std::process::id() as u64;
471 let mix = addr ^ now ^ pid ^ 0xA5A5_A5A5_A5A5_A5A5;
472 mix.to_le_bytes()
473 });
474 let counter = CTR.fetch_add(1, Ordering::Relaxed);
475 let mut out = [0u8; 16];
476 out[..8].copy_from_slice(&salt);
477 out[8..].copy_from_slice(&counter.to_le_bytes());
478 out
479}
480
481#[cfg(test)]
482#[allow(clippy::unwrap_used, clippy::expect_used)]
483mod tests {
484 use super::*;
485 use crate::common_types::ReplyHeader;
486 use zerodds_dcps::factory::DomainParticipantFactory;
487 use zerodds_dcps::qos::DomainParticipantQos;
488
489 fn participant(domain: i32) -> DomainParticipant {
490 DomainParticipantFactory::instance()
491 .create_participant_offline(domain, DomainParticipantQos::default())
492 }
493
494 #[test]
495 fn synthesize_writer_guid_is_unique_per_call() {
496 let a = synthesize_writer_guid();
497 let b = synthesize_writer_guid();
498 assert_ne!(a, b);
499 assert_eq!(&a[..8], &b[..8]);
502 assert_ne!(&a[8..], &b[8..]);
503 }
504
505 #[test]
506 fn requester_new_creates_topics_and_writer() {
507 let p = participant(101);
508 let q = RpcQos::default_basic();
509 let r = Requester::<RawBytes, RawBytes>::new(&p, "Calc", &q).unwrap();
510 assert_eq!(r.service_name(), "Calc");
511 assert_eq!(r.instance_name(), "");
512 assert_eq!(r.pending_count(), 0);
513 }
514
515 #[test]
516 fn requester_invalid_service_name_rejected() {
517 let p = participant(102);
518 let q = RpcQos::default_basic();
519 let err = Requester::<RawBytes, RawBytes>::new(&p, "", &q).unwrap_err();
520 assert!(matches!(err, RpcError::InvalidServiceName(_)));
521 }
522
523 #[test]
524 fn requester_uses_qos_default_timeout() {
525 let p = participant(103);
526 let mut q = RpcQos::default_basic();
527 q.request_timeout = Duration::from_millis(7);
528 let r = Requester::<RawBytes, RawBytes>::new(&p, "Calc", &q).unwrap();
529 assert_eq!(r.default_timeout(), Duration::from_millis(7));
530 }
531
532 #[test]
533 fn send_request_async_assigns_unique_sample_ids() {
534 let p = participant(104);
535 let q = RpcQos::default_basic();
536 let r = Requester::<RawBytes, RawBytes>::new(&p, "Calc", &q).unwrap();
537 let payload = RawBytes::new(alloc::vec![1, 2, 3]);
538 let (id1, _rx1) = r.send_request_async(&payload).unwrap();
539 let (id2, _rx2) = r.send_request_async(&payload).unwrap();
540 assert_ne!(id1.sequence_number, id2.sequence_number);
541 assert_eq!(id1.writer_guid, id2.writer_guid);
542 assert_eq!(r.pending_count(), 2);
543 }
544
545 #[test]
546 fn send_request_async_increments_seq_monotonically() {
547 let p = participant(105);
548 let q = RpcQos::default_basic();
549 let r = Requester::<RawBytes, RawBytes>::new(&p, "Calc", &q).unwrap();
550 let payload = RawBytes::new(alloc::vec![]);
551 let (id1, _rx1) = r.send_request_async(&payload).unwrap();
552 let (id2, _rx2) = r.send_request_async(&payload).unwrap();
553 let (id3, _rx3) = r.send_request_async(&payload).unwrap();
554 assert_eq!(id1.sequence_number + 1, id2.sequence_number);
555 assert_eq!(id2.sequence_number + 1, id3.sequence_number);
556 }
557
558 #[test]
559 fn send_oneway_does_not_register_pending_slot() {
560 let p = participant(106);
561 let q = RpcQos::default_basic();
562 let r = Requester::<RawBytes, RawBytes>::new(&p, "Calc", &q).unwrap();
563 let payload = RawBytes::new(alloc::vec![9]);
564 let id = r.send_oneway(&payload).unwrap();
565 assert!(id.sequence_number > 0);
566 assert_eq!(r.pending_count(), 0);
567 }
568
569 #[test]
570 fn send_request_blocking_times_out_when_no_reply() {
571 let p = participant(107);
572 let mut q = RpcQos::default_basic();
573 q.request_timeout = Duration::from_millis(20);
574 let r = Requester::<RawBytes, RawBytes>::new(&p, "Calc", &q).unwrap();
575 let err = r
576 .send_request_blocking(&RawBytes::new(alloc::vec![1]), None)
577 .unwrap_err();
578 assert!(matches!(err, RpcError::Timeout));
579 }
580
581 #[test]
582 fn duplicate_instance_name_rejected_on_same_participant() {
583 let p = participant(108);
584 let q = RpcQos::default_basic();
585 let _r1 = Requester::<RawBytes, RawBytes>::with_instance(&p, "Calc", "calc-A", &q).unwrap();
586 let err =
587 Requester::<RawBytes, RawBytes>::with_instance(&p, "Calc", "calc-A", &q).unwrap_err();
588 assert!(matches!(err, RpcError::DuplicateInstanceName(ref n) if n == "calc-A"));
589 }
590
591 #[test]
592 fn duplicate_instance_name_freed_after_drop() {
593 let p = participant(109);
594 let q = RpcQos::default_basic();
595 {
596 let _r1 =
597 Requester::<RawBytes, RawBytes>::with_instance(&p, "Calc", "calc-X", &q).unwrap();
598 }
599 let _r2 = Requester::<RawBytes, RawBytes>::with_instance(&p, "Calc", "calc-X", &q).unwrap();
601 }
602
603 #[test]
604 fn anonymous_instance_name_allows_multiple_requesters() {
605 let p = participant(110);
606 let q = RpcQos::default_basic();
607 let _r1 = Requester::<RawBytes, RawBytes>::new(&p, "Calc", &q).unwrap();
608 let _r2 = Requester::<RawBytes, RawBytes>::new(&p, "Calc", &q).unwrap();
609 }
610
611 #[test]
612 fn tick_correlates_reply_with_pending_slot() {
613 let p = participant(111);
614 let q = RpcQos::default_basic();
615 let r = Requester::<RawBytes, RawBytes>::new(&p, "Calc", &q).unwrap();
616 let (id, rx) = r
617 .send_request_async(&RawBytes::new(alloc::vec![1]))
618 .unwrap();
619 let reply_header = ReplyHeader::new(id, RemoteExceptionCode::Ok);
621 let frame = crate::wire_codec::encode_reply_frame(&reply_header, &[7u8, 8, 9]);
622 r.__push_reply_raw(frame).unwrap();
623 r.tick();
624 let result = rx.try_recv().expect("reply expected after tick");
625 let bytes = result.expect("ok reply");
626 assert_eq!(bytes, alloc::vec![7u8, 8, 9]);
627 assert_eq!(r.pending_count(), 0);
628 }
629
630 #[test]
631 fn tick_drops_reply_without_matching_request() {
632 let p = participant(112);
633 let q = RpcQos::default_basic();
634 let r = Requester::<RawBytes, RawBytes>::new(&p, "Calc", &q).unwrap();
635 let bogus = SampleIdentity::new([0xFF; 16], 999);
636 let frame = crate::wire_codec::encode_reply_frame(
637 &ReplyHeader::new(bogus, RemoteExceptionCode::Ok),
638 &[],
639 );
640 r.__push_reply_raw(frame).unwrap();
641 r.tick();
642 assert_eq!(r.pending_count(), 0);
644 }
645
646 #[test]
647 fn tick_propagates_remote_exception_to_caller() {
648 let p = participant(113);
649 let q = RpcQos::default_basic();
650 let r = Requester::<RawBytes, RawBytes>::new(&p, "Calc", &q).unwrap();
651 let (id, rx) = r.send_request_async(&RawBytes::new(alloc::vec![])).unwrap();
652 let frame = crate::wire_codec::encode_reply_frame(
653 &ReplyHeader::new(id, RemoteExceptionCode::InvalidArgument),
654 &[],
655 );
656 r.__push_reply_raw(frame).unwrap();
657 r.tick();
658 let res = rx.try_recv().expect("reply expected");
659 assert_eq!(res, Err(RemoteExceptionCode::InvalidArgument));
660 }
661
662 #[test]
663 fn tick_handles_malformed_reply_silently() {
664 let p = participant(114);
665 let q = RpcQos::default_basic();
666 let r = Requester::<RawBytes, RawBytes>::new(&p, "Calc", &q).unwrap();
667 let (_id, _rx) = r.send_request_async(&RawBytes::new(alloc::vec![])).unwrap();
668 r.__push_reply_raw(alloc::vec![0u8; 4]).unwrap(); r.tick();
670 assert_eq!(r.pending_count(), 1);
672 }
673
674 #[test]
675 fn drain_request_writer_yields_encoded_frames() {
676 let p = participant(115);
677 let q = RpcQos::default_basic();
678 let r = Requester::<RawBytes, RawBytes>::new(&p, "Calc", &q).unwrap();
679 let _ = r
680 .send_oneway(&RawBytes::new(alloc::vec![0xDE, 0xAD]))
681 .unwrap();
682 let frames = r.__drain_request_writer();
683 assert_eq!(frames.len(), 1);
684 let (header, payload) = crate::wire_codec::decode_request_frame(&frames[0]).unwrap();
685 assert_eq!(payload, &[0xDE, 0xAD]);
686 assert_eq!(header.instance_name, "");
688 }
689}