use super::Transport;
use crate::error::{Error, Result};
use bytes::{Bytes, BytesMut};
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::{Mutex, OwnedMutexGuard};
use tokio::time::timeout;
const MAX_TCP_MESSAGE_SIZE: usize = 0x7fffffff;
const DEFAULT_MAX_ALLOCATION_SIZE: usize = 10 * 1024 * 1024;
#[derive(Debug, Clone)]
pub struct TcpOptions {
pub max_allocation_size: usize,
}
impl Default for TcpOptions {
fn default() -> Self {
Self {
max_allocation_size: DEFAULT_MAX_ALLOCATION_SIZE,
}
}
}
#[derive(Debug)]
pub struct TcpTransportBuilder {
timeout: Option<Duration>,
options: TcpOptions,
}
impl TcpTransportBuilder {
pub fn new() -> Self {
Self {
timeout: None,
options: TcpOptions::default(),
}
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn max_allocation_size(mut self, size: usize) -> Self {
self.options.max_allocation_size = size;
self
}
pub async fn connect(self, target: SocketAddr) -> Result<TcpTransport> {
let stream = match self.timeout {
Some(t) => timeout(t, TcpStream::connect(target))
.await
.map_err(|_| {
Error::Timeout {
target,
elapsed: t,
retries: 0,
}
.boxed()
})?
.map_err(|e| Error::Network { target, source: e }.boxed())?,
None => TcpStream::connect(target)
.await
.map_err(|e| Error::Network { target, source: e }.boxed())?,
};
let local_addr = stream
.local_addr()
.map_err(|e| Error::Network { target, source: e }.boxed())?;
Ok(TcpTransport {
inner: Arc::new(TcpTransportInner {
stream: Arc::new(Mutex::new(stream)),
active_guard: Mutex::new(None),
current_timeout_ms: AtomicU64::new(30_000),
target,
local_addr,
max_allocation_size: self.options.max_allocation_size,
}),
})
}
}
impl Default for TcpTransportBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct TcpTransport {
inner: Arc<TcpTransportInner>,
}
struct TcpTransportInner {
stream: Arc<Mutex<TcpStream>>,
active_guard: Mutex<Option<OwnedMutexGuard<TcpStream>>>,
current_timeout_ms: AtomicU64,
target: SocketAddr,
local_addr: SocketAddr,
max_allocation_size: usize,
}
impl TcpTransport {
pub async fn connect(target: SocketAddr) -> Result<Self> {
Self::builder().connect(target).await
}
pub async fn connect_timeout(target: SocketAddr, connect_timeout: Duration) -> Result<Self> {
Self::builder()
.timeout(connect_timeout)
.connect(target)
.await
}
pub fn builder() -> TcpTransportBuilder {
TcpTransportBuilder::new()
}
pub async fn from_socket(
socket: tokio::net::TcpSocket,
target: SocketAddr,
options: TcpOptions,
) -> Result<Self> {
let stream = socket
.connect(target)
.await
.map_err(|e| Error::Network { target, source: e }.boxed())?;
let local_addr = stream
.local_addr()
.map_err(|e| Error::Network { target, source: e }.boxed())?;
Ok(Self {
inner: Arc::new(TcpTransportInner {
stream: Arc::new(Mutex::new(stream)),
active_guard: Mutex::new(None),
current_timeout_ms: AtomicU64::new(30_000),
target,
local_addr,
max_allocation_size: options.max_allocation_size,
}),
})
}
}
impl Transport for TcpTransport {
async fn send(&self, data: &[u8]) -> Result<()> {
let mut stream = self.inner.stream.clone().lock_owned().await;
let target = self.inner.target;
let result = async {
stream
.write_all(data)
.await
.map_err(|e| Error::Network { target, source: e }.boxed())?;
stream
.flush()
.await
.map_err(|e| Error::Network { target, source: e }.boxed())?;
Ok::<_, Box<Error>>(())
}
.await;
match result {
Ok(()) => {
*self.inner.active_guard.lock().await = Some(stream);
Ok(())
}
Err(e) => {
Err(e)
}
}
}
fn register_request(&self, _request_id: i32, timeout: Duration) {
self.inner
.current_timeout_ms
.store(timeout.as_millis() as u64, Ordering::Relaxed);
}
async fn recv(&self, request_id: i32) -> Result<(Bytes, SocketAddr)> {
let recv_timeout =
Duration::from_millis(self.inner.current_timeout_ms.load(Ordering::Relaxed));
let target = self.inner.target;
let mut stream =
self.inner.active_guard.lock().await.take().ok_or_else(|| {
Error::Config("recv() called without prior send()".into()).boxed()
})?;
let max_alloc = self.inner.max_allocation_size;
let result = timeout(
recv_timeout,
read_ber_message(&mut stream, target, max_alloc),
)
.await;
match result {
Ok(Ok(data)) => Ok((data, target)),
Ok(Err(e)) => Err(e),
Err(_) => {
tracing::debug!(target: "async_snmp::transport::tcp", { request_id, %target, elapsed = ?recv_timeout }, "transport timeout");
Err(Error::Timeout {
target,
elapsed: recv_timeout,
retries: 0,
}
.boxed())
}
}
}
fn peer_addr(&self) -> SocketAddr {
self.inner.target
}
fn local_addr(&self) -> SocketAddr {
self.inner.local_addr
}
fn is_reliable(&self) -> bool {
true
}
fn max_message_size(&self) -> u32 {
MAX_TCP_MESSAGE_SIZE as u32
}
}
async fn read_ber_message(
stream: &mut TcpStream,
target: SocketAddr,
max_allocation_size: usize,
) -> Result<Bytes> {
let mut tag_buf = [0u8; 1];
stream
.read_exact(&mut tag_buf)
.await
.map_err(|e| Error::Network { target, source: e }.boxed())?;
let tag = tag_buf[0];
if tag != 0x30 {
tracing::debug!(target: "async_snmp::transport::tcp", { expected_tag = 0x30, actual_tag = tag, %target }, "invalid SNMP message tag");
return Err(Error::MalformedResponse { target }.boxed());
}
let mut first_len_byte = [0u8; 1];
stream
.read_exact(&mut first_len_byte)
.await
.map_err(|e| Error::Network { target, source: e }.boxed())?;
let (content_len, len_bytes) = if first_len_byte[0] < 0x80 {
(first_len_byte[0] as usize, vec![first_len_byte[0]])
} else if first_len_byte[0] == 0x80 {
tracing::debug!(target: "async_snmp::transport::tcp", { %target }, "indefinite length encoding not supported");
return Err(Error::MalformedResponse { target }.boxed());
} else {
let num_len_bytes = (first_len_byte[0] & 0x7F) as usize;
if num_len_bytes > 4 {
tracing::debug!(target: "async_snmp::transport::tcp", { octets = num_len_bytes, %target }, "length encoding too long");
return Err(Error::MalformedResponse { target }.boxed());
}
let mut len_bytes_buf = vec![0u8; num_len_bytes];
stream
.read_exact(&mut len_bytes_buf)
.await
.map_err(|e| Error::Network { target, source: e }.boxed())?;
let mut length: usize = 0;
for &b in &len_bytes_buf {
length = (length << 8) | (b as usize);
}
let mut all_len_bytes = vec![first_len_byte[0]];
all_len_bytes.extend_from_slice(&len_bytes_buf);
(length, all_len_bytes)
};
if content_len > max_allocation_size {
tracing::warn!(target: "async_snmp::transport::tcp", { size = content_len, max = max_allocation_size, %target }, "message size exceeds limit");
return Err(Error::MalformedResponse { target }.boxed());
}
let mut content = vec![0u8; content_len];
stream
.read_exact(&mut content)
.await
.map_err(|e| Error::Network { target, source: e }.boxed())?;
let total_len = 1 + len_bytes.len() + content_len;
let mut message = BytesMut::with_capacity(total_len);
message.extend_from_slice(&[tag]);
message.extend_from_slice(&len_bytes);
message.extend_from_slice(&content);
Ok(message.freeze())
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpListener;
#[tokio::test]
async fn test_tcp_send_recv() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let server_addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let (mut socket, _) = listener.accept().await.unwrap();
let mut buf = vec![0u8; 1024];
let n = socket.read(&mut buf).await.unwrap();
let response = [
0x30, 0x1c, 0x02, 0x01, 0x01, 0x04, 0x06, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0xa2, 0x0f, 0x02, 0x01, 0x01, 0x02, 0x01, 0x00, 0x02, 0x01, 0x00, 0x30, 0x04, 0x30, 0x02, 0x05, 0x00, ];
socket.write_all(&response).await.unwrap();
n
});
let transport = TcpTransport::connect(server_addr).await.unwrap();
let request = [
0x30, 0x1a, 0x02, 0x01, 0x01, 0x04, 0x06, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0xa0, 0x0d, 0x02, 0x01, 0x01, 0x02, 0x01, 0x00, 0x02, 0x01, 0x00, 0x30, 0x02, 0x30, 0x00,
];
transport.send(&request).await.unwrap();
transport.register_request(1, Duration::from_secs(5));
let (response, source) = transport.recv(1).await.unwrap();
assert_eq!(source, server_addr);
assert_eq!(response[0], 0x30); assert!(response.len() > 10);
server.await.unwrap();
}
#[tokio::test]
async fn test_tcp_long_length_form() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let server_addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let (mut socket, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 1];
let _ = socket.read(&mut buf).await;
let mut response = vec![0x30, 0x81, 0xc8]; response.extend(vec![0x00; 200]); socket.write_all(&response).await.unwrap();
});
let transport = TcpTransport::connect(server_addr).await.unwrap();
transport.send(&[0x00]).await.unwrap();
transport.register_request(1, Duration::from_secs(5));
let (response, _) = transport.recv(1).await.unwrap();
assert_eq!(response.len(), 203);
assert_eq!(response[0], 0x30);
assert_eq!(response[1], 0x81);
assert_eq!(response[2], 0xc8);
server.await.unwrap();
}
#[tokio::test]
async fn test_tcp_is_reliable() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let server_addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let _ = listener.accept().await;
});
let transport = TcpTransport::connect(server_addr).await.unwrap();
assert!(transport.is_reliable());
}
#[tokio::test]
async fn test_tcp_concurrent_requests() {
use std::sync::Arc;
use std::sync::atomic::{AtomicI32, Ordering};
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let server_addr = listener.local_addr().unwrap();
let request_counter = Arc::new(AtomicI32::new(0));
let counter_clone = request_counter.clone();
let server = tokio::spawn(async move {
let (mut socket, _) = listener.accept().await.unwrap();
for _ in 0..5 {
let mut tag = [0u8; 1];
if socket.read_exact(&mut tag).await.is_err() {
break;
}
let mut len_byte = [0u8; 1];
socket.read_exact(&mut len_byte).await.unwrap();
let content_len = len_byte[0] as usize;
let mut content = vec![0u8; content_len];
socket.read_exact(&mut content).await.unwrap();
let request_id = counter_clone.fetch_add(1, Ordering::SeqCst) + 1;
let response = build_response_with_id(request_id);
socket.write_all(&response).await.unwrap();
}
});
let transport = TcpTransport::connect(server_addr).await.unwrap();
let mut handles = vec![];
for i in 0..5 {
let transport = transport.clone();
let handle = tokio::spawn(async move {
let request_id = i + 1;
let request = build_request_with_id(request_id);
transport.register_request(request_id, Duration::from_secs(5));
transport.send(&request).await?;
let (response, _) = transport.recv(request_id).await?;
assert_eq!(response[0], 0x30, "Response should be SEQUENCE");
Ok::<_, Box<Error>>(i)
});
handles.push(handle);
}
let results: Vec<_> = futures::future::join_all(handles).await;
let success_count = results
.iter()
.filter(|r| r.as_ref().map(|r| r.is_ok()).unwrap_or(false))
.count();
assert_eq!(
success_count, 5,
"All 5 concurrent requests should succeed (serialized)"
);
server.await.unwrap();
}
fn build_request_with_id(request_id: i32) -> Vec<u8> {
let id_bytes = request_id.to_be_bytes();
vec![
0x30,
0x1d, 0x02,
0x01,
0x01, 0x04,
0x06,
0x70,
0x75,
0x62,
0x6c,
0x69,
0x63, 0xa0,
0x10, 0x02,
0x04,
id_bytes[0],
id_bytes[1],
id_bytes[2],
id_bytes[3], 0x02,
0x01,
0x00, 0x02,
0x01,
0x00, 0x30,
0x02,
0x30,
0x00, ]
}
fn build_response_with_id(request_id: i32) -> Vec<u8> {
let id_bytes = request_id.to_be_bytes();
vec![
0x30,
0x1d, 0x02,
0x01,
0x01, 0x04,
0x06,
0x70,
0x75,
0x62,
0x6c,
0x69,
0x63, 0xa2,
0x10, 0x02,
0x04,
id_bytes[0],
id_bytes[1],
id_bytes[2],
id_bytes[3], 0x02,
0x01,
0x00, 0x02,
0x01,
0x00, 0x30,
0x02,
0x30,
0x00, ]
}
#[tokio::test]
async fn test_tcp_rejects_excessive_claimed_size() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let server_addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let (mut socket, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 64];
let _ = socket.read(&mut buf).await;
let malicious_response = [
0x30, 0x84, 0x06, 0x40, 0x00,
0x00, ];
let _ = socket.write_all(&malicious_response).await;
tokio::time::sleep(Duration::from_millis(100)).await;
});
let transport = TcpTransport::connect(server_addr).await.unwrap();
let request = build_request_with_id(1);
transport.send(&request).await.unwrap();
transport.register_request(1, Duration::from_secs(5));
let result = transport.recv(1).await;
assert!(result.is_err(), "Should reject excessive claimed size");
let err = result.unwrap_err();
assert!(
matches!(*err, Error::MalformedResponse { .. }),
"Expected MalformedResponse error, got: {:?}",
err
);
server.await.unwrap();
}
#[tokio::test]
async fn test_tcp_builder_custom_allocation_limit() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let server_addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let (mut socket, _) = listener.accept().await.unwrap();
let mut buf = [0u8; 64];
let _ = socket.read(&mut buf).await;
let response = [
0x30, 0x82, 0x28, 0x00, ];
let _ = socket.write_all(&response).await;
tokio::time::sleep(Duration::from_millis(100)).await;
});
let transport = TcpTransport::builder()
.max_allocation_size(1024) .connect(server_addr)
.await
.unwrap();
let request = build_request_with_id(1);
transport.send(&request).await.unwrap();
transport.register_request(1, Duration::from_secs(5));
let result = transport.recv(1).await;
assert!(
result.is_err(),
"Should reject message exceeding custom limit"
);
let err = result.unwrap_err();
assert!(
matches!(*err, Error::MalformedResponse { .. }),
"Expected MalformedResponse error, got: {:?}",
err
);
server.await.unwrap();
}
}