extern crate alloc;
use alloc::sync::Arc;
use alloc::vec::Vec;
use crate::datagram::encode_data_datagram;
use crate::error::WireError;
use crate::header::RtpsHeader;
use crate::submessages::DataSubmessage;
use crate::wire_types::{EntityId, Guid, GuidPrefix, SequenceNumber, VendorId};
#[derive(Debug, Clone)]
pub struct BestEffortWriter {
guid: Guid,
vendor_id: VendorId,
next_sn: i64,
target_reader: EntityId,
}
impl BestEffortWriter {
#[must_use]
pub fn new(
participant_prefix: GuidPrefix,
writer_id: EntityId,
target_reader: EntityId,
) -> Self {
Self {
guid: Guid::new(participant_prefix, writer_id),
vendor_id: VendorId::ZERODDS,
next_sn: 1,
target_reader,
}
}
pub fn set_vendor_id(&mut self, vendor: VendorId) {
self.vendor_id = vendor;
}
#[must_use]
pub fn guid(&self) -> Guid {
self.guid
}
#[must_use]
pub fn next_sequence_number(&self) -> SequenceNumber {
SequenceNumber(self.next_sn)
}
pub fn write(&mut self, payload: &[u8]) -> Result<Vec<u8>, WireError> {
let sn = SequenceNumber(self.next_sn);
self.next_sn = self
.next_sn
.checked_add(1)
.ok_or(WireError::ValueOutOfRange {
message: "writer sequence number overflow",
})?;
let data = DataSubmessage {
extra_flags: 0,
reader_id: self.target_reader,
writer_id: self.guid.entity_id,
writer_sn: sn,
inline_qos: None,
key_flag: false,
non_standard_flag: false,
serialized_payload: Arc::from(payload),
};
let header = RtpsHeader::new(self.vendor_id, self.guid.prefix);
encode_data_datagram(header, &[data])
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
use super::*;
use crate::datagram::{ParsedSubmessage, decode_datagram};
fn writer() -> BestEffortWriter {
BestEffortWriter::new(
GuidPrefix::from_bytes([1; 12]),
EntityId::user_writer_with_key([0x10, 0x20, 0x30]),
EntityId::user_reader_with_key([0xA0, 0xB0, 0xC0]),
)
}
#[test]
fn writer_starts_at_sequence_number_one() {
let w = writer();
assert_eq!(w.next_sequence_number(), SequenceNumber(1));
}
#[test]
fn writer_increments_sequence_number_per_write() {
let mut w = writer();
w.write(b"a").unwrap();
assert_eq!(w.next_sequence_number(), SequenceNumber(2));
w.write(b"b").unwrap();
assert_eq!(w.next_sequence_number(), SequenceNumber(3));
}
#[test]
fn writer_produces_decodable_datagram() {
let mut w = writer();
let bytes = w.write(b"hello world").unwrap();
let parsed = decode_datagram(&bytes).unwrap();
assert_eq!(parsed.submessages.len(), 1);
match &parsed.submessages[0] {
ParsedSubmessage::Data(d) => {
assert_eq!(d.writer_sn, SequenceNumber(1));
assert_eq!(d.serialized_payload.as_ref(), b"hello world");
assert_eq!(d.writer_id.entity_key, [0x10, 0x20, 0x30]);
assert_eq!(d.reader_id.entity_key, [0xA0, 0xB0, 0xC0]);
}
other => panic!("expected Data, got {other:?}"),
}
}
#[test]
fn writer_sets_header_with_zerodds_vendor() {
let mut w = writer();
let bytes = w.write(b"x").unwrap();
let parsed = decode_datagram(&bytes).unwrap();
assert_eq!(parsed.header.vendor_id, VendorId::ZERODDS);
}
#[test]
fn writer_set_vendor_id_overrides_default() {
let mut w = writer();
w.set_vendor_id(VendorId([0xAB, 0xCD]));
let bytes = w.write(b"x").unwrap();
let parsed = decode_datagram(&bytes).unwrap();
assert_eq!(parsed.header.vendor_id, VendorId([0xAB, 0xCD]));
}
#[test]
fn writer_sn_overflow_is_error() {
let mut w = writer();
w.next_sn = i64::MAX;
let res = w.write(b"x");
assert!(matches!(res, Err(WireError::ValueOutOfRange { .. })));
}
#[test]
fn writer_three_writes_have_increasing_sn_in_decoded() {
use alloc::format;
let mut w = writer();
let mut sns = Vec::new();
for i in 0..3 {
let bytes = w.write(format!("msg-{i}").as_bytes()).unwrap();
let parsed = decode_datagram(&bytes).unwrap();
if let ParsedSubmessage::Data(d) = &parsed.submessages[0] {
sns.push(d.writer_sn);
}
}
assert_eq!(
sns,
alloc::vec![SequenceNumber(1), SequenceNumber(2), SequenceNumber(3)]
);
}
}