use bytes::Bytes;
use crate::ber::{Decoder, EncodeBuf};
use crate::error::internal::DecodeErrorKind;
use crate::error::{Error, Result, UNKNOWN_TARGET};
use crate::pdu::Pdu;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(i32)]
pub enum SecurityModel {
Usm = 3,
}
impl SecurityModel {
pub fn from_i32(value: i32) -> Option<Self> {
match value {
3 => Some(Self::Usm),
_ => None,
}
}
pub fn as_i32(self) -> i32 {
self as i32
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum SecurityLevel {
NoAuthNoPriv,
AuthNoPriv,
AuthPriv,
}
impl SecurityLevel {
pub fn from_flags(flags: u8) -> Option<Self> {
let auth = flags & 0x01 != 0;
let priv_ = flags & 0x02 != 0;
match (auth, priv_) {
(false, false) => Some(Self::NoAuthNoPriv),
(true, false) => Some(Self::AuthNoPriv),
(true, true) => Some(Self::AuthPriv),
(false, true) => None, }
}
pub fn to_flags(self) -> u8 {
match self {
Self::NoAuthNoPriv => 0x00,
Self::AuthNoPriv => 0x01,
Self::AuthPriv => 0x03,
}
}
pub fn requires_auth(self) -> bool {
matches!(self, Self::AuthNoPriv | Self::AuthPriv)
}
pub fn requires_priv(self) -> bool {
matches!(self, Self::AuthPriv)
}
}
impl TryFrom<u8> for SecurityLevel {
type Error = u8;
fn try_from(flags: u8) -> std::result::Result<Self, u8> {
Self::from_flags(flags).ok_or(flags)
}
}
impl From<SecurityLevel> for u8 {
fn from(level: SecurityLevel) -> u8 {
level.to_flags()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct MsgFlags {
pub security_level: SecurityLevel,
pub reportable: bool,
}
impl MsgFlags {
pub fn new(security_level: SecurityLevel, reportable: bool) -> Self {
Self {
security_level,
reportable,
}
}
pub fn from_byte(byte: u8) -> Result<Self> {
let security_level = SecurityLevel::from_flags(byte).ok_or_else(|| {
tracing::debug!(target: "async_snmp::v3", { byte, kind = %DecodeErrorKind::InvalidMsgFlags }, "decode error");
Error::MalformedResponse {
target: UNKNOWN_TARGET,
}
.boxed()
})?;
let reportable = byte & 0x04 != 0;
Ok(Self {
security_level,
reportable,
})
}
pub fn to_byte(self) -> u8 {
let mut flags = self.security_level.to_flags();
if self.reportable {
flags |= 0x04;
}
flags
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MsgGlobalData {
pub msg_id: i32,
pub msg_max_size: i32,
pub msg_flags: MsgFlags,
pub msg_security_model: SecurityModel,
}
impl MsgGlobalData {
pub fn new(msg_id: i32, msg_max_size: i32, msg_flags: MsgFlags) -> Self {
Self {
msg_id,
msg_max_size,
msg_flags,
msg_security_model: SecurityModel::Usm,
}
}
pub fn encode(&self, buf: &mut EncodeBuf) {
buf.push_sequence(|buf| {
buf.push_integer(self.msg_security_model.as_i32());
buf.push_octet_string(&[self.msg_flags.to_byte()]);
buf.push_integer(self.msg_max_size);
buf.push_integer(self.msg_id);
});
}
pub fn decode(decoder: &mut Decoder) -> Result<Self> {
const MSG_MAX_SIZE_MINIMUM: i32 = 484;
let mut seq = decoder.read_sequence()?;
let msg_id = seq.read_integer()?;
let msg_max_size = seq.read_integer()?;
if msg_id < 0 {
tracing::debug!(target: "async_snmp::v3", { offset = seq.offset(), value = msg_id, kind = %DecodeErrorKind::InvalidMsgId { value: msg_id } }, "decode error");
return Err(Error::MalformedResponse {
target: UNKNOWN_TARGET,
}
.boxed());
}
if msg_max_size < 0 {
tracing::debug!(target: "async_snmp::v3", { offset = seq.offset(), value = msg_max_size, kind = %DecodeErrorKind::MsgMaxSizeTooLarge { value: msg_max_size } }, "decode error");
return Err(Error::MalformedResponse {
target: UNKNOWN_TARGET,
}
.boxed());
}
if msg_max_size < MSG_MAX_SIZE_MINIMUM {
tracing::debug!(target: "async_snmp::v3", { offset = seq.offset(), value = msg_max_size, minimum = MSG_MAX_SIZE_MINIMUM, kind = %DecodeErrorKind::MsgMaxSizeTooSmall { value: msg_max_size, minimum: MSG_MAX_SIZE_MINIMUM } }, "decode error");
return Err(Error::MalformedResponse {
target: UNKNOWN_TARGET,
}
.boxed());
}
let flags_bytes = seq.read_octet_string()?;
if flags_bytes.len() != 1 {
tracing::debug!(target: "async_snmp::v3", { offset = seq.offset(), expected = 1, actual = flags_bytes.len() }, "invalid msgFlags length");
return Err(Error::MalformedResponse {
target: UNKNOWN_TARGET,
}
.boxed());
}
let msg_flags = MsgFlags::from_byte(flags_bytes[0])?;
let msg_security_model_raw = seq.read_integer()?;
let msg_security_model =
SecurityModel::from_i32(msg_security_model_raw).ok_or_else(|| {
tracing::debug!(target: "async_snmp::v3", { offset = seq.offset(), model = msg_security_model_raw, kind = %DecodeErrorKind::UnknownSecurityModel(msg_security_model_raw) }, "decode error");
Error::MalformedResponse {
target: UNKNOWN_TARGET,
}
.boxed()
})?;
Ok(Self {
msg_id,
msg_max_size,
msg_flags,
msg_security_model,
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ScopedPdu {
pub context_engine_id: Bytes,
pub context_name: Bytes,
pub pdu: Pdu,
}
impl ScopedPdu {
pub fn new(
context_engine_id: impl Into<Bytes>,
context_name: impl Into<Bytes>,
pdu: Pdu,
) -> Self {
Self {
context_engine_id: context_engine_id.into(),
context_name: context_name.into(),
pdu,
}
}
pub fn with_empty_context(pdu: Pdu) -> Self {
Self {
context_engine_id: Bytes::new(),
context_name: Bytes::new(),
pdu,
}
}
pub fn encode(&self, buf: &mut EncodeBuf) {
buf.push_sequence(|buf| {
self.pdu.encode(buf);
buf.push_octet_string(&self.context_name);
buf.push_octet_string(&self.context_engine_id);
});
}
pub fn encode_to_bytes(&self) -> Bytes {
let mut buf = EncodeBuf::new();
self.encode(&mut buf);
buf.finish()
}
pub fn decode(decoder: &mut Decoder) -> Result<Self> {
let mut seq = decoder.read_sequence()?;
let context_engine_id = seq.read_octet_string()?;
let context_name = seq.read_octet_string()?;
let pdu = Pdu::decode(&mut seq)?;
Ok(Self {
context_engine_id,
context_name,
pdu,
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct V3Message {
pub global_data: MsgGlobalData,
pub security_params: Bytes,
pub data: V3MessageData,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum V3MessageData {
Plaintext(ScopedPdu),
Encrypted(Bytes),
}
impl V3Message {
pub fn new(global_data: MsgGlobalData, security_params: Bytes, scoped_pdu: ScopedPdu) -> Self {
Self {
global_data,
security_params,
data: V3MessageData::Plaintext(scoped_pdu),
}
}
pub fn new_encrypted(
global_data: MsgGlobalData,
security_params: Bytes,
encrypted: Bytes,
) -> Self {
Self {
global_data,
security_params,
data: V3MessageData::Encrypted(encrypted),
}
}
pub fn scoped_pdu(&self) -> Option<&ScopedPdu> {
match &self.data {
V3MessageData::Plaintext(pdu) => Some(pdu),
V3MessageData::Encrypted(_) => None,
}
}
pub fn into_scoped_pdu(self) -> Option<ScopedPdu> {
match self.data {
V3MessageData::Plaintext(pdu) => Some(pdu),
V3MessageData::Encrypted(_) => None,
}
}
pub fn pdu(&self) -> Option<&Pdu> {
self.scoped_pdu().map(|s| &s.pdu)
}
pub fn into_pdu(self) -> Option<Pdu> {
self.into_scoped_pdu().map(|s| s.pdu)
}
pub fn msg_id(&self) -> i32 {
self.global_data.msg_id
}
pub fn security_level(&self) -> SecurityLevel {
self.global_data.msg_flags.security_level
}
pub fn encode(&self) -> Bytes {
let mut buf = EncodeBuf::new();
buf.push_sequence(|buf| {
match &self.data {
V3MessageData::Plaintext(scoped_pdu) => {
scoped_pdu.encode(buf);
}
V3MessageData::Encrypted(ciphertext) => {
buf.push_octet_string(ciphertext);
}
}
buf.push_octet_string(&self.security_params);
self.global_data.encode(buf);
buf.push_integer(3);
});
buf.finish()
}
pub fn decode(data: Bytes) -> Result<Self> {
let mut decoder = Decoder::new(data);
let mut seq = decoder.read_sequence()?;
let version = seq.read_integer()?;
if version != 3 {
tracing::debug!(target: "async_snmp::v3", { offset = seq.offset(), version, kind = %DecodeErrorKind::UnknownVersion(version) }, "decode error");
return Err(Error::MalformedResponse {
target: UNKNOWN_TARGET,
}
.boxed());
}
Self::decode_from_sequence(&mut seq)
}
pub(crate) fn decode_from_sequence(seq: &mut Decoder) -> Result<Self> {
let global_data = MsgGlobalData::decode(seq)?;
let security_params = seq.read_octet_string()?;
let data = if global_data.msg_flags.security_level.requires_priv() {
let encrypted = seq.read_octet_string()?;
V3MessageData::Encrypted(encrypted)
} else {
let scoped_pdu = ScopedPdu::decode(seq)?;
V3MessageData::Plaintext(scoped_pdu)
};
Ok(Self {
global_data,
security_params,
data,
})
}
pub fn discovery_request(msg_id: i32) -> Self {
let global_data = MsgGlobalData::new(
msg_id,
65507, MsgFlags::new(SecurityLevel::NoAuthNoPriv, true),
);
let security_params = crate::v3::UsmSecurityParams::empty().encode();
let pdu = Pdu::get_request(0, &[]);
let scoped_pdu = ScopedPdu::with_empty_context(pdu);
Self::new(global_data, security_params, scoped_pdu)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::oid;
#[test]
fn test_security_level_flags() {
assert_eq!(SecurityLevel::NoAuthNoPriv.to_flags(), 0x00);
assert_eq!(SecurityLevel::AuthNoPriv.to_flags(), 0x01);
assert_eq!(SecurityLevel::AuthPriv.to_flags(), 0x03);
assert_eq!(
SecurityLevel::from_flags(0x00),
Some(SecurityLevel::NoAuthNoPriv)
);
assert_eq!(
SecurityLevel::from_flags(0x01),
Some(SecurityLevel::AuthNoPriv)
);
assert_eq!(
SecurityLevel::from_flags(0x03),
Some(SecurityLevel::AuthPriv)
);
assert_eq!(SecurityLevel::from_flags(0x02), None); }
#[test]
fn security_level_try_from_u8() {
assert_eq!(
SecurityLevel::try_from(0x00),
Ok(SecurityLevel::NoAuthNoPriv)
);
assert_eq!(SecurityLevel::try_from(0x01), Ok(SecurityLevel::AuthNoPriv));
assert_eq!(SecurityLevel::try_from(0x03), Ok(SecurityLevel::AuthPriv));
assert_eq!(SecurityLevel::try_from(0x02), Err(0x02));
}
#[test]
fn security_level_into_u8() {
assert_eq!(u8::from(SecurityLevel::NoAuthNoPriv), 0x00);
assert_eq!(u8::from(SecurityLevel::AuthNoPriv), 0x01);
assert_eq!(u8::from(SecurityLevel::AuthPriv), 0x03);
}
#[test]
fn test_msg_flags_roundtrip() {
let flags = MsgFlags::new(SecurityLevel::AuthPriv, true);
let byte = flags.to_byte();
assert_eq!(byte, 0x07);
let decoded = MsgFlags::from_byte(byte).unwrap();
assert_eq!(decoded.security_level, SecurityLevel::AuthPriv);
assert!(decoded.reportable);
}
#[test]
fn test_msg_global_data_roundtrip() {
let global =
MsgGlobalData::new(12345, 1472, MsgFlags::new(SecurityLevel::AuthNoPriv, true));
let mut buf = EncodeBuf::new();
global.encode(&mut buf);
let encoded = buf.finish();
let mut decoder = Decoder::new(encoded);
let decoded = MsgGlobalData::decode(&mut decoder).unwrap();
assert_eq!(decoded.msg_id, 12345);
assert_eq!(decoded.msg_max_size, 1472);
assert_eq!(decoded.msg_flags.security_level, SecurityLevel::AuthNoPriv);
assert!(decoded.msg_flags.reportable);
assert_eq!(decoded.msg_security_model, SecurityModel::Usm);
}
#[test]
fn test_scoped_pdu_roundtrip() {
let pdu = Pdu::get_request(42, &[oid!(1, 3, 6, 1, 2, 1, 1, 1, 0)]);
let scoped = ScopedPdu::new(b"engine".as_slice(), b"ctx".as_slice(), pdu);
let mut buf = EncodeBuf::new();
scoped.encode(&mut buf);
let encoded = buf.finish();
let mut decoder = Decoder::new(encoded);
let decoded = ScopedPdu::decode(&mut decoder).unwrap();
assert_eq!(decoded.context_engine_id.as_ref(), b"engine");
assert_eq!(decoded.context_name.as_ref(), b"ctx");
assert_eq!(decoded.pdu.request_id, 42);
}
#[test]
fn test_v3_message_plaintext_roundtrip() {
let global =
MsgGlobalData::new(100, 1472, MsgFlags::new(SecurityLevel::NoAuthNoPriv, true));
let pdu = Pdu::get_request(42, &[oid!(1, 3, 6, 1, 2, 1, 1, 1, 0)]);
let scoped = ScopedPdu::with_empty_context(pdu);
let msg = V3Message::new(global, Bytes::from_static(b"usm-params"), scoped);
let encoded = msg.encode();
let decoded = V3Message::decode(encoded).unwrap();
assert_eq!(decoded.global_data.msg_id, 100);
assert_eq!(decoded.security_level(), SecurityLevel::NoAuthNoPriv);
assert_eq!(decoded.security_params.as_ref(), b"usm-params");
let scoped_pdu = decoded.scoped_pdu().unwrap();
assert_eq!(scoped_pdu.pdu.request_id, 42);
}
#[test]
fn test_v3_message_encrypted_roundtrip() {
let global = MsgGlobalData::new(200, 1472, MsgFlags::new(SecurityLevel::AuthPriv, false));
let msg = V3Message::new_encrypted(
global,
Bytes::from_static(b"usm-params"),
Bytes::from_static(b"encrypted-data"),
);
let encoded = msg.encode();
let decoded = V3Message::decode(encoded).unwrap();
assert_eq!(decoded.global_data.msg_id, 200);
assert_eq!(decoded.security_level(), SecurityLevel::AuthPriv);
match &decoded.data {
V3MessageData::Encrypted(data) => {
assert_eq!(data.as_ref(), b"encrypted-data");
}
V3MessageData::Plaintext(_) => panic!("expected encrypted data"),
}
}
#[test]
fn test_msg_global_data_rejects_msg_max_size_below_minimum() {
let global = MsgGlobalData {
msg_id: 100,
msg_max_size: 400, msg_flags: MsgFlags::new(SecurityLevel::NoAuthNoPriv, true),
msg_security_model: SecurityModel::Usm,
};
let mut buf = EncodeBuf::new();
global.encode(&mut buf);
let encoded = buf.finish();
let mut decoder = Decoder::new(encoded);
let result = MsgGlobalData::decode(&mut decoder);
assert!(result.is_err());
assert!(matches!(
*result.unwrap_err(),
Error::MalformedResponse { .. }
));
}
#[test]
fn test_msg_global_data_accepts_msg_max_size_at_minimum() {
let global = MsgGlobalData::new(100, 484, MsgFlags::new(SecurityLevel::NoAuthNoPriv, true));
let mut buf = EncodeBuf::new();
global.encode(&mut buf);
let encoded = buf.finish();
let mut decoder = Decoder::new(encoded);
let decoded = MsgGlobalData::decode(&mut decoder).unwrap();
assert_eq!(decoded.msg_max_size, 484);
}
#[test]
fn test_msg_global_data_rejects_unknown_security_model() {
let mut buf = EncodeBuf::new();
buf.push_sequence(|buf| {
buf.push_integer(99); buf.push_octet_string(&[0x04]); buf.push_integer(1472); buf.push_integer(100); });
let encoded = buf.finish();
let mut decoder = Decoder::new(encoded);
let result = MsgGlobalData::decode(&mut decoder);
assert!(result.is_err());
assert!(matches!(
*result.unwrap_err(),
Error::MalformedResponse { .. }
));
}
#[test]
fn test_msg_global_data_accepts_usm_security_model() {
let global =
MsgGlobalData::new(100, 1472, MsgFlags::new(SecurityLevel::NoAuthNoPriv, true));
let mut buf = EncodeBuf::new();
global.encode(&mut buf);
let encoded = buf.finish();
let mut decoder = Decoder::new(encoded);
let decoded = MsgGlobalData::decode(&mut decoder).unwrap();
assert_eq!(decoded.msg_security_model, SecurityModel::Usm);
}
#[test]
fn test_msg_global_data_rejects_negative_msg_id() {
let mut buf = EncodeBuf::new();
buf.push_sequence(|buf| {
buf.push_integer(3); buf.push_octet_string(&[0x04]); buf.push_integer(1472); buf.push_integer(-1); });
let encoded = buf.finish();
let mut decoder = Decoder::new(encoded);
let result = MsgGlobalData::decode(&mut decoder);
assert!(result.is_err());
assert!(matches!(
*result.unwrap_err(),
Error::MalformedResponse { .. }
));
}
#[test]
fn test_msg_global_data_rejects_negative_msg_max_size() {
let mut buf = EncodeBuf::new();
buf.push_sequence(|buf| {
buf.push_integer(3); buf.push_octet_string(&[0x04]); buf.push_integer(-1); buf.push_integer(100); });
let encoded = buf.finish();
let mut decoder = Decoder::new(encoded);
let result = MsgGlobalData::decode(&mut decoder);
assert!(result.is_err());
assert!(matches!(
*result.unwrap_err(),
Error::MalformedResponse { .. }
));
}
#[test]
fn test_msg_global_data_accepts_msg_id_at_zero() {
let mut buf = EncodeBuf::new();
buf.push_sequence(|buf| {
buf.push_integer(3); buf.push_octet_string(&[0x04]); buf.push_integer(1472); buf.push_integer(0); });
let encoded = buf.finish();
let mut decoder = Decoder::new(encoded);
let decoded = MsgGlobalData::decode(&mut decoder).unwrap();
assert_eq!(decoded.msg_id, 0);
}
#[test]
fn test_msg_global_data_accepts_msg_id_at_maximum() {
let mut buf = EncodeBuf::new();
buf.push_sequence(|buf| {
buf.push_integer(3); buf.push_octet_string(&[0x04]); buf.push_integer(1472); buf.push_integer(i32::MAX); });
let encoded = buf.finish();
let mut decoder = Decoder::new(encoded);
let decoded = MsgGlobalData::decode(&mut decoder).unwrap();
assert_eq!(decoded.msg_id, i32::MAX);
}
#[test]
fn test_msg_global_data_accepts_msg_max_size_at_maximum() {
let mut buf = EncodeBuf::new();
buf.push_sequence(|buf| {
buf.push_integer(3); buf.push_octet_string(&[0x04]); buf.push_integer(i32::MAX); buf.push_integer(100); });
let encoded = buf.finish();
let mut decoder = Decoder::new(encoded);
let decoded = MsgGlobalData::decode(&mut decoder).unwrap();
assert_eq!(decoded.msg_max_size, i32::MAX);
}
}