mod handlers;
mod types;
mod varbind;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Instant;
use bytes::Bytes;
use tokio::net::UdpSocket;
use tracing::instrument;
use crate::ber::Decoder;
use crate::error::internal::DecodeErrorKind;
use crate::error::{Error, Result};
use crate::oid::Oid;
use crate::pdu::TrapV1Pdu;
use crate::util::bind_udp_socket;
use crate::v3::SaltCounter;
use crate::varbind::VarBind;
use crate::version::Version;
pub use types::{DerivedKeys, UsmConfig};
pub use varbind::validate_notification_varbinds;
pub mod oids {
use crate::oid;
pub fn sys_uptime() -> crate::Oid {
oid!(1, 3, 6, 1, 2, 1, 1, 3, 0)
}
pub fn snmp_trap_oid() -> crate::Oid {
oid!(1, 3, 6, 1, 6, 3, 1, 1, 4, 1, 0)
}
pub fn snmp_trap_enterprise() -> crate::Oid {
oid!(1, 3, 6, 1, 6, 3, 1, 1, 4, 3, 0)
}
pub fn snmp_trap_address() -> crate::Oid {
oid!(1, 3, 6, 1, 6, 3, 18, 1, 3, 0)
}
pub fn snmp_traps() -> crate::Oid {
oid!(1, 3, 6, 1, 6, 3, 1, 1, 5)
}
pub fn cold_start() -> crate::Oid {
oid!(1, 3, 6, 1, 6, 3, 1, 1, 5, 1)
}
pub fn warm_start() -> crate::Oid {
oid!(1, 3, 6, 1, 6, 3, 1, 1, 5, 2)
}
pub fn link_down() -> crate::Oid {
oid!(1, 3, 6, 1, 6, 3, 1, 1, 5, 3)
}
pub fn link_up() -> crate::Oid {
oid!(1, 3, 6, 1, 6, 3, 1, 1, 5, 4)
}
pub fn auth_failure() -> crate::Oid {
oid!(1, 3, 6, 1, 6, 3, 1, 1, 5, 5)
}
pub fn egp_neighbor_loss() -> crate::Oid {
oid!(1, 3, 6, 1, 6, 3, 1, 1, 5, 6)
}
}
pub struct NotificationReceiverBuilder {
bind_addr: String,
usm_users: HashMap<Bytes, UsmConfig>,
engine_id: Option<Vec<u8>>,
engine_boots: u32,
}
impl NotificationReceiverBuilder {
pub fn new() -> Self {
Self {
bind_addr: "0.0.0.0:162".to_string(),
usm_users: HashMap::new(),
engine_id: None,
engine_boots: 1,
}
}
pub fn bind(mut self, addr: impl Into<String>) -> Self {
self.bind_addr = addr.into();
self
}
pub fn usm_user<F>(mut self, username: impl Into<Bytes>, configure: F) -> Self
where
F: FnOnce(UsmConfig) -> UsmConfig,
{
let username_bytes: Bytes = username.into();
let config = configure(UsmConfig::new(username_bytes.clone()));
self.usm_users.insert(username_bytes, config);
self
}
pub fn engine_id(mut self, engine_id: impl Into<Vec<u8>>) -> Self {
self.engine_id = Some(engine_id.into());
self
}
pub fn engine_boots(mut self, boots: u32) -> Self {
self.engine_boots = boots;
self
}
pub async fn build(self) -> Result<NotificationReceiver> {
let bind_addr: SocketAddr = self.bind_addr.parse().map_err(|_| {
Error::Config(format!("invalid bind address: {}", self.bind_addr).into())
})?;
let socket = bind_udp_socket(bind_addr, None, None, false)
.await
.map_err(|e| Error::Network {
target: bind_addr,
source: e,
})?;
let local_addr = socket.local_addr().map_err(|e| Error::Network {
target: bind_addr,
source: e,
})?;
let engine_id: Bytes = self.engine_id.map(Bytes::from).unwrap_or_else(|| {
let mut id = vec![0x80, 0x00, 0x00, 0x00, 0x01];
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
id.extend_from_slice(×tamp.to_be_bytes());
Bytes::from(id)
});
Ok(NotificationReceiver {
inner: Arc::new(ReceiverInner {
socket,
local_addr,
usm_users: self.usm_users,
engine_id,
salt_counter: SaltCounter::new(),
engine_boots_base: self.engine_boots,
engine_start: Instant::now(),
usm_unknown_engine_ids: AtomicU32::new(0),
}),
})
}
}
impl Default for NotificationReceiverBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub enum Notification {
TrapV1 {
community: Bytes,
trap: TrapV1Pdu,
},
TrapV2c {
community: Bytes,
uptime: u32,
trap_oid: Oid,
varbinds: Vec<VarBind>,
request_id: i32,
},
TrapV3 {
username: Bytes,
context_engine_id: Bytes,
context_name: Bytes,
uptime: u32,
trap_oid: Oid,
varbinds: Vec<VarBind>,
request_id: i32,
},
InformV2c {
community: Bytes,
uptime: u32,
trap_oid: Oid,
varbinds: Vec<VarBind>,
request_id: i32,
},
InformV3 {
username: Bytes,
context_engine_id: Bytes,
context_name: Bytes,
uptime: u32,
trap_oid: Oid,
varbinds: Vec<VarBind>,
request_id: i32,
},
}
impl Notification {
pub fn trap_oid(&self) -> Result<Oid> {
match self {
Notification::TrapV1 { trap, .. } => trap.v2_trap_oid(),
Notification::TrapV2c { trap_oid, .. }
| Notification::TrapV3 { trap_oid, .. }
| Notification::InformV2c { trap_oid, .. }
| Notification::InformV3 { trap_oid, .. } => Ok(trap_oid.clone()),
}
}
pub fn uptime(&self) -> u32 {
match self {
Notification::TrapV1 { trap, .. } => trap.time_stamp,
Notification::TrapV2c { uptime, .. }
| Notification::TrapV3 { uptime, .. }
| Notification::InformV2c { uptime, .. }
| Notification::InformV3 { uptime, .. } => *uptime,
}
}
pub fn varbinds(&self) -> &[VarBind] {
match self {
Notification::TrapV1 { trap, .. } => &trap.varbinds,
Notification::TrapV2c { varbinds, .. }
| Notification::TrapV3 { varbinds, .. }
| Notification::InformV2c { varbinds, .. }
| Notification::InformV3 { varbinds, .. } => varbinds,
}
}
pub fn is_confirmed(&self) -> bool {
matches!(
self,
Notification::InformV2c { .. } | Notification::InformV3 { .. }
)
}
pub fn version(&self) -> Version {
match self {
Notification::TrapV1 { .. } => Version::V1,
Notification::TrapV2c { .. } | Notification::InformV2c { .. } => Version::V2c,
Notification::TrapV3 { .. } | Notification::InformV3 { .. } => Version::V3,
}
}
}
pub struct NotificationReceiver {
inner: Arc<ReceiverInner>,
}
struct ReceiverInner {
socket: UdpSocket,
local_addr: SocketAddr,
usm_users: HashMap<Bytes, UsmConfig>,
engine_id: Bytes,
salt_counter: SaltCounter,
engine_boots_base: u32,
engine_start: Instant,
usm_unknown_engine_ids: AtomicU32,
}
impl NotificationReceiver {
pub fn builder() -> NotificationReceiverBuilder {
NotificationReceiverBuilder::new()
}
pub async fn bind(addr: impl AsRef<str>) -> Result<Self> {
let addr_str = addr.as_ref();
let bind_addr: SocketAddr = addr_str
.parse()
.map_err(|_| Error::Config(format!("invalid bind address: {}", addr_str).into()))?;
let socket = bind_udp_socket(bind_addr, None, None, false)
.await
.map_err(|e| Error::Network {
target: bind_addr,
source: e,
})?;
let local_addr = socket.local_addr().map_err(|e| Error::Network {
target: bind_addr,
source: e,
})?;
let engine_id: Bytes = {
let mut id = vec![0x80, 0x00, 0x00, 0x00, 0x01];
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
id.extend_from_slice(×tamp.to_be_bytes());
Bytes::from(id)
};
Ok(Self {
inner: Arc::new(ReceiverInner {
socket,
local_addr,
usm_users: HashMap::new(),
engine_id,
salt_counter: SaltCounter::new(),
engine_boots_base: 1,
engine_start: Instant::now(),
usm_unknown_engine_ids: AtomicU32::new(0),
}),
})
}
pub fn local_addr(&self) -> SocketAddr {
self.inner.local_addr
}
pub fn engine_id(&self) -> &[u8] {
&self.inner.engine_id
}
pub fn usm_unknown_engine_ids(&self) -> u32 {
self.inner.usm_unknown_engine_ids.load(Ordering::Relaxed)
}
#[instrument(skip(self), err, fields(snmp.local_addr = %self.local_addr()))]
pub async fn recv(&self) -> Result<(Notification, SocketAddr)> {
let mut buf = vec![0u8; 65535];
loop {
let (len, source) =
self.inner
.socket
.recv_from(&mut buf)
.await
.map_err(|e| Error::Network {
target: self.inner.local_addr,
source: e,
})?;
let data = Bytes::copy_from_slice(&buf[..len]);
match self.parse_and_respond(data, source).await {
Ok(Some(notification)) => return Ok((notification, source)),
Ok(None) => continue, Err(e) => {
tracing::warn!(target: "async_snmp::notification", { snmp.source = %source, error = %e }, "failed to parse notification");
continue;
}
}
}
}
async fn parse_and_respond(
&self,
data: Bytes,
source: SocketAddr,
) -> Result<Option<Notification>> {
let mut decoder = Decoder::with_target(data.clone(), source);
let mut seq = decoder.read_sequence()?;
let version_num = seq.read_integer()?;
let version = Version::from_i32(version_num).ok_or_else(|| {
tracing::debug!(target: "async_snmp::notification", { source = %source, kind = %DecodeErrorKind::UnknownVersion(version_num) }, "unknown SNMP version");
Error::MalformedResponse { target: source }.boxed()
})?;
drop(seq);
drop(decoder);
match version {
Version::V1 => self.handle_v1(data, source).await,
Version::V2c => self.handle_v2c(data, source).await,
Version::V3 => self.handle_v3(data, source).await,
}
}
}
impl Clone for NotificationReceiver {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message::SecurityLevel;
use crate::oid;
use crate::pdu::GenericTrap;
use crate::v3::AuthProtocol;
#[test]
fn test_notification_trap_v1() {
let trap = TrapV1Pdu::new(
oid!(1, 3, 6, 1, 4, 1, 9999),
[192, 168, 1, 1],
GenericTrap::LinkDown,
0,
12345,
vec![],
);
let notification = Notification::TrapV1 {
community: Bytes::from_static(b"public"),
trap,
};
assert!(!notification.is_confirmed());
assert_eq!(notification.version(), Version::V1);
assert_eq!(notification.uptime(), 12345);
assert_eq!(notification.trap_oid().unwrap(), oids::link_down());
}
#[test]
fn test_notification_trap_v2c() {
let notification = Notification::TrapV2c {
community: Bytes::from_static(b"public"),
uptime: 54321,
trap_oid: oids::link_up(),
varbinds: vec![],
request_id: 1,
};
assert!(!notification.is_confirmed());
assert_eq!(notification.version(), Version::V2c);
assert_eq!(notification.uptime(), 54321);
assert_eq!(notification.trap_oid().unwrap(), oids::link_up());
}
#[test]
fn test_notification_inform() {
let notification = Notification::InformV2c {
community: Bytes::from_static(b"public"),
uptime: 11111,
trap_oid: oids::cold_start(),
varbinds: vec![],
request_id: 42,
};
assert!(notification.is_confirmed());
assert_eq!(notification.version(), Version::V2c);
}
#[test]
fn test_notification_receiver_builder_default() {
let builder = NotificationReceiverBuilder::new();
assert_eq!(builder.bind_addr, "0.0.0.0:162");
assert!(builder.usm_users.is_empty());
}
#[test]
fn test_notification_receiver_builder_with_user() {
let builder = NotificationReceiverBuilder::new()
.bind("0.0.0.0:1162")
.usm_user("trapuser", |u| u.auth(AuthProtocol::Sha1, b"authpass"));
assert_eq!(builder.bind_addr, "0.0.0.0:1162");
assert_eq!(builder.usm_users.len(), 1);
let user = builder
.usm_users
.get(&Bytes::from_static(b"trapuser"))
.unwrap();
assert_eq!(user.security_level(), SecurityLevel::AuthNoPriv);
}
#[test]
fn test_notification_v3_inform() {
let notification = Notification::InformV3 {
username: Bytes::from_static(b"testuser"),
context_engine_id: Bytes::from_static(b"engine123"),
context_name: Bytes::new(),
uptime: 99999,
trap_oid: oids::warm_start(),
varbinds: vec![],
request_id: 100,
};
assert!(notification.is_confirmed());
assert_eq!(notification.version(), Version::V3);
assert_eq!(notification.uptime(), 99999);
assert_eq!(notification.trap_oid().unwrap(), oids::warm_start());
}
#[test]
fn test_notification_trap_v1_enterprise_specific_oid() {
let trap = TrapV1Pdu::new(
oid!(1, 3, 6, 1, 4, 1, 9999, 1, 2),
[192, 168, 1, 1],
GenericTrap::EnterpriseSpecific,
42,
12345,
vec![],
);
let notification = Notification::TrapV1 {
community: Bytes::from_static(b"public"),
trap,
};
assert_eq!(
notification.trap_oid().unwrap(),
oid!(1, 3, 6, 1, 4, 1, 9999, 1, 2, 0, 42)
);
}
#[test]
fn test_compute_engine_boots_time_basic() {
let (boots, time) = crate::v3::compute_engine_boots_time(1, 1000);
assert_eq!(boots, 1);
assert_eq!(time, 1000);
}
#[test]
fn test_compute_engine_boots_time_zero_elapsed() {
let (boots, time) = crate::v3::compute_engine_boots_time(1, 0);
assert_eq!(boots, 1);
assert_eq!(time, 0);
}
#[test]
fn test_builder_engine_boots_default() {
let builder = NotificationReceiverBuilder::new();
assert_eq!(builder.engine_boots, 1);
}
#[test]
fn test_builder_engine_boots_custom() {
let builder = NotificationReceiverBuilder::new().engine_boots(5);
assert_eq!(builder.engine_boots, 5);
}
fn build_authed_v3_inform(
engine_id: &[u8],
engine_boots: u32,
engine_time: u32,
username: &[u8],
auth_password: &[u8],
auth_protocol: AuthProtocol,
) -> Bytes {
use crate::message::{MsgFlags, MsgGlobalData, ScopedPdu, V3Message};
use crate::pdu::{Pdu, PduType};
use crate::v3::auth::authenticate_message;
use crate::v3::{LocalizedKey, UsmSecurityParams};
use crate::value::Value;
let auth_key =
LocalizedKey::from_password(auth_protocol, auth_password, engine_id).unwrap();
let mac_len = auth_key.mac_len();
let pdu = Pdu {
pdu_type: PduType::InformRequest,
request_id: 1,
error_status: 0,
error_index: 0,
varbinds: vec![
VarBind::new(oids::sys_uptime(), Value::TimeTicks(1000)),
VarBind::new(
oids::snmp_trap_oid(),
Value::ObjectIdentifier(oids::cold_start()),
),
],
};
let global = MsgGlobalData::new(1, 65507, MsgFlags::new(SecurityLevel::AuthNoPriv, false));
let usm_params = UsmSecurityParams::new(
Bytes::copy_from_slice(engine_id),
engine_boots,
engine_time,
Bytes::copy_from_slice(username),
)
.with_auth_placeholder(mac_len);
let scoped = ScopedPdu::new(Bytes::copy_from_slice(engine_id), Bytes::new(), pdu);
let msg = V3Message::new(global, usm_params.encode(), scoped);
let mut msg_bytes = msg.encode().to_vec();
let (auth_offset, auth_len) =
UsmSecurityParams::find_auth_params_offset(&msg_bytes).unwrap();
authenticate_message(&auth_key, &mut msg_bytes, auth_offset, auth_len).unwrap();
Bytes::from(msg_bytes)
}
#[tokio::test]
async fn test_v3_inform_outside_time_window_rejected() {
let receiver = NotificationReceiver::builder()
.bind("127.0.0.1:0")
.engine_id(b"test-engine".to_vec())
.engine_boots(1)
.usm_user("informuser", |u| {
u.auth(AuthProtocol::Sha1, b"authpass12345678")
})
.build()
.await
.unwrap();
let engine_id = b"test-engine";
let source: SocketAddr = "127.0.0.1:9999".parse().unwrap();
let msg = build_authed_v3_inform(
engine_id,
1, 5000, b"informuser",
b"authpass12345678",
AuthProtocol::Sha1,
);
let result = receiver.handle_v3(msg, source).await;
assert!(
result.is_err(),
"message with engine_time=5000 should be rejected (outside 150s window)"
);
}
#[tokio::test]
async fn test_v3_inform_wrong_boots_rejected() {
let receiver = NotificationReceiver::builder()
.bind("127.0.0.1:0")
.engine_id(b"test-engine".to_vec())
.engine_boots(1)
.usm_user("informuser", |u| {
u.auth(AuthProtocol::Sha1, b"authpass12345678")
})
.build()
.await
.unwrap();
let engine_id = b"test-engine";
let source: SocketAddr = "127.0.0.1:9999".parse().unwrap();
let msg = build_authed_v3_inform(
engine_id,
2, 0, b"informuser",
b"authpass12345678",
AuthProtocol::Sha1,
);
let result = receiver.handle_v3(msg, source).await;
assert!(
result.is_err(),
"message with wrong engine_boots should be rejected"
);
}
#[tokio::test]
async fn test_v3_inform_within_time_window_accepted() {
let receiver = NotificationReceiver::builder()
.bind("127.0.0.1:0")
.engine_id(b"test-engine".to_vec())
.engine_boots(1)
.usm_user("informuser", |u| {
u.auth(AuthProtocol::Sha1, b"authpass12345678")
})
.build()
.await
.unwrap();
let engine_id = b"test-engine";
let source: SocketAddr = "127.0.0.1:9999".parse().unwrap();
let msg = build_authed_v3_inform(
engine_id,
1, 0, b"informuser",
b"authpass12345678",
AuthProtocol::Sha1,
);
let result = receiver.handle_v3(msg, source).await;
match result {
Ok(Some(_)) => {} Err(e) => {
let err_str = format!("{}", e);
assert!(
!err_str.contains("Auth"),
"should not be an auth error for valid time window, got: {}",
err_str
);
}
Ok(None) => panic!("should not return None for a valid InformRequest"),
}
}
fn build_v3_discovery_request(msg_id: i32, reportable: bool) -> Bytes {
use crate::message::{MsgFlags, MsgGlobalData, ScopedPdu, V3Message};
use crate::pdu::{Pdu, PduType};
use crate::v3::UsmSecurityParams;
let pdu = Pdu {
pdu_type: PduType::GetRequest,
request_id: 0,
error_status: 0,
error_index: 0,
varbinds: vec![],
};
let global = MsgGlobalData::new(
msg_id,
65507,
MsgFlags::new(SecurityLevel::NoAuthNoPriv, reportable),
);
let usm_params = UsmSecurityParams::new(
Bytes::new(), 0,
0,
Bytes::new(), );
let scoped = ScopedPdu::new(Bytes::new(), Bytes::new(), pdu);
let msg = V3Message::new(global, usm_params.encode(), scoped);
msg.encode()
}
#[tokio::test]
async fn test_v3_discovery_gets_response() {
let receiver = NotificationReceiver::builder()
.bind("127.0.0.1:0")
.engine_id(b"test-discovery-engine".to_vec())
.build()
.await
.unwrap();
let recv_addr = receiver.local_addr();
let client = tokio::net::UdpSocket::bind("127.0.0.1:0").await.unwrap();
let client_addr = client.local_addr().unwrap();
let discovery_msg = build_v3_discovery_request(42, true);
client.send_to(&discovery_msg, recv_addr).await.unwrap();
let result = receiver.handle_v3(discovery_msg, client_addr).await;
assert!(result.is_ok());
assert!(result.unwrap().is_none());
assert_eq!(receiver.usm_unknown_engine_ids(), 1);
}
#[tokio::test]
async fn test_v3_discovery_non_reportable_ignored() {
let receiver = NotificationReceiver::builder()
.bind("127.0.0.1:0")
.engine_id(b"test-discovery-engine".to_vec())
.build()
.await
.unwrap();
let source: SocketAddr = "127.0.0.1:9999".parse().unwrap();
let discovery_msg = build_v3_discovery_request(42, false);
let result = receiver.handle_v3(discovery_msg, source).await;
assert!(result.is_ok());
assert!(result.unwrap().is_none());
assert_eq!(receiver.usm_unknown_engine_ids(), 0);
}
#[tokio::test]
async fn test_v3_engine_id_mismatch_ignored() {
let receiver = NotificationReceiver::builder()
.bind("127.0.0.1:0")
.engine_id(b"my-receiver-engine".to_vec())
.engine_boots(1)
.usm_user("informuser", |u| {
u.auth(AuthProtocol::Sha1, b"authpass12345678")
})
.build()
.await
.unwrap();
let source: SocketAddr = "127.0.0.1:9999".parse().unwrap();
let msg = build_authed_v3_inform(
b"wrong-engine-id",
1,
0,
b"informuser",
b"authpass12345678",
AuthProtocol::Sha1,
);
let result = receiver.handle_v3(msg, source).await;
assert!(result.is_ok());
assert!(
result.unwrap().is_none(),
"engine ID mismatch should return None"
);
}
#[test]
fn test_auto_generated_engine_id_non_empty() {
let builder = NotificationReceiverBuilder::new();
assert!(builder.engine_id.is_none());
}
#[tokio::test]
async fn test_bind_generates_engine_id() {
let receiver = NotificationReceiver::bind("127.0.0.1:0").await.unwrap();
assert!(!receiver.engine_id().is_empty());
assert_eq!(receiver.engine_id()[0], 0x80);
}
#[tokio::test]
async fn test_builder_generates_engine_id() {
let receiver = NotificationReceiver::builder()
.bind("127.0.0.1:0")
.build()
.await
.unwrap();
assert!(!receiver.engine_id().is_empty());
assert_eq!(receiver.engine_id()[0], 0x80);
}
#[tokio::test]
async fn test_builder_custom_engine_id() {
let receiver = NotificationReceiver::builder()
.bind("127.0.0.1:0")
.engine_id(b"custom-engine".to_vec())
.build()
.await
.unwrap();
assert_eq!(receiver.engine_id(), b"custom-engine");
}
}