use bytes::Bytes;
use std::net::SocketAddr;
use crate::ber::Decoder;
use crate::error::internal::{CryptoErrorKind, DecodeErrorKind};
use crate::error::{Error, Result};
use crate::handler::{RequestContext, SecurityModel};
use crate::message::{
CommunityMessage, MsgFlags, MsgGlobalData, ScopedPdu, SecurityLevel, V3Message, V3MessageData,
};
use crate::pdu::{Pdu, PduType};
use crate::v3::auth::verify_message;
use crate::v3::{MAX_ENGINE_TIME, UsmSecurityParams};
use crate::value::Value;
use crate::varbind::VarBind;
use crate::version::Version;
use std::sync::atomic::Ordering;
use super::Agent;
impl Agent {
pub(super) async fn handle_v1(&self, data: Bytes, source: SocketAddr) -> Result<Option<Bytes>> {
self.handle_community(data, source, Version::V1).await
}
pub(super) async fn handle_v2c(
&self,
data: Bytes,
source: SocketAddr,
) -> Result<Option<Bytes>> {
self.handle_community(data, source, Version::V2c).await
}
async fn handle_community(
&self,
data: Bytes,
source: SocketAddr,
version: Version,
) -> Result<Option<Bytes>> {
let msg = CommunityMessage::decode(data)?;
if !self.validate_community(&msg.community) {
tracing::debug!(target: "async_snmp::agent", { snmp.source = %source }, "invalid community string");
return Ok(None);
}
let pdu = match msg.pdu.standard() {
Some(p) if is_request_pdu(p.pdu_type) => p,
_ => return Ok(None),
};
let security_model = match version {
Version::V1 => SecurityModel::V1,
Version::V2c => SecurityModel::V2c,
Version::V3 => unreachable!("handle_community called with V3"),
};
let mut ctx = RequestContext {
source,
version,
security_model,
security_name: msg.community.clone(),
security_level: SecurityLevel::NoAuthNoPriv,
context_name: Bytes::new(),
request_id: pdu.request_id,
pdu_type: pdu.pdu_type,
group_name: None,
read_view: None,
write_view: None,
msg_max_size: None,
};
self.resolve_vacm(&mut ctx);
let response_pdu = self.dispatch_request(&ctx, pdu).await?;
let response_msg = match version {
Version::V1 => CommunityMessage::v1(msg.community, response_pdu),
Version::V2c => CommunityMessage::v2c(msg.community, response_pdu),
Version::V3 => unreachable!("handle_community called with V3"),
};
Ok(Some(response_msg.encode()))
}
pub(super) async fn handle_v3(&self, data: Bytes, source: SocketAddr) -> Result<Option<Bytes>> {
let msg = V3Message::decode(data.clone())?;
let security_level = msg.global_data.msg_flags.security_level;
let usm_params = UsmSecurityParams::decode(msg.security_params.clone())?;
if usm_params.engine_id.is_empty() {
return self.handle_v3_discovery(&msg, source);
}
if usm_params.engine_id.as_ref() != self.inner.state.engine_id.as_ref() {
tracing::debug!(target: "async_snmp::agent", { snmp.source = %source }, "engine ID mismatch");
let count = self
.inner
.state
.usm_unknown_engine_ids
.fetch_add(1, Ordering::Relaxed)
+ 1;
return self.send_v3_report(
&msg,
&usm_params,
crate::v3::report_oids::unknown_engine_ids(),
count,
source,
);
}
let user_config = self.inner.usm_users.get(&usm_params.username);
let derived_keys = user_config
.map(|u| u.derive_keys(&self.inner.state.engine_id))
.transpose()
.map_err(|e| Error::Config(e.to_string().into()).boxed())?;
if user_config.is_none() {
tracing::debug!(target: "async_snmp::agent", { snmp.source = %source, snmp.username = %String::from_utf8_lossy(&usm_params.username) }, "unknown user");
let count = self
.inner
.state
.usm_unknown_usernames
.fetch_add(1, Ordering::Relaxed)
+ 1;
return self.send_v3_report(
&msg,
&usm_params,
crate::v3::report_oids::unknown_user_names(),
count,
source,
);
}
if security_level.requires_auth()
&& self.inner.state.engine_boots.load(Ordering::Relaxed) == MAX_ENGINE_TIME
{
tracing::warn!(target: "async_snmp::agent", { snmp.source = %source }, "engine boots at maximum, rejecting authenticated request");
let count = self
.inner
.state
.usm_not_in_time_windows
.fetch_add(1, Ordering::Relaxed)
+ 1;
return self.send_v3_report(
&msg,
&usm_params,
crate::v3::report_oids::not_in_time_windows(),
count,
source,
);
}
if security_level == SecurityLevel::AuthNoPriv || security_level == SecurityLevel::AuthPriv
{
match &derived_keys {
Some(keys) if keys.auth_key.is_some() => {
let auth_key = keys.auth_key.as_ref().unwrap();
let (auth_offset, auth_len) = UsmSecurityParams::find_auth_params_offset(&data)
.ok_or_else(|| {
tracing::debug!(target: "async_snmp::agent", { source = %source }, "could not find auth params in message");
Error::Auth { target: source }.boxed()
})?;
if !verify_message(auth_key, &data, auth_offset, auth_len)
.map_err(|_| Error::Auth { target: source }.boxed())?
{
tracing::debug!(target: "async_snmp::agent", { snmp.source = %source }, "authentication failed");
let count = self
.inner
.state
.usm_wrong_digests
.fetch_add(1, Ordering::Relaxed)
+ 1;
return self.send_v3_report(
&msg,
&usm_params,
crate::v3::report_oids::wrong_digests(),
count,
source,
);
}
let our_time = self.inner.state.engine_time.load(Ordering::Relaxed);
let time_diff = (usm_params.engine_time as i64 - our_time as i64).abs();
if time_diff > 150 {
tracing::debug!(target: "async_snmp::agent", { snmp.source = %source }, "message outside time window");
let count = self
.inner
.state
.usm_not_in_time_windows
.fetch_add(1, Ordering::Relaxed)
+ 1;
return self.send_v3_report(
&msg,
&usm_params,
crate::v3::report_oids::not_in_time_windows(),
count,
source,
);
}
}
_ => {
tracing::debug!(target: "async_snmp::agent", { snmp.source = %source, snmp.username = %String::from_utf8_lossy(&usm_params.username) }, "user does not support requested security level");
let count = self
.inner
.state
.usm_unsupported_sec_levels
.fetch_add(1, Ordering::Relaxed)
+ 1;
return self.send_v3_report(
&msg,
&usm_params,
crate::v3::report_oids::unsupported_sec_levels(),
count,
source,
);
}
}
}
let scoped_pdu = if security_level == SecurityLevel::AuthPriv {
match &derived_keys {
Some(keys) if keys.priv_key.is_some() => {
let priv_key = keys.priv_key.as_ref().unwrap();
let encrypted_data = match &msg.data {
V3MessageData::Encrypted(data) => data,
V3MessageData::Plaintext(_) => {
tracing::debug!(target: "async_snmp::agent", { source = %source, kind = %DecodeErrorKind::ExpectedEncryption }, "expected encrypted scoped PDU");
return Err(Error::MalformedResponse { target: source }.boxed());
}
};
let decrypted = match priv_key.decrypt(
encrypted_data,
usm_params.engine_boots,
usm_params.engine_time,
&usm_params.priv_params,
) {
Ok(data) => data,
Err(e) => {
tracing::debug!(target: "async_snmp::agent", { source = %source, error = %e }, "decryption failed");
let count = self
.inner
.state
.usm_decryption_errors
.fetch_add(1, Ordering::Relaxed)
+ 1;
return self.send_v3_report(
&msg,
&usm_params,
crate::v3::report_oids::decryption_errors(),
count,
source,
);
}
};
let mut decoder = Decoder::with_target(decrypted, source);
ScopedPdu::decode(&mut decoder)?
}
_ => {
tracing::debug!(target: "async_snmp::agent", { source = %source, kind = %CryptoErrorKind::NoPrivKey }, "no privacy key configured for user");
let count = self
.inner
.state
.usm_unsupported_sec_levels
.fetch_add(1, Ordering::Relaxed)
+ 1;
return self.send_v3_report(
&msg,
&usm_params,
crate::v3::report_oids::unsupported_sec_levels(),
count,
source,
);
}
}
} else {
match msg.scoped_pdu() {
Some(sp) => sp.clone(),
None => {
tracing::debug!(target: "async_snmp::agent", { source = %source, kind = %DecodeErrorKind::UnexpectedEncryption }, "unexpected encrypted scoped PDU");
return Err(Error::MalformedResponse { target: source }.boxed());
}
}
};
let pdu = &scoped_pdu.pdu;
if !is_request_pdu(pdu.pdu_type) {
return Ok(None);
}
let mut ctx = RequestContext {
source,
version: Version::V3,
security_model: SecurityModel::Usm,
security_name: usm_params.username.clone(),
security_level,
context_name: scoped_pdu.context_name.clone(),
request_id: pdu.request_id,
pdu_type: pdu.pdu_type,
group_name: None,
read_view: None,
write_view: None,
msg_max_size: Some(msg.global_data.msg_max_size as u32),
};
self.resolve_vacm(&mut ctx);
let response_pdu = self.dispatch_request(&ctx, pdu).await?;
self.build_v3_response(
&msg,
&usm_params,
response_pdu,
scoped_pdu.context_engine_id.clone(),
scoped_pdu.context_name.clone(),
derived_keys.as_ref(),
)
}
fn resolve_vacm(&self, ctx: &mut RequestContext) {
if let Some(ref vacm) = self.inner.vacm
&& let Some(group) = vacm.get_group(ctx.security_model, &ctx.security_name)
{
ctx.group_name = Some(group.clone());
if let Some(access) = vacm.get_access(
group,
&ctx.context_name,
ctx.security_model,
ctx.security_level,
) {
ctx.read_view = Some(access.read_view.clone());
ctx.write_view = Some(access.write_view.clone());
} else {
tracing::warn!(
target: "async_snmp::agent",
group = %String::from_utf8_lossy(group),
context = %String::from_utf8_lossy(&ctx.context_name),
security_model = ?ctx.security_model,
security_level = ?ctx.security_level,
"VACM group has no matching access entry, denying access"
);
}
}
}
pub(super) fn handle_v3_discovery(
&self,
incoming: &V3Message,
_source: SocketAddr,
) -> Result<Option<Bytes>> {
if !incoming.global_data.msg_flags.reportable {
tracing::debug!(target: "async_snmp::agent", "discovery request has reportable=false, not sending report");
return Ok(None);
}
let engine_boots = self.inner.state.engine_boots.load(Ordering::Relaxed);
let engine_time = self.inner.state.engine_time.load(Ordering::Relaxed);
let unknown_engine_ids_count = self
.inner
.state
.usm_unknown_engine_ids
.fetch_add(1, Ordering::Relaxed)
+ 1;
let report_pdu = Pdu {
pdu_type: PduType::Report,
request_id: incoming.global_data.msg_id,
error_status: 0,
error_index: 0,
varbinds: vec![VarBind::new(
crate::v3::report_oids::unknown_engine_ids(),
Value::Counter32(unknown_engine_ids_count),
)],
};
let response_global = MsgGlobalData::new(
incoming.global_data.msg_id,
incoming.global_data.msg_max_size,
MsgFlags::new(SecurityLevel::NoAuthNoPriv, false),
);
let response_usm = UsmSecurityParams::new(
self.inner.state.engine_id.clone(),
engine_boots,
engine_time,
Bytes::new(),
);
let response_scoped =
ScopedPdu::new(self.inner.state.engine_id.clone(), Bytes::new(), report_pdu);
let response_msg = V3Message::new(response_global, response_usm.encode(), response_scoped);
Ok(Some(response_msg.encode()))
}
}
pub(super) fn is_request_pdu(pdu_type: PduType) -> bool {
matches!(
pdu_type,
PduType::GetRequest
| PduType::GetNextRequest
| PduType::GetBulkRequest
| PduType::SetRequest
| PduType::InformRequest
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_request_pdu() {
assert!(is_request_pdu(PduType::GetRequest));
assert!(is_request_pdu(PduType::GetNextRequest));
assert!(is_request_pdu(PduType::GetBulkRequest));
assert!(is_request_pdu(PduType::SetRequest));
assert!(is_request_pdu(PduType::InformRequest));
assert!(!is_request_pdu(PduType::Response));
assert!(!is_request_pdu(PduType::TrapV2));
}
}