1use clap::{Parser, ValueEnum};
6use std::net::SocketAddr;
7use std::time::Duration;
8
9use crate::Version;
10use crate::client::{Auth, Backoff, Retry};
11use crate::v3::{AuthProtocol, PrivProtocol};
12
13#[derive(Debug, Clone, Copy, Default, ValueEnum)]
15pub enum SnmpVersion {
16 #[value(name = "1")]
18 V1,
19 #[default]
21 #[value(name = "2c")]
22 V2c,
23 #[value(name = "3")]
25 V3,
26}
27
28impl From<SnmpVersion> for Version {
29 fn from(v: SnmpVersion) -> Self {
30 match v {
31 SnmpVersion::V1 => Version::V1,
32 SnmpVersion::V2c => Version::V2c,
33 SnmpVersion::V3 => Version::V3,
34 }
35 }
36}
37
38#[derive(Debug, Clone, Copy, Default, ValueEnum)]
40pub enum OutputFormat {
41 #[default]
43 Human,
44 Json,
46 Raw,
48}
49
50#[derive(Debug, Clone, Copy, Default, ValueEnum)]
52pub enum BackoffStrategy {
53 #[default]
55 None,
56 Fixed,
58 Exponential,
60}
61
62#[derive(Debug, Parser)]
64pub struct CommonArgs {
65 #[arg(value_name = "TARGET")]
67 pub target: String,
68
69 #[arg(short = 'v', long = "snmp-version", default_value = "2c")]
71 pub snmp_version: SnmpVersion,
72
73 #[arg(short = 'c', long = "community", default_value = "public")]
75 pub community: String,
76
77 #[arg(short = 't', long = "timeout", default_value = "5")]
79 pub timeout: f64,
80
81 #[arg(short = 'r', long = "retries", default_value = "3")]
83 pub retries: u32,
84
85 #[arg(long = "backoff", default_value = "none")]
87 pub backoff: BackoffStrategy,
88
89 #[arg(long = "backoff-delay", default_value = "1000")]
91 pub backoff_delay: u64,
92
93 #[arg(long = "backoff-max", default_value = "5000")]
95 pub backoff_max: u64,
96
97 #[arg(long = "backoff-jitter", default_value = "0.25")]
99 pub backoff_jitter: f64,
100}
101
102impl CommonArgs {
103 pub fn target_addr(&self) -> Result<SocketAddr, String> {
105 let addr_str = if self.target.contains(':') {
107 self.target.clone()
108 } else {
109 format!("{}:161", self.target)
110 };
111
112 addr_str
113 .parse()
114 .or_else(|_| {
115 use std::net::ToSocketAddrs;
117 addr_str
118 .to_socket_addrs()
119 .map_err(|e| e.to_string())?
120 .next()
121 .ok_or_else(|| format!("could not resolve hostname: {}", self.target))
122 })
123 .map_err(|e| format!("invalid target '{}': {}", self.target, e))
124 }
125
126 pub fn timeout_duration(&self) -> Duration {
128 Duration::from_secs_f64(self.timeout)
129 }
130
131 pub fn retry_config(&self) -> Retry {
133 let backoff = match self.backoff {
134 BackoffStrategy::None => Backoff::None,
135 BackoffStrategy::Fixed => Backoff::Fixed {
136 delay: Duration::from_millis(self.backoff_delay),
137 },
138 BackoffStrategy::Exponential => Backoff::Exponential {
139 initial: Duration::from_millis(self.backoff_delay),
140 max: Duration::from_millis(self.backoff_max),
141 jitter: self.backoff_jitter.clamp(0.0, 1.0),
142 },
143 };
144 Retry {
145 max_attempts: self.retries,
146 backoff,
147 }
148 }
149}
150
151#[derive(Debug, Parser)]
153pub struct V3Args {
154 #[arg(short = 'u', long = "username")]
156 pub username: Option<String>,
157
158 #[arg(short = 'l', long = "level")]
160 pub level: Option<String>,
161
162 #[arg(short = 'a', long = "auth-protocol")]
164 pub auth_protocol: Option<AuthProtocol>,
165
166 #[arg(short = 'A', long = "auth-password")]
168 pub auth_password: Option<String>,
169
170 #[arg(short = 'x', long = "priv-protocol")]
172 pub priv_protocol: Option<PrivProtocol>,
173
174 #[arg(short = 'X', long = "priv-password")]
176 pub priv_password: Option<String>,
177}
178
179impl V3Args {
180 pub fn is_v3(&self) -> bool {
182 self.username.is_some()
183 }
184
185 pub fn auth(&self, common: &CommonArgs) -> Result<Auth, String> {
190 if let Some(ref username) = self.username {
191 let mut builder = Auth::usm(username);
192 if let Some(proto) = self.auth_protocol {
193 let pass = self
194 .auth_password
195 .as_ref()
196 .ok_or("auth password required")?;
197 builder = builder.auth(proto, pass);
198 }
199 if let Some(proto) = self.priv_protocol {
200 let pass = self
201 .priv_password
202 .as_ref()
203 .ok_or("priv password required")?;
204 builder = builder.privacy(proto, pass);
205 }
206 Ok(builder.into())
207 } else {
208 let community = &common.community;
209 Ok(match common.snmp_version {
210 SnmpVersion::V1 => Auth::v1(community),
211 _ => Auth::v2c(community),
212 })
213 }
214 }
215
216 pub fn validate(&self) -> Result<(), String> {
218 if let Some(ref _username) = self.username {
219 if self.auth_protocol.is_some() && self.auth_password.is_none() {
221 return Err(
222 "authentication password (-A) required when using auth protocol".into(),
223 );
224 }
225
226 if self.priv_protocol.is_some() && self.priv_password.is_none() {
228 return Err("privacy password (-X) required when using priv protocol".into());
229 }
230
231 if self.priv_protocol.is_some() && self.auth_protocol.is_none() {
233 return Err("authentication protocol (-a) required when using privacy".into());
234 }
235 }
236 Ok(())
237 }
238}
239
240#[derive(Debug, Parser)]
242pub struct OutputArgs {
243 #[arg(short = 'O', long = "output", default_value = "human")]
245 pub format: OutputFormat,
246
247 #[arg(long = "verbose")]
249 pub verbose: bool,
250
251 #[arg(long = "hex")]
253 pub hex: bool,
254
255 #[arg(long = "timing")]
257 pub timing: bool,
258
259 #[arg(long = "no-hints")]
261 pub no_hints: bool,
262
263 #[arg(short = 'd', long = "debug")]
265 pub debug: bool,
266
267 #[arg(short = 'D', long = "trace")]
269 pub trace: bool,
270}
271
272impl OutputArgs {
273 pub fn init_tracing(&self) {
278 use tracing_subscriber::EnvFilter;
279
280 let filter = if self.trace {
281 "async_snmp=trace"
282 } else if self.debug {
283 "async_snmp=debug"
284 } else {
285 "async_snmp=warn"
286 };
287
288 let _ = tracing_subscriber::fmt()
289 .with_env_filter(EnvFilter::new(filter))
290 .with_writer(std::io::stderr)
291 .try_init();
292 }
293}
294
295#[derive(Debug, Parser)]
297pub struct WalkArgs {
298 #[arg(long = "getnext")]
300 pub getnext: bool,
301
302 #[arg(long = "max-rep", default_value = "10")]
304 pub max_repetitions: u32,
305}
306
307#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
309pub enum ValueType {
310 #[value(name = "i")]
312 Integer,
313 #[value(name = "u")]
315 Unsigned,
316 #[value(name = "s")]
318 String,
319 #[value(name = "x")]
321 HexString,
322 #[value(name = "o")]
324 Oid,
325 #[value(name = "a")]
327 IpAddress,
328 #[value(name = "t")]
330 TimeTicks,
331 #[value(name = "c")]
333 Counter32,
334 #[value(name = "C")]
336 Counter64,
337}
338
339impl std::str::FromStr for ValueType {
340 type Err = String;
341
342 fn from_str(s: &str) -> Result<Self, Self::Err> {
343 match s {
344 "i" => Ok(ValueType::Integer),
345 "u" => Ok(ValueType::Unsigned),
346 "s" => Ok(ValueType::String),
347 "x" => Ok(ValueType::HexString),
348 "o" => Ok(ValueType::Oid),
349 "a" => Ok(ValueType::IpAddress),
350 "t" => Ok(ValueType::TimeTicks),
351 "c" => Ok(ValueType::Counter32),
352 "C" => Ok(ValueType::Counter64),
353 _ => Err(format!("invalid type specifier: {}", s)),
354 }
355 }
356}
357
358impl ValueType {
359 pub fn parse_value(&self, s: &str) -> Result<crate::Value, String> {
361 use crate::{Oid, Value};
362
363 match self {
364 ValueType::Integer => {
365 let v: i32 = s
366 .parse()
367 .map_err(|_| format!("invalid integer value: {}", s))?;
368 Ok(Value::Integer(v))
369 }
370 ValueType::Unsigned => {
371 let v: u32 = s
372 .parse()
373 .map_err(|_| format!("invalid unsigned value: {}", s))?;
374 Ok(Value::Gauge32(v))
375 }
376 ValueType::String => Ok(Value::OctetString(s.as_bytes().to_vec().into())),
377 ValueType::HexString => {
378 let bytes = parse_hex_string(s)?;
379 Ok(Value::OctetString(bytes.into()))
380 }
381 ValueType::Oid => {
382 let oid = Oid::parse(s).map_err(|e| format!("invalid OID value: {}", e))?;
383 Ok(Value::ObjectIdentifier(oid))
384 }
385 ValueType::IpAddress => {
386 let parts: Vec<&str> = s.split('.').collect();
387 if parts.len() != 4 {
388 return Err(format!("invalid IP address: {}", s));
389 }
390 let mut bytes = [0u8; 4];
391 for (i, part) in parts.iter().enumerate() {
392 bytes[i] = part
393 .parse()
394 .map_err(|_| format!("invalid IP address octet: {}", part))?;
395 }
396 Ok(Value::IpAddress(bytes))
397 }
398 ValueType::TimeTicks => {
399 let v: u32 = s
400 .parse()
401 .map_err(|_| format!("invalid timeticks value: {}", s))?;
402 Ok(Value::TimeTicks(v))
403 }
404 ValueType::Counter32 => {
405 let v: u32 = s
406 .parse()
407 .map_err(|_| format!("invalid counter32 value: {}", s))?;
408 Ok(Value::Counter32(v))
409 }
410 ValueType::Counter64 => {
411 let v: u64 = s
412 .parse()
413 .map_err(|_| format!("invalid counter64 value: {}", s))?;
414 Ok(Value::Counter64(v))
415 }
416 }
417 }
418}
419
420fn parse_hex_string(s: &str) -> Result<Vec<u8>, String> {
422 let clean: String = s.chars().filter(|c| c.is_ascii_hexdigit()).collect();
424
425 if !clean.len().is_multiple_of(2) {
426 return Err("hex string must have even number of digits".into());
427 }
428
429 clean
430 .as_bytes()
431 .chunks(2)
432 .map(|chunk| {
433 let hex = std::str::from_utf8(chunk).unwrap();
434 u8::from_str_radix(hex, 16).map_err(|_| format!("invalid hex: {}", hex))
435 })
436 .collect()
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442
443 #[test]
444 fn test_target_addr_with_port() {
445 let args = CommonArgs {
446 target: "192.168.1.1:162".to_string(),
447 snmp_version: SnmpVersion::V2c,
448 community: "public".to_string(),
449 timeout: 5.0,
450 retries: 3,
451 backoff: BackoffStrategy::None,
452 backoff_delay: 100,
453 backoff_max: 5000,
454 backoff_jitter: 0.25,
455 };
456 let addr = args.target_addr().unwrap();
457 assert_eq!(addr.port(), 162);
458 }
459
460 #[test]
461 fn test_target_addr_default_port() {
462 let args = CommonArgs {
463 target: "192.168.1.1".to_string(),
464 snmp_version: SnmpVersion::V2c,
465 community: "public".to_string(),
466 timeout: 5.0,
467 retries: 3,
468 backoff: BackoffStrategy::None,
469 backoff_delay: 100,
470 backoff_max: 5000,
471 backoff_jitter: 0.25,
472 };
473 let addr = args.target_addr().unwrap();
474 assert_eq!(addr.port(), 161);
475 }
476
477 #[test]
478 fn test_retry_config_none() {
479 let args = CommonArgs {
480 target: "192.168.1.1".to_string(),
481 snmp_version: SnmpVersion::V2c,
482 community: "public".to_string(),
483 timeout: 5.0,
484 retries: 3,
485 backoff: BackoffStrategy::None,
486 backoff_delay: 100,
487 backoff_max: 5000,
488 backoff_jitter: 0.25,
489 };
490 let retry = args.retry_config();
491 assert_eq!(retry.max_attempts, 3);
492 assert!(matches!(retry.backoff, Backoff::None));
493 }
494
495 #[test]
496 fn test_retry_config_fixed() {
497 let args = CommonArgs {
498 target: "192.168.1.1".to_string(),
499 snmp_version: SnmpVersion::V2c,
500 community: "public".to_string(),
501 timeout: 5.0,
502 retries: 5,
503 backoff: BackoffStrategy::Fixed,
504 backoff_delay: 200,
505 backoff_max: 5000,
506 backoff_jitter: 0.25,
507 };
508 let retry = args.retry_config();
509 assert_eq!(retry.max_attempts, 5);
510 assert!(matches!(
511 retry.backoff,
512 Backoff::Fixed { delay } if delay == Duration::from_millis(200)
513 ));
514 }
515
516 #[test]
517 fn test_retry_config_exponential() {
518 let args = CommonArgs {
519 target: "192.168.1.1".to_string(),
520 snmp_version: SnmpVersion::V2c,
521 community: "public".to_string(),
522 timeout: 5.0,
523 retries: 4,
524 backoff: BackoffStrategy::Exponential,
525 backoff_delay: 50,
526 backoff_max: 2000,
527 backoff_jitter: 0.1,
528 };
529 let retry = args.retry_config();
530 assert_eq!(retry.max_attempts, 4);
531 match retry.backoff {
532 Backoff::Exponential {
533 initial,
534 max,
535 jitter,
536 } => {
537 assert_eq!(initial, Duration::from_millis(50));
538 assert_eq!(max, Duration::from_millis(2000));
539 assert!((jitter - 0.1).abs() < f64::EPSILON);
540 }
541 _ => panic!("expected Exponential"),
542 }
543 }
544
545 #[test]
546 fn test_v3_args_validation() {
547 let args = V3Args {
549 username: None,
550 level: None,
551 auth_protocol: None,
552 auth_password: None,
553 priv_protocol: None,
554 priv_password: None,
555 };
556 assert!(args.validate().is_ok());
557
558 let args = V3Args {
560 username: Some("admin".to_string()),
561 level: None,
562 auth_protocol: None,
563 auth_password: None,
564 priv_protocol: None,
565 priv_password: None,
566 };
567 assert!(args.validate().is_ok());
568
569 let args = V3Args {
571 username: Some("admin".to_string()),
572 level: None,
573 auth_protocol: Some(AuthProtocol::Sha256),
574 auth_password: None,
575 priv_protocol: None,
576 priv_password: None,
577 };
578 assert!(args.validate().is_err());
579
580 let args = V3Args {
582 username: Some("admin".to_string()),
583 level: None,
584 auth_protocol: None,
585 auth_password: None,
586 priv_protocol: Some(PrivProtocol::Aes128),
587 priv_password: Some("pass".to_string()),
588 };
589 assert!(args.validate().is_err());
590
591 let args = V3Args {
593 username: Some("admin".to_string()),
594 level: None,
595 auth_protocol: Some(AuthProtocol::Sha1),
596 auth_password: Some("pass".to_string()),
597 priv_protocol: Some(PrivProtocol::Aes256),
598 priv_password: Some("pass".to_string()),
599 };
600 assert!(args.validate().is_ok());
601 }
602
603 #[test]
604 fn test_value_type_parse_integer() {
605 use crate::Value;
606 let v = ValueType::Integer.parse_value("42").unwrap();
607 assert!(matches!(v, Value::Integer(42)));
608
609 let v = ValueType::Integer.parse_value("-100").unwrap();
610 assert!(matches!(v, Value::Integer(-100)));
611
612 assert!(ValueType::Integer.parse_value("not_a_number").is_err());
613 }
614
615 #[test]
616 fn test_value_type_parse_unsigned() {
617 use crate::Value;
618 let v = ValueType::Unsigned.parse_value("42").unwrap();
619 assert!(matches!(v, Value::Gauge32(42)));
620
621 assert!(ValueType::Unsigned.parse_value("-1").is_err());
622 }
623
624 #[test]
625 fn test_value_type_parse_string() {
626 use crate::Value;
627 let v = ValueType::String.parse_value("hello world").unwrap();
628 if let Value::OctetString(bytes) = v {
629 assert_eq!(&*bytes, b"hello world");
630 } else {
631 panic!("expected OctetString");
632 }
633 }
634
635 #[test]
636 fn test_value_type_parse_hex_string() {
637 use crate::Value;
638
639 let v = ValueType::HexString.parse_value("001a2b").unwrap();
641 if let Value::OctetString(bytes) = v {
642 assert_eq!(&*bytes, &[0x00, 0x1a, 0x2b]);
643 } else {
644 panic!("expected OctetString");
645 }
646
647 let v = ValueType::HexString.parse_value("00 1A 2B").unwrap();
649 if let Value::OctetString(bytes) = v {
650 assert_eq!(&*bytes, &[0x00, 0x1a, 0x2b]);
651 } else {
652 panic!("expected OctetString");
653 }
654
655 assert!(ValueType::HexString.parse_value("001").is_err());
657 }
658
659 #[test]
660 fn test_value_type_parse_ip_address() {
661 use crate::Value;
662 let v = ValueType::IpAddress.parse_value("192.168.1.1").unwrap();
663 assert!(matches!(v, Value::IpAddress([192, 168, 1, 1])));
664
665 assert!(ValueType::IpAddress.parse_value("192.168.1").is_err());
666 assert!(ValueType::IpAddress.parse_value("256.1.1.1").is_err());
667 }
668
669 #[test]
670 fn test_value_type_parse_timeticks() {
671 use crate::Value;
672 let v = ValueType::TimeTicks.parse_value("12345678").unwrap();
673 assert!(matches!(v, Value::TimeTicks(12345678)));
674 }
675
676 #[test]
677 fn test_value_type_parse_counters() {
678 use crate::Value;
679
680 let v = ValueType::Counter32.parse_value("4294967295").unwrap();
681 assert!(matches!(v, Value::Counter32(4294967295)));
682
683 let v = ValueType::Counter64
684 .parse_value("18446744073709551615")
685 .unwrap();
686 assert!(matches!(v, Value::Counter64(18446744073709551615)));
687 }
688}