Skip to main content

async_snmp/cli/
args.rs

1//! Command-line argument structures for async-snmp CLI tools.
2//!
3//! This module provides reusable clap argument structures for the `asnmp-*` CLI tools.
4
5use clap::{Parser, ValueEnum};
6use std::time::Duration;
7
8use crate::Version;
9use crate::client::{Auth, Backoff, Retry};
10use crate::v3::{AuthProtocol, PrivProtocol};
11
12/// SNMP version for CLI argument parsing.
13#[derive(Debug, Clone, Copy, Default, ValueEnum)]
14pub enum SnmpVersion {
15    /// SNMPv1
16    #[value(name = "1")]
17    V1,
18    /// SNMPv2c (default)
19    #[default]
20    #[value(name = "2c")]
21    V2c,
22    /// SNMPv3
23    #[value(name = "3")]
24    V3,
25}
26
27impl From<SnmpVersion> for Version {
28    fn from(v: SnmpVersion) -> Self {
29        match v {
30            SnmpVersion::V1 => Version::V1,
31            SnmpVersion::V2c => Version::V2c,
32            SnmpVersion::V3 => Version::V3,
33        }
34    }
35}
36
37/// Output format for CLI tools.
38#[derive(Debug, Clone, Copy, Default, ValueEnum)]
39pub enum OutputFormat {
40    /// Human-readable output with type information.
41    #[default]
42    Human,
43    /// JSON output for scripting.
44    Json,
45    /// Raw tab-separated output for scripting.
46    Raw,
47}
48
49/// Backoff strategy for CLI argument parsing.
50#[derive(Debug, Clone, Copy, Default, ValueEnum)]
51pub enum BackoffStrategy {
52    /// No delay between retries (immediate retry on timeout).
53    #[default]
54    None,
55    /// Fixed delay between each retry.
56    Fixed,
57    /// Exponential backoff: delay doubles after each attempt.
58    Exponential,
59}
60
61/// Common arguments shared across all CLI tools.
62#[derive(Debug, Parser)]
63pub struct CommonArgs {
64    /// Target host or host:port (default port 161).
65    #[arg(value_name = "TARGET")]
66    pub target: String,
67
68    /// SNMP version: 1, 2c, or 3.
69    #[arg(short = 'v', long = "snmp-version", default_value = "2c")]
70    pub snmp_version: SnmpVersion,
71
72    /// Community string (v1/v2c).
73    #[arg(short = 'c', long = "community", default_value = "public")]
74    pub community: String,
75
76    /// Request timeout in seconds.
77    #[arg(short = 't', long = "timeout", default_value = "5")]
78    pub timeout: f64,
79
80    /// Retry count.
81    #[arg(short = 'r', long = "retries", default_value = "3")]
82    pub retries: u32,
83
84    /// Backoff strategy between retries: none, fixed, or exponential.
85    #[arg(long = "backoff", default_value = "none")]
86    pub backoff: BackoffStrategy,
87
88    /// Backoff delay in milliseconds (initial delay for exponential, fixed delay otherwise).
89    #[arg(long = "backoff-delay", default_value = "1000")]
90    pub backoff_delay: u64,
91
92    /// Maximum backoff delay in milliseconds (exponential only).
93    #[arg(long = "backoff-max", default_value = "5000")]
94    pub backoff_max: u64,
95
96    /// Jitter factor for exponential backoff (0.0-1.0, e.g., 0.25 means +/-25%).
97    #[arg(long = "backoff-jitter", default_value = "0.25")]
98    pub backoff_jitter: f64,
99}
100
101impl CommonArgs {
102    /// Get the timeout as a Duration.
103    pub fn timeout_duration(&self) -> Duration {
104        Duration::from_secs_f64(self.timeout)
105    }
106
107    /// Build a Retry configuration from the CLI arguments.
108    pub fn retry_config(&self) -> Retry {
109        let backoff = match self.backoff {
110            BackoffStrategy::None => Backoff::None,
111            BackoffStrategy::Fixed => Backoff::Fixed {
112                delay: Duration::from_millis(self.backoff_delay),
113            },
114            BackoffStrategy::Exponential => Backoff::Exponential {
115                initial: Duration::from_millis(self.backoff_delay),
116                max: Duration::from_millis(self.backoff_max),
117                jitter: self.backoff_jitter.clamp(0.0, 1.0),
118            },
119        };
120        Retry {
121            max_attempts: self.retries,
122            backoff,
123        }
124    }
125}
126
127/// SNMPv3 security arguments.
128#[derive(Debug, Parser)]
129pub struct V3Args {
130    /// Security name/username (implies -v 3).
131    #[arg(short = 'u', long = "username")]
132    pub username: Option<String>,
133
134    /// Security level: noAuthNoPriv, authNoPriv, or authPriv.
135    #[arg(short = 'l', long = "level")]
136    pub level: Option<String>,
137
138    /// Authentication protocol: MD5, SHA, SHA-224, SHA-256, SHA-384, SHA-512.
139    #[arg(short = 'a', long = "auth-protocol")]
140    pub auth_protocol: Option<AuthProtocol>,
141
142    /// Authentication passphrase.
143    #[arg(short = 'A', long = "auth-password")]
144    pub auth_password: Option<String>,
145
146    /// Privacy protocol: DES, AES, AES-128, AES-192, AES-256.
147    #[arg(short = 'x', long = "priv-protocol")]
148    pub priv_protocol: Option<PrivProtocol>,
149
150    /// Privacy passphrase.
151    #[arg(short = 'X', long = "priv-password")]
152    pub priv_password: Option<String>,
153}
154
155impl V3Args {
156    /// Check if V3 mode is enabled (username provided).
157    pub fn is_v3(&self) -> bool {
158        self.username.is_some()
159    }
160
161    /// Build an Auth configuration from the V3 args and common args.
162    ///
163    /// If a username is provided, builds a USM auth configuration.
164    /// Otherwise, builds a community auth based on the version and community from common args.
165    pub fn auth(&self, common: &CommonArgs) -> Result<Auth, String> {
166        if let Some(ref username) = self.username {
167            let mut builder = Auth::usm(username);
168            if let Some(proto) = self.auth_protocol {
169                let pass = self
170                    .auth_password
171                    .as_ref()
172                    .ok_or("auth password required")?;
173                builder = builder.auth(proto, pass);
174            }
175            if let Some(proto) = self.priv_protocol {
176                let pass = self
177                    .priv_password
178                    .as_ref()
179                    .ok_or("priv password required")?;
180                builder = builder.privacy(proto, pass);
181            }
182            Ok(builder.into())
183        } else {
184            let community = &common.community;
185            Ok(match common.snmp_version {
186                SnmpVersion::V1 => Auth::v1(community),
187                _ => Auth::v2c(community),
188            })
189        }
190    }
191
192    /// Validate V3 arguments and return an error message if invalid.
193    pub fn validate(&self) -> Result<(), String> {
194        if let Some(ref _username) = self.username {
195            // If auth-protocol is specified, auth-password is required
196            if self.auth_protocol.is_some() && self.auth_password.is_none() {
197                return Err(
198                    "authentication password (-A) required when using auth protocol".into(),
199                );
200            }
201
202            // If priv-protocol is specified, priv-password is required
203            if self.priv_protocol.is_some() && self.priv_password.is_none() {
204                return Err("privacy password (-X) required when using priv protocol".into());
205            }
206
207            // Privacy requires authentication
208            if self.priv_protocol.is_some() && self.auth_protocol.is_none() {
209                return Err("authentication protocol (-a) required when using privacy".into());
210            }
211        }
212        Ok(())
213    }
214}
215
216/// Output control arguments.
217#[derive(Debug, Parser)]
218pub struct OutputArgs {
219    /// Output format: human, json, or raw.
220    #[arg(short = 'O', long = "output", default_value = "human")]
221    pub format: OutputFormat,
222
223    /// Show PDU structure and wire details.
224    #[arg(long = "verbose")]
225    pub verbose: bool,
226
227    /// Always display OctetString as hex.
228    #[arg(long = "hex")]
229    pub hex: bool,
230
231    /// Show request timing.
232    #[arg(long = "timing")]
233    pub timing: bool,
234
235    /// Disable well-known OID name hints.
236    #[arg(long = "no-hints")]
237    pub no_hints: bool,
238
239    /// Enable debug logging (async_snmp=debug).
240    #[arg(short = 'd', long = "debug")]
241    pub debug: bool,
242
243    /// Enable trace logging (async_snmp=trace).
244    #[arg(short = 'D', long = "trace")]
245    pub trace: bool,
246}
247
248impl OutputArgs {
249    /// Initialize tracing based on debug/trace flags.
250    ///
251    /// Note: --verbose is handled separately and shows structured request/response info.
252    /// Use -d/--debug for library-level tracing.
253    pub fn init_tracing(&self) {
254        use tracing_subscriber::EnvFilter;
255
256        let filter = if self.trace {
257            "async_snmp=trace"
258        } else if self.debug {
259            "async_snmp=debug"
260        } else {
261            "async_snmp=warn"
262        };
263
264        let _ = tracing_subscriber::fmt()
265            .with_env_filter(EnvFilter::new(filter))
266            .with_writer(std::io::stderr)
267            .try_init();
268    }
269}
270
271/// Walk-specific arguments.
272#[derive(Debug, Parser)]
273pub struct WalkArgs {
274    /// Use GETNEXT instead of GETBULK.
275    #[arg(long = "getnext")]
276    pub getnext: bool,
277
278    /// GETBULK max-repetitions.
279    #[arg(long = "max-rep", default_value = "10")]
280    pub max_repetitions: u32,
281}
282
283/// Set-specific type specifier for values.
284#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
285pub enum ValueType {
286    /// INTEGER (i32)
287    #[value(name = "i")]
288    Integer,
289    /// Unsigned32/Gauge32 (u32)
290    #[value(name = "u")]
291    Unsigned,
292    /// STRING (OctetString from UTF-8)
293    #[value(name = "s")]
294    String,
295    /// Hex-STRING (OctetString from hex)
296    #[value(name = "x")]
297    HexString,
298    /// OBJECT IDENTIFIER
299    #[value(name = "o")]
300    Oid,
301    /// IpAddress
302    #[value(name = "a")]
303    IpAddress,
304    /// TimeTicks
305    #[value(name = "t")]
306    TimeTicks,
307    /// Counter32
308    #[value(name = "c")]
309    Counter32,
310    /// Counter64
311    #[value(name = "C")]
312    Counter64,
313}
314
315impl std::str::FromStr for ValueType {
316    type Err = String;
317
318    fn from_str(s: &str) -> Result<Self, Self::Err> {
319        match s {
320            "i" => Ok(ValueType::Integer),
321            "u" => Ok(ValueType::Unsigned),
322            "s" => Ok(ValueType::String),
323            "x" => Ok(ValueType::HexString),
324            "o" => Ok(ValueType::Oid),
325            "a" => Ok(ValueType::IpAddress),
326            "t" => Ok(ValueType::TimeTicks),
327            "c" => Ok(ValueType::Counter32),
328            "C" => Ok(ValueType::Counter64),
329            _ => Err(format!("invalid type specifier: {}", s)),
330        }
331    }
332}
333
334impl ValueType {
335    /// Parse a string value into an SNMP Value according to the type specifier.
336    pub fn parse_value(&self, s: &str) -> Result<crate::Value, String> {
337        use crate::{Oid, Value};
338
339        match self {
340            ValueType::Integer => {
341                let v: i32 = s
342                    .parse()
343                    .map_err(|_| format!("invalid integer value: {}", s))?;
344                Ok(Value::Integer(v))
345            }
346            ValueType::Unsigned => {
347                let v: u32 = s
348                    .parse()
349                    .map_err(|_| format!("invalid unsigned value: {}", s))?;
350                Ok(Value::Gauge32(v))
351            }
352            ValueType::String => Ok(Value::OctetString(s.as_bytes().to_vec().into())),
353            ValueType::HexString => {
354                let bytes = parse_hex_string(s)?;
355                Ok(Value::OctetString(bytes.into()))
356            }
357            ValueType::Oid => {
358                let oid = Oid::parse(s).map_err(|e| format!("invalid OID value: {}", e))?;
359                Ok(Value::ObjectIdentifier(oid))
360            }
361            ValueType::IpAddress => {
362                let parts: Vec<&str> = s.split('.').collect();
363                if parts.len() != 4 {
364                    return Err(format!("invalid IP address: {}", s));
365                }
366                let mut bytes = [0u8; 4];
367                for (i, part) in parts.iter().enumerate() {
368                    bytes[i] = part
369                        .parse()
370                        .map_err(|_| format!("invalid IP address octet: {}", part))?;
371                }
372                Ok(Value::IpAddress(bytes))
373            }
374            ValueType::TimeTicks => {
375                let v: u32 = s
376                    .parse()
377                    .map_err(|_| format!("invalid timeticks value: {}", s))?;
378                Ok(Value::TimeTicks(v))
379            }
380            ValueType::Counter32 => {
381                let v: u32 = s
382                    .parse()
383                    .map_err(|_| format!("invalid counter32 value: {}", s))?;
384                Ok(Value::Counter32(v))
385            }
386            ValueType::Counter64 => {
387                let v: u64 = s
388                    .parse()
389                    .map_err(|_| format!("invalid counter64 value: {}", s))?;
390                Ok(Value::Counter64(v))
391            }
392        }
393    }
394}
395
396/// Parse a hex string (with or without spaces/separators) into bytes.
397fn parse_hex_string(s: &str) -> Result<Vec<u8>, String> {
398    // Remove common separators: spaces, colons, dashes
399    let clean: String = s.chars().filter(|c| c.is_ascii_hexdigit()).collect();
400
401    if !clean.len().is_multiple_of(2) {
402        return Err("hex string must have even number of digits".into());
403    }
404
405    clean
406        .as_bytes()
407        .chunks(2)
408        .map(|chunk| {
409            let hex = std::str::from_utf8(chunk).unwrap();
410            u8::from_str_radix(hex, 16).map_err(|_| format!("invalid hex: {}", hex))
411        })
412        .collect()
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418
419    #[test]
420    fn test_retry_config_none() {
421        let args = CommonArgs {
422            target: "192.168.1.1".to_string(),
423            snmp_version: SnmpVersion::V2c,
424            community: "public".to_string(),
425            timeout: 5.0,
426            retries: 3,
427            backoff: BackoffStrategy::None,
428            backoff_delay: 100,
429            backoff_max: 5000,
430            backoff_jitter: 0.25,
431        };
432        let retry = args.retry_config();
433        assert_eq!(retry.max_attempts, 3);
434        assert!(matches!(retry.backoff, Backoff::None));
435    }
436
437    #[test]
438    fn test_retry_config_fixed() {
439        let args = CommonArgs {
440            target: "192.168.1.1".to_string(),
441            snmp_version: SnmpVersion::V2c,
442            community: "public".to_string(),
443            timeout: 5.0,
444            retries: 5,
445            backoff: BackoffStrategy::Fixed,
446            backoff_delay: 200,
447            backoff_max: 5000,
448            backoff_jitter: 0.25,
449        };
450        let retry = args.retry_config();
451        assert_eq!(retry.max_attempts, 5);
452        assert!(matches!(
453            retry.backoff,
454            Backoff::Fixed { delay } if delay == Duration::from_millis(200)
455        ));
456    }
457
458    #[test]
459    fn test_retry_config_exponential() {
460        let args = CommonArgs {
461            target: "192.168.1.1".to_string(),
462            snmp_version: SnmpVersion::V2c,
463            community: "public".to_string(),
464            timeout: 5.0,
465            retries: 4,
466            backoff: BackoffStrategy::Exponential,
467            backoff_delay: 50,
468            backoff_max: 2000,
469            backoff_jitter: 0.1,
470        };
471        let retry = args.retry_config();
472        assert_eq!(retry.max_attempts, 4);
473        match retry.backoff {
474            Backoff::Exponential {
475                initial,
476                max,
477                jitter,
478            } => {
479                assert_eq!(initial, Duration::from_millis(50));
480                assert_eq!(max, Duration::from_millis(2000));
481                assert!((jitter - 0.1).abs() < f64::EPSILON);
482            }
483            _ => panic!("expected Exponential"),
484        }
485    }
486
487    #[test]
488    fn test_v3_args_validation() {
489        // No username - valid (not v3)
490        let args = V3Args {
491            username: None,
492            level: None,
493            auth_protocol: None,
494            auth_password: None,
495            priv_protocol: None,
496            priv_password: None,
497        };
498        assert!(args.validate().is_ok());
499
500        // Username only - valid (noAuthNoPriv)
501        let args = V3Args {
502            username: Some("admin".to_string()),
503            level: None,
504            auth_protocol: None,
505            auth_password: None,
506            priv_protocol: None,
507            priv_password: None,
508        };
509        assert!(args.validate().is_ok());
510
511        // Auth protocol without password - invalid
512        let args = V3Args {
513            username: Some("admin".to_string()),
514            level: None,
515            auth_protocol: Some(AuthProtocol::Sha256),
516            auth_password: None,
517            priv_protocol: None,
518            priv_password: None,
519        };
520        assert!(args.validate().is_err());
521
522        // Privacy without auth - invalid
523        let args = V3Args {
524            username: Some("admin".to_string()),
525            level: None,
526            auth_protocol: None,
527            auth_password: None,
528            priv_protocol: Some(PrivProtocol::Aes128),
529            priv_password: Some("pass".to_string()),
530        };
531        assert!(args.validate().is_err());
532
533        // SHA-1 with AES-256 - valid (key extension auto-applied)
534        let args = V3Args {
535            username: Some("admin".to_string()),
536            level: None,
537            auth_protocol: Some(AuthProtocol::Sha1),
538            auth_password: Some("pass".to_string()),
539            priv_protocol: Some(PrivProtocol::Aes256),
540            priv_password: Some("pass".to_string()),
541        };
542        assert!(args.validate().is_ok());
543    }
544
545    #[test]
546    fn test_value_type_parse_integer() {
547        use crate::Value;
548        let v = ValueType::Integer.parse_value("42").unwrap();
549        assert!(matches!(v, Value::Integer(42)));
550
551        let v = ValueType::Integer.parse_value("-100").unwrap();
552        assert!(matches!(v, Value::Integer(-100)));
553
554        assert!(ValueType::Integer.parse_value("not_a_number").is_err());
555    }
556
557    #[test]
558    fn test_value_type_parse_unsigned() {
559        use crate::Value;
560        let v = ValueType::Unsigned.parse_value("42").unwrap();
561        assert!(matches!(v, Value::Gauge32(42)));
562
563        assert!(ValueType::Unsigned.parse_value("-1").is_err());
564    }
565
566    #[test]
567    fn test_value_type_parse_string() {
568        use crate::Value;
569        let v = ValueType::String.parse_value("hello world").unwrap();
570        if let Value::OctetString(bytes) = v {
571            assert_eq!(&*bytes, b"hello world");
572        } else {
573            panic!("expected OctetString");
574        }
575    }
576
577    #[test]
578    fn test_value_type_parse_hex_string() {
579        use crate::Value;
580
581        // Plain hex
582        let v = ValueType::HexString.parse_value("001a2b").unwrap();
583        if let Value::OctetString(bytes) = v {
584            assert_eq!(&*bytes, &[0x00, 0x1a, 0x2b]);
585        } else {
586            panic!("expected OctetString");
587        }
588
589        // With spaces
590        let v = ValueType::HexString.parse_value("00 1A 2B").unwrap();
591        if let Value::OctetString(bytes) = v {
592            assert_eq!(&*bytes, &[0x00, 0x1a, 0x2b]);
593        } else {
594            panic!("expected OctetString");
595        }
596
597        // Odd number of digits
598        assert!(ValueType::HexString.parse_value("001").is_err());
599    }
600
601    #[test]
602    fn test_value_type_parse_ip_address() {
603        use crate::Value;
604        let v = ValueType::IpAddress.parse_value("192.168.1.1").unwrap();
605        assert!(matches!(v, Value::IpAddress([192, 168, 1, 1])));
606
607        assert!(ValueType::IpAddress.parse_value("192.168.1").is_err());
608        assert!(ValueType::IpAddress.parse_value("256.1.1.1").is_err());
609    }
610
611    #[test]
612    fn test_value_type_parse_timeticks() {
613        use crate::Value;
614        let v = ValueType::TimeTicks.parse_value("12345678").unwrap();
615        assert!(matches!(v, Value::TimeTicks(12345678)));
616    }
617
618    #[test]
619    fn test_value_type_parse_counters() {
620        use crate::Value;
621
622        let v = ValueType::Counter32.parse_value("4294967295").unwrap();
623        assert!(matches!(v, Value::Counter32(4294967295)));
624
625        let v = ValueType::Counter64
626            .parse_value("18446744073709551615")
627            .unwrap();
628        assert!(matches!(v, Value::Counter64(18446744073709551615)));
629    }
630}