pub use super::udp_core::TransportStats;
use super::udp_core::UdpCore;
use super::{Transport, extract_request_id};
use crate::error::{Error, Result};
use crate::util::bind_udp_socket;
use bytes::Bytes;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::UdpSocket;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
const UDP_RECV_BUFFER_SIZE: usize = 65535;
#[derive(Clone)]
pub struct UdpTransportConfig {
pub max_message_size: usize,
pub warn_on_source_mismatch: bool,
}
impl Default for UdpTransportConfig {
fn default() -> Self {
Self {
max_message_size: 1472,
warn_on_source_mismatch: true,
}
}
}
#[derive(Clone)]
pub struct UdpTransport {
inner: Arc<UdpTransportInner>,
}
struct UdpTransportInner {
socket: UdpSocket,
local_addr: SocketAddr,
core: UdpCore,
config: UdpTransportConfig,
shutdown: CancellationToken,
recv_task: tokio::sync::Mutex<Option<JoinHandle<()>>>,
}
impl Drop for UdpTransport {
fn drop(&mut self) {
if Arc::get_mut(&mut self.inner).is_some() {
self.inner.shutdown.cancel();
}
}
}
impl UdpTransport {
pub async fn bind(addr: impl AsRef<str>) -> Result<Self> {
Self::builder().bind(addr).build().await
}
pub fn builder() -> UdpTransportBuilder {
UdpTransportBuilder::new()
}
pub fn handle(&self, target: SocketAddr) -> UdpHandle {
let target = self.map_to_socket_family(target);
UdpHandle {
inner: self.inner.clone(),
target,
}
}
fn map_to_socket_family(&self, target: SocketAddr) -> SocketAddr {
if let SocketAddr::V4(v4) = target
&& self.inner.local_addr.is_ipv6()
{
return SocketAddr::new(std::net::IpAddr::V6(v4.ip().to_ipv6_mapped()), v4.port());
}
target
}
pub fn local_addr(&self) -> SocketAddr {
self.inner.local_addr
}
pub fn stats(&self) -> TransportStats {
self.inner.core.stats()
}
pub async fn shutdown(&self) {
self.inner.shutdown.cancel();
let handle = self.inner.recv_task.lock().await.take();
if let Some(handle) = handle {
let _ = handle.await;
}
}
fn start_recv_loop(inner: &Arc<UdpTransportInner>) {
let task_inner = inner.clone();
let handle = tokio::spawn(async move {
let mut buf = vec![0u8; UDP_RECV_BUFFER_SIZE];
let mut cleanup_interval = tokio::time::interval(Duration::from_secs(1));
loop {
tokio::select! {
biased;
_ = task_inner.shutdown.cancelled() => {
tracing::debug!(target: "async_snmp::transport", { snmp.local_addr = %task_inner.local_addr }, "UDP transport shutdown");
break;
}
_ = cleanup_interval.tick() => {
task_inner.core.cleanup_expired();
}
result = task_inner.socket.recv_from(&mut buf) => {
match result {
Ok((len, source)) => {
let data = Bytes::copy_from_slice(&buf[..len]);
if let Some(request_id) = extract_request_id(&data) {
if !task_inner.core.deliver(request_id, data, source) {
tracing::debug!(target: "async_snmp::transport", { snmp.request_id = request_id, snmp.source = %source }, "response for unknown request");
}
} else {
tracing::debug!(target: "async_snmp::transport", { snmp.source = %source, snmp.bytes = len }, "malformed response (no request_id)");
}
}
Err(_) if task_inner.shutdown.is_cancelled() => break,
Err(e) => {
tracing::error!(target: "async_snmp::transport", { error = %e }, "UDP recv error");
}
}
}
}
}
});
*inner
.recv_task
.try_lock()
.expect("recv_task lock at startup") = Some(handle);
}
}
pub struct UdpTransportBuilder {
bind_addr: String,
config: UdpTransportConfig,
recv_buffer_size: Option<usize>,
send_buffer_size: Option<usize>,
}
impl UdpTransportBuilder {
pub fn new() -> Self {
Self {
bind_addr: "0.0.0.0:0".into(),
config: UdpTransportConfig::default(),
recv_buffer_size: None,
send_buffer_size: None,
}
}
pub fn bind(mut self, addr: impl AsRef<str>) -> Self {
self.bind_addr = addr.as_ref().to_string();
self
}
pub fn max_message_size(mut self, size: usize) -> Self {
self.config.max_message_size = size;
self
}
pub fn warn_on_source_mismatch(mut self, warn: bool) -> Self {
self.config.warn_on_source_mismatch = warn;
self
}
pub fn recv_buffer_size(mut self, size: usize) -> Self {
self.recv_buffer_size = Some(size);
self
}
pub fn send_buffer_size(mut self, size: usize) -> Self {
self.send_buffer_size = Some(size);
self
}
pub async fn build(self) -> Result<UdpTransport> {
let bind_addr: SocketAddr = self.bind_addr.parse().map_err(|_| {
Error::Config(format!("invalid bind address: {}", self.bind_addr).into())
})?;
let socket = bind_udp_socket(
bind_addr,
self.recv_buffer_size,
self.send_buffer_size,
true,
)
.await
.map_err(|e| Error::Network {
target: bind_addr,
source: e,
})?;
let local_addr = socket.local_addr().map_err(|e| Error::Network {
target: bind_addr,
source: e,
})?;
tracing::debug!(target: "async_snmp::transport", { snmp.local_addr = %local_addr }, "UDP transport bound");
let inner = Arc::new(UdpTransportInner {
socket,
local_addr,
core: UdpCore::new(),
config: self.config,
shutdown: CancellationToken::new(),
recv_task: tokio::sync::Mutex::new(None),
});
UdpTransport::start_recv_loop(&inner);
Ok(UdpTransport { inner })
}
}
impl Default for UdpTransportBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct UdpHandle {
inner: Arc<UdpTransportInner>,
target: SocketAddr,
}
impl Transport for UdpHandle {
async fn send(&self, data: &[u8]) -> Result<()> {
tracing::trace!(target: "async_snmp::transport", { snmp.target = %self.target, snmp.bytes = data.len() }, "UDP send");
self.inner
.socket
.send_to(data, self.target)
.await
.map_err(|e| Error::Network {
target: self.target,
source: e,
})?;
Ok(())
}
async fn recv(&self, request_id: i32) -> Result<(Bytes, SocketAddr)> {
tracing::trace!(target: "async_snmp::transport", { snmp.target = %self.target, snmp.request_id = request_id }, "UDP recv waiting");
let result = self
.inner
.core
.wait_for_response(request_id, self.target)
.await;
match &result {
Ok((data, source)) => {
if self.inner.config.warn_on_source_mismatch && *source != self.target {
tracing::warn!(target: "async_snmp::transport", { snmp.request_id = request_id, snmp.target = %self.target, snmp.source = %source }, "response source address mismatch");
}
tracing::trace!(target: "async_snmp::transport", { snmp.target = %self.target, snmp.source = %source, snmp.bytes = data.len() }, "UDP recv complete");
}
Err(_) => {
tracing::trace!(target: "async_snmp::transport", { snmp.target = %self.target, snmp.request_id = request_id }, "UDP recv failed");
}
}
result
}
fn peer_addr(&self) -> SocketAddr {
self.target
}
fn local_addr(&self) -> SocketAddr {
self.inner.local_addr
}
fn max_message_size(&self) -> u32 {
self.inner.config.max_message_size as u32
}
fn is_reliable(&self) -> bool {
false
}
fn register_request(&self, request_id: i32, timeout: Duration) {
self.inner.core.register(request_id, timeout);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn ipv6_transport_maps_ipv4_target() {
let transport = UdpTransport::bind("[::]:0").await.unwrap();
let handle = transport.handle("127.0.0.1:161".parse().unwrap());
let mapped: SocketAddr = "[::ffff:127.0.0.1]:161".parse().unwrap();
assert_eq!(handle.peer_addr(), mapped);
}
#[tokio::test]
async fn ipv4_transport_preserves_ipv4_target() {
let transport = UdpTransport::bind("0.0.0.0:0").await.unwrap();
let handle = transport.handle("127.0.0.1:161".parse().unwrap());
let expected: SocketAddr = "127.0.0.1:161".parse().unwrap();
assert_eq!(handle.peer_addr(), expected);
}
#[tokio::test]
async fn ipv6_transport_preserves_ipv6_target() {
let transport = UdpTransport::bind("[::]:0").await.unwrap();
let handle = transport.handle("[::1]:161".parse().unwrap());
let expected: SocketAddr = "[::1]:161".parse().unwrap();
assert_eq!(handle.peer_addr(), expected);
}
#[tokio::test]
async fn max_message_size_default() {
let transport = UdpTransport::bind("0.0.0.0:0").await.unwrap();
let handle = transport.handle("127.0.0.1:161".parse().unwrap());
assert_eq!(handle.max_message_size(), 1472);
}
#[tokio::test]
async fn max_message_size_custom() {
let transport = UdpTransport::builder()
.max_message_size(8192)
.build()
.await
.unwrap();
let handle = transport.handle("127.0.0.1:161".parse().unwrap());
assert_eq!(handle.max_message_size(), 8192);
}
#[tokio::test]
async fn recv_buffer_size_configurable() {
let transport = UdpTransport::builder()
.recv_buffer_size(2 * 1024 * 1024)
.build()
.await
.unwrap();
assert!(transport.local_addr().port() > 0);
}
#[tokio::test]
async fn send_buffer_size_configurable() {
let transport = UdpTransport::builder()
.send_buffer_size(512 * 1024)
.build()
.await
.unwrap();
assert!(transport.local_addr().port() > 0);
}
}