use clap::{Parser, ValueEnum};
use std::time::Duration;
use crate::Version;
use crate::client::{Auth, Backoff, Retry};
use crate::format::hex;
use crate::v3::{AuthProtocol, PrivProtocol};
#[derive(Debug, Clone, Copy, Default, ValueEnum)]
pub enum SnmpVersion {
#[value(name = "1")]
V1,
#[default]
#[value(name = "2c")]
V2c,
#[value(name = "3")]
V3,
}
impl From<SnmpVersion> for Version {
fn from(v: SnmpVersion) -> Self {
match v {
SnmpVersion::V1 => Version::V1,
SnmpVersion::V2c => Version::V2c,
SnmpVersion::V3 => Version::V3,
}
}
}
#[derive(Debug, Clone, Copy, Default, ValueEnum)]
pub enum OutputFormat {
#[default]
Human,
Json,
Raw,
}
#[derive(Debug, Clone, Copy, Default, ValueEnum)]
pub enum BackoffStrategy {
#[default]
None,
Fixed,
Exponential,
}
#[derive(Debug, Parser)]
pub struct CommonArgs {
#[arg(value_name = "TARGET")]
pub target: String,
#[arg(short = 'v', long = "snmp-version", default_value = "2c")]
pub snmp_version: SnmpVersion,
#[arg(short = 'c', long = "community", default_value = "public")]
pub community: String,
#[arg(short = 't', long = "timeout", default_value = "5")]
pub timeout: f64,
#[arg(short = 'r', long = "retries", default_value = "3")]
pub retries: u32,
#[arg(long = "backoff", default_value = "none")]
pub backoff: BackoffStrategy,
#[arg(long = "backoff-delay", default_value = "1000")]
pub backoff_delay: u64,
#[arg(long = "backoff-max", default_value = "5000")]
pub backoff_max: u64,
#[arg(long = "backoff-jitter", default_value = "0.25")]
pub backoff_jitter: f64,
}
impl CommonArgs {
pub fn timeout_duration(&self) -> Duration {
Duration::from_secs_f64(self.timeout)
}
pub fn effective_version(&self, v3: &V3Args) -> SnmpVersion {
if v3.is_v3() {
SnmpVersion::V3
} else {
self.snmp_version
}
}
pub fn retry_config(&self) -> Retry {
let backoff = match self.backoff {
BackoffStrategy::None => Backoff::None,
BackoffStrategy::Fixed => Backoff::Fixed {
delay: Duration::from_millis(self.backoff_delay),
},
BackoffStrategy::Exponential => Backoff::Exponential {
initial: Duration::from_millis(self.backoff_delay),
max: Duration::from_millis(self.backoff_max),
jitter: self.backoff_jitter.clamp(0.0, 1.0),
},
};
Retry {
max_attempts: self.retries,
backoff,
}
}
}
#[derive(Debug, Parser)]
pub struct V3Args {
#[arg(short = 'u', long = "username")]
pub username: Option<String>,
#[arg(short = 'a', long = "auth-protocol")]
pub auth_protocol: Option<AuthProtocol>,
#[arg(short = 'A', long = "auth-password")]
pub auth_password: Option<String>,
#[arg(short = 'x', long = "priv-protocol")]
pub priv_protocol: Option<PrivProtocol>,
#[arg(short = 'X', long = "priv-password")]
pub priv_password: Option<String>,
}
impl V3Args {
pub fn is_v3(&self) -> bool {
self.username.is_some()
}
pub fn auth(&self, common: &CommonArgs) -> Result<Auth, String> {
if let Some(ref username) = self.username {
let mut builder = Auth::usm(username);
if let Some(proto) = self.auth_protocol {
let pass = self
.auth_password
.as_ref()
.ok_or("auth password required")?;
builder = builder.auth(proto, pass);
}
if let Some(proto) = self.priv_protocol {
let pass = self
.priv_password
.as_ref()
.ok_or("priv password required")?;
builder = builder.privacy(proto, pass);
}
Ok(builder.into())
} else {
let community = &common.community;
Ok(match common.snmp_version {
SnmpVersion::V1 => Auth::v1(community),
_ => Auth::v2c(community),
})
}
}
pub fn validate(&self) -> Result<(), String> {
if let Some(ref _username) = self.username {
if self.auth_protocol.is_some() && self.auth_password.is_none() {
return Err(
"authentication password (-A) required when using auth protocol".into(),
);
}
if self.priv_protocol.is_some() && self.priv_password.is_none() {
return Err("privacy password (-X) required when using priv protocol".into());
}
if self.priv_protocol.is_some() && self.auth_protocol.is_none() {
return Err("authentication protocol (-a) required when using privacy".into());
}
}
Ok(())
}
}
#[derive(Debug, Parser)]
pub struct OutputArgs {
#[arg(short = 'O', long = "output", default_value = "human")]
pub format: OutputFormat,
#[arg(long = "verbose")]
pub verbose: bool,
#[arg(long = "hex")]
pub hex: bool,
#[arg(long = "timing")]
pub timing: bool,
#[arg(long = "no-hints")]
pub no_hints: bool,
#[arg(short = 'd', long = "debug")]
pub debug: bool,
#[arg(short = 'D', long = "trace")]
pub trace: bool,
}
impl OutputArgs {
pub fn elapsed(&self, elapsed: Duration) -> Option<Duration> {
if self.timing { Some(elapsed) } else { None }
}
pub fn init_tracing(&self) {
use tracing_subscriber::EnvFilter;
let filter = if self.trace {
"async_snmp=trace"
} else if self.debug {
"async_snmp=debug"
} else {
"async_snmp=warn"
};
let _ = tracing_subscriber::fmt()
.with_env_filter(EnvFilter::new(filter))
.with_writer(std::io::stderr)
.try_init();
}
}
#[derive(Debug, Parser)]
pub struct WalkArgs {
#[arg(long = "getnext")]
pub getnext: bool,
#[arg(long = "max-rep", default_value = "10")]
pub max_repetitions: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
pub enum ValueType {
#[value(name = "i")]
Integer,
#[value(name = "u")]
Unsigned,
#[value(name = "s")]
String,
#[value(name = "x")]
HexString,
#[value(name = "o")]
Oid,
#[value(name = "a")]
IpAddress,
#[value(name = "t")]
TimeTicks,
#[value(name = "c")]
Counter32,
#[value(name = "C")]
Counter64,
}
impl std::str::FromStr for ValueType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"i" => Ok(ValueType::Integer),
"u" => Ok(ValueType::Unsigned),
"s" => Ok(ValueType::String),
"x" => Ok(ValueType::HexString),
"o" => Ok(ValueType::Oid),
"a" => Ok(ValueType::IpAddress),
"t" => Ok(ValueType::TimeTicks),
"c" => Ok(ValueType::Counter32),
"C" => Ok(ValueType::Counter64),
_ => Err(format!("invalid type specifier: {}", s)),
}
}
}
impl ValueType {
pub fn parse_value(&self, s: &str) -> Result<crate::Value, String> {
use crate::{Oid, Value};
match self {
ValueType::Integer => {
let v: i32 = s
.parse()
.map_err(|_| format!("invalid integer value: {}", s))?;
Ok(Value::Integer(v))
}
ValueType::Unsigned => {
let v: u32 = s
.parse()
.map_err(|_| format!("invalid unsigned value: {}", s))?;
Ok(Value::Gauge32(v))
}
ValueType::String => Ok(Value::OctetString(s.as_bytes().to_vec().into())),
ValueType::HexString => {
let bytes = hex::decode_relaxed(s)
.map_err(|_| "hex string must have even number of hex digits".to_string())?;
Ok(Value::OctetString(bytes.into()))
}
ValueType::Oid => {
let oid = Oid::parse(s).map_err(|e| format!("invalid OID value: {}", e))?;
Ok(Value::ObjectIdentifier(oid))
}
ValueType::IpAddress => {
let parts: Vec<&str> = s.split('.').collect();
if parts.len() != 4 {
return Err(format!("invalid IP address: {}", s));
}
let mut bytes = [0u8; 4];
for (i, part) in parts.iter().enumerate() {
bytes[i] = part
.parse()
.map_err(|_| format!("invalid IP address octet: {}", part))?;
}
Ok(Value::IpAddress(bytes))
}
ValueType::TimeTicks => {
let v: u32 = s
.parse()
.map_err(|_| format!("invalid timeticks value: {}", s))?;
Ok(Value::TimeTicks(v))
}
ValueType::Counter32 => {
let v: u32 = s
.parse()
.map_err(|_| format!("invalid counter32 value: {}", s))?;
Ok(Value::Counter32(v))
}
ValueType::Counter64 => {
let v: u64 = s
.parse()
.map_err(|_| format!("invalid counter64 value: {}", s))?;
Ok(Value::Counter64(v))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_retry_config_none() {
let args = CommonArgs {
target: "192.168.1.1".to_string(),
snmp_version: SnmpVersion::V2c,
community: "public".to_string(),
timeout: 5.0,
retries: 3,
backoff: BackoffStrategy::None,
backoff_delay: 100,
backoff_max: 5000,
backoff_jitter: 0.25,
};
let retry = args.retry_config();
assert_eq!(retry.max_attempts, 3);
assert!(matches!(retry.backoff, Backoff::None));
}
#[test]
fn test_retry_config_fixed() {
let args = CommonArgs {
target: "192.168.1.1".to_string(),
snmp_version: SnmpVersion::V2c,
community: "public".to_string(),
timeout: 5.0,
retries: 5,
backoff: BackoffStrategy::Fixed,
backoff_delay: 200,
backoff_max: 5000,
backoff_jitter: 0.25,
};
let retry = args.retry_config();
assert_eq!(retry.max_attempts, 5);
assert!(matches!(
retry.backoff,
Backoff::Fixed { delay } if delay == Duration::from_millis(200)
));
}
#[test]
fn test_retry_config_exponential() {
let args = CommonArgs {
target: "192.168.1.1".to_string(),
snmp_version: SnmpVersion::V2c,
community: "public".to_string(),
timeout: 5.0,
retries: 4,
backoff: BackoffStrategy::Exponential,
backoff_delay: 50,
backoff_max: 2000,
backoff_jitter: 0.1,
};
let retry = args.retry_config();
assert_eq!(retry.max_attempts, 4);
match retry.backoff {
Backoff::Exponential {
initial,
max,
jitter,
} => {
assert_eq!(initial, Duration::from_millis(50));
assert_eq!(max, Duration::from_millis(2000));
assert!((jitter - 0.1).abs() < f64::EPSILON);
}
_ => panic!("expected Exponential"),
}
}
#[test]
fn test_v3_args_validation() {
let args = V3Args {
username: None,
auth_protocol: None,
auth_password: None,
priv_protocol: None,
priv_password: None,
};
assert!(args.validate().is_ok());
let args = V3Args {
username: Some("admin".to_string()),
auth_protocol: None,
auth_password: None,
priv_protocol: None,
priv_password: None,
};
assert!(args.validate().is_ok());
let args = V3Args {
username: Some("admin".to_string()),
auth_protocol: Some(AuthProtocol::Sha256),
auth_password: None,
priv_protocol: None,
priv_password: None,
};
assert!(args.validate().is_err());
let args = V3Args {
username: Some("admin".to_string()),
auth_protocol: None,
auth_password: None,
priv_protocol: Some(PrivProtocol::Aes128),
priv_password: Some("pass".to_string()),
};
assert!(args.validate().is_err());
let args = V3Args {
username: Some("admin".to_string()),
auth_protocol: Some(AuthProtocol::Sha1),
auth_password: Some("pass".to_string()),
priv_protocol: Some(PrivProtocol::Aes256),
priv_password: Some("pass".to_string()),
};
assert!(args.validate().is_ok());
}
#[test]
fn test_value_type_parse_integer() {
use crate::Value;
let v = ValueType::Integer.parse_value("42").unwrap();
assert!(matches!(v, Value::Integer(42)));
let v = ValueType::Integer.parse_value("-100").unwrap();
assert!(matches!(v, Value::Integer(-100)));
assert!(ValueType::Integer.parse_value("not_a_number").is_err());
}
#[test]
fn test_value_type_parse_unsigned() {
use crate::Value;
let v = ValueType::Unsigned.parse_value("42").unwrap();
assert!(matches!(v, Value::Gauge32(42)));
assert!(ValueType::Unsigned.parse_value("-1").is_err());
}
#[test]
fn test_value_type_parse_string() {
use crate::Value;
let v = ValueType::String.parse_value("hello world").unwrap();
if let Value::OctetString(bytes) = v {
assert_eq!(&*bytes, b"hello world");
} else {
panic!("expected OctetString");
}
}
#[test]
fn test_value_type_parse_hex_string() {
use crate::Value;
let v = ValueType::HexString.parse_value("001a2b").unwrap();
if let Value::OctetString(bytes) = v {
assert_eq!(&*bytes, &[0x00, 0x1a, 0x2b]);
} else {
panic!("expected OctetString");
}
let v = ValueType::HexString.parse_value("00 1A 2B").unwrap();
if let Value::OctetString(bytes) = v {
assert_eq!(&*bytes, &[0x00, 0x1a, 0x2b]);
} else {
panic!("expected OctetString");
}
assert!(ValueType::HexString.parse_value("001").is_err());
}
#[test]
fn test_value_type_parse_ip_address() {
use crate::Value;
let v = ValueType::IpAddress.parse_value("192.168.1.1").unwrap();
assert!(matches!(v, Value::IpAddress([192, 168, 1, 1])));
assert!(ValueType::IpAddress.parse_value("192.168.1").is_err());
assert!(ValueType::IpAddress.parse_value("256.1.1.1").is_err());
}
#[test]
fn test_value_type_parse_timeticks() {
use crate::Value;
let v = ValueType::TimeTicks.parse_value("12345678").unwrap();
assert!(matches!(v, Value::TimeTicks(12345678)));
}
#[test]
fn test_value_type_parse_counters() {
use crate::Value;
let v = ValueType::Counter32.parse_value("4294967295").unwrap();
assert!(matches!(v, Value::Counter32(4294967295)));
let v = ValueType::Counter64
.parse_value("18446744073709551615")
.unwrap();
assert!(matches!(v, Value::Counter64(18446744073709551615)));
}
}