1use clap::{Parser, ValueEnum};
6use std::net::SocketAddr;
7use std::time::Duration;
8
9use crate::Version;
10use crate::client::{Auth, Backoff, Retry};
11use crate::v3::{AuthProtocol, PrivProtocol};
12
13#[derive(Debug, Clone, Copy, Default, ValueEnum)]
15pub enum SnmpVersion {
16 #[value(name = "1")]
18 V1,
19 #[default]
21 #[value(name = "2c")]
22 V2c,
23 #[value(name = "3")]
25 V3,
26}
27
28impl From<SnmpVersion> for Version {
29 fn from(v: SnmpVersion) -> Self {
30 match v {
31 SnmpVersion::V1 => Version::V1,
32 SnmpVersion::V2c => Version::V2c,
33 SnmpVersion::V3 => Version::V3,
34 }
35 }
36}
37
38#[derive(Debug, Clone, Copy, Default, ValueEnum)]
40pub enum OutputFormat {
41 #[default]
43 Human,
44 Json,
46 Raw,
48}
49
50#[derive(Debug, Clone, Copy, Default, ValueEnum)]
52pub enum BackoffStrategy {
53 #[default]
55 None,
56 Fixed,
58 Exponential,
60}
61
62#[derive(Debug, Parser)]
64pub struct CommonArgs {
65 #[arg(value_name = "TARGET")]
67 pub target: String,
68
69 #[arg(short = 'v', long = "snmp-version", default_value = "2c")]
71 pub snmp_version: SnmpVersion,
72
73 #[arg(short = 'c', long = "community", default_value = "public")]
75 pub community: String,
76
77 #[arg(short = 't', long = "timeout", default_value = "5")]
79 pub timeout: f64,
80
81 #[arg(short = 'r', long = "retries", default_value = "3")]
83 pub retries: u32,
84
85 #[arg(long = "backoff", default_value = "none")]
87 pub backoff: BackoffStrategy,
88
89 #[arg(long = "backoff-delay", default_value = "1000")]
91 pub backoff_delay: u64,
92
93 #[arg(long = "backoff-max", default_value = "5000")]
95 pub backoff_max: u64,
96
97 #[arg(long = "backoff-jitter", default_value = "0.25")]
99 pub backoff_jitter: f64,
100}
101
102impl CommonArgs {
103 pub fn target_addr(&self) -> Result<SocketAddr, String> {
105 let addr_str = if self.target.contains(':') {
107 self.target.clone()
108 } else {
109 format!("{}:161", self.target)
110 };
111
112 addr_str
113 .parse()
114 .or_else(|_| {
115 use std::net::ToSocketAddrs;
117 addr_str
118 .to_socket_addrs()
119 .map_err(|e| e.to_string())?
120 .next()
121 .ok_or_else(|| format!("could not resolve hostname: {}", self.target))
122 })
123 .map_err(|e| format!("invalid target '{}': {}", self.target, e))
124 }
125
126 pub fn timeout_duration(&self) -> Duration {
128 Duration::from_secs_f64(self.timeout)
129 }
130
131 pub fn retry_config(&self) -> Retry {
133 let backoff = match self.backoff {
134 BackoffStrategy::None => Backoff::None,
135 BackoffStrategy::Fixed => Backoff::Fixed {
136 delay: Duration::from_millis(self.backoff_delay),
137 },
138 BackoffStrategy::Exponential => Backoff::Exponential {
139 initial: Duration::from_millis(self.backoff_delay),
140 max: Duration::from_millis(self.backoff_max),
141 jitter: self.backoff_jitter.clamp(0.0, 1.0),
142 },
143 };
144 Retry {
145 max_attempts: self.retries,
146 backoff,
147 }
148 }
149}
150
151#[derive(Debug, Parser)]
153pub struct V3Args {
154 #[arg(short = 'u', long = "username")]
156 pub username: Option<String>,
157
158 #[arg(short = 'l', long = "level")]
160 pub level: Option<String>,
161
162 #[arg(short = 'a', long = "auth-protocol")]
164 pub auth_protocol: Option<AuthProtocol>,
165
166 #[arg(short = 'A', long = "auth-password")]
168 pub auth_password: Option<String>,
169
170 #[arg(short = 'x', long = "priv-protocol")]
172 pub priv_protocol: Option<PrivProtocol>,
173
174 #[arg(short = 'X', long = "priv-password")]
176 pub priv_password: Option<String>,
177}
178
179impl V3Args {
180 pub fn is_v3(&self) -> bool {
182 self.username.is_some()
183 }
184
185 pub fn auth(&self, common: &CommonArgs) -> Result<Auth, String> {
190 if let Some(ref username) = self.username {
191 let mut builder = Auth::usm(username);
192 if let Some(proto) = self.auth_protocol {
193 let pass = self
194 .auth_password
195 .as_ref()
196 .ok_or("auth password required")?;
197 builder = builder.auth(proto, pass);
198 }
199 if let Some(proto) = self.priv_protocol {
200 let pass = self
201 .priv_password
202 .as_ref()
203 .ok_or("priv password required")?;
204 builder = builder.privacy(proto, pass);
205 }
206 Ok(builder.into())
207 } else {
208 let community = &common.community;
209 Ok(match common.snmp_version {
210 SnmpVersion::V1 => Auth::v1(community),
211 _ => Auth::v2c(community),
212 })
213 }
214 }
215
216 pub fn validate(&self) -> Result<(), String> {
218 if let Some(ref _username) = self.username {
219 if self.auth_protocol.is_some() && self.auth_password.is_none() {
221 return Err(
222 "authentication password (-A) required when using auth protocol".into(),
223 );
224 }
225
226 if self.priv_protocol.is_some() && self.priv_password.is_none() {
228 return Err("privacy password (-X) required when using priv protocol".into());
229 }
230
231 if self.priv_protocol.is_some() && self.auth_protocol.is_none() {
233 return Err("authentication protocol (-a) required when using privacy".into());
234 }
235
236 if let (Some(auth), Some(priv_proto)) = (self.auth_protocol, self.priv_protocol)
238 && !auth.is_compatible_with(priv_proto)
239 {
240 return Err(format!(
241 "{} authentication does not produce enough key material for {} privacy; use {} or stronger",
242 auth,
243 priv_proto,
244 priv_proto.min_auth_protocol()
245 ));
246 }
247 }
248 Ok(())
249 }
250}
251
252#[derive(Debug, Parser)]
254pub struct OutputArgs {
255 #[arg(short = 'O', long = "output", default_value = "human")]
257 pub format: OutputFormat,
258
259 #[arg(long = "verbose")]
261 pub verbose: bool,
262
263 #[arg(long = "hex")]
265 pub hex: bool,
266
267 #[arg(long = "timing")]
269 pub timing: bool,
270
271 #[arg(long = "no-hints")]
273 pub no_hints: bool,
274
275 #[arg(short = 'd', long = "debug")]
277 pub debug: bool,
278
279 #[arg(short = 'D', long = "trace")]
281 pub trace: bool,
282}
283
284impl OutputArgs {
285 pub fn init_tracing(&self) {
290 use tracing_subscriber::EnvFilter;
291
292 let filter = if self.trace {
293 "async_snmp=trace"
294 } else if self.debug {
295 "async_snmp=debug"
296 } else {
297 "async_snmp=warn"
298 };
299
300 let _ = tracing_subscriber::fmt()
301 .with_env_filter(EnvFilter::new(filter))
302 .with_writer(std::io::stderr)
303 .try_init();
304 }
305}
306
307#[derive(Debug, Parser)]
309pub struct WalkArgs {
310 #[arg(long = "getnext")]
312 pub getnext: bool,
313
314 #[arg(long = "max-rep", default_value = "10")]
316 pub max_repetitions: u32,
317}
318
319#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
321pub enum ValueType {
322 #[value(name = "i")]
324 Integer,
325 #[value(name = "u")]
327 Unsigned,
328 #[value(name = "s")]
330 String,
331 #[value(name = "x")]
333 HexString,
334 #[value(name = "o")]
336 Oid,
337 #[value(name = "a")]
339 IpAddress,
340 #[value(name = "t")]
342 TimeTicks,
343 #[value(name = "c")]
345 Counter32,
346 #[value(name = "C")]
348 Counter64,
349}
350
351impl std::str::FromStr for ValueType {
352 type Err = String;
353
354 fn from_str(s: &str) -> Result<Self, Self::Err> {
355 match s {
356 "i" => Ok(ValueType::Integer),
357 "u" => Ok(ValueType::Unsigned),
358 "s" => Ok(ValueType::String),
359 "x" => Ok(ValueType::HexString),
360 "o" => Ok(ValueType::Oid),
361 "a" => Ok(ValueType::IpAddress),
362 "t" => Ok(ValueType::TimeTicks),
363 "c" => Ok(ValueType::Counter32),
364 "C" => Ok(ValueType::Counter64),
365 _ => Err(format!("invalid type specifier: {}", s)),
366 }
367 }
368}
369
370impl ValueType {
371 pub fn parse_value(&self, s: &str) -> Result<crate::Value, String> {
373 use crate::{Oid, Value};
374
375 match self {
376 ValueType::Integer => {
377 let v: i32 = s
378 .parse()
379 .map_err(|_| format!("invalid integer value: {}", s))?;
380 Ok(Value::Integer(v))
381 }
382 ValueType::Unsigned => {
383 let v: u32 = s
384 .parse()
385 .map_err(|_| format!("invalid unsigned value: {}", s))?;
386 Ok(Value::Gauge32(v))
387 }
388 ValueType::String => Ok(Value::OctetString(s.as_bytes().to_vec().into())),
389 ValueType::HexString => {
390 let bytes = parse_hex_string(s)?;
391 Ok(Value::OctetString(bytes.into()))
392 }
393 ValueType::Oid => {
394 let oid = Oid::parse(s).map_err(|e| format!("invalid OID value: {}", e))?;
395 Ok(Value::ObjectIdentifier(oid))
396 }
397 ValueType::IpAddress => {
398 let parts: Vec<&str> = s.split('.').collect();
399 if parts.len() != 4 {
400 return Err(format!("invalid IP address: {}", s));
401 }
402 let mut bytes = [0u8; 4];
403 for (i, part) in parts.iter().enumerate() {
404 bytes[i] = part
405 .parse()
406 .map_err(|_| format!("invalid IP address octet: {}", part))?;
407 }
408 Ok(Value::IpAddress(bytes))
409 }
410 ValueType::TimeTicks => {
411 let v: u32 = s
412 .parse()
413 .map_err(|_| format!("invalid timeticks value: {}", s))?;
414 Ok(Value::TimeTicks(v))
415 }
416 ValueType::Counter32 => {
417 let v: u32 = s
418 .parse()
419 .map_err(|_| format!("invalid counter32 value: {}", s))?;
420 Ok(Value::Counter32(v))
421 }
422 ValueType::Counter64 => {
423 let v: u64 = s
424 .parse()
425 .map_err(|_| format!("invalid counter64 value: {}", s))?;
426 Ok(Value::Counter64(v))
427 }
428 }
429 }
430}
431
432fn parse_hex_string(s: &str) -> Result<Vec<u8>, String> {
434 let clean: String = s.chars().filter(|c| c.is_ascii_hexdigit()).collect();
436
437 if !clean.len().is_multiple_of(2) {
438 return Err("hex string must have even number of digits".into());
439 }
440
441 clean
442 .as_bytes()
443 .chunks(2)
444 .map(|chunk| {
445 let hex = std::str::from_utf8(chunk).unwrap();
446 u8::from_str_radix(hex, 16).map_err(|_| format!("invalid hex: {}", hex))
447 })
448 .collect()
449}
450
451#[cfg(test)]
452mod tests {
453 use super::*;
454
455 #[test]
456 fn test_target_addr_with_port() {
457 let args = CommonArgs {
458 target: "192.168.1.1:162".to_string(),
459 snmp_version: SnmpVersion::V2c,
460 community: "public".to_string(),
461 timeout: 5.0,
462 retries: 3,
463 backoff: BackoffStrategy::None,
464 backoff_delay: 100,
465 backoff_max: 5000,
466 backoff_jitter: 0.25,
467 };
468 let addr = args.target_addr().unwrap();
469 assert_eq!(addr.port(), 162);
470 }
471
472 #[test]
473 fn test_target_addr_default_port() {
474 let args = CommonArgs {
475 target: "192.168.1.1".to_string(),
476 snmp_version: SnmpVersion::V2c,
477 community: "public".to_string(),
478 timeout: 5.0,
479 retries: 3,
480 backoff: BackoffStrategy::None,
481 backoff_delay: 100,
482 backoff_max: 5000,
483 backoff_jitter: 0.25,
484 };
485 let addr = args.target_addr().unwrap();
486 assert_eq!(addr.port(), 161);
487 }
488
489 #[test]
490 fn test_retry_config_none() {
491 let args = CommonArgs {
492 target: "192.168.1.1".to_string(),
493 snmp_version: SnmpVersion::V2c,
494 community: "public".to_string(),
495 timeout: 5.0,
496 retries: 3,
497 backoff: BackoffStrategy::None,
498 backoff_delay: 100,
499 backoff_max: 5000,
500 backoff_jitter: 0.25,
501 };
502 let retry = args.retry_config();
503 assert_eq!(retry.max_attempts, 3);
504 assert!(matches!(retry.backoff, Backoff::None));
505 }
506
507 #[test]
508 fn test_retry_config_fixed() {
509 let args = CommonArgs {
510 target: "192.168.1.1".to_string(),
511 snmp_version: SnmpVersion::V2c,
512 community: "public".to_string(),
513 timeout: 5.0,
514 retries: 5,
515 backoff: BackoffStrategy::Fixed,
516 backoff_delay: 200,
517 backoff_max: 5000,
518 backoff_jitter: 0.25,
519 };
520 let retry = args.retry_config();
521 assert_eq!(retry.max_attempts, 5);
522 assert!(matches!(
523 retry.backoff,
524 Backoff::Fixed { delay } if delay == Duration::from_millis(200)
525 ));
526 }
527
528 #[test]
529 fn test_retry_config_exponential() {
530 let args = CommonArgs {
531 target: "192.168.1.1".to_string(),
532 snmp_version: SnmpVersion::V2c,
533 community: "public".to_string(),
534 timeout: 5.0,
535 retries: 4,
536 backoff: BackoffStrategy::Exponential,
537 backoff_delay: 50,
538 backoff_max: 2000,
539 backoff_jitter: 0.1,
540 };
541 let retry = args.retry_config();
542 assert_eq!(retry.max_attempts, 4);
543 match retry.backoff {
544 Backoff::Exponential {
545 initial,
546 max,
547 jitter,
548 } => {
549 assert_eq!(initial, Duration::from_millis(50));
550 assert_eq!(max, Duration::from_millis(2000));
551 assert!((jitter - 0.1).abs() < f64::EPSILON);
552 }
553 _ => panic!("expected Exponential"),
554 }
555 }
556
557 #[test]
558 fn test_v3_args_validation() {
559 let args = V3Args {
561 username: None,
562 level: None,
563 auth_protocol: None,
564 auth_password: None,
565 priv_protocol: None,
566 priv_password: None,
567 };
568 assert!(args.validate().is_ok());
569
570 let args = V3Args {
572 username: Some("admin".to_string()),
573 level: None,
574 auth_protocol: None,
575 auth_password: None,
576 priv_protocol: None,
577 priv_password: None,
578 };
579 assert!(args.validate().is_ok());
580
581 let args = V3Args {
583 username: Some("admin".to_string()),
584 level: None,
585 auth_protocol: Some(AuthProtocol::Sha256),
586 auth_password: None,
587 priv_protocol: None,
588 priv_password: None,
589 };
590 assert!(args.validate().is_err());
591
592 let args = V3Args {
594 username: Some("admin".to_string()),
595 level: None,
596 auth_protocol: None,
597 auth_password: None,
598 priv_protocol: Some(PrivProtocol::Aes128),
599 priv_password: Some("pass".to_string()),
600 };
601 assert!(args.validate().is_err());
602
603 let args = V3Args {
605 username: Some("admin".to_string()),
606 level: None,
607 auth_protocol: Some(AuthProtocol::Sha1),
608 auth_password: Some("pass".to_string()),
609 priv_protocol: Some(PrivProtocol::Aes256),
610 priv_password: Some("pass".to_string()),
611 };
612 assert!(args.validate().is_err());
613 }
614
615 #[test]
616 fn test_value_type_parse_integer() {
617 use crate::Value;
618 let v = ValueType::Integer.parse_value("42").unwrap();
619 assert!(matches!(v, Value::Integer(42)));
620
621 let v = ValueType::Integer.parse_value("-100").unwrap();
622 assert!(matches!(v, Value::Integer(-100)));
623
624 assert!(ValueType::Integer.parse_value("not_a_number").is_err());
625 }
626
627 #[test]
628 fn test_value_type_parse_unsigned() {
629 use crate::Value;
630 let v = ValueType::Unsigned.parse_value("42").unwrap();
631 assert!(matches!(v, Value::Gauge32(42)));
632
633 assert!(ValueType::Unsigned.parse_value("-1").is_err());
634 }
635
636 #[test]
637 fn test_value_type_parse_string() {
638 use crate::Value;
639 let v = ValueType::String.parse_value("hello world").unwrap();
640 if let Value::OctetString(bytes) = v {
641 assert_eq!(&*bytes, b"hello world");
642 } else {
643 panic!("expected OctetString");
644 }
645 }
646
647 #[test]
648 fn test_value_type_parse_hex_string() {
649 use crate::Value;
650
651 let v = ValueType::HexString.parse_value("001a2b").unwrap();
653 if let Value::OctetString(bytes) = v {
654 assert_eq!(&*bytes, &[0x00, 0x1a, 0x2b]);
655 } else {
656 panic!("expected OctetString");
657 }
658
659 let v = ValueType::HexString.parse_value("00 1A 2B").unwrap();
661 if let Value::OctetString(bytes) = v {
662 assert_eq!(&*bytes, &[0x00, 0x1a, 0x2b]);
663 } else {
664 panic!("expected OctetString");
665 }
666
667 assert!(ValueType::HexString.parse_value("001").is_err());
669 }
670
671 #[test]
672 fn test_value_type_parse_ip_address() {
673 use crate::Value;
674 let v = ValueType::IpAddress.parse_value("192.168.1.1").unwrap();
675 assert!(matches!(v, Value::IpAddress([192, 168, 1, 1])));
676
677 assert!(ValueType::IpAddress.parse_value("192.168.1").is_err());
678 assert!(ValueType::IpAddress.parse_value("256.1.1.1").is_err());
679 }
680
681 #[test]
682 fn test_value_type_parse_timeticks() {
683 use crate::Value;
684 let v = ValueType::TimeTicks.parse_value("12345678").unwrap();
685 assert!(matches!(v, Value::TimeTicks(12345678)));
686 }
687
688 #[test]
689 fn test_value_type_parse_counters() {
690 use crate::Value;
691
692 let v = ValueType::Counter32.parse_value("4294967295").unwrap();
693 assert!(matches!(v, Value::Counter32(4294967295)));
694
695 let v = ValueType::Counter64
696 .parse_value("18446744073709551615")
697 .unwrap();
698 assert!(matches!(v, Value::Counter64(18446744073709551615)));
699 }
700}