async_snmp/cli/
args.rs

1//! Command-line argument structures for async-snmp CLI tools.
2//!
3//! This module provides reusable clap argument structures for the `asnmp-*` CLI tools.
4
5use clap::{Parser, ValueEnum};
6use std::net::SocketAddr;
7use std::time::Duration;
8
9use crate::Version;
10use crate::client::{Auth, Backoff, Retry};
11use crate::v3::{AuthProtocol, PrivProtocol};
12
13/// SNMP version for CLI argument parsing.
14#[derive(Debug, Clone, Copy, Default, ValueEnum)]
15pub enum SnmpVersion {
16    /// SNMPv1
17    #[value(name = "1")]
18    V1,
19    /// SNMPv2c (default)
20    #[default]
21    #[value(name = "2c")]
22    V2c,
23    /// SNMPv3
24    #[value(name = "3")]
25    V3,
26}
27
28impl From<SnmpVersion> for Version {
29    fn from(v: SnmpVersion) -> Self {
30        match v {
31            SnmpVersion::V1 => Version::V1,
32            SnmpVersion::V2c => Version::V2c,
33            SnmpVersion::V3 => Version::V3,
34        }
35    }
36}
37
38/// Output format for CLI tools.
39#[derive(Debug, Clone, Copy, Default, ValueEnum)]
40pub enum OutputFormat {
41    /// Human-readable output with type information.
42    #[default]
43    Human,
44    /// JSON output for scripting.
45    Json,
46    /// Raw tab-separated output for scripting.
47    Raw,
48}
49
50/// Backoff strategy for CLI argument parsing.
51#[derive(Debug, Clone, Copy, Default, ValueEnum)]
52pub enum BackoffStrategy {
53    /// No delay between retries (immediate retry on timeout).
54    #[default]
55    None,
56    /// Fixed delay between each retry.
57    Fixed,
58    /// Exponential backoff: delay doubles after each attempt.
59    Exponential,
60}
61
62/// Common arguments shared across all CLI tools.
63#[derive(Debug, Parser)]
64pub struct CommonArgs {
65    /// Target host or host:port (default port 161).
66    #[arg(value_name = "TARGET")]
67    pub target: String,
68
69    /// SNMP version: 1, 2c, or 3.
70    #[arg(short = 'v', long = "snmp-version", default_value = "2c")]
71    pub snmp_version: SnmpVersion,
72
73    /// Community string (v1/v2c).
74    #[arg(short = 'c', long = "community", default_value = "public")]
75    pub community: String,
76
77    /// Request timeout in seconds.
78    #[arg(short = 't', long = "timeout", default_value = "5")]
79    pub timeout: f64,
80
81    /// Retry count.
82    #[arg(short = 'r', long = "retries", default_value = "3")]
83    pub retries: u32,
84
85    /// Backoff strategy between retries: none, fixed, or exponential.
86    #[arg(long = "backoff", default_value = "none")]
87    pub backoff: BackoffStrategy,
88
89    /// Backoff delay in milliseconds (initial delay for exponential, fixed delay otherwise).
90    #[arg(long = "backoff-delay", default_value = "1000")]
91    pub backoff_delay: u64,
92
93    /// Maximum backoff delay in milliseconds (exponential only).
94    #[arg(long = "backoff-max", default_value = "5000")]
95    pub backoff_max: u64,
96
97    /// Jitter factor for exponential backoff (0.0-1.0, e.g., 0.25 means +/-25%).
98    #[arg(long = "backoff-jitter", default_value = "0.25")]
99    pub backoff_jitter: f64,
100}
101
102impl CommonArgs {
103    /// Parse the target into a SocketAddr, defaulting to port 161.
104    pub fn target_addr(&self) -> Result<SocketAddr, String> {
105        // If the target doesn't contain a port, add the default SNMP port
106        let addr_str = if self.target.contains(':') {
107            self.target.clone()
108        } else {
109            format!("{}:161", self.target)
110        };
111
112        addr_str
113            .parse()
114            .or_else(|_| {
115                // Try to resolve as hostname
116                use std::net::ToSocketAddrs;
117                addr_str
118                    .to_socket_addrs()
119                    .map_err(|e| e.to_string())?
120                    .next()
121                    .ok_or_else(|| format!("could not resolve hostname: {}", self.target))
122            })
123            .map_err(|e| format!("invalid target '{}': {}", self.target, e))
124    }
125
126    /// Get the timeout as a Duration.
127    pub fn timeout_duration(&self) -> Duration {
128        Duration::from_secs_f64(self.timeout)
129    }
130
131    /// Build a Retry configuration from the CLI arguments.
132    pub fn retry_config(&self) -> Retry {
133        let backoff = match self.backoff {
134            BackoffStrategy::None => Backoff::None,
135            BackoffStrategy::Fixed => Backoff::Fixed {
136                delay: Duration::from_millis(self.backoff_delay),
137            },
138            BackoffStrategy::Exponential => Backoff::Exponential {
139                initial: Duration::from_millis(self.backoff_delay),
140                max: Duration::from_millis(self.backoff_max),
141                jitter: self.backoff_jitter.clamp(0.0, 1.0),
142            },
143        };
144        Retry {
145            max_attempts: self.retries,
146            backoff,
147        }
148    }
149}
150
151/// SNMPv3 security arguments.
152#[derive(Debug, Parser)]
153pub struct V3Args {
154    /// Security name/username (implies -v 3).
155    #[arg(short = 'u', long = "username")]
156    pub username: Option<String>,
157
158    /// Security level: noAuthNoPriv, authNoPriv, or authPriv.
159    #[arg(short = 'l', long = "level")]
160    pub level: Option<String>,
161
162    /// Authentication protocol: MD5, SHA, SHA-224, SHA-256, SHA-384, SHA-512.
163    #[arg(short = 'a', long = "auth-protocol")]
164    pub auth_protocol: Option<AuthProtocol>,
165
166    /// Authentication passphrase.
167    #[arg(short = 'A', long = "auth-password")]
168    pub auth_password: Option<String>,
169
170    /// Privacy protocol: DES, AES, AES-128, AES-192, AES-256.
171    #[arg(short = 'x', long = "priv-protocol")]
172    pub priv_protocol: Option<PrivProtocol>,
173
174    /// Privacy passphrase.
175    #[arg(short = 'X', long = "priv-password")]
176    pub priv_password: Option<String>,
177}
178
179impl V3Args {
180    /// Check if V3 mode is enabled (username provided).
181    pub fn is_v3(&self) -> bool {
182        self.username.is_some()
183    }
184
185    /// Build an Auth configuration from the V3 args and common args.
186    ///
187    /// If a username is provided, builds a USM auth configuration.
188    /// Otherwise, builds a community auth based on the version and community from common args.
189    pub fn auth(&self, common: &CommonArgs) -> Result<Auth, String> {
190        if let Some(ref username) = self.username {
191            let mut builder = Auth::usm(username);
192            if let Some(proto) = self.auth_protocol {
193                let pass = self
194                    .auth_password
195                    .as_ref()
196                    .ok_or("auth password required")?;
197                builder = builder.auth(proto, pass);
198            }
199            if let Some(proto) = self.priv_protocol {
200                let pass = self
201                    .priv_password
202                    .as_ref()
203                    .ok_or("priv password required")?;
204                builder = builder.privacy(proto, pass);
205            }
206            Ok(builder.into())
207        } else {
208            let community = &common.community;
209            Ok(match common.snmp_version {
210                SnmpVersion::V1 => Auth::v1(community),
211                _ => Auth::v2c(community),
212            })
213        }
214    }
215
216    /// Validate V3 arguments and return an error message if invalid.
217    pub fn validate(&self) -> Result<(), String> {
218        if let Some(ref _username) = self.username {
219            // If auth-protocol is specified, auth-password is required
220            if self.auth_protocol.is_some() && self.auth_password.is_none() {
221                return Err(
222                    "authentication password (-A) required when using auth protocol".into(),
223                );
224            }
225
226            // If priv-protocol is specified, priv-password is required
227            if self.priv_protocol.is_some() && self.priv_password.is_none() {
228                return Err("privacy password (-X) required when using priv protocol".into());
229            }
230
231            // Privacy requires authentication
232            if self.priv_protocol.is_some() && self.auth_protocol.is_none() {
233                return Err("authentication protocol (-a) required when using privacy".into());
234            }
235        }
236        Ok(())
237    }
238}
239
240/// Output control arguments.
241#[derive(Debug, Parser)]
242pub struct OutputArgs {
243    /// Output format: human, json, or raw.
244    #[arg(short = 'O', long = "output", default_value = "human")]
245    pub format: OutputFormat,
246
247    /// Show PDU structure and wire details.
248    #[arg(long = "verbose")]
249    pub verbose: bool,
250
251    /// Always display OctetString as hex.
252    #[arg(long = "hex")]
253    pub hex: bool,
254
255    /// Show request timing.
256    #[arg(long = "timing")]
257    pub timing: bool,
258
259    /// Disable well-known OID name hints.
260    #[arg(long = "no-hints")]
261    pub no_hints: bool,
262
263    /// Enable debug logging (async_snmp=debug).
264    #[arg(short = 'd', long = "debug")]
265    pub debug: bool,
266
267    /// Enable trace logging (async_snmp=trace).
268    #[arg(short = 'D', long = "trace")]
269    pub trace: bool,
270}
271
272impl OutputArgs {
273    /// Initialize tracing based on debug/trace flags.
274    ///
275    /// Note: --verbose is handled separately and shows structured request/response info.
276    /// Use -d/--debug for library-level tracing.
277    pub fn init_tracing(&self) {
278        use tracing_subscriber::EnvFilter;
279
280        let filter = if self.trace {
281            "async_snmp=trace"
282        } else if self.debug {
283            "async_snmp=debug"
284        } else {
285            "async_snmp=warn"
286        };
287
288        let _ = tracing_subscriber::fmt()
289            .with_env_filter(EnvFilter::new(filter))
290            .with_writer(std::io::stderr)
291            .try_init();
292    }
293}
294
295/// Walk-specific arguments.
296#[derive(Debug, Parser)]
297pub struct WalkArgs {
298    /// Use GETNEXT instead of GETBULK.
299    #[arg(long = "getnext")]
300    pub getnext: bool,
301
302    /// GETBULK max-repetitions.
303    #[arg(long = "max-rep", default_value = "10")]
304    pub max_repetitions: u32,
305}
306
307/// Set-specific type specifier for values.
308#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
309pub enum ValueType {
310    /// INTEGER (i32)
311    #[value(name = "i")]
312    Integer,
313    /// Unsigned32/Gauge32 (u32)
314    #[value(name = "u")]
315    Unsigned,
316    /// STRING (OctetString from UTF-8)
317    #[value(name = "s")]
318    String,
319    /// Hex-STRING (OctetString from hex)
320    #[value(name = "x")]
321    HexString,
322    /// OBJECT IDENTIFIER
323    #[value(name = "o")]
324    Oid,
325    /// IpAddress
326    #[value(name = "a")]
327    IpAddress,
328    /// TimeTicks
329    #[value(name = "t")]
330    TimeTicks,
331    /// Counter32
332    #[value(name = "c")]
333    Counter32,
334    /// Counter64
335    #[value(name = "C")]
336    Counter64,
337}
338
339impl std::str::FromStr for ValueType {
340    type Err = String;
341
342    fn from_str(s: &str) -> Result<Self, Self::Err> {
343        match s {
344            "i" => Ok(ValueType::Integer),
345            "u" => Ok(ValueType::Unsigned),
346            "s" => Ok(ValueType::String),
347            "x" => Ok(ValueType::HexString),
348            "o" => Ok(ValueType::Oid),
349            "a" => Ok(ValueType::IpAddress),
350            "t" => Ok(ValueType::TimeTicks),
351            "c" => Ok(ValueType::Counter32),
352            "C" => Ok(ValueType::Counter64),
353            _ => Err(format!("invalid type specifier: {}", s)),
354        }
355    }
356}
357
358impl ValueType {
359    /// Parse a string value into an SNMP Value according to the type specifier.
360    pub fn parse_value(&self, s: &str) -> Result<crate::Value, String> {
361        use crate::{Oid, Value};
362
363        match self {
364            ValueType::Integer => {
365                let v: i32 = s
366                    .parse()
367                    .map_err(|_| format!("invalid integer value: {}", s))?;
368                Ok(Value::Integer(v))
369            }
370            ValueType::Unsigned => {
371                let v: u32 = s
372                    .parse()
373                    .map_err(|_| format!("invalid unsigned value: {}", s))?;
374                Ok(Value::Gauge32(v))
375            }
376            ValueType::String => Ok(Value::OctetString(s.as_bytes().to_vec().into())),
377            ValueType::HexString => {
378                let bytes = parse_hex_string(s)?;
379                Ok(Value::OctetString(bytes.into()))
380            }
381            ValueType::Oid => {
382                let oid = Oid::parse(s).map_err(|e| format!("invalid OID value: {}", e))?;
383                Ok(Value::ObjectIdentifier(oid))
384            }
385            ValueType::IpAddress => {
386                let parts: Vec<&str> = s.split('.').collect();
387                if parts.len() != 4 {
388                    return Err(format!("invalid IP address: {}", s));
389                }
390                let mut bytes = [0u8; 4];
391                for (i, part) in parts.iter().enumerate() {
392                    bytes[i] = part
393                        .parse()
394                        .map_err(|_| format!("invalid IP address octet: {}", part))?;
395                }
396                Ok(Value::IpAddress(bytes))
397            }
398            ValueType::TimeTicks => {
399                let v: u32 = s
400                    .parse()
401                    .map_err(|_| format!("invalid timeticks value: {}", s))?;
402                Ok(Value::TimeTicks(v))
403            }
404            ValueType::Counter32 => {
405                let v: u32 = s
406                    .parse()
407                    .map_err(|_| format!("invalid counter32 value: {}", s))?;
408                Ok(Value::Counter32(v))
409            }
410            ValueType::Counter64 => {
411                let v: u64 = s
412                    .parse()
413                    .map_err(|_| format!("invalid counter64 value: {}", s))?;
414                Ok(Value::Counter64(v))
415            }
416        }
417    }
418}
419
420/// Parse a hex string (with or without spaces/separators) into bytes.
421fn parse_hex_string(s: &str) -> Result<Vec<u8>, String> {
422    // Remove common separators: spaces, colons, dashes
423    let clean: String = s.chars().filter(|c| c.is_ascii_hexdigit()).collect();
424
425    if !clean.len().is_multiple_of(2) {
426        return Err("hex string must have even number of digits".into());
427    }
428
429    clean
430        .as_bytes()
431        .chunks(2)
432        .map(|chunk| {
433            let hex = std::str::from_utf8(chunk).unwrap();
434            u8::from_str_radix(hex, 16).map_err(|_| format!("invalid hex: {}", hex))
435        })
436        .collect()
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442
443    #[test]
444    fn test_target_addr_with_port() {
445        let args = CommonArgs {
446            target: "192.168.1.1:162".to_string(),
447            snmp_version: SnmpVersion::V2c,
448            community: "public".to_string(),
449            timeout: 5.0,
450            retries: 3,
451            backoff: BackoffStrategy::None,
452            backoff_delay: 100,
453            backoff_max: 5000,
454            backoff_jitter: 0.25,
455        };
456        let addr = args.target_addr().unwrap();
457        assert_eq!(addr.port(), 162);
458    }
459
460    #[test]
461    fn test_target_addr_default_port() {
462        let args = CommonArgs {
463            target: "192.168.1.1".to_string(),
464            snmp_version: SnmpVersion::V2c,
465            community: "public".to_string(),
466            timeout: 5.0,
467            retries: 3,
468            backoff: BackoffStrategy::None,
469            backoff_delay: 100,
470            backoff_max: 5000,
471            backoff_jitter: 0.25,
472        };
473        let addr = args.target_addr().unwrap();
474        assert_eq!(addr.port(), 161);
475    }
476
477    #[test]
478    fn test_retry_config_none() {
479        let args = CommonArgs {
480            target: "192.168.1.1".to_string(),
481            snmp_version: SnmpVersion::V2c,
482            community: "public".to_string(),
483            timeout: 5.0,
484            retries: 3,
485            backoff: BackoffStrategy::None,
486            backoff_delay: 100,
487            backoff_max: 5000,
488            backoff_jitter: 0.25,
489        };
490        let retry = args.retry_config();
491        assert_eq!(retry.max_attempts, 3);
492        assert!(matches!(retry.backoff, Backoff::None));
493    }
494
495    #[test]
496    fn test_retry_config_fixed() {
497        let args = CommonArgs {
498            target: "192.168.1.1".to_string(),
499            snmp_version: SnmpVersion::V2c,
500            community: "public".to_string(),
501            timeout: 5.0,
502            retries: 5,
503            backoff: BackoffStrategy::Fixed,
504            backoff_delay: 200,
505            backoff_max: 5000,
506            backoff_jitter: 0.25,
507        };
508        let retry = args.retry_config();
509        assert_eq!(retry.max_attempts, 5);
510        assert!(matches!(
511            retry.backoff,
512            Backoff::Fixed { delay } if delay == Duration::from_millis(200)
513        ));
514    }
515
516    #[test]
517    fn test_retry_config_exponential() {
518        let args = CommonArgs {
519            target: "192.168.1.1".to_string(),
520            snmp_version: SnmpVersion::V2c,
521            community: "public".to_string(),
522            timeout: 5.0,
523            retries: 4,
524            backoff: BackoffStrategy::Exponential,
525            backoff_delay: 50,
526            backoff_max: 2000,
527            backoff_jitter: 0.1,
528        };
529        let retry = args.retry_config();
530        assert_eq!(retry.max_attempts, 4);
531        match retry.backoff {
532            Backoff::Exponential {
533                initial,
534                max,
535                jitter,
536            } => {
537                assert_eq!(initial, Duration::from_millis(50));
538                assert_eq!(max, Duration::from_millis(2000));
539                assert!((jitter - 0.1).abs() < f64::EPSILON);
540            }
541            _ => panic!("expected Exponential"),
542        }
543    }
544
545    #[test]
546    fn test_v3_args_validation() {
547        // No username - valid (not v3)
548        let args = V3Args {
549            username: None,
550            level: None,
551            auth_protocol: None,
552            auth_password: None,
553            priv_protocol: None,
554            priv_password: None,
555        };
556        assert!(args.validate().is_ok());
557
558        // Username only - valid (noAuthNoPriv)
559        let args = V3Args {
560            username: Some("admin".to_string()),
561            level: None,
562            auth_protocol: None,
563            auth_password: None,
564            priv_protocol: None,
565            priv_password: None,
566        };
567        assert!(args.validate().is_ok());
568
569        // Auth protocol without password - invalid
570        let args = V3Args {
571            username: Some("admin".to_string()),
572            level: None,
573            auth_protocol: Some(AuthProtocol::Sha256),
574            auth_password: None,
575            priv_protocol: None,
576            priv_password: None,
577        };
578        assert!(args.validate().is_err());
579
580        // Privacy without auth - invalid
581        let args = V3Args {
582            username: Some("admin".to_string()),
583            level: None,
584            auth_protocol: None,
585            auth_password: None,
586            priv_protocol: Some(PrivProtocol::Aes128),
587            priv_password: Some("pass".to_string()),
588        };
589        assert!(args.validate().is_err());
590
591        // SHA-1 with AES-256 - valid (key extension auto-applied)
592        let args = V3Args {
593            username: Some("admin".to_string()),
594            level: None,
595            auth_protocol: Some(AuthProtocol::Sha1),
596            auth_password: Some("pass".to_string()),
597            priv_protocol: Some(PrivProtocol::Aes256),
598            priv_password: Some("pass".to_string()),
599        };
600        assert!(args.validate().is_ok());
601    }
602
603    #[test]
604    fn test_value_type_parse_integer() {
605        use crate::Value;
606        let v = ValueType::Integer.parse_value("42").unwrap();
607        assert!(matches!(v, Value::Integer(42)));
608
609        let v = ValueType::Integer.parse_value("-100").unwrap();
610        assert!(matches!(v, Value::Integer(-100)));
611
612        assert!(ValueType::Integer.parse_value("not_a_number").is_err());
613    }
614
615    #[test]
616    fn test_value_type_parse_unsigned() {
617        use crate::Value;
618        let v = ValueType::Unsigned.parse_value("42").unwrap();
619        assert!(matches!(v, Value::Gauge32(42)));
620
621        assert!(ValueType::Unsigned.parse_value("-1").is_err());
622    }
623
624    #[test]
625    fn test_value_type_parse_string() {
626        use crate::Value;
627        let v = ValueType::String.parse_value("hello world").unwrap();
628        if let Value::OctetString(bytes) = v {
629            assert_eq!(&*bytes, b"hello world");
630        } else {
631            panic!("expected OctetString");
632        }
633    }
634
635    #[test]
636    fn test_value_type_parse_hex_string() {
637        use crate::Value;
638
639        // Plain hex
640        let v = ValueType::HexString.parse_value("001a2b").unwrap();
641        if let Value::OctetString(bytes) = v {
642            assert_eq!(&*bytes, &[0x00, 0x1a, 0x2b]);
643        } else {
644            panic!("expected OctetString");
645        }
646
647        // With spaces
648        let v = ValueType::HexString.parse_value("00 1A 2B").unwrap();
649        if let Value::OctetString(bytes) = v {
650            assert_eq!(&*bytes, &[0x00, 0x1a, 0x2b]);
651        } else {
652            panic!("expected OctetString");
653        }
654
655        // Odd number of digits
656        assert!(ValueType::HexString.parse_value("001").is_err());
657    }
658
659    #[test]
660    fn test_value_type_parse_ip_address() {
661        use crate::Value;
662        let v = ValueType::IpAddress.parse_value("192.168.1.1").unwrap();
663        assert!(matches!(v, Value::IpAddress([192, 168, 1, 1])));
664
665        assert!(ValueType::IpAddress.parse_value("192.168.1").is_err());
666        assert!(ValueType::IpAddress.parse_value("256.1.1.1").is_err());
667    }
668
669    #[test]
670    fn test_value_type_parse_timeticks() {
671        use crate::Value;
672        let v = ValueType::TimeTicks.parse_value("12345678").unwrap();
673        assert!(matches!(v, Value::TimeTicks(12345678)));
674    }
675
676    #[test]
677    fn test_value_type_parse_counters() {
678        use crate::Value;
679
680        let v = ValueType::Counter32.parse_value("4294967295").unwrap();
681        assert!(matches!(v, Value::Counter32(4294967295)));
682
683        let v = ValueType::Counter64
684            .parse_value("18446744073709551615")
685            .unwrap();
686        assert!(matches!(v, Value::Counter64(18446744073709551615)));
687    }
688}