use alloc::collections::BTreeMap;
use alloc::string::String;
use alloc::sync::Arc;
use alloc::vec::Vec;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ClientInterceptionPoint {
SendRequest,
SendPoll,
ReceiveReply,
ReceiveException,
ReceiveOther,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ServerInterceptionPoint {
ReceiveRequestServiceContexts,
ReceiveRequest,
SendReply,
SendException,
SendOther,
}
pub trait ClientRequestInterceptor: Send + Sync {
fn name(&self) -> &str;
fn intercept(&self, point: ClientInterceptionPoint, op: &str);
}
pub trait ServerRequestInterceptor: Send + Sync {
fn name(&self) -> &str;
fn intercept(&self, point: ServerInterceptionPoint, op: &str);
}
pub trait IorInterceptor: Send + Sync {
fn name(&self) -> &str;
fn establish_components(&self) -> Vec<u32>;
}
#[derive(Default)]
pub struct InterceptorRegistry {
client: Vec<Arc<dyn ClientRequestInterceptor>>,
server: Vec<Arc<dyn ServerRequestInterceptor>>,
ior: Vec<Arc<dyn IorInterceptor>>,
}
impl core::fmt::Debug for InterceptorRegistry {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("InterceptorRegistry")
.field("client_count", &self.client.len())
.field("server_count", &self.server.len())
.field("ior_count", &self.ior.len())
.finish()
}
}
impl InterceptorRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn add_client(&mut self, c: Arc<dyn ClientRequestInterceptor>) {
self.client.push(c);
}
pub fn add_server(&mut self, s: Arc<dyn ServerRequestInterceptor>) {
self.server.push(s);
}
pub fn add_ior(&mut self, i: Arc<dyn IorInterceptor>) {
self.ior.push(i);
}
#[must_use]
pub fn client_count(&self) -> usize {
self.client.len()
}
#[must_use]
pub fn server_count(&self) -> usize {
self.server.len()
}
#[must_use]
pub fn ior_count(&self) -> usize {
self.ior.len()
}
#[must_use]
pub fn client_interceptors(&self) -> &[Arc<dyn ClientRequestInterceptor>] {
&self.client
}
#[must_use]
pub fn server_interceptors(&self) -> &[Arc<dyn ServerRequestInterceptor>] {
&self.server
}
#[must_use]
pub fn ior_interceptors(&self) -> &[Arc<dyn IorInterceptor>] {
&self.ior
}
pub fn walk_client(&self, point: ClientInterceptionPoint, op: &str) {
for ic in &self.client {
ic.intercept(point, op);
}
}
pub fn walk_server(&self, point: ServerInterceptionPoint, op: &str) {
for ic in &self.server {
ic.intercept(point, op);
}
}
#[must_use]
pub fn walk_ior(&self) -> Vec<u32> {
let mut tags = Vec::new();
for ic in &self.ior {
tags.extend(ic.establish_components());
}
tags
}
}
pub trait PolicyFactory: Send + Sync {
fn policy_type(&self) -> u32;
#[allow(clippy::result_unit_err)]
fn create_policy(&self, value: &[u8]) -> Result<Vec<u8>, ()>;
}
#[derive(Debug, Clone, Default)]
pub struct PiCurrent {
slots: BTreeMap<u32, Vec<u8>>,
}
impl PiCurrent {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn set_slot(&mut self, slot_id: u32, value: Vec<u8>) {
self.slots.insert(slot_id, value);
}
#[must_use]
pub fn get_slot(&self, slot_id: u32) -> Option<&[u8]> {
self.slots.get(&slot_id).map(Vec::as_slice)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MessagingPolicy {
Rebind,
SyncScope,
RequestPriority,
ReplyPriority,
Routing,
MaxHops,
RequestTime,
ReplyTime,
RelativeRoundtripTimeout,
RoutingTypeRange,
}
impl MessagingPolicy {
#[must_use]
pub const fn policy_type(self) -> u32 {
match self {
Self::Rebind => 23,
Self::SyncScope => 24,
Self::RequestPriority => 25,
Self::ReplyPriority => 26,
Self::Routing => 30,
Self::MaxHops => 32,
Self::RequestTime => 27,
Self::ReplyTime => 28,
Self::RelativeRoundtripTimeout => 31,
Self::RoutingTypeRange => 33,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AmiReplyHandler {
Callback,
Polling,
}
pub trait AmiReplySink: Send + Sync {
fn handle_reply(&self, request_id: u32, body: &[u8]);
fn handle_excep(&self, request_id: u32, body: &[u8]);
fn handle_other(&self, request_id: u32, body: &[u8]);
}
pub fn dispatch_async_reply<S: AmiReplySink + ?Sized>(sink: &S, reply: &zerodds_corba_giop::Reply) {
use zerodds_corba_giop::ReplyStatusType;
match reply.reply_status {
ReplyStatusType::NoException => sink.handle_reply(reply.request_id, &reply.body),
ReplyStatusType::UserException | ReplyStatusType::SystemException => {
sink.handle_excep(reply.request_id, &reply.body);
}
ReplyStatusType::LocationForward
| ReplyStatusType::LocationForwardPerm
| ReplyStatusType::NeedsAddressingMode => {
sink.handle_other(reply.request_id, &reply.body);
}
}
}
#[cfg(feature = "std")]
#[derive(Debug, Default)]
pub struct PersistentRequestStore {
inner: std::sync::Mutex<BTreeMap<u32, PersistentRequestEntry>>,
}
#[cfg(feature = "std")]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PersistentRequestEntry {
pub body: Vec<u8>,
pub deadline_secs: u64,
}
#[cfg(feature = "std")]
impl PersistentRequestStore {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn add(&self, request_id: u32, body: Vec<u8>, deadline_secs: u64) {
if let Ok(mut g) = self.inner.lock() {
g.insert(
request_id,
PersistentRequestEntry {
body,
deadline_secs,
},
);
}
}
#[must_use]
pub fn poll(&self, request_id: u32) -> Option<PersistentRequestEntry> {
self.inner
.lock()
.ok()
.and_then(|mut g| g.remove(&request_id))
}
pub fn timeout_expired(&self, now_secs: u64) -> Vec<u32> {
let Ok(mut g) = self.inner.lock() else {
return Vec::new();
};
let expired: Vec<u32> = g
.iter()
.filter(|(_, e)| e.deadline_secs < now_secs)
.map(|(k, _)| *k)
.collect();
for k in &expired {
g.remove(k);
}
expired
}
pub fn len(&self) -> usize {
self.inner.lock().map_or(0, |g| g.len())
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CompressionAlgorithm {
None,
Zlib,
Gzip,
Lzma,
Deflate,
}
impl CompressionAlgorithm {
#[must_use]
pub const fn to_u8(self) -> u8 {
match self {
Self::None => 0,
Self::Zlib => 1,
Self::Gzip => 2,
Self::Lzma => 3,
Self::Deflate => 4,
}
}
#[allow(clippy::result_unit_err)]
pub const fn from_u8(v: u8) -> Result<Self, ()> {
match v {
0 => Ok(Self::None),
1 => Ok(Self::Zlib),
2 => Ok(Self::Gzip),
3 => Ok(Self::Lzma),
4 => Ok(Self::Deflate),
_ => Err(()),
}
}
#[cfg(feature = "std")]
pub fn compress(self, input: &[u8]) -> Result<Vec<u8>, CompressionError> {
use std::io::Write;
match self {
Self::None => Ok(input.to_vec()),
Self::Zlib => {
let mut e =
flate2::write::ZlibEncoder::new(Vec::new(), flate2::Compression::default());
e.write_all(input).map_err(CompressionError::from)?;
e.finish().map_err(CompressionError::from)
}
Self::Gzip => {
let mut e =
flate2::write::GzEncoder::new(Vec::new(), flate2::Compression::default());
e.write_all(input).map_err(CompressionError::from)?;
e.finish().map_err(CompressionError::from)
}
Self::Deflate => {
let mut e =
flate2::write::DeflateEncoder::new(Vec::new(), flate2::Compression::default());
e.write_all(input).map_err(CompressionError::from)?;
e.finish().map_err(CompressionError::from)
}
Self::Lzma => Err(CompressionError::Unsupported(Self::Lzma)),
}
}
#[cfg(feature = "std")]
pub fn decompress(self, input: &[u8]) -> Result<Vec<u8>, CompressionError> {
use std::io::Read;
match self {
Self::None => Ok(input.to_vec()),
Self::Zlib => {
let mut d = flate2::read::ZlibDecoder::new(input);
let mut out = Vec::new();
d.read_to_end(&mut out).map_err(CompressionError::from)?;
Ok(out)
}
Self::Gzip => {
let mut d = flate2::read::GzDecoder::new(input);
let mut out = Vec::new();
d.read_to_end(&mut out).map_err(CompressionError::from)?;
Ok(out)
}
Self::Deflate => {
let mut d = flate2::read::DeflateDecoder::new(input);
let mut out = Vec::new();
d.read_to_end(&mut out).map_err(CompressionError::from)?;
Ok(out)
}
Self::Lzma => Err(CompressionError::Unsupported(Self::Lzma)),
}
}
}
#[cfg(feature = "std")]
#[derive(Debug)]
pub enum CompressionError {
Io(std::io::Error),
Unsupported(CompressionAlgorithm),
}
#[cfg(feature = "std")]
impl core::fmt::Display for CompressionError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::Io(e) => write!(f, "compression io: {e}"),
Self::Unsupported(a) => write!(f, "compression unsupported: {a:?}"),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for CompressionError {}
#[cfg(feature = "std")]
impl From<std::io::Error> for CompressionError {
fn from(e: std::io::Error) -> Self {
Self::Io(e)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ZiopConfig {
pub algorithm: CompressionAlgorithm,
pub min_size_threshold: u32,
pub level: u8,
}
impl Default for ZiopConfig {
fn default() -> Self {
Self {
algorithm: CompressionAlgorithm::None,
min_size_threshold: 1024,
level: 6,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MiopConfig {
pub group_addr_v4: [u8; 4],
pub port: u16,
pub ttl: u8,
pub loopback: bool,
}
impl Default for MiopConfig {
fn default() -> Self {
Self {
group_addr_v4: [239, 255, 0, 1],
port: 5683,
ttl: 1,
loopback: false,
}
}
}
pub const MIOP_MAGIC: [u8; 4] = *b"MIOP";
pub const MIOP_VERSION_1_0: u8 = 0x10;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MiopFrameHeader {
pub version: u8,
pub flags: u8,
pub packet_length: u16,
pub unique_id: u32,
pub packet_number: u8,
pub number_of_packets: u8,
}
impl MiopFrameHeader {
pub const ENCODED_LEN: usize = 14;
#[must_use]
pub const fn single_packet(unique_id: u32, packet_length: u16, little_endian: bool) -> Self {
let mut flags: u8 = 0;
if little_endian {
flags |= 0x01;
}
flags |= 0x02;
Self {
version: MIOP_VERSION_1_0,
flags,
packet_length,
unique_id,
packet_number: 0,
number_of_packets: 1,
}
}
pub fn encode(&self, out: &mut Vec<u8>) {
out.extend_from_slice(&MIOP_MAGIC);
out.push(self.version);
out.push(self.flags);
out.extend_from_slice(&self.packet_length.to_be_bytes());
out.extend_from_slice(&self.unique_id.to_be_bytes());
out.push(self.packet_number);
out.push(self.number_of_packets);
}
pub fn decode(input: &[u8]) -> Result<(Self, usize), MiopError> {
if input.len() < Self::ENCODED_LEN {
return Err(MiopError::TooShort);
}
if input[0..4] != MIOP_MAGIC {
return Err(MiopError::InvalidMagic);
}
let version = input[4];
if version != MIOP_VERSION_1_0 {
return Err(MiopError::UnsupportedVersion(version));
}
let flags = input[5];
let packet_length = u16::from_be_bytes([input[6], input[7]]);
let unique_id = u32::from_be_bytes([input[8], input[9], input[10], input[11]]);
let packet_number = input[12];
let number_of_packets = input[13];
Ok((
Self {
version,
flags,
packet_length,
unique_id,
packet_number,
number_of_packets,
},
Self::ENCODED_LEN,
))
}
#[must_use]
pub const fn is_last_fragment(&self) -> bool {
(self.flags & 0x02) != 0
}
#[must_use]
pub const fn is_little_endian(&self) -> bool {
(self.flags & 0x01) != 0
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MiopError {
TooShort,
InvalidMagic,
UnsupportedVersion(u8),
}
pub trait MulticastSink: Send + Sync {
fn send_datagram(&self, data: &[u8]) -> Result<(), MulticastSinkError>;
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MulticastSinkError(pub String);
pub struct MiopSender<S: MulticastSink> {
sink: S,
mtu: usize,
next_unique_id: core::sync::atomic::AtomicU32,
}
impl<S: MulticastSink> MiopSender<S> {
#[must_use]
pub fn new(sink: S, mtu: usize) -> Self {
Self {
sink,
mtu,
next_unique_id: core::sync::atomic::AtomicU32::new(1),
}
}
#[must_use]
pub const fn max_body_per_frame(&self) -> usize {
self.mtu.saturating_sub(MiopFrameHeader::ENCODED_LEN)
}
pub fn send_giop(
&self,
giop_bytes: &[u8],
little_endian: bool,
) -> Result<(), MulticastSinkError> {
let max = self.max_body_per_frame();
let unique_id = self
.next_unique_id
.fetch_add(1, core::sync::atomic::Ordering::Relaxed);
if max == 0 || giop_bytes.len() <= max {
let header = MiopFrameHeader::single_packet(
unique_id,
u16::try_from(giop_bytes.len()).unwrap_or(u16::MAX),
little_endian,
);
let mut datagram = Vec::with_capacity(MiopFrameHeader::ENCODED_LEN + giop_bytes.len());
header.encode(&mut datagram);
datagram.extend_from_slice(giop_bytes);
return self.sink.send_datagram(&datagram);
}
let total_len = giop_bytes.len();
let total_packets = total_len.div_ceil(max);
let total_packets_u8 = u8::try_from(total_packets).unwrap_or(u8::MAX);
for (idx, chunk) in giop_bytes.chunks(max).enumerate() {
let mut flags: u8 = 0;
if little_endian {
flags |= 0x01;
}
let is_last = idx + 1 == total_packets;
if is_last {
flags |= 0x02;
}
let header = MiopFrameHeader {
version: MIOP_VERSION_1_0,
flags,
packet_length: u16::try_from(chunk.len()).unwrap_or(u16::MAX),
unique_id,
packet_number: u8::try_from(idx).unwrap_or(u8::MAX),
number_of_packets: total_packets_u8,
};
let mut datagram = Vec::with_capacity(MiopFrameHeader::ENCODED_LEN + chunk.len());
header.encode(&mut datagram);
datagram.extend_from_slice(chunk);
self.sink.send_datagram(&datagram)?;
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BridgeMode {
Inline,
RequestLevel,
}
#[derive(Debug, Clone)]
pub struct BridgeConfig {
pub mode: BridgeMode,
pub source_orb: String,
pub target_orb: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BiDirPolicy {
Normal,
Both,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BiDirServiceContext {
pub listen_points: Vec<(String, u16)>,
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
struct DummyClient;
impl ClientRequestInterceptor for DummyClient {
fn name(&self) -> &str {
"dummy-client"
}
fn intercept(&self, _: ClientInterceptionPoint, _: &str) {}
}
struct DummyServer;
impl ServerRequestInterceptor for DummyServer {
fn name(&self) -> &str {
"dummy-server"
}
fn intercept(&self, _: ServerInterceptionPoint, _: &str) {}
}
struct DummyIor;
impl IorInterceptor for DummyIor {
fn name(&self) -> &str {
"dummy-ior"
}
fn establish_components(&self) -> Vec<u32> {
alloc::vec![]
}
}
#[test]
fn registry_add_increments_counts() {
let mut r = InterceptorRegistry::new();
r.add_client(Arc::new(DummyClient) as Arc<dyn ClientRequestInterceptor>);
r.add_server(Arc::new(DummyServer) as Arc<dyn ServerRequestInterceptor>);
r.add_ior(Arc::new(DummyIor) as Arc<dyn IorInterceptor>);
assert_eq!(r.client_count(), 1);
assert_eq!(r.server_count(), 1);
assert_eq!(r.ior_count(), 1);
}
#[test]
fn picurrent_set_get_slot() {
let mut p = PiCurrent::new();
p.set_slot(7, alloc::vec![0xab, 0xcd]);
assert_eq!(p.get_slot(7), Some(&[0xab, 0xcd][..]));
assert!(p.get_slot(99).is_none());
}
#[test]
fn client_interception_points_distinct() {
assert_ne!(
ClientInterceptionPoint::SendRequest,
ClientInterceptionPoint::ReceiveReply
);
}
#[test]
fn server_interception_points_distinct() {
assert_ne!(
ServerInterceptionPoint::ReceiveRequest,
ServerInterceptionPoint::SendReply
);
}
#[test]
fn messaging_policies_distinct() {
assert_ne!(MessagingPolicy::Rebind, MessagingPolicy::SyncScope);
}
#[test]
fn ami_reply_handler_distinct() {
assert_ne!(AmiReplyHandler::Callback, AmiReplyHandler::Polling);
}
#[test]
fn compression_algorithm_round_trip() {
for a in [
CompressionAlgorithm::None,
CompressionAlgorithm::Zlib,
CompressionAlgorithm::Gzip,
CompressionAlgorithm::Lzma,
CompressionAlgorithm::Deflate,
] {
assert_eq!(CompressionAlgorithm::from_u8(a.to_u8()).expect("ok"), a);
}
}
#[test]
fn compression_algorithm_unknown_rejected() {
assert!(CompressionAlgorithm::from_u8(99).is_err());
}
#[test]
fn ziop_config_default_no_compression() {
let c = ZiopConfig::default();
assert_eq!(c.algorithm, CompressionAlgorithm::None);
assert_eq!(c.min_size_threshold, 1024);
}
#[cfg(feature = "std")]
#[test]
fn compression_none_passes_through() {
let input = b"hello, corba";
let out = CompressionAlgorithm::None
.compress(input)
.expect("compress ok");
assert_eq!(out, input);
let back = CompressionAlgorithm::None
.decompress(&out)
.expect("decompress ok");
assert_eq!(back, input);
}
#[cfg(feature = "std")]
#[test]
fn compression_zlib_round_trip() {
let input = b"the quick brown fox jumps over the lazy dog".repeat(8);
let compressed = CompressionAlgorithm::Zlib
.compress(&input)
.expect("compress ok");
assert!(compressed.len() < input.len());
let back = CompressionAlgorithm::Zlib
.decompress(&compressed)
.expect("decompress ok");
assert_eq!(back, input);
}
#[cfg(feature = "std")]
#[test]
fn compression_gzip_round_trip() {
let input = b"OMG-CORBA-3.3 18 Compression spec".repeat(16);
let compressed = CompressionAlgorithm::Gzip
.compress(&input)
.expect("compress ok");
let back = CompressionAlgorithm::Gzip
.decompress(&compressed)
.expect("decompress ok");
assert_eq!(back, input);
}
#[cfg(feature = "std")]
#[test]
fn compression_deflate_round_trip() {
let input = b"deflate raw RFC1951".repeat(32);
let compressed = CompressionAlgorithm::Deflate
.compress(&input)
.expect("compress ok");
let back = CompressionAlgorithm::Deflate
.decompress(&compressed)
.expect("decompress ok");
assert_eq!(back, input);
}
#[cfg(feature = "std")]
#[test]
fn compression_lzma_returns_unsupported() {
let err = CompressionAlgorithm::Lzma
.compress(b"x")
.expect_err("must fail");
assert!(matches!(
err,
CompressionError::Unsupported(CompressionAlgorithm::Lzma)
));
}
#[cfg(feature = "std")]
#[test]
fn compression_zlib_handles_large_block() {
let input: Vec<u8> = (0..10_000_u32)
.map(|i| (i.wrapping_mul(2654435761) >> 24) as u8)
.collect();
let compressed = CompressionAlgorithm::Zlib
.compress(&input)
.expect("compress ok");
let back = CompressionAlgorithm::Zlib
.decompress(&compressed)
.expect("decompress ok");
assert_eq!(back, input);
}
#[test]
fn miop_config_default_uses_239_range() {
let m = MiopConfig::default();
assert_eq!(m.group_addr_v4[0], 239);
assert_eq!(m.port, 5683);
assert_eq!(m.ttl, 1);
}
#[test]
fn bridge_modes_distinct() {
assert_ne!(BridgeMode::Inline, BridgeMode::RequestLevel);
}
#[test]
fn bridge_config_construct() {
let c = BridgeConfig {
mode: BridgeMode::RequestLevel,
source_orb: "corba".into(),
target_orb: "dds".into(),
};
assert_eq!(c.source_orb, "corba");
}
#[test]
fn bidir_policy_distinct() {
assert_ne!(BiDirPolicy::Normal, BiDirPolicy::Both);
}
#[test]
fn bidir_service_context_listen_points() {
let sc = BiDirServiceContext {
listen_points: alloc::vec![("client.example".into(), 8080)],
};
assert_eq!(sc.listen_points.len(), 1);
}
#[test]
fn registry_walk_client_invokes_all_client_interceptors() {
use core::sync::atomic::{AtomicUsize, Ordering};
struct Counting {
count: alloc::sync::Arc<AtomicUsize>,
}
impl ClientRequestInterceptor for Counting {
fn name(&self) -> &str {
"counting"
}
fn intercept(&self, _: ClientInterceptionPoint, _: &str) {
self.count.fetch_add(1, Ordering::Relaxed);
}
}
let count = alloc::sync::Arc::new(AtomicUsize::new(0));
let mut r = InterceptorRegistry::new();
r.add_client(alloc::sync::Arc::new(Counting {
count: count.clone(),
}) as Arc<dyn ClientRequestInterceptor>);
r.add_client(alloc::sync::Arc::new(Counting {
count: count.clone(),
}) as Arc<dyn ClientRequestInterceptor>);
r.walk_client(ClientInterceptionPoint::SendRequest, "op");
assert_eq!(count.load(Ordering::Relaxed), 2);
}
#[test]
fn registry_walk_ior_collects_tags() {
struct EmitTwo;
impl IorInterceptor for EmitTwo {
fn name(&self) -> &str {
"emit-two"
}
fn establish_components(&self) -> Vec<u32> {
alloc::vec![0xAAAA_AAAA, 0xBBBB_BBBB]
}
}
let mut r = InterceptorRegistry::new();
r.add_ior(Arc::new(EmitTwo) as Arc<dyn IorInterceptor>);
let tags = r.walk_ior();
assert_eq!(tags, alloc::vec![0xAAAA_AAAA, 0xBBBB_BBBB]);
}
#[test]
fn messaging_policy_wire_values_match_omg_messaging_idl() {
assert_eq!(MessagingPolicy::Rebind.policy_type(), 23);
assert_eq!(MessagingPolicy::SyncScope.policy_type(), 24);
assert_eq!(MessagingPolicy::Routing.policy_type(), 30);
assert_eq!(MessagingPolicy::RelativeRoundtripTimeout.policy_type(), 31);
}
struct RecordingSink {
replies: alloc::sync::Arc<std::sync::Mutex<Vec<(u32, &'static str)>>>,
}
impl AmiReplySink for RecordingSink {
fn handle_reply(&self, request_id: u32, _body: &[u8]) {
if let Ok(mut g) = self.replies.lock() {
g.push((request_id, "reply"));
}
}
fn handle_excep(&self, request_id: u32, _body: &[u8]) {
if let Ok(mut g) = self.replies.lock() {
g.push((request_id, "excep"));
}
}
fn handle_other(&self, request_id: u32, _body: &[u8]) {
if let Ok(mut g) = self.replies.lock() {
g.push((request_id, "other"));
}
}
}
#[test]
fn ami_handler_handles_no_exception_reply() {
use zerodds_corba_giop::{Reply, ReplyStatusType, ServiceContextList};
let replies = alloc::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let sink = RecordingSink {
replies: replies.clone(),
};
let r = Reply {
request_id: 42,
reply_status: ReplyStatusType::NoException,
service_context: ServiceContextList::default(),
body: alloc::vec![1, 2, 3],
};
dispatch_async_reply(&sink, &r);
let g = replies.lock().unwrap();
assert_eq!(*g, alloc::vec![(42_u32, "reply")]);
}
#[test]
fn ami_handler_handles_user_exception_reply() {
use zerodds_corba_giop::{Reply, ReplyStatusType, ServiceContextList};
let replies = alloc::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let sink = RecordingSink {
replies: replies.clone(),
};
let r = Reply {
request_id: 7,
reply_status: ReplyStatusType::UserException,
service_context: ServiceContextList::default(),
body: alloc::vec![],
};
dispatch_async_reply(&sink, &r);
let g = replies.lock().unwrap();
assert_eq!(*g, alloc::vec![(7_u32, "excep")]);
}
#[cfg(feature = "std")]
#[test]
fn persistent_request_store_add_poll_timeout() {
let s = PersistentRequestStore::new();
s.add(1, alloc::vec![0xAA], 100);
s.add(2, alloc::vec![0xBB], 50);
s.add(3, alloc::vec![0xCC], 200);
assert_eq!(s.len(), 3);
let e1 = s.poll(1).expect("present");
assert_eq!(e1.body, alloc::vec![0xAA]);
assert_eq!(e1.deadline_secs, 100);
assert!(s.poll(1).is_none());
assert_eq!(s.len(), 2);
let expired = s.timeout_expired(120);
assert_eq!(expired, alloc::vec![2]);
assert_eq!(s.len(), 1);
assert!(s.poll(3).is_some());
}
#[test]
fn miop_frame_encode_decode_roundtrip() {
let h = MiopFrameHeader::single_packet(0xCAFE_BABE, 1234, true);
let mut bytes = Vec::new();
h.encode(&mut bytes);
assert_eq!(bytes.len(), MiopFrameHeader::ENCODED_LEN);
let (decoded, consumed) = MiopFrameHeader::decode(&bytes).expect("decode");
assert_eq!(consumed, MiopFrameHeader::ENCODED_LEN);
assert_eq!(decoded, h);
assert!(decoded.is_last_fragment());
assert!(decoded.is_little_endian());
}
#[test]
fn miop_frame_decode_rejects_bad_magic_and_version() {
let mut bad = alloc::vec![b'X', b'X', b'X', b'X'];
bad.extend_from_slice(&[0u8; 10]);
assert_eq!(
MiopFrameHeader::decode(&bad).unwrap_err(),
MiopError::InvalidMagic
);
let mut wrong_version = MIOP_MAGIC.to_vec();
wrong_version.push(0xFF); wrong_version.extend_from_slice(&[0u8; 9]);
assert_eq!(
MiopFrameHeader::decode(&wrong_version).unwrap_err(),
MiopError::UnsupportedVersion(0xFF)
);
let too_short = MIOP_MAGIC.to_vec();
assert_eq!(
MiopFrameHeader::decode(&too_short).unwrap_err(),
MiopError::TooShort
);
}
struct MockSink {
sent: alloc::sync::Arc<std::sync::Mutex<Vec<Vec<u8>>>>,
}
impl MulticastSink for MockSink {
fn send_datagram(&self, data: &[u8]) -> Result<(), MulticastSinkError> {
if let Ok(mut g) = self.sent.lock() {
g.push(data.to_vec());
}
Ok(())
}
}
#[test]
fn miop_sender_single_packet_fits_mtu() {
let sent = alloc::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let sink = MockSink { sent: sent.clone() };
let sender = MiopSender::new(sink, 256);
let payload = alloc::vec![0xAB; 100];
sender.send_giop(&payload, false).expect("send");
let g = sent.lock().unwrap();
assert_eq!(g.len(), 1, "single-packet path produces 1 datagram");
assert!(g[0].starts_with(&MIOP_MAGIC));
assert_eq!(g[0].len(), MiopFrameHeader::ENCODED_LEN + 100);
}
#[test]
fn miop_sender_fragments_multi_packet_over_small_mtu() {
let sent = alloc::sync::Arc::new(std::sync::Mutex::new(Vec::new()));
let sink = MockSink { sent: sent.clone() };
let sender = MiopSender::new(sink, 44);
let payload = alloc::vec![0xCD; 100];
sender.send_giop(&payload, true).expect("send");
let g = sent.lock().unwrap();
assert_eq!(g.len(), 4);
let (last_header, _) = MiopFrameHeader::decode(&g[3]).expect("decode");
assert!(last_header.is_last_fragment());
assert_eq!(last_header.packet_number, 3);
assert_eq!(last_header.number_of_packets, 4);
let (first_header, _) = MiopFrameHeader::decode(&g[0]).expect("decode");
assert!(!first_header.is_last_fragment());
assert_eq!(first_header.packet_number, 0);
}
}