use bytes::Bytes;
use crate::ber::length::parse_ber_length;
use crate::ber::{Decoder, EncodeBuf};
use crate::error::internal::DecodeErrorKind;
use crate::error::{Error, Result, UNKNOWN_TARGET};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UsmSecurityParams {
pub engine_id: Bytes,
pub engine_boots: u32,
pub engine_time: u32,
pub username: Bytes,
pub auth_params: Bytes,
pub priv_params: Bytes,
}
impl UsmSecurityParams {
pub fn new(
engine_id: impl Into<Bytes>,
engine_boots: u32,
engine_time: u32,
username: impl Into<Bytes>,
) -> Self {
Self {
engine_id: engine_id.into(),
engine_boots,
engine_time,
username: username.into(),
auth_params: Bytes::new(),
priv_params: Bytes::new(),
}
}
pub fn empty() -> Self {
Self {
engine_id: Bytes::new(),
engine_boots: 0,
engine_time: 0,
username: Bytes::new(),
auth_params: Bytes::new(),
priv_params: Bytes::new(),
}
}
pub fn with_auth_params(mut self, auth_params: impl Into<Bytes>) -> Self {
self.auth_params = auth_params.into();
self
}
pub fn with_priv_params(mut self, priv_params: impl Into<Bytes>) -> Self {
self.priv_params = priv_params.into();
self
}
pub fn with_auth_placeholder(mut self, mac_len: usize) -> Self {
self.auth_params = Bytes::from(vec![0u8; mac_len]);
self
}
pub fn encode(&self) -> Bytes {
let mut buf = EncodeBuf::new();
self.encode_to_buf(&mut buf);
buf.finish()
}
pub fn encode_to_buf(&self, buf: &mut EncodeBuf) {
buf.push_sequence(|buf| {
buf.push_octet_string(&self.priv_params);
buf.push_octet_string(&self.auth_params);
buf.push_octet_string(&self.username);
buf.push_unsigned32(crate::ber::tag::universal::INTEGER, self.engine_time);
buf.push_unsigned32(crate::ber::tag::universal::INTEGER, self.engine_boots);
buf.push_octet_string(&self.engine_id);
});
}
pub fn decode(data: Bytes) -> Result<Self> {
let mut decoder = Decoder::new(data);
Self::decode_from(&mut decoder)
}
pub fn decode_from(decoder: &mut Decoder) -> Result<Self> {
let mut seq = decoder.read_sequence()?;
let engine_id = seq.read_octet_string()?;
let raw_boots = seq.read_integer()?;
if raw_boots < 0 {
tracing::debug!(target: "async_snmp::usm", { offset = seq.offset(), value = raw_boots, kind = %DecodeErrorKind::InvalidEngineBoots { value: raw_boots } }, "decode error");
return Err(Error::MalformedResponse {
target: UNKNOWN_TARGET,
}
.boxed());
}
let engine_boots = raw_boots as u32;
let raw_time = seq.read_integer()?;
if raw_time < 0 {
tracing::debug!(target: "async_snmp::usm", { offset = seq.offset(), value = raw_time, kind = %DecodeErrorKind::InvalidEngineTime { value: raw_time } }, "decode error");
return Err(Error::MalformedResponse {
target: UNKNOWN_TARGET,
}
.boxed());
}
let engine_time = raw_time as u32;
let username = seq.read_octet_string()?;
let auth_params = seq.read_octet_string()?;
let priv_params = seq.read_octet_string()?;
Ok(Self {
engine_id,
engine_boots,
engine_time,
username,
auth_params,
priv_params,
})
}
pub fn find_auth_params_offset(encoded_msg: &[u8]) -> Option<(usize, usize)> {
let mut offset = 0;
if offset >= encoded_msg.len() {
return None;
}
if encoded_msg[offset] != 0x30 {
return None;
}
offset += 1;
let (_, len_size) = parse_ber_length(&encoded_msg[offset..])?;
offset += len_size;
if offset >= encoded_msg.len() {
return None;
}
if encoded_msg[offset] != 0x02 {
return None;
}
offset += 1;
let (ver_len, len_size) = parse_ber_length(&encoded_msg[offset..])?;
offset = offset.checked_add(len_size)?.checked_add(ver_len)?;
if offset >= encoded_msg.len() {
return None;
}
if encoded_msg[offset] != 0x30 {
return None;
}
offset += 1;
let (global_len, len_size) = parse_ber_length(&encoded_msg[offset..])?;
offset = offset.checked_add(len_size)?.checked_add(global_len)?;
if offset >= encoded_msg.len() {
return None;
}
if encoded_msg[offset] != 0x04 {
return None;
}
offset += 1;
let (_, len_size) = parse_ber_length(&encoded_msg[offset..])?;
offset = offset.checked_add(len_size)?;
if offset >= encoded_msg.len() {
return None;
}
if encoded_msg[offset] != 0x30 {
return None;
}
offset += 1;
let (_, len_size) = parse_ber_length(&encoded_msg[offset..])?;
offset = offset.checked_add(len_size)?;
if offset >= encoded_msg.len() {
return None;
}
offset = skip_tlv(encoded_msg, offset)?;
offset = skip_tlv(encoded_msg, offset)?;
offset = skip_tlv(encoded_msg, offset)?;
offset = skip_tlv(encoded_msg, offset)?;
if offset >= encoded_msg.len() {
return None;
}
if encoded_msg[offset] != 0x04 {
return None;
}
offset += 1;
let (auth_len, len_size) = parse_ber_length(&encoded_msg[offset..])?;
let auth_start = offset.checked_add(len_size)?;
if auth_start.checked_add(auth_len)? > encoded_msg.len() {
return None;
}
Some((auth_start, auth_len))
}
}
fn skip_tlv(data: &[u8], offset: usize) -> Option<usize> {
if offset >= data.len() {
return None;
}
let mut pos = offset + 1;
if pos >= data.len() {
return None;
}
let (len, len_size) = parse_ber_length(&data[pos..])?;
pos += len_size + len;
if pos > data.len() {
return None;
}
Some(pos)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_usm_params_empty_roundtrip() {
let params = UsmSecurityParams::empty();
let encoded = params.encode();
let decoded = UsmSecurityParams::decode(encoded).unwrap();
assert!(decoded.engine_id.is_empty());
assert_eq!(decoded.engine_boots, 0);
assert_eq!(decoded.engine_time, 0);
assert!(decoded.username.is_empty());
assert!(decoded.auth_params.is_empty());
assert!(decoded.priv_params.is_empty());
}
#[test]
fn test_usm_params_roundtrip() {
let params =
UsmSecurityParams::new(b"engine-id".as_slice(), 1234, 5678, b"admin".as_slice())
.with_auth_params(b"auth123456789012".as_slice()) .with_priv_params(b"priv1234".as_slice());
let encoded = params.encode();
let decoded = UsmSecurityParams::decode(encoded).unwrap();
assert_eq!(decoded.engine_id.as_ref(), b"engine-id");
assert_eq!(decoded.engine_boots, 1234);
assert_eq!(decoded.engine_time, 5678);
assert_eq!(decoded.username.as_ref(), b"admin");
assert_eq!(decoded.auth_params.as_ref(), b"auth123456789012");
assert_eq!(decoded.priv_params.as_ref(), b"priv1234");
}
#[test]
fn test_usm_params_with_placeholder() {
let params = UsmSecurityParams::new(b"engine".as_slice(), 100, 200, b"user".as_slice())
.with_auth_placeholder(12);
assert_eq!(params.auth_params.len(), 12);
assert!(params.auth_params.iter().all(|&b| b == 0));
}
#[test]
fn test_find_auth_params_offset() {
use crate::message::{MsgFlags, MsgGlobalData, ScopedPdu, SecurityLevel, V3Message};
use crate::oid;
use crate::pdu::Pdu;
let global =
MsgGlobalData::new(12345, 1472, MsgFlags::new(SecurityLevel::AuthNoPriv, true));
let usm_params =
UsmSecurityParams::new(b"engine123".as_slice(), 100, 200, b"testuser".as_slice())
.with_auth_placeholder(12);
let pdu = Pdu::get_request(42, &[oid!(1, 3, 6, 1, 2, 1, 1, 1, 0)]);
let scoped = ScopedPdu::with_empty_context(pdu);
let msg = V3Message::new(global, usm_params.encode(), scoped);
let encoded = msg.encode();
let (offset, len) = UsmSecurityParams::find_auth_params_offset(&encoded).unwrap();
assert_eq!(len, 12);
assert!(encoded[offset..offset + len].iter().all(|&b| b == 0));
}
#[test]
fn test_usm_params_rejects_negative_engine_boots() {
use crate::ber::EncodeBuf;
let mut buf = EncodeBuf::new();
buf.push_sequence(|buf| {
buf.push_octet_string(&[]);
buf.push_octet_string(&[]);
buf.push_octet_string(&[]);
buf.push_integer(100);
buf.push_integer(-1);
buf.push_octet_string(&[]);
});
let encoded = buf.finish();
let result = UsmSecurityParams::decode(encoded);
assert!(result.is_err());
assert!(matches!(
*result.unwrap_err(),
Error::MalformedResponse { .. }
));
}
#[test]
fn test_usm_params_rejects_negative_engine_time() {
use crate::ber::EncodeBuf;
let mut buf = EncodeBuf::new();
buf.push_sequence(|buf| {
buf.push_octet_string(&[]);
buf.push_octet_string(&[]);
buf.push_octet_string(&[]);
buf.push_integer(-1);
buf.push_integer(100);
buf.push_octet_string(&[]);
});
let encoded = buf.finish();
let result = UsmSecurityParams::decode(encoded);
assert!(result.is_err());
assert!(matches!(
*result.unwrap_err(),
Error::MalformedResponse { .. }
));
}
#[test]
fn test_usm_params_accepts_max_values() {
use crate::ber::EncodeBuf;
let mut buf = EncodeBuf::new();
buf.push_sequence(|buf| {
buf.push_octet_string(&[]);
buf.push_octet_string(&[]);
buf.push_octet_string(&[]);
buf.push_integer(i32::MAX);
buf.push_integer(i32::MAX);
buf.push_octet_string(&[]);
});
let encoded = buf.finish();
let decoded = UsmSecurityParams::decode(encoded).unwrap();
assert_eq!(decoded.engine_boots, i32::MAX as u32);
assert_eq!(decoded.engine_time, i32::MAX as u32);
}
#[test]
fn test_usm_params_accepts_zero_values() {
use crate::ber::EncodeBuf;
let mut buf = EncodeBuf::new();
buf.push_sequence(|buf| {
buf.push_octet_string(&[]);
buf.push_octet_string(&[]);
buf.push_octet_string(&[]);
buf.push_integer(0);
buf.push_integer(0);
buf.push_octet_string(&[]);
});
let encoded = buf.finish();
let decoded = UsmSecurityParams::decode(encoded).unwrap();
assert_eq!(decoded.engine_boots, 0);
assert_eq!(decoded.engine_time, 0);
}
#[test]
fn test_find_auth_params_offset_truncated_returns_none() {
assert_eq!(UsmSecurityParams::find_auth_params_offset(&[]), None);
assert_eq!(UsmSecurityParams::find_auth_params_offset(&[0x30]), None);
let msg: &[u8] = &[
0x30, 0x64, ];
assert_eq!(UsmSecurityParams::find_auth_params_offset(msg), None);
}
#[test]
fn test_find_auth_params_offset_inflated_global_len_returns_none() {
let msg: &[u8] = &[
0x30, 0x06, 0x02, 0x01, 0x03, 0x30, 0x7f, ];
assert_eq!(UsmSecurityParams::find_auth_params_offset(msg), None);
}
#[test]
fn test_find_auth_params_offset_auth_len_overflow_returns_none() {
use crate::message::{MsgFlags, MsgGlobalData, ScopedPdu, SecurityLevel, V3Message};
use crate::oid;
use crate::pdu::Pdu;
let global = MsgGlobalData::new(1, 1472, MsgFlags::new(SecurityLevel::AuthNoPriv, true));
let usm_params = UsmSecurityParams::new(b"eng".as_slice(), 1, 1, b"u".as_slice())
.with_auth_placeholder(12);
let pdu = Pdu::get_request(1, &[oid!(1, 3, 6, 1, 2, 1, 1, 1, 0)]);
let scoped = ScopedPdu::with_empty_context(pdu);
let msg = V3Message::new(global, usm_params.encode(), scoped);
let encoded_bytes = msg.encode();
let mut encoded: Vec<u8> = encoded_bytes.to_vec();
let (auth_start, auth_len) = UsmSecurityParams::find_auth_params_offset(&encoded).unwrap();
assert_eq!(auth_len, 12);
encoded[auth_start - 1] = 0x40;
assert_eq!(UsmSecurityParams::find_auth_params_offset(&encoded), None);
}
#[test]
fn usm_security_params_equality() {
let a = UsmSecurityParams {
engine_id: Bytes::from_static(b"\x80\x00\x01"),
engine_boots: 1,
engine_time: 100,
username: Bytes::from_static(b"user"),
auth_params: Bytes::new(),
priv_params: Bytes::new(),
};
let b = a.clone();
assert_eq!(a, b);
}
}