extern crate alloc;
use alloc::vec::Vec;
use core::time::Duration;
use zerodds_rtps::error::WireError;
use zerodds_rtps::fragment_assembler::AssemblerCaps;
use zerodds_rtps::history_cache::HistoryKind;
use zerodds_rtps::message_builder::{DEFAULT_MTU, OutboundDatagram};
use zerodds_rtps::reader_proxy::ReaderProxy;
use zerodds_rtps::reliable_reader::{
DEFAULT_HEARTBEAT_RESPONSE_DELAY, ReliableReader, ReliableReaderConfig,
};
use zerodds_rtps::reliable_writer::{DEFAULT_FRAGMENT_SIZE, ReliableWriter, ReliableWriterConfig};
use zerodds_rtps::submessages::{
DataFragSubmessage, DataSubmessage, GapSubmessage, HeartbeatSubmessage, NackFragSubmessage,
};
use zerodds_rtps::wire_types::{EntityId, Guid, GuidPrefix, SequenceNumber, VendorId};
use zerodds_rtps::writer_proxy::WriterProxy;
use zerodds_security::error::{SecurityError, SecurityErrorKind, SecurityResult};
use zerodds_security::generic_message::ParticipantGenericMessage;
use crate::security::codec::{decode_generic_message, encode_generic_message};
pub const VOLATILE_SECURE_DEFAULT_DEPTH: usize = 16;
pub const VOLATILE_SECURE_HEARTBEAT_PERIOD: Duration = Duration::from_millis(250);
pub const VOLATILE_SECURE_READER_CAPACITY: usize = 64;
#[derive(Debug)]
pub struct VolatileSecureMessageWriter {
inner: ReliableWriter,
}
impl VolatileSecureMessageWriter {
#[must_use]
pub fn new(participant_prefix: GuidPrefix, vendor_id: VendorId) -> Self {
let guid = Guid::new(
participant_prefix,
EntityId::BUILTIN_PARTICIPANT_VOLATILE_MESSAGE_SECURE_WRITER,
);
Self {
inner: ReliableWriter::new(ReliableWriterConfig {
guid,
vendor_id,
reader_proxies: Vec::new(),
max_samples: VOLATILE_SECURE_DEFAULT_DEPTH,
history_kind: HistoryKind::KeepLast {
depth: VOLATILE_SECURE_DEFAULT_DEPTH,
},
heartbeat_period: VOLATILE_SECURE_HEARTBEAT_PERIOD,
fragment_size: DEFAULT_FRAGMENT_SIZE,
mtu: DEFAULT_MTU,
}),
}
}
#[must_use]
pub fn guid(&self) -> Guid {
self.inner.guid()
}
#[must_use]
pub fn reader_proxy_count(&self) -> usize {
self.inner.reader_proxy_count()
}
#[must_use]
pub fn inner(&self) -> &ReliableWriter {
&self.inner
}
pub fn add_reader_proxy(&mut self, proxy: ReaderProxy) {
self.inner.add_reader_proxy(proxy);
}
pub fn remove_reader_proxy(&mut self, guid: Guid) -> Option<ReaderProxy> {
self.inner.remove_reader_proxy(guid)
}
pub fn write(
&mut self,
msg: &ParticipantGenericMessage,
) -> Result<Vec<OutboundDatagram>, WireError> {
let payload = encode_generic_message(msg);
self.inner.write(&payload)
}
pub fn tick(&mut self, now: Duration) -> Result<Vec<OutboundDatagram>, WireError> {
self.inner.tick(now)
}
pub fn handle_acknack(
&mut self,
src_guid: Guid,
base: SequenceNumber,
requested: impl IntoIterator<Item = SequenceNumber>,
) {
self.inner.handle_acknack(src_guid, base, requested);
}
pub fn handle_nackfrag(&mut self, src_guid: Guid, nf: &NackFragSubmessage) {
self.inner.handle_nackfrag(src_guid, nf);
}
}
#[derive(Debug)]
pub struct VolatileSecureMessageReader {
inner: ReliableReader,
}
impl VolatileSecureMessageReader {
#[must_use]
pub fn new(participant_prefix: GuidPrefix, vendor_id: VendorId) -> Self {
let guid = Guid::new(
participant_prefix,
EntityId::BUILTIN_PARTICIPANT_VOLATILE_MESSAGE_SECURE_READER,
);
Self {
inner: ReliableReader::new(ReliableReaderConfig {
guid,
vendor_id,
writer_proxies: Vec::new(),
max_samples_per_proxy: VOLATILE_SECURE_READER_CAPACITY,
heartbeat_response_delay: DEFAULT_HEARTBEAT_RESPONSE_DELAY,
assembler_caps: AssemblerCaps::default(),
}),
}
}
#[must_use]
pub fn guid(&self) -> Guid {
self.inner.guid()
}
#[must_use]
pub fn writer_proxy_count(&self) -> usize {
self.inner.writer_proxy_count()
}
#[must_use]
pub fn inner(&self) -> &ReliableReader {
&self.inner
}
pub fn add_writer_proxy(&mut self, proxy: WriterProxy) {
self.inner.add_writer_proxy(proxy);
}
pub fn remove_writer_proxy(&mut self, guid: Guid) -> Option<WriterProxy> {
self.inner.remove_writer_proxy(guid)
}
pub fn handle_data(
&mut self,
data: &DataSubmessage,
) -> SecurityResult<Vec<ParticipantGenericMessage>> {
let samples = self.inner.handle_data(data);
decode_samples(samples.into_iter().map(|s| s.payload))
}
pub fn handle_data_frag(
&mut self,
df: &DataFragSubmessage,
now: Duration,
) -> SecurityResult<Vec<ParticipantGenericMessage>> {
let samples = self.inner.handle_data_frag(df, now);
decode_samples(samples.into_iter().map(|s| s.payload))
}
pub fn handle_gap(
&mut self,
gap: &GapSubmessage,
) -> SecurityResult<Vec<ParticipantGenericMessage>> {
let samples = self.inner.handle_gap(gap);
decode_samples(samples.into_iter().map(|s| s.payload))
}
pub fn handle_heartbeat(&mut self, hb: &HeartbeatSubmessage, now: Duration) {
self.inner.handle_heartbeat(hb, now);
}
pub fn tick_outbound(&mut self, now: Duration) -> Result<Vec<OutboundDatagram>, WireError> {
self.inner.tick_outbound(now)
}
}
fn decode_samples<B, I>(payloads: I) -> SecurityResult<Vec<ParticipantGenericMessage>>
where
B: AsRef<[u8]>,
I: IntoIterator<Item = B>,
{
let mut out = Vec::new();
for p in payloads {
out.push(decode_generic_message(p.as_ref())?);
}
Ok(out)
}
const _: Option<SecurityErrorKind> = None;
const _: Option<SecurityError> = None;
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used)]
mod tests {
use super::*;
use zerodds_rtps::wire_types::Locator;
use zerodds_security::generic_message::{MessageIdentity, class_id};
use zerodds_security::token::DataHolder;
fn local_prefix() -> GuidPrefix {
GuidPrefix::from_bytes([1; 12])
}
fn remote_prefix() -> GuidPrefix {
GuidPrefix::from_bytes([2; 12])
}
fn sample_msg() -> ParticipantGenericMessage {
ParticipantGenericMessage {
message_identity: MessageIdentity {
source_guid: [0xAA; 16],
sequence_number: 1,
},
related_message_identity: MessageIdentity::default(),
destination_participant_key: [0xBB; 16],
destination_endpoint_key: [0; 16],
source_endpoint_key: [0xCC; 16],
message_class_id: class_id::PARTICIPANT_CRYPTO_TOKENS.into(),
message_data: alloc::vec![DataHolder::new("DDS:Crypto:AES-GCM-GMAC")],
}
}
#[test]
fn writer_has_expected_entity_id() {
let w = VolatileSecureMessageWriter::new(local_prefix(), VendorId::ZERODDS);
assert_eq!(
w.guid().entity_id,
EntityId::BUILTIN_PARTICIPANT_VOLATILE_MESSAGE_SECURE_WRITER
);
}
#[test]
fn reader_has_expected_entity_id() {
let r = VolatileSecureMessageReader::new(local_prefix(), VendorId::ZERODDS);
assert_eq!(
r.guid().entity_id,
EntityId::BUILTIN_PARTICIPANT_VOLATILE_MESSAGE_SECURE_READER
);
}
#[test]
fn writer_starts_with_zero_proxies() {
let w = VolatileSecureMessageWriter::new(local_prefix(), VendorId::ZERODDS);
assert_eq!(w.reader_proxy_count(), 0);
}
#[test]
fn reader_starts_with_zero_proxies() {
let r = VolatileSecureMessageReader::new(local_prefix(), VendorId::ZERODDS);
assert_eq!(r.writer_proxy_count(), 0);
}
#[test]
fn write_without_proxies_returns_empty_datagrams() {
let mut w = VolatileSecureMessageWriter::new(local_prefix(), VendorId::ZERODDS);
let dgs = w.write(&sample_msg()).unwrap();
assert!(dgs.is_empty());
}
#[test]
fn write_with_one_proxy_produces_one_datagram() {
let mut w = VolatileSecureMessageWriter::new(local_prefix(), VendorId::ZERODDS);
let remote = Guid::new(
remote_prefix(),
EntityId::BUILTIN_PARTICIPANT_VOLATILE_MESSAGE_SECURE_READER,
);
w.add_reader_proxy(ReaderProxy::new(
remote,
alloc::vec![Locator::udp_v4([127, 0, 0, 1], 7411)],
alloc::vec![],
true,
));
let dgs = w.write(&sample_msg()).unwrap();
assert_eq!(dgs.len(), 1);
}
#[test]
fn add_remove_reader_proxy_roundtrip() {
let mut w = VolatileSecureMessageWriter::new(local_prefix(), VendorId::ZERODDS);
let remote = Guid::new(
remote_prefix(),
EntityId::BUILTIN_PARTICIPANT_VOLATILE_MESSAGE_SECURE_READER,
);
w.add_reader_proxy(ReaderProxy::new(remote, alloc::vec![], alloc::vec![], true));
assert_eq!(w.reader_proxy_count(), 1);
assert!(w.remove_reader_proxy(remote).is_some());
assert_eq!(w.reader_proxy_count(), 0);
}
#[test]
fn add_remove_writer_proxy_roundtrip() {
let mut r = VolatileSecureMessageReader::new(local_prefix(), VendorId::ZERODDS);
let remote = Guid::new(
remote_prefix(),
EntityId::BUILTIN_PARTICIPANT_VOLATILE_MESSAGE_SECURE_WRITER,
);
r.add_writer_proxy(WriterProxy::new(
remote,
alloc::vec![Locator::udp_v4([127, 0, 0, 1], 7411)],
alloc::vec![],
true,
));
assert_eq!(r.writer_proxy_count(), 1);
assert!(r.remove_writer_proxy(remote).is_some());
assert_eq!(r.writer_proxy_count(), 0);
}
#[test]
fn reader_decodes_data_with_known_writer() {
let mut r = VolatileSecureMessageReader::new(local_prefix(), VendorId::ZERODDS);
let remote = Guid::new(
remote_prefix(),
EntityId::BUILTIN_PARTICIPANT_VOLATILE_MESSAGE_SECURE_WRITER,
);
r.add_writer_proxy(WriterProxy::new(
remote,
alloc::vec![Locator::udp_v4([127, 0, 0, 1], 7411)],
alloc::vec![],
true,
));
let msg = sample_msg();
let payload = encode_generic_message(&msg);
let data = DataSubmessage {
extra_flags: 0,
reader_id: EntityId::BUILTIN_PARTICIPANT_VOLATILE_MESSAGE_SECURE_READER,
writer_id: EntityId::BUILTIN_PARTICIPANT_VOLATILE_MESSAGE_SECURE_WRITER,
writer_sn: SequenceNumber(1),
inline_qos: None,
key_flag: false,
non_standard_flag: false,
serialized_payload: payload.into(),
};
let out = r.handle_data(&data).unwrap();
assert_eq!(out.len(), 1);
assert_eq!(out[0], msg);
}
#[test]
fn reader_drops_data_from_unknown_writer() {
let mut r = VolatileSecureMessageReader::new(local_prefix(), VendorId::ZERODDS);
let msg = sample_msg();
let payload = encode_generic_message(&msg);
let data = DataSubmessage {
extra_flags: 0,
reader_id: EntityId::BUILTIN_PARTICIPANT_VOLATILE_MESSAGE_SECURE_READER,
writer_id: EntityId::BUILTIN_PARTICIPANT_VOLATILE_MESSAGE_SECURE_WRITER,
writer_sn: SequenceNumber(1),
inline_qos: None,
key_flag: false,
non_standard_flag: false,
serialized_payload: payload.into(),
};
let out = r.handle_data(&data).unwrap();
assert!(out.is_empty());
}
#[test]
fn reader_rejects_corrupt_payload() {
let mut r = VolatileSecureMessageReader::new(local_prefix(), VendorId::ZERODDS);
let remote = Guid::new(
remote_prefix(),
EntityId::BUILTIN_PARTICIPANT_VOLATILE_MESSAGE_SECURE_WRITER,
);
r.add_writer_proxy(WriterProxy::new(remote, alloc::vec![], alloc::vec![], true));
let data = DataSubmessage {
extra_flags: 0,
reader_id: EntityId::BUILTIN_PARTICIPANT_VOLATILE_MESSAGE_SECURE_READER,
writer_id: EntityId::BUILTIN_PARTICIPANT_VOLATILE_MESSAGE_SECURE_WRITER,
writer_sn: SequenceNumber(1),
inline_qos: None,
key_flag: false,
non_standard_flag: false,
serialized_payload: alloc::vec![0x00, 0x99, 0, 0].into(),
};
let err = r.handle_data(&data).unwrap_err();
assert_eq!(err.kind, SecurityErrorKind::BadArgument);
}
}