mod tcp;
mod udp;
mod udp_core;
pub use tcp::*;
pub use udp::*;
use crate::ber::length::parse_ber_length;
use crate::error::Result;
use bytes::Bytes;
use std::future::Future;
use std::net::SocketAddr;
use std::sync::LazyLock;
use std::sync::atomic::{AtomicI32, Ordering};
use std::time::Duration;
pub const MAX_UDP_PAYLOAD: u32 = 65507;
static REQUEST_ID_COUNTER: LazyLock<AtomicI32> = LazyLock::new(|| {
let mut buf = [0u8; 4];
getrandom::fill(&mut buf).expect("getrandom failed");
let seed = i32::from_ne_bytes(buf);
AtomicI32::new(seed)
});
pub fn alloc_request_id() -> i32 {
loop {
let id = REQUEST_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
let id = id & 0x7FFFFFFF;
if id != 0 {
return id;
}
}
}
pub trait Transport: Send + Sync {
fn send(&self, data: &[u8]) -> impl Future<Output = Result<()>> + Send;
fn recv(&self, request_id: i32) -> impl Future<Output = Result<(Bytes, SocketAddr)>> + Send;
fn peer_addr(&self) -> SocketAddr;
fn local_addr(&self) -> SocketAddr;
fn alloc_request_id(&self) -> i32 {
alloc_request_id()
}
fn is_reliable(&self) -> bool;
fn register_request(&self, _request_id: i32, _timeout: Duration) {
}
fn max_message_size(&self) -> u32 {
MAX_UDP_PAYLOAD
}
}
pub trait AgentTransport: Send + Sync {
fn recv_from(&self, buf: &mut [u8])
-> impl Future<Output = Result<(usize, SocketAddr)>> + Send;
fn send_to(&self, data: &[u8], target: SocketAddr) -> impl Future<Output = Result<()>> + Send;
fn local_addr(&self) -> SocketAddr;
}
pub(crate) fn extract_request_id(data: &[u8]) -> Option<i32> {
let mut pos = 0;
if pos >= data.len() || data[pos] != 0x30 {
return None;
}
pos += 1;
let (_, consumed) = parse_ber_length(&data[pos..])?;
pos += consumed;
if pos >= data.len() || data[pos] != 0x02 {
return None;
}
pos += 1;
let (version_len, consumed) = parse_ber_length(&data[pos..])?;
pos += consumed;
if pos + version_len > data.len() {
return None;
}
let version = if version_len == 1 {
data[pos] as i32
} else {
let mut v: i32 = 0;
for i in 0..version_len {
v = (v << 8) | (data[pos + i] as i32);
}
v
};
pos += version_len;
if pos >= data.len() {
return None;
}
let next_tag = data[pos];
if version == 3 && next_tag == 0x30 {
extract_v3_msg_id(data, pos)
} else if next_tag == 0x04 {
extract_v1v2c_request_id(data, pos)
} else {
None
}
}
fn extract_v3_msg_id(data: &[u8], mut pos: usize) -> Option<i32> {
if pos >= data.len() || data[pos] != 0x30 {
return None;
}
pos += 1;
let (_, consumed) = parse_ber_length(&data[pos..])?;
pos += consumed;
if pos >= data.len() || data[pos] != 0x02 {
return None;
}
pos += 1;
let (id_len, consumed) = parse_ber_length(&data[pos..])?;
pos += consumed;
if pos + id_len > data.len() {
return None;
}
decode_ber_signed_integer(&data[pos..pos + id_len])
}
fn extract_v1v2c_request_id(data: &[u8], mut pos: usize) -> Option<i32> {
if pos >= data.len() || data[pos] != 0x04 {
return None;
}
pos += 1;
let (community_len, consumed) = parse_ber_length(&data[pos..])?;
pos += consumed + community_len;
if pos >= data.len() {
return None;
}
let pdu_tag = data[pos];
if !(0xA0..=0xA8).contains(&pdu_tag) {
return None;
}
pos += 1;
let (_, consumed) = parse_ber_length(&data[pos..])?;
pos += consumed;
if pos >= data.len() || data[pos] != 0x02 {
return None;
}
pos += 1;
let (id_len, consumed) = parse_ber_length(&data[pos..])?;
pos += consumed;
if pos + id_len > data.len() {
return None;
}
decode_ber_signed_integer(&data[pos..pos + id_len])
}
fn decode_ber_signed_integer(bytes: &[u8]) -> Option<i32> {
if bytes.is_empty() {
return Some(0);
}
let mut value: i32 = if bytes[0] & 0x80 != 0 { -1 } else { 0 };
for &byte in bytes {
value = (value << 8) | (byte as i32);
}
Some(value)
}
#[cfg(test)]
mod request_id_tests {
use super::*;
use std::sync::atomic::AtomicI32;
#[test]
fn request_id_is_always_positive() {
for _ in 0..10_000 {
let id = alloc_request_id();
assert!(id > 0, "request ID must be positive, got {}", id);
}
}
#[test]
fn request_id_zero_is_skipped() {
for _ in 0..10_000 {
let id = alloc_request_id();
assert_ne!(id, 0, "request ID must not be zero");
}
}
#[test]
fn request_id_wrap_around_stays_positive() {
let counter = AtomicI32::new(i32::MAX - 100);
let alloc_test_id = || -> i32 {
loop {
let id = counter.fetch_add(1, Ordering::Relaxed);
let id = id & 0x7FFFFFFF;
if id != 0 {
return id;
}
}
};
for i in 0..200 {
let id = alloc_test_id();
assert!(
id > 0,
"request ID must be positive after wrap, iteration {}, got {}",
i,
id
);
}
}
#[test]
fn request_ids_are_unique() {
use std::collections::HashSet;
let mut seen = HashSet::new();
for _ in 0..10_000 {
let id = alloc_request_id();
assert!(seen.insert(id), "request ID {} was allocated twice", id);
}
}
}
#[cfg(test)]
mod extract_tests {
use super::*;
#[test]
fn test_extract_request_id_v2c() {
let response = [
0x30, 0x1c, 0x02, 0x01, 0x01, 0x04, 0x06, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0xa2, 0x0f, 0x02, 0x02, 0x30, 0x39, 0x02, 0x01, 0x00, 0x02, 0x01, 0x00, 0x30, 0x03, 0x30, 0x01, 0x00, ];
assert_eq!(extract_request_id(&response), Some(12345));
}
#[test]
fn test_extract_request_id_v3() {
let v3_response = [
0x30, 0x35, 0x02, 0x01, 0x03, 0x30, 0x11, 0x02, 0x02, 0x30, 0x39, 0x02, 0x03, 0x00, 0xff, 0xe3, 0x04, 0x01, 0x04, 0x02, 0x01, 0x03, 0x04, 0x00, 0x30, 0x1b, 0x04, 0x00, 0x04, 0x00, 0xa2, 0x15, 0x02, 0x02, 0x30, 0x39, 0x02, 0x01, 0x00, 0x02, 0x01, 0x00, 0x30, 0x09, 0x30, 0x07, 0x06, 0x03, 0x2b, 0x06, 0x01, 0x05, 0x00, ];
assert_eq!(extract_request_id(&v3_response), Some(12345));
}
#[test]
fn test_extract_request_id_v1() {
let v1_response = [
0x30, 0x1b, 0x02, 0x01, 0x00, 0x04, 0x06, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0xa2, 0x0e, 0x02, 0x01, 0x2a, 0x02, 0x01, 0x00, 0x02, 0x01, 0x00, 0x30, 0x03, 0x30, 0x01, 0x00, ];
assert_eq!(extract_request_id(&v1_response), Some(42));
}
#[test]
fn test_extract_request_id_negative() {
let response = [
0x30, 0x19, 0x02, 0x01, 0x01, 0x04, 0x06, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0xa2,
0x0c, 0x02, 0x01, 0xff, 0x02, 0x01, 0x00, 0x02, 0x01, 0x00, 0x30, 0x00,
];
assert_eq!(extract_request_id(&response), Some(-1));
}
#[test]
fn test_extract_request_id_malformed() {
assert_eq!(extract_request_id(&[]), None);
assert_eq!(extract_request_id(&[0x02, 0x01, 0x00]), None);
assert_eq!(extract_request_id(&[0x30, 0x10]), None);
}
}