async_snmp/cli/
args.rs

1//! Command-line argument structures for async-snmp CLI tools.
2//!
3//! This module provides reusable clap argument structures for the `asnmp-*` CLI tools.
4
5use clap::{Parser, ValueEnum};
6use std::net::SocketAddr;
7use std::time::Duration;
8
9use crate::Version;
10use crate::client::{Auth, Backoff, Retry};
11use crate::v3::{AuthProtocol, PrivProtocol};
12
13/// SNMP version for CLI argument parsing.
14#[derive(Debug, Clone, Copy, Default, ValueEnum)]
15pub enum SnmpVersion {
16    /// SNMPv1
17    #[value(name = "1")]
18    V1,
19    /// SNMPv2c (default)
20    #[default]
21    #[value(name = "2c")]
22    V2c,
23    /// SNMPv3
24    #[value(name = "3")]
25    V3,
26}
27
28impl From<SnmpVersion> for Version {
29    fn from(v: SnmpVersion) -> Self {
30        match v {
31            SnmpVersion::V1 => Version::V1,
32            SnmpVersion::V2c => Version::V2c,
33            SnmpVersion::V3 => Version::V3,
34        }
35    }
36}
37
38/// Output format for CLI tools.
39#[derive(Debug, Clone, Copy, Default, ValueEnum)]
40pub enum OutputFormat {
41    /// Human-readable output with type information.
42    #[default]
43    Human,
44    /// JSON output for scripting.
45    Json,
46    /// Raw tab-separated output for scripting.
47    Raw,
48}
49
50/// Backoff strategy for CLI argument parsing.
51#[derive(Debug, Clone, Copy, Default, ValueEnum)]
52pub enum BackoffStrategy {
53    /// No delay between retries (immediate retry on timeout).
54    #[default]
55    None,
56    /// Fixed delay between each retry.
57    Fixed,
58    /// Exponential backoff: delay doubles after each attempt.
59    Exponential,
60}
61
62/// Common arguments shared across all CLI tools.
63#[derive(Debug, Parser)]
64pub struct CommonArgs {
65    /// Target host or host:port (default port 161).
66    #[arg(value_name = "TARGET")]
67    pub target: String,
68
69    /// SNMP version: 1, 2c, or 3.
70    #[arg(short = 'v', long = "snmp-version", default_value = "2c")]
71    pub snmp_version: SnmpVersion,
72
73    /// Community string (v1/v2c).
74    #[arg(short = 'c', long = "community", default_value = "public")]
75    pub community: String,
76
77    /// Request timeout in seconds.
78    #[arg(short = 't', long = "timeout", default_value = "5")]
79    pub timeout: f64,
80
81    /// Retry count.
82    #[arg(short = 'r', long = "retries", default_value = "3")]
83    pub retries: u32,
84
85    /// Backoff strategy between retries: none, fixed, or exponential.
86    #[arg(long = "backoff", default_value = "none")]
87    pub backoff: BackoffStrategy,
88
89    /// Backoff delay in milliseconds (initial delay for exponential, fixed delay otherwise).
90    #[arg(long = "backoff-delay", default_value = "1000")]
91    pub backoff_delay: u64,
92
93    /// Maximum backoff delay in milliseconds (exponential only).
94    #[arg(long = "backoff-max", default_value = "5000")]
95    pub backoff_max: u64,
96
97    /// Jitter factor for exponential backoff (0.0-1.0, e.g., 0.25 means +/-25%).
98    #[arg(long = "backoff-jitter", default_value = "0.25")]
99    pub backoff_jitter: f64,
100}
101
102impl CommonArgs {
103    /// Parse the target into a SocketAddr, defaulting to port 161.
104    pub fn target_addr(&self) -> Result<SocketAddr, String> {
105        // If the target doesn't contain a port, add the default SNMP port
106        let addr_str = if self.target.contains(':') {
107            self.target.clone()
108        } else {
109            format!("{}:161", self.target)
110        };
111
112        addr_str
113            .parse()
114            .or_else(|_| {
115                // Try to resolve as hostname
116                use std::net::ToSocketAddrs;
117                addr_str
118                    .to_socket_addrs()
119                    .map_err(|e| e.to_string())?
120                    .next()
121                    .ok_or_else(|| format!("could not resolve hostname: {}", self.target))
122            })
123            .map_err(|e| format!("invalid target '{}': {}", self.target, e))
124    }
125
126    /// Get the timeout as a Duration.
127    pub fn timeout_duration(&self) -> Duration {
128        Duration::from_secs_f64(self.timeout)
129    }
130
131    /// Build a Retry configuration from the CLI arguments.
132    pub fn retry_config(&self) -> Retry {
133        let backoff = match self.backoff {
134            BackoffStrategy::None => Backoff::None,
135            BackoffStrategy::Fixed => Backoff::Fixed {
136                delay: Duration::from_millis(self.backoff_delay),
137            },
138            BackoffStrategy::Exponential => Backoff::Exponential {
139                initial: Duration::from_millis(self.backoff_delay),
140                max: Duration::from_millis(self.backoff_max),
141                jitter: self.backoff_jitter.clamp(0.0, 1.0),
142            },
143        };
144        Retry {
145            max_attempts: self.retries,
146            backoff,
147        }
148    }
149}
150
151/// SNMPv3 security arguments.
152#[derive(Debug, Parser)]
153pub struct V3Args {
154    /// Security name/username (implies -v 3).
155    #[arg(short = 'u', long = "username")]
156    pub username: Option<String>,
157
158    /// Security level: noAuthNoPriv, authNoPriv, or authPriv.
159    #[arg(short = 'l', long = "level")]
160    pub level: Option<String>,
161
162    /// Authentication protocol: MD5, SHA, SHA-224, SHA-256, SHA-384, SHA-512.
163    #[arg(short = 'a', long = "auth-protocol")]
164    pub auth_protocol: Option<AuthProtocol>,
165
166    /// Authentication passphrase.
167    #[arg(short = 'A', long = "auth-password")]
168    pub auth_password: Option<String>,
169
170    /// Privacy protocol: DES, AES, AES-128, AES-192, AES-256.
171    #[arg(short = 'x', long = "priv-protocol")]
172    pub priv_protocol: Option<PrivProtocol>,
173
174    /// Privacy passphrase.
175    #[arg(short = 'X', long = "priv-password")]
176    pub priv_password: Option<String>,
177}
178
179impl V3Args {
180    /// Check if V3 mode is enabled (username provided).
181    pub fn is_v3(&self) -> bool {
182        self.username.is_some()
183    }
184
185    /// Build an Auth configuration from the V3 args and common args.
186    ///
187    /// If a username is provided, builds a USM auth configuration.
188    /// Otherwise, builds a community auth based on the version and community from common args.
189    pub fn auth(&self, common: &CommonArgs) -> Result<Auth, String> {
190        if let Some(ref username) = self.username {
191            let mut builder = Auth::usm(username);
192            if let Some(proto) = self.auth_protocol {
193                let pass = self
194                    .auth_password
195                    .as_ref()
196                    .ok_or("auth password required")?;
197                builder = builder.auth(proto, pass);
198            }
199            if let Some(proto) = self.priv_protocol {
200                let pass = self
201                    .priv_password
202                    .as_ref()
203                    .ok_or("priv password required")?;
204                builder = builder.privacy(proto, pass);
205            }
206            Ok(builder.into())
207        } else {
208            let community = &common.community;
209            Ok(match common.snmp_version {
210                SnmpVersion::V1 => Auth::v1(community),
211                _ => Auth::v2c(community),
212            })
213        }
214    }
215
216    /// Validate V3 arguments and return an error message if invalid.
217    pub fn validate(&self) -> Result<(), String> {
218        if let Some(ref _username) = self.username {
219            // If auth-protocol is specified, auth-password is required
220            if self.auth_protocol.is_some() && self.auth_password.is_none() {
221                return Err(
222                    "authentication password (-A) required when using auth protocol".into(),
223                );
224            }
225
226            // If priv-protocol is specified, priv-password is required
227            if self.priv_protocol.is_some() && self.priv_password.is_none() {
228                return Err("privacy password (-X) required when using priv protocol".into());
229            }
230
231            // Privacy requires authentication
232            if self.priv_protocol.is_some() && self.auth_protocol.is_none() {
233                return Err("authentication protocol (-a) required when using privacy".into());
234            }
235
236            // Check auth/priv compatibility
237            if let (Some(auth), Some(priv_proto)) = (self.auth_protocol, self.priv_protocol)
238                && !auth.is_compatible_with(priv_proto)
239            {
240                return Err(format!(
241                    "{} authentication does not produce enough key material for {} privacy; use {} or stronger",
242                    auth,
243                    priv_proto,
244                    priv_proto.min_auth_protocol()
245                ));
246            }
247        }
248        Ok(())
249    }
250}
251
252/// Output control arguments.
253#[derive(Debug, Parser)]
254pub struct OutputArgs {
255    /// Output format: human, json, or raw.
256    #[arg(short = 'O', long = "output", default_value = "human")]
257    pub format: OutputFormat,
258
259    /// Show PDU structure and wire details.
260    #[arg(long = "verbose")]
261    pub verbose: bool,
262
263    /// Always display OctetString as hex.
264    #[arg(long = "hex")]
265    pub hex: bool,
266
267    /// Show request timing.
268    #[arg(long = "timing")]
269    pub timing: bool,
270
271    /// Disable well-known OID name hints.
272    #[arg(long = "no-hints")]
273    pub no_hints: bool,
274
275    /// Enable debug logging (async_snmp=debug).
276    #[arg(short = 'd', long = "debug")]
277    pub debug: bool,
278
279    /// Enable trace logging (async_snmp=trace).
280    #[arg(short = 'D', long = "trace")]
281    pub trace: bool,
282}
283
284impl OutputArgs {
285    /// Initialize tracing based on debug/trace flags.
286    ///
287    /// Note: --verbose is handled separately and shows structured request/response info.
288    /// Use -d/--debug for library-level tracing.
289    pub fn init_tracing(&self) {
290        use tracing_subscriber::EnvFilter;
291
292        let filter = if self.trace {
293            "async_snmp=trace"
294        } else if self.debug {
295            "async_snmp=debug"
296        } else {
297            "async_snmp=warn"
298        };
299
300        let _ = tracing_subscriber::fmt()
301            .with_env_filter(EnvFilter::new(filter))
302            .with_writer(std::io::stderr)
303            .try_init();
304    }
305}
306
307/// Walk-specific arguments.
308#[derive(Debug, Parser)]
309pub struct WalkArgs {
310    /// Use GETNEXT instead of GETBULK.
311    #[arg(long = "getnext")]
312    pub getnext: bool,
313
314    /// GETBULK max-repetitions.
315    #[arg(long = "max-rep", default_value = "10")]
316    pub max_repetitions: u32,
317}
318
319/// Set-specific type specifier for values.
320#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
321pub enum ValueType {
322    /// INTEGER (i32)
323    #[value(name = "i")]
324    Integer,
325    /// Unsigned32/Gauge32 (u32)
326    #[value(name = "u")]
327    Unsigned,
328    /// STRING (OctetString from UTF-8)
329    #[value(name = "s")]
330    String,
331    /// Hex-STRING (OctetString from hex)
332    #[value(name = "x")]
333    HexString,
334    /// OBJECT IDENTIFIER
335    #[value(name = "o")]
336    Oid,
337    /// IpAddress
338    #[value(name = "a")]
339    IpAddress,
340    /// TimeTicks
341    #[value(name = "t")]
342    TimeTicks,
343    /// Counter32
344    #[value(name = "c")]
345    Counter32,
346    /// Counter64
347    #[value(name = "C")]
348    Counter64,
349}
350
351impl std::str::FromStr for ValueType {
352    type Err = String;
353
354    fn from_str(s: &str) -> Result<Self, Self::Err> {
355        match s {
356            "i" => Ok(ValueType::Integer),
357            "u" => Ok(ValueType::Unsigned),
358            "s" => Ok(ValueType::String),
359            "x" => Ok(ValueType::HexString),
360            "o" => Ok(ValueType::Oid),
361            "a" => Ok(ValueType::IpAddress),
362            "t" => Ok(ValueType::TimeTicks),
363            "c" => Ok(ValueType::Counter32),
364            "C" => Ok(ValueType::Counter64),
365            _ => Err(format!("invalid type specifier: {}", s)),
366        }
367    }
368}
369
370impl ValueType {
371    /// Parse a string value into an SNMP Value according to the type specifier.
372    pub fn parse_value(&self, s: &str) -> Result<crate::Value, String> {
373        use crate::{Oid, Value};
374
375        match self {
376            ValueType::Integer => {
377                let v: i32 = s
378                    .parse()
379                    .map_err(|_| format!("invalid integer value: {}", s))?;
380                Ok(Value::Integer(v))
381            }
382            ValueType::Unsigned => {
383                let v: u32 = s
384                    .parse()
385                    .map_err(|_| format!("invalid unsigned value: {}", s))?;
386                Ok(Value::Gauge32(v))
387            }
388            ValueType::String => Ok(Value::OctetString(s.as_bytes().to_vec().into())),
389            ValueType::HexString => {
390                let bytes = parse_hex_string(s)?;
391                Ok(Value::OctetString(bytes.into()))
392            }
393            ValueType::Oid => {
394                let oid = Oid::parse(s).map_err(|e| format!("invalid OID value: {}", e))?;
395                Ok(Value::ObjectIdentifier(oid))
396            }
397            ValueType::IpAddress => {
398                let parts: Vec<&str> = s.split('.').collect();
399                if parts.len() != 4 {
400                    return Err(format!("invalid IP address: {}", s));
401                }
402                let mut bytes = [0u8; 4];
403                for (i, part) in parts.iter().enumerate() {
404                    bytes[i] = part
405                        .parse()
406                        .map_err(|_| format!("invalid IP address octet: {}", part))?;
407                }
408                Ok(Value::IpAddress(bytes))
409            }
410            ValueType::TimeTicks => {
411                let v: u32 = s
412                    .parse()
413                    .map_err(|_| format!("invalid timeticks value: {}", s))?;
414                Ok(Value::TimeTicks(v))
415            }
416            ValueType::Counter32 => {
417                let v: u32 = s
418                    .parse()
419                    .map_err(|_| format!("invalid counter32 value: {}", s))?;
420                Ok(Value::Counter32(v))
421            }
422            ValueType::Counter64 => {
423                let v: u64 = s
424                    .parse()
425                    .map_err(|_| format!("invalid counter64 value: {}", s))?;
426                Ok(Value::Counter64(v))
427            }
428        }
429    }
430}
431
432/// Parse a hex string (with or without spaces/separators) into bytes.
433fn parse_hex_string(s: &str) -> Result<Vec<u8>, String> {
434    // Remove common separators: spaces, colons, dashes
435    let clean: String = s.chars().filter(|c| c.is_ascii_hexdigit()).collect();
436
437    if !clean.len().is_multiple_of(2) {
438        return Err("hex string must have even number of digits".into());
439    }
440
441    clean
442        .as_bytes()
443        .chunks(2)
444        .map(|chunk| {
445            let hex = std::str::from_utf8(chunk).unwrap();
446            u8::from_str_radix(hex, 16).map_err(|_| format!("invalid hex: {}", hex))
447        })
448        .collect()
449}
450
451#[cfg(test)]
452mod tests {
453    use super::*;
454
455    #[test]
456    fn test_target_addr_with_port() {
457        let args = CommonArgs {
458            target: "192.168.1.1:162".to_string(),
459            snmp_version: SnmpVersion::V2c,
460            community: "public".to_string(),
461            timeout: 5.0,
462            retries: 3,
463            backoff: BackoffStrategy::None,
464            backoff_delay: 100,
465            backoff_max: 5000,
466            backoff_jitter: 0.25,
467        };
468        let addr = args.target_addr().unwrap();
469        assert_eq!(addr.port(), 162);
470    }
471
472    #[test]
473    fn test_target_addr_default_port() {
474        let args = CommonArgs {
475            target: "192.168.1.1".to_string(),
476            snmp_version: SnmpVersion::V2c,
477            community: "public".to_string(),
478            timeout: 5.0,
479            retries: 3,
480            backoff: BackoffStrategy::None,
481            backoff_delay: 100,
482            backoff_max: 5000,
483            backoff_jitter: 0.25,
484        };
485        let addr = args.target_addr().unwrap();
486        assert_eq!(addr.port(), 161);
487    }
488
489    #[test]
490    fn test_retry_config_none() {
491        let args = CommonArgs {
492            target: "192.168.1.1".to_string(),
493            snmp_version: SnmpVersion::V2c,
494            community: "public".to_string(),
495            timeout: 5.0,
496            retries: 3,
497            backoff: BackoffStrategy::None,
498            backoff_delay: 100,
499            backoff_max: 5000,
500            backoff_jitter: 0.25,
501        };
502        let retry = args.retry_config();
503        assert_eq!(retry.max_attempts, 3);
504        assert!(matches!(retry.backoff, Backoff::None));
505    }
506
507    #[test]
508    fn test_retry_config_fixed() {
509        let args = CommonArgs {
510            target: "192.168.1.1".to_string(),
511            snmp_version: SnmpVersion::V2c,
512            community: "public".to_string(),
513            timeout: 5.0,
514            retries: 5,
515            backoff: BackoffStrategy::Fixed,
516            backoff_delay: 200,
517            backoff_max: 5000,
518            backoff_jitter: 0.25,
519        };
520        let retry = args.retry_config();
521        assert_eq!(retry.max_attempts, 5);
522        assert!(matches!(
523            retry.backoff,
524            Backoff::Fixed { delay } if delay == Duration::from_millis(200)
525        ));
526    }
527
528    #[test]
529    fn test_retry_config_exponential() {
530        let args = CommonArgs {
531            target: "192.168.1.1".to_string(),
532            snmp_version: SnmpVersion::V2c,
533            community: "public".to_string(),
534            timeout: 5.0,
535            retries: 4,
536            backoff: BackoffStrategy::Exponential,
537            backoff_delay: 50,
538            backoff_max: 2000,
539            backoff_jitter: 0.1,
540        };
541        let retry = args.retry_config();
542        assert_eq!(retry.max_attempts, 4);
543        match retry.backoff {
544            Backoff::Exponential {
545                initial,
546                max,
547                jitter,
548            } => {
549                assert_eq!(initial, Duration::from_millis(50));
550                assert_eq!(max, Duration::from_millis(2000));
551                assert!((jitter - 0.1).abs() < f64::EPSILON);
552            }
553            _ => panic!("expected Exponential"),
554        }
555    }
556
557    #[test]
558    fn test_v3_args_validation() {
559        // No username - valid (not v3)
560        let args = V3Args {
561            username: None,
562            level: None,
563            auth_protocol: None,
564            auth_password: None,
565            priv_protocol: None,
566            priv_password: None,
567        };
568        assert!(args.validate().is_ok());
569
570        // Username only - valid (noAuthNoPriv)
571        let args = V3Args {
572            username: Some("admin".to_string()),
573            level: None,
574            auth_protocol: None,
575            auth_password: None,
576            priv_protocol: None,
577            priv_password: None,
578        };
579        assert!(args.validate().is_ok());
580
581        // Auth protocol without password - invalid
582        let args = V3Args {
583            username: Some("admin".to_string()),
584            level: None,
585            auth_protocol: Some(AuthProtocol::Sha256),
586            auth_password: None,
587            priv_protocol: None,
588            priv_password: None,
589        };
590        assert!(args.validate().is_err());
591
592        // Privacy without auth - invalid
593        let args = V3Args {
594            username: Some("admin".to_string()),
595            level: None,
596            auth_protocol: None,
597            auth_password: None,
598            priv_protocol: Some(PrivProtocol::Aes128),
599            priv_password: Some("pass".to_string()),
600        };
601        assert!(args.validate().is_err());
602
603        // Incompatible auth/priv - invalid (SHA1 with AES256)
604        let args = V3Args {
605            username: Some("admin".to_string()),
606            level: None,
607            auth_protocol: Some(AuthProtocol::Sha1),
608            auth_password: Some("pass".to_string()),
609            priv_protocol: Some(PrivProtocol::Aes256),
610            priv_password: Some("pass".to_string()),
611        };
612        assert!(args.validate().is_err());
613    }
614
615    #[test]
616    fn test_value_type_parse_integer() {
617        use crate::Value;
618        let v = ValueType::Integer.parse_value("42").unwrap();
619        assert!(matches!(v, Value::Integer(42)));
620
621        let v = ValueType::Integer.parse_value("-100").unwrap();
622        assert!(matches!(v, Value::Integer(-100)));
623
624        assert!(ValueType::Integer.parse_value("not_a_number").is_err());
625    }
626
627    #[test]
628    fn test_value_type_parse_unsigned() {
629        use crate::Value;
630        let v = ValueType::Unsigned.parse_value("42").unwrap();
631        assert!(matches!(v, Value::Gauge32(42)));
632
633        assert!(ValueType::Unsigned.parse_value("-1").is_err());
634    }
635
636    #[test]
637    fn test_value_type_parse_string() {
638        use crate::Value;
639        let v = ValueType::String.parse_value("hello world").unwrap();
640        if let Value::OctetString(bytes) = v {
641            assert_eq!(&*bytes, b"hello world");
642        } else {
643            panic!("expected OctetString");
644        }
645    }
646
647    #[test]
648    fn test_value_type_parse_hex_string() {
649        use crate::Value;
650
651        // Plain hex
652        let v = ValueType::HexString.parse_value("001a2b").unwrap();
653        if let Value::OctetString(bytes) = v {
654            assert_eq!(&*bytes, &[0x00, 0x1a, 0x2b]);
655        } else {
656            panic!("expected OctetString");
657        }
658
659        // With spaces
660        let v = ValueType::HexString.parse_value("00 1A 2B").unwrap();
661        if let Value::OctetString(bytes) = v {
662            assert_eq!(&*bytes, &[0x00, 0x1a, 0x2b]);
663        } else {
664            panic!("expected OctetString");
665        }
666
667        // Odd number of digits
668        assert!(ValueType::HexString.parse_value("001").is_err());
669    }
670
671    #[test]
672    fn test_value_type_parse_ip_address() {
673        use crate::Value;
674        let v = ValueType::IpAddress.parse_value("192.168.1.1").unwrap();
675        assert!(matches!(v, Value::IpAddress([192, 168, 1, 1])));
676
677        assert!(ValueType::IpAddress.parse_value("192.168.1").is_err());
678        assert!(ValueType::IpAddress.parse_value("256.1.1.1").is_err());
679    }
680
681    #[test]
682    fn test_value_type_parse_timeticks() {
683        use crate::Value;
684        let v = ValueType::TimeTicks.parse_value("12345678").unwrap();
685        assert!(matches!(v, Value::TimeTicks(12345678)));
686    }
687
688    #[test]
689    fn test_value_type_parse_counters() {
690        use crate::Value;
691
692        let v = ValueType::Counter32.parse_value("4294967295").unwrap();
693        assert!(matches!(v, Value::Counter32(4294967295)));
694
695        let v = ValueType::Counter64
696            .parse_value("18446744073709551615")
697            .unwrap();
698        assert!(matches!(v, Value::Counter64(18446744073709551615)));
699    }
700}