Skip to main content

async_snmp/cli/
args.rs

1//! Command-line argument structures for async-snmp CLI tools.
2//!
3//! This module provides reusable clap argument structures for the `asnmp-*` CLI tools.
4
5use clap::{Parser, ValueEnum};
6use std::time::Duration;
7
8use crate::Version;
9use crate::client::{Auth, Backoff, Retry};
10use crate::format::hex;
11use crate::v3::{AuthProtocol, PrivProtocol};
12
13/// SNMP version for CLI argument parsing.
14#[derive(Debug, Clone, Copy, Default, ValueEnum)]
15pub enum SnmpVersion {
16    /// SNMPv1
17    #[value(name = "1")]
18    V1,
19    /// SNMPv2c (default)
20    #[default]
21    #[value(name = "2c")]
22    V2c,
23    /// SNMPv3
24    #[value(name = "3")]
25    V3,
26}
27
28impl From<SnmpVersion> for Version {
29    fn from(v: SnmpVersion) -> Self {
30        match v {
31            SnmpVersion::V1 => Version::V1,
32            SnmpVersion::V2c => Version::V2c,
33            SnmpVersion::V3 => Version::V3,
34        }
35    }
36}
37
38/// Output format for CLI tools.
39#[derive(Debug, Clone, Copy, Default, ValueEnum)]
40pub enum OutputFormat {
41    /// Human-readable output with type information.
42    #[default]
43    Human,
44    /// JSON output for scripting.
45    Json,
46    /// Raw tab-separated output for scripting.
47    Raw,
48}
49
50/// Backoff strategy for CLI argument parsing.
51#[derive(Debug, Clone, Copy, Default, ValueEnum)]
52pub enum BackoffStrategy {
53    /// No delay between retries (immediate retry on timeout).
54    #[default]
55    None,
56    /// Fixed delay between each retry.
57    Fixed,
58    /// Exponential backoff: delay doubles after each attempt.
59    Exponential,
60}
61
62/// Common arguments shared across all CLI tools.
63#[derive(Debug, Parser)]
64pub struct CommonArgs {
65    /// Target host or host:port (default port 161).
66    #[arg(value_name = "TARGET")]
67    pub target: String,
68
69    /// SNMP version: 1, 2c, or 3.
70    #[arg(short = 'v', long = "snmp-version", default_value = "2c")]
71    pub snmp_version: SnmpVersion,
72
73    /// Community string (v1/v2c).
74    #[arg(short = 'c', long = "community", default_value = "public")]
75    pub community: String,
76
77    /// Request timeout in seconds.
78    #[arg(short = 't', long = "timeout", default_value = "5")]
79    pub timeout: f64,
80
81    /// Retry count.
82    #[arg(short = 'r', long = "retries", default_value = "3")]
83    pub retries: u32,
84
85    /// Backoff strategy between retries: none, fixed, or exponential.
86    #[arg(long = "backoff", default_value = "none")]
87    pub backoff: BackoffStrategy,
88
89    /// Backoff delay in milliseconds (initial delay for exponential, fixed delay otherwise).
90    #[arg(long = "backoff-delay", default_value = "1000")]
91    pub backoff_delay: u64,
92
93    /// Maximum backoff delay in milliseconds (exponential only).
94    #[arg(long = "backoff-max", default_value = "5000")]
95    pub backoff_max: u64,
96
97    /// Jitter factor for exponential backoff (0.0-1.0, e.g., 0.25 means +/-25%).
98    #[arg(long = "backoff-jitter", default_value = "0.25")]
99    pub backoff_jitter: f64,
100}
101
102impl CommonArgs {
103    /// Get the timeout as a Duration.
104    pub fn timeout_duration(&self) -> Duration {
105        Duration::from_secs_f64(self.timeout)
106    }
107
108    /// Resolve the effective SNMP version, upgrading to V3 when a username is set.
109    pub fn effective_version(&self, v3: &V3Args) -> SnmpVersion {
110        if v3.is_v3() {
111            SnmpVersion::V3
112        } else {
113            self.snmp_version
114        }
115    }
116
117    /// Build a Retry configuration from the CLI arguments.
118    pub fn retry_config(&self) -> Retry {
119        let backoff = match self.backoff {
120            BackoffStrategy::None => Backoff::None,
121            BackoffStrategy::Fixed => Backoff::Fixed {
122                delay: Duration::from_millis(self.backoff_delay),
123            },
124            BackoffStrategy::Exponential => Backoff::Exponential {
125                initial: Duration::from_millis(self.backoff_delay),
126                max: Duration::from_millis(self.backoff_max),
127                jitter: self.backoff_jitter.clamp(0.0, 1.0),
128            },
129        };
130        Retry {
131            max_attempts: self.retries,
132            backoff,
133        }
134    }
135}
136
137/// SNMPv3 security arguments.
138#[derive(Debug, Parser)]
139pub struct V3Args {
140    /// Security name/username (implies -v 3).
141    #[arg(short = 'u', long = "username")]
142    pub username: Option<String>,
143
144    /// Authentication protocol: MD5, SHA, SHA-224, SHA-256, SHA-384, SHA-512.
145    #[arg(short = 'a', long = "auth-protocol")]
146    pub auth_protocol: Option<AuthProtocol>,
147
148    /// Authentication passphrase.
149    #[arg(short = 'A', long = "auth-password")]
150    pub auth_password: Option<String>,
151
152    /// Privacy protocol: DES, AES, AES-128, AES-192, AES-256.
153    #[arg(short = 'x', long = "priv-protocol")]
154    pub priv_protocol: Option<PrivProtocol>,
155
156    /// Privacy passphrase.
157    #[arg(short = 'X', long = "priv-password")]
158    pub priv_password: Option<String>,
159}
160
161impl V3Args {
162    /// Check if V3 mode is enabled (username provided).
163    pub fn is_v3(&self) -> bool {
164        self.username.is_some()
165    }
166
167    /// Build an Auth configuration from the V3 args and common args.
168    ///
169    /// If a username is provided, builds a USM auth configuration.
170    /// Otherwise, builds a community auth based on the version and community from common args.
171    pub fn auth(&self, common: &CommonArgs) -> Result<Auth, String> {
172        if let Some(ref username) = self.username {
173            let mut builder = Auth::usm(username);
174            if let Some(proto) = self.auth_protocol {
175                let pass = self
176                    .auth_password
177                    .as_ref()
178                    .ok_or("auth password required")?;
179                builder = builder.auth(proto, pass);
180            }
181            if let Some(proto) = self.priv_protocol {
182                let pass = self
183                    .priv_password
184                    .as_ref()
185                    .ok_or("priv password required")?;
186                builder = builder.privacy(proto, pass);
187            }
188            Ok(builder.into())
189        } else {
190            let community = &common.community;
191            Ok(match common.snmp_version {
192                SnmpVersion::V1 => Auth::v1(community),
193                _ => Auth::v2c(community),
194            })
195        }
196    }
197
198    /// Validate V3 arguments and return an error message if invalid.
199    pub fn validate(&self) -> Result<(), String> {
200        if let Some(ref _username) = self.username {
201            // If auth-protocol is specified, auth-password is required
202            if self.auth_protocol.is_some() && self.auth_password.is_none() {
203                return Err(
204                    "authentication password (-A) required when using auth protocol".into(),
205                );
206            }
207
208            // If priv-protocol is specified, priv-password is required
209            if self.priv_protocol.is_some() && self.priv_password.is_none() {
210                return Err("privacy password (-X) required when using priv protocol".into());
211            }
212
213            // Privacy requires authentication
214            if self.priv_protocol.is_some() && self.auth_protocol.is_none() {
215                return Err("authentication protocol (-a) required when using privacy".into());
216            }
217        }
218        Ok(())
219    }
220}
221
222/// Output control arguments.
223#[derive(Debug, Parser)]
224pub struct OutputArgs {
225    /// Output format: human, json, or raw.
226    #[arg(short = 'O', long = "output", default_value = "human")]
227    pub format: OutputFormat,
228
229    /// Show PDU structure and wire details.
230    #[arg(long = "verbose")]
231    pub verbose: bool,
232
233    /// Always display OctetString as hex.
234    #[arg(long = "hex")]
235    pub hex: bool,
236
237    /// Show request timing.
238    #[arg(long = "timing")]
239    pub timing: bool,
240
241    /// Disable well-known OID name hints.
242    #[arg(long = "no-hints")]
243    pub no_hints: bool,
244
245    /// Enable debug logging (async_snmp=debug).
246    #[arg(short = 'd', long = "debug")]
247    pub debug: bool,
248
249    /// Enable trace logging (async_snmp=trace).
250    #[arg(short = 'D', long = "trace")]
251    pub trace: bool,
252}
253
254impl OutputArgs {
255    /// Return elapsed as Some if timing output is enabled, None otherwise.
256    pub fn elapsed(&self, elapsed: Duration) -> Option<Duration> {
257        if self.timing { Some(elapsed) } else { None }
258    }
259
260    /// Initialize tracing based on debug/trace flags.
261    ///
262    /// Note: --verbose is handled separately and shows structured request/response info.
263    /// Use -d/--debug for library-level tracing.
264    pub fn init_tracing(&self) {
265        use tracing_subscriber::EnvFilter;
266
267        let filter = if self.trace {
268            "async_snmp=trace"
269        } else if self.debug {
270            "async_snmp=debug"
271        } else {
272            "async_snmp=warn"
273        };
274
275        let _ = tracing_subscriber::fmt()
276            .with_env_filter(EnvFilter::new(filter))
277            .with_writer(std::io::stderr)
278            .try_init();
279    }
280}
281
282/// Walk-specific arguments.
283#[derive(Debug, Parser)]
284pub struct WalkArgs {
285    /// Use GETNEXT instead of GETBULK.
286    #[arg(long = "getnext")]
287    pub getnext: bool,
288
289    /// GETBULK max-repetitions.
290    #[arg(long = "max-rep", default_value = "10")]
291    pub max_repetitions: u32,
292}
293
294/// Set-specific type specifier for values.
295#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
296pub enum ValueType {
297    /// INTEGER (i32)
298    #[value(name = "i")]
299    Integer,
300    /// Unsigned32/Gauge32 (u32)
301    #[value(name = "u")]
302    Unsigned,
303    /// STRING (OctetString from UTF-8)
304    #[value(name = "s")]
305    String,
306    /// Hex-STRING (OctetString from hex)
307    #[value(name = "x")]
308    HexString,
309    /// OBJECT IDENTIFIER
310    #[value(name = "o")]
311    Oid,
312    /// IpAddress
313    #[value(name = "a")]
314    IpAddress,
315    /// TimeTicks
316    #[value(name = "t")]
317    TimeTicks,
318    /// Counter32
319    #[value(name = "c")]
320    Counter32,
321    /// Counter64
322    #[value(name = "C")]
323    Counter64,
324}
325
326impl std::str::FromStr for ValueType {
327    type Err = String;
328
329    fn from_str(s: &str) -> Result<Self, Self::Err> {
330        match s {
331            "i" => Ok(ValueType::Integer),
332            "u" => Ok(ValueType::Unsigned),
333            "s" => Ok(ValueType::String),
334            "x" => Ok(ValueType::HexString),
335            "o" => Ok(ValueType::Oid),
336            "a" => Ok(ValueType::IpAddress),
337            "t" => Ok(ValueType::TimeTicks),
338            "c" => Ok(ValueType::Counter32),
339            "C" => Ok(ValueType::Counter64),
340            _ => Err(format!("invalid type specifier: {}", s)),
341        }
342    }
343}
344
345impl ValueType {
346    /// Parse a string value into an SNMP Value according to the type specifier.
347    pub fn parse_value(&self, s: &str) -> Result<crate::Value, String> {
348        use crate::{Oid, Value};
349
350        match self {
351            ValueType::Integer => {
352                let v: i32 = s
353                    .parse()
354                    .map_err(|_| format!("invalid integer value: {}", s))?;
355                Ok(Value::Integer(v))
356            }
357            ValueType::Unsigned => {
358                let v: u32 = s
359                    .parse()
360                    .map_err(|_| format!("invalid unsigned value: {}", s))?;
361                Ok(Value::Gauge32(v))
362            }
363            ValueType::String => Ok(Value::OctetString(s.as_bytes().to_vec().into())),
364            ValueType::HexString => {
365                let bytes = hex::decode_relaxed(s)
366                    .map_err(|_| "hex string must have even number of hex digits".to_string())?;
367                Ok(Value::OctetString(bytes.into()))
368            }
369            ValueType::Oid => {
370                let oid = Oid::parse(s).map_err(|e| format!("invalid OID value: {}", e))?;
371                Ok(Value::ObjectIdentifier(oid))
372            }
373            ValueType::IpAddress => {
374                let parts: Vec<&str> = s.split('.').collect();
375                if parts.len() != 4 {
376                    return Err(format!("invalid IP address: {}", s));
377                }
378                let mut bytes = [0u8; 4];
379                for (i, part) in parts.iter().enumerate() {
380                    bytes[i] = part
381                        .parse()
382                        .map_err(|_| format!("invalid IP address octet: {}", part))?;
383                }
384                Ok(Value::IpAddress(bytes))
385            }
386            ValueType::TimeTicks => {
387                let v: u32 = s
388                    .parse()
389                    .map_err(|_| format!("invalid timeticks value: {}", s))?;
390                Ok(Value::TimeTicks(v))
391            }
392            ValueType::Counter32 => {
393                let v: u32 = s
394                    .parse()
395                    .map_err(|_| format!("invalid counter32 value: {}", s))?;
396                Ok(Value::Counter32(v))
397            }
398            ValueType::Counter64 => {
399                let v: u64 = s
400                    .parse()
401                    .map_err(|_| format!("invalid counter64 value: {}", s))?;
402                Ok(Value::Counter64(v))
403            }
404        }
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411
412    #[test]
413    fn test_retry_config_none() {
414        let args = CommonArgs {
415            target: "192.168.1.1".to_string(),
416            snmp_version: SnmpVersion::V2c,
417            community: "public".to_string(),
418            timeout: 5.0,
419            retries: 3,
420            backoff: BackoffStrategy::None,
421            backoff_delay: 100,
422            backoff_max: 5000,
423            backoff_jitter: 0.25,
424        };
425        let retry = args.retry_config();
426        assert_eq!(retry.max_attempts, 3);
427        assert!(matches!(retry.backoff, Backoff::None));
428    }
429
430    #[test]
431    fn test_retry_config_fixed() {
432        let args = CommonArgs {
433            target: "192.168.1.1".to_string(),
434            snmp_version: SnmpVersion::V2c,
435            community: "public".to_string(),
436            timeout: 5.0,
437            retries: 5,
438            backoff: BackoffStrategy::Fixed,
439            backoff_delay: 200,
440            backoff_max: 5000,
441            backoff_jitter: 0.25,
442        };
443        let retry = args.retry_config();
444        assert_eq!(retry.max_attempts, 5);
445        assert!(matches!(
446            retry.backoff,
447            Backoff::Fixed { delay } if delay == Duration::from_millis(200)
448        ));
449    }
450
451    #[test]
452    fn test_retry_config_exponential() {
453        let args = CommonArgs {
454            target: "192.168.1.1".to_string(),
455            snmp_version: SnmpVersion::V2c,
456            community: "public".to_string(),
457            timeout: 5.0,
458            retries: 4,
459            backoff: BackoffStrategy::Exponential,
460            backoff_delay: 50,
461            backoff_max: 2000,
462            backoff_jitter: 0.1,
463        };
464        let retry = args.retry_config();
465        assert_eq!(retry.max_attempts, 4);
466        match retry.backoff {
467            Backoff::Exponential {
468                initial,
469                max,
470                jitter,
471            } => {
472                assert_eq!(initial, Duration::from_millis(50));
473                assert_eq!(max, Duration::from_millis(2000));
474                assert!((jitter - 0.1).abs() < f64::EPSILON);
475            }
476            _ => panic!("expected Exponential"),
477        }
478    }
479
480    #[test]
481    fn test_v3_args_validation() {
482        // No username - valid (not v3)
483        let args = V3Args {
484            username: None,
485            auth_protocol: None,
486            auth_password: None,
487            priv_protocol: None,
488            priv_password: None,
489        };
490        assert!(args.validate().is_ok());
491
492        // Username only - valid (noAuthNoPriv)
493        let args = V3Args {
494            username: Some("admin".to_string()),
495            auth_protocol: None,
496            auth_password: None,
497            priv_protocol: None,
498            priv_password: None,
499        };
500        assert!(args.validate().is_ok());
501
502        // Auth protocol without password - invalid
503        let args = V3Args {
504            username: Some("admin".to_string()),
505            auth_protocol: Some(AuthProtocol::Sha256),
506            auth_password: None,
507            priv_protocol: None,
508            priv_password: None,
509        };
510        assert!(args.validate().is_err());
511
512        // Privacy without auth - invalid
513        let args = V3Args {
514            username: Some("admin".to_string()),
515            auth_protocol: None,
516            auth_password: None,
517            priv_protocol: Some(PrivProtocol::Aes128),
518            priv_password: Some("pass".to_string()),
519        };
520        assert!(args.validate().is_err());
521
522        // SHA-1 with AES-256 - valid (key extension auto-applied)
523        let args = V3Args {
524            username: Some("admin".to_string()),
525            auth_protocol: Some(AuthProtocol::Sha1),
526            auth_password: Some("pass".to_string()),
527            priv_protocol: Some(PrivProtocol::Aes256),
528            priv_password: Some("pass".to_string()),
529        };
530        assert!(args.validate().is_ok());
531    }
532
533    #[test]
534    fn test_value_type_parse_integer() {
535        use crate::Value;
536        let v = ValueType::Integer.parse_value("42").unwrap();
537        assert!(matches!(v, Value::Integer(42)));
538
539        let v = ValueType::Integer.parse_value("-100").unwrap();
540        assert!(matches!(v, Value::Integer(-100)));
541
542        assert!(ValueType::Integer.parse_value("not_a_number").is_err());
543    }
544
545    #[test]
546    fn test_value_type_parse_unsigned() {
547        use crate::Value;
548        let v = ValueType::Unsigned.parse_value("42").unwrap();
549        assert!(matches!(v, Value::Gauge32(42)));
550
551        assert!(ValueType::Unsigned.parse_value("-1").is_err());
552    }
553
554    #[test]
555    fn test_value_type_parse_string() {
556        use crate::Value;
557        let v = ValueType::String.parse_value("hello world").unwrap();
558        if let Value::OctetString(bytes) = v {
559            assert_eq!(&*bytes, b"hello world");
560        } else {
561            panic!("expected OctetString");
562        }
563    }
564
565    #[test]
566    fn test_value_type_parse_hex_string() {
567        use crate::Value;
568
569        // Plain hex
570        let v = ValueType::HexString.parse_value("001a2b").unwrap();
571        if let Value::OctetString(bytes) = v {
572            assert_eq!(&*bytes, &[0x00, 0x1a, 0x2b]);
573        } else {
574            panic!("expected OctetString");
575        }
576
577        // With spaces
578        let v = ValueType::HexString.parse_value("00 1A 2B").unwrap();
579        if let Value::OctetString(bytes) = v {
580            assert_eq!(&*bytes, &[0x00, 0x1a, 0x2b]);
581        } else {
582            panic!("expected OctetString");
583        }
584
585        // Odd number of digits
586        assert!(ValueType::HexString.parse_value("001").is_err());
587    }
588
589    #[test]
590    fn test_value_type_parse_ip_address() {
591        use crate::Value;
592        let v = ValueType::IpAddress.parse_value("192.168.1.1").unwrap();
593        assert!(matches!(v, Value::IpAddress([192, 168, 1, 1])));
594
595        assert!(ValueType::IpAddress.parse_value("192.168.1").is_err());
596        assert!(ValueType::IpAddress.parse_value("256.1.1.1").is_err());
597    }
598
599    #[test]
600    fn test_value_type_parse_timeticks() {
601        use crate::Value;
602        let v = ValueType::TimeTicks.parse_value("12345678").unwrap();
603        assert!(matches!(v, Value::TimeTicks(12345678)));
604    }
605
606    #[test]
607    fn test_value_type_parse_counters() {
608        use crate::Value;
609
610        let v = ValueType::Counter32.parse_value("4294967295").unwrap();
611        assert!(matches!(v, Value::Counter32(4294967295)));
612
613        let v = ValueType::Counter64
614            .parse_value("18446744073709551615")
615            .unwrap();
616        assert!(matches!(v, Value::Counter64(18446744073709551615)));
617    }
618}