1use clap::{Parser, ValueEnum};
6use std::time::Duration;
7
8use crate::Version;
9use crate::client::{Auth, Backoff, Retry};
10use crate::v3::{AuthProtocol, PrivProtocol};
11
12#[derive(Debug, Clone, Copy, Default, ValueEnum)]
14pub enum SnmpVersion {
15 #[value(name = "1")]
17 V1,
18 #[default]
20 #[value(name = "2c")]
21 V2c,
22 #[value(name = "3")]
24 V3,
25}
26
27impl From<SnmpVersion> for Version {
28 fn from(v: SnmpVersion) -> Self {
29 match v {
30 SnmpVersion::V1 => Version::V1,
31 SnmpVersion::V2c => Version::V2c,
32 SnmpVersion::V3 => Version::V3,
33 }
34 }
35}
36
37#[derive(Debug, Clone, Copy, Default, ValueEnum)]
39pub enum OutputFormat {
40 #[default]
42 Human,
43 Json,
45 Raw,
47}
48
49#[derive(Debug, Clone, Copy, Default, ValueEnum)]
51pub enum BackoffStrategy {
52 #[default]
54 None,
55 Fixed,
57 Exponential,
59}
60
61#[derive(Debug, Parser)]
63pub struct CommonArgs {
64 #[arg(value_name = "TARGET")]
66 pub target: String,
67
68 #[arg(short = 'v', long = "snmp-version", default_value = "2c")]
70 pub snmp_version: SnmpVersion,
71
72 #[arg(short = 'c', long = "community", default_value = "public")]
74 pub community: String,
75
76 #[arg(short = 't', long = "timeout", default_value = "5")]
78 pub timeout: f64,
79
80 #[arg(short = 'r', long = "retries", default_value = "3")]
82 pub retries: u32,
83
84 #[arg(long = "backoff", default_value = "none")]
86 pub backoff: BackoffStrategy,
87
88 #[arg(long = "backoff-delay", default_value = "1000")]
90 pub backoff_delay: u64,
91
92 #[arg(long = "backoff-max", default_value = "5000")]
94 pub backoff_max: u64,
95
96 #[arg(long = "backoff-jitter", default_value = "0.25")]
98 pub backoff_jitter: f64,
99}
100
101impl CommonArgs {
102 pub fn timeout_duration(&self) -> Duration {
104 Duration::from_secs_f64(self.timeout)
105 }
106
107 pub fn retry_config(&self) -> Retry {
109 let backoff = match self.backoff {
110 BackoffStrategy::None => Backoff::None,
111 BackoffStrategy::Fixed => Backoff::Fixed {
112 delay: Duration::from_millis(self.backoff_delay),
113 },
114 BackoffStrategy::Exponential => Backoff::Exponential {
115 initial: Duration::from_millis(self.backoff_delay),
116 max: Duration::from_millis(self.backoff_max),
117 jitter: self.backoff_jitter.clamp(0.0, 1.0),
118 },
119 };
120 Retry {
121 max_attempts: self.retries,
122 backoff,
123 }
124 }
125}
126
127#[derive(Debug, Parser)]
129pub struct V3Args {
130 #[arg(short = 'u', long = "username")]
132 pub username: Option<String>,
133
134 #[arg(short = 'l', long = "level")]
136 pub level: Option<String>,
137
138 #[arg(short = 'a', long = "auth-protocol")]
140 pub auth_protocol: Option<AuthProtocol>,
141
142 #[arg(short = 'A', long = "auth-password")]
144 pub auth_password: Option<String>,
145
146 #[arg(short = 'x', long = "priv-protocol")]
148 pub priv_protocol: Option<PrivProtocol>,
149
150 #[arg(short = 'X', long = "priv-password")]
152 pub priv_password: Option<String>,
153}
154
155impl V3Args {
156 pub fn is_v3(&self) -> bool {
158 self.username.is_some()
159 }
160
161 pub fn auth(&self, common: &CommonArgs) -> Result<Auth, String> {
166 if let Some(ref username) = self.username {
167 let mut builder = Auth::usm(username);
168 if let Some(proto) = self.auth_protocol {
169 let pass = self
170 .auth_password
171 .as_ref()
172 .ok_or("auth password required")?;
173 builder = builder.auth(proto, pass);
174 }
175 if let Some(proto) = self.priv_protocol {
176 let pass = self
177 .priv_password
178 .as_ref()
179 .ok_or("priv password required")?;
180 builder = builder.privacy(proto, pass);
181 }
182 Ok(builder.into())
183 } else {
184 let community = &common.community;
185 Ok(match common.snmp_version {
186 SnmpVersion::V1 => Auth::v1(community),
187 _ => Auth::v2c(community),
188 })
189 }
190 }
191
192 pub fn validate(&self) -> Result<(), String> {
194 if let Some(ref _username) = self.username {
195 if self.auth_protocol.is_some() && self.auth_password.is_none() {
197 return Err(
198 "authentication password (-A) required when using auth protocol".into(),
199 );
200 }
201
202 if self.priv_protocol.is_some() && self.priv_password.is_none() {
204 return Err("privacy password (-X) required when using priv protocol".into());
205 }
206
207 if self.priv_protocol.is_some() && self.auth_protocol.is_none() {
209 return Err("authentication protocol (-a) required when using privacy".into());
210 }
211 }
212 Ok(())
213 }
214}
215
216#[derive(Debug, Parser)]
218pub struct OutputArgs {
219 #[arg(short = 'O', long = "output", default_value = "human")]
221 pub format: OutputFormat,
222
223 #[arg(long = "verbose")]
225 pub verbose: bool,
226
227 #[arg(long = "hex")]
229 pub hex: bool,
230
231 #[arg(long = "timing")]
233 pub timing: bool,
234
235 #[arg(long = "no-hints")]
237 pub no_hints: bool,
238
239 #[arg(short = 'd', long = "debug")]
241 pub debug: bool,
242
243 #[arg(short = 'D', long = "trace")]
245 pub trace: bool,
246}
247
248impl OutputArgs {
249 pub fn init_tracing(&self) {
254 use tracing_subscriber::EnvFilter;
255
256 let filter = if self.trace {
257 "async_snmp=trace"
258 } else if self.debug {
259 "async_snmp=debug"
260 } else {
261 "async_snmp=warn"
262 };
263
264 let _ = tracing_subscriber::fmt()
265 .with_env_filter(EnvFilter::new(filter))
266 .with_writer(std::io::stderr)
267 .try_init();
268 }
269}
270
271#[derive(Debug, Parser)]
273pub struct WalkArgs {
274 #[arg(long = "getnext")]
276 pub getnext: bool,
277
278 #[arg(long = "max-rep", default_value = "10")]
280 pub max_repetitions: u32,
281}
282
283#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
285pub enum ValueType {
286 #[value(name = "i")]
288 Integer,
289 #[value(name = "u")]
291 Unsigned,
292 #[value(name = "s")]
294 String,
295 #[value(name = "x")]
297 HexString,
298 #[value(name = "o")]
300 Oid,
301 #[value(name = "a")]
303 IpAddress,
304 #[value(name = "t")]
306 TimeTicks,
307 #[value(name = "c")]
309 Counter32,
310 #[value(name = "C")]
312 Counter64,
313}
314
315impl std::str::FromStr for ValueType {
316 type Err = String;
317
318 fn from_str(s: &str) -> Result<Self, Self::Err> {
319 match s {
320 "i" => Ok(ValueType::Integer),
321 "u" => Ok(ValueType::Unsigned),
322 "s" => Ok(ValueType::String),
323 "x" => Ok(ValueType::HexString),
324 "o" => Ok(ValueType::Oid),
325 "a" => Ok(ValueType::IpAddress),
326 "t" => Ok(ValueType::TimeTicks),
327 "c" => Ok(ValueType::Counter32),
328 "C" => Ok(ValueType::Counter64),
329 _ => Err(format!("invalid type specifier: {}", s)),
330 }
331 }
332}
333
334impl ValueType {
335 pub fn parse_value(&self, s: &str) -> Result<crate::Value, String> {
337 use crate::{Oid, Value};
338
339 match self {
340 ValueType::Integer => {
341 let v: i32 = s
342 .parse()
343 .map_err(|_| format!("invalid integer value: {}", s))?;
344 Ok(Value::Integer(v))
345 }
346 ValueType::Unsigned => {
347 let v: u32 = s
348 .parse()
349 .map_err(|_| format!("invalid unsigned value: {}", s))?;
350 Ok(Value::Gauge32(v))
351 }
352 ValueType::String => Ok(Value::OctetString(s.as_bytes().to_vec().into())),
353 ValueType::HexString => {
354 let bytes = parse_hex_string(s)?;
355 Ok(Value::OctetString(bytes.into()))
356 }
357 ValueType::Oid => {
358 let oid = Oid::parse(s).map_err(|e| format!("invalid OID value: {}", e))?;
359 Ok(Value::ObjectIdentifier(oid))
360 }
361 ValueType::IpAddress => {
362 let parts: Vec<&str> = s.split('.').collect();
363 if parts.len() != 4 {
364 return Err(format!("invalid IP address: {}", s));
365 }
366 let mut bytes = [0u8; 4];
367 for (i, part) in parts.iter().enumerate() {
368 bytes[i] = part
369 .parse()
370 .map_err(|_| format!("invalid IP address octet: {}", part))?;
371 }
372 Ok(Value::IpAddress(bytes))
373 }
374 ValueType::TimeTicks => {
375 let v: u32 = s
376 .parse()
377 .map_err(|_| format!("invalid timeticks value: {}", s))?;
378 Ok(Value::TimeTicks(v))
379 }
380 ValueType::Counter32 => {
381 let v: u32 = s
382 .parse()
383 .map_err(|_| format!("invalid counter32 value: {}", s))?;
384 Ok(Value::Counter32(v))
385 }
386 ValueType::Counter64 => {
387 let v: u64 = s
388 .parse()
389 .map_err(|_| format!("invalid counter64 value: {}", s))?;
390 Ok(Value::Counter64(v))
391 }
392 }
393 }
394}
395
396fn parse_hex_string(s: &str) -> Result<Vec<u8>, String> {
398 let clean: String = s.chars().filter(|c| c.is_ascii_hexdigit()).collect();
400
401 if !clean.len().is_multiple_of(2) {
402 return Err("hex string must have even number of digits".into());
403 }
404
405 clean
406 .as_bytes()
407 .chunks(2)
408 .map(|chunk| {
409 let hex = std::str::from_utf8(chunk).unwrap();
410 u8::from_str_radix(hex, 16).map_err(|_| format!("invalid hex: {}", hex))
411 })
412 .collect()
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418
419 #[test]
420 fn test_retry_config_none() {
421 let args = CommonArgs {
422 target: "192.168.1.1".to_string(),
423 snmp_version: SnmpVersion::V2c,
424 community: "public".to_string(),
425 timeout: 5.0,
426 retries: 3,
427 backoff: BackoffStrategy::None,
428 backoff_delay: 100,
429 backoff_max: 5000,
430 backoff_jitter: 0.25,
431 };
432 let retry = args.retry_config();
433 assert_eq!(retry.max_attempts, 3);
434 assert!(matches!(retry.backoff, Backoff::None));
435 }
436
437 #[test]
438 fn test_retry_config_fixed() {
439 let args = CommonArgs {
440 target: "192.168.1.1".to_string(),
441 snmp_version: SnmpVersion::V2c,
442 community: "public".to_string(),
443 timeout: 5.0,
444 retries: 5,
445 backoff: BackoffStrategy::Fixed,
446 backoff_delay: 200,
447 backoff_max: 5000,
448 backoff_jitter: 0.25,
449 };
450 let retry = args.retry_config();
451 assert_eq!(retry.max_attempts, 5);
452 assert!(matches!(
453 retry.backoff,
454 Backoff::Fixed { delay } if delay == Duration::from_millis(200)
455 ));
456 }
457
458 #[test]
459 fn test_retry_config_exponential() {
460 let args = CommonArgs {
461 target: "192.168.1.1".to_string(),
462 snmp_version: SnmpVersion::V2c,
463 community: "public".to_string(),
464 timeout: 5.0,
465 retries: 4,
466 backoff: BackoffStrategy::Exponential,
467 backoff_delay: 50,
468 backoff_max: 2000,
469 backoff_jitter: 0.1,
470 };
471 let retry = args.retry_config();
472 assert_eq!(retry.max_attempts, 4);
473 match retry.backoff {
474 Backoff::Exponential {
475 initial,
476 max,
477 jitter,
478 } => {
479 assert_eq!(initial, Duration::from_millis(50));
480 assert_eq!(max, Duration::from_millis(2000));
481 assert!((jitter - 0.1).abs() < f64::EPSILON);
482 }
483 _ => panic!("expected Exponential"),
484 }
485 }
486
487 #[test]
488 fn test_v3_args_validation() {
489 let args = V3Args {
491 username: None,
492 level: None,
493 auth_protocol: None,
494 auth_password: None,
495 priv_protocol: None,
496 priv_password: None,
497 };
498 assert!(args.validate().is_ok());
499
500 let args = V3Args {
502 username: Some("admin".to_string()),
503 level: None,
504 auth_protocol: None,
505 auth_password: None,
506 priv_protocol: None,
507 priv_password: None,
508 };
509 assert!(args.validate().is_ok());
510
511 let args = V3Args {
513 username: Some("admin".to_string()),
514 level: None,
515 auth_protocol: Some(AuthProtocol::Sha256),
516 auth_password: None,
517 priv_protocol: None,
518 priv_password: None,
519 };
520 assert!(args.validate().is_err());
521
522 let args = V3Args {
524 username: Some("admin".to_string()),
525 level: None,
526 auth_protocol: None,
527 auth_password: None,
528 priv_protocol: Some(PrivProtocol::Aes128),
529 priv_password: Some("pass".to_string()),
530 };
531 assert!(args.validate().is_err());
532
533 let args = V3Args {
535 username: Some("admin".to_string()),
536 level: None,
537 auth_protocol: Some(AuthProtocol::Sha1),
538 auth_password: Some("pass".to_string()),
539 priv_protocol: Some(PrivProtocol::Aes256),
540 priv_password: Some("pass".to_string()),
541 };
542 assert!(args.validate().is_ok());
543 }
544
545 #[test]
546 fn test_value_type_parse_integer() {
547 use crate::Value;
548 let v = ValueType::Integer.parse_value("42").unwrap();
549 assert!(matches!(v, Value::Integer(42)));
550
551 let v = ValueType::Integer.parse_value("-100").unwrap();
552 assert!(matches!(v, Value::Integer(-100)));
553
554 assert!(ValueType::Integer.parse_value("not_a_number").is_err());
555 }
556
557 #[test]
558 fn test_value_type_parse_unsigned() {
559 use crate::Value;
560 let v = ValueType::Unsigned.parse_value("42").unwrap();
561 assert!(matches!(v, Value::Gauge32(42)));
562
563 assert!(ValueType::Unsigned.parse_value("-1").is_err());
564 }
565
566 #[test]
567 fn test_value_type_parse_string() {
568 use crate::Value;
569 let v = ValueType::String.parse_value("hello world").unwrap();
570 if let Value::OctetString(bytes) = v {
571 assert_eq!(&*bytes, b"hello world");
572 } else {
573 panic!("expected OctetString");
574 }
575 }
576
577 #[test]
578 fn test_value_type_parse_hex_string() {
579 use crate::Value;
580
581 let v = ValueType::HexString.parse_value("001a2b").unwrap();
583 if let Value::OctetString(bytes) = v {
584 assert_eq!(&*bytes, &[0x00, 0x1a, 0x2b]);
585 } else {
586 panic!("expected OctetString");
587 }
588
589 let v = ValueType::HexString.parse_value("00 1A 2B").unwrap();
591 if let Value::OctetString(bytes) = v {
592 assert_eq!(&*bytes, &[0x00, 0x1a, 0x2b]);
593 } else {
594 panic!("expected OctetString");
595 }
596
597 assert!(ValueType::HexString.parse_value("001").is_err());
599 }
600
601 #[test]
602 fn test_value_type_parse_ip_address() {
603 use crate::Value;
604 let v = ValueType::IpAddress.parse_value("192.168.1.1").unwrap();
605 assert!(matches!(v, Value::IpAddress([192, 168, 1, 1])));
606
607 assert!(ValueType::IpAddress.parse_value("192.168.1").is_err());
608 assert!(ValueType::IpAddress.parse_value("256.1.1.1").is_err());
609 }
610
611 #[test]
612 fn test_value_type_parse_timeticks() {
613 use crate::Value;
614 let v = ValueType::TimeTicks.parse_value("12345678").unwrap();
615 assert!(matches!(v, Value::TimeTicks(12345678)));
616 }
617
618 #[test]
619 fn test_value_type_parse_counters() {
620 use crate::Value;
621
622 let v = ValueType::Counter32.parse_value("4294967295").unwrap();
623 assert!(matches!(v, Value::Counter32(4294967295)));
624
625 let v = ValueType::Counter64
626 .parse_value("18446744073709551615")
627 .unwrap();
628 assert!(matches!(v, Value::Counter64(18446744073709551615)));
629 }
630}