async_snmp/cli/
args.rs

1//! Command-line argument structures for async-snmp CLI tools.
2//!
3//! This module provides reusable clap argument structures for the `asnmp-*` CLI tools.
4
5use clap::{Parser, ValueEnum};
6use std::net::SocketAddr;
7use std::time::Duration;
8
9use crate::Version;
10use crate::client::Auth;
11use crate::v3::{AuthProtocol, PrivProtocol};
12
13/// SNMP version for CLI argument parsing.
14#[derive(Debug, Clone, Copy, Default, ValueEnum)]
15pub enum SnmpVersion {
16    /// SNMPv1
17    #[value(name = "1")]
18    V1,
19    /// SNMPv2c (default)
20    #[default]
21    #[value(name = "2c")]
22    V2c,
23    /// SNMPv3
24    #[value(name = "3")]
25    V3,
26}
27
28impl From<SnmpVersion> for Version {
29    fn from(v: SnmpVersion) -> Self {
30        match v {
31            SnmpVersion::V1 => Version::V1,
32            SnmpVersion::V2c => Version::V2c,
33            SnmpVersion::V3 => Version::V3,
34        }
35    }
36}
37
38/// Output format for CLI tools.
39#[derive(Debug, Clone, Copy, Default, ValueEnum)]
40pub enum OutputFormat {
41    /// Human-readable output with type information.
42    #[default]
43    Human,
44    /// JSON output for scripting.
45    Json,
46    /// Raw tab-separated output for scripting.
47    Raw,
48}
49
50/// Common arguments shared across all CLI tools.
51#[derive(Debug, Parser)]
52pub struct CommonArgs {
53    /// Target host or host:port (default port 161).
54    #[arg(value_name = "TARGET")]
55    pub target: String,
56
57    /// SNMP version: 1, 2c, or 3.
58    #[arg(short = 'v', long = "snmp-version", default_value = "2c")]
59    pub snmp_version: SnmpVersion,
60
61    /// Community string (v1/v2c).
62    #[arg(short = 'c', long = "community", default_value = "public")]
63    pub community: String,
64
65    /// Request timeout in seconds.
66    #[arg(short = 't', long = "timeout", default_value = "5")]
67    pub timeout: f64,
68
69    /// Retry count.
70    #[arg(short = 'r', long = "retries", default_value = "3")]
71    pub retries: u32,
72}
73
74impl CommonArgs {
75    /// Parse the target into a SocketAddr, defaulting to port 161.
76    pub fn target_addr(&self) -> Result<SocketAddr, String> {
77        // If the target doesn't contain a port, add the default SNMP port
78        let addr_str = if self.target.contains(':') {
79            self.target.clone()
80        } else {
81            format!("{}:161", self.target)
82        };
83
84        addr_str
85            .parse()
86            .or_else(|_| {
87                // Try to resolve as hostname
88                use std::net::ToSocketAddrs;
89                addr_str
90                    .to_socket_addrs()
91                    .map_err(|e| e.to_string())?
92                    .next()
93                    .ok_or_else(|| format!("could not resolve hostname: {}", self.target))
94            })
95            .map_err(|e| format!("invalid target '{}': {}", self.target, e))
96    }
97
98    /// Get the timeout as a Duration.
99    pub fn timeout_duration(&self) -> Duration {
100        Duration::from_secs_f64(self.timeout)
101    }
102}
103
104/// SNMPv3 security arguments.
105#[derive(Debug, Parser)]
106pub struct V3Args {
107    /// Security name/username (implies -v 3).
108    #[arg(short = 'u', long = "username")]
109    pub username: Option<String>,
110
111    /// Security level: noAuthNoPriv, authNoPriv, or authPriv.
112    #[arg(short = 'l', long = "level")]
113    pub level: Option<String>,
114
115    /// Authentication protocol: MD5, SHA, SHA-224, SHA-256, SHA-384, SHA-512.
116    #[arg(short = 'a', long = "auth-protocol")]
117    pub auth_protocol: Option<AuthProtocol>,
118
119    /// Authentication passphrase.
120    #[arg(short = 'A', long = "auth-password")]
121    pub auth_password: Option<String>,
122
123    /// Privacy protocol: DES, AES, AES-128, AES-192, AES-256.
124    #[arg(short = 'x', long = "priv-protocol")]
125    pub priv_protocol: Option<PrivProtocol>,
126
127    /// Privacy passphrase.
128    #[arg(short = 'X', long = "priv-password")]
129    pub priv_password: Option<String>,
130}
131
132impl V3Args {
133    /// Check if V3 mode is enabled (username provided).
134    pub fn is_v3(&self) -> bool {
135        self.username.is_some()
136    }
137
138    /// Build an Auth configuration from the V3 args and common args.
139    ///
140    /// If a username is provided, builds a USM auth configuration.
141    /// Otherwise, builds a community auth based on the version and community from common args.
142    pub fn auth(&self, common: &CommonArgs) -> Result<Auth, String> {
143        if let Some(ref username) = self.username {
144            let mut builder = Auth::usm(username);
145            if let Some(proto) = self.auth_protocol {
146                let pass = self
147                    .auth_password
148                    .as_ref()
149                    .ok_or("auth password required")?;
150                builder = builder.auth(proto, pass);
151            }
152            if let Some(proto) = self.priv_protocol {
153                let pass = self
154                    .priv_password
155                    .as_ref()
156                    .ok_or("priv password required")?;
157                builder = builder.privacy(proto, pass);
158            }
159            Ok(builder.into())
160        } else {
161            let community = &common.community;
162            Ok(match common.snmp_version {
163                SnmpVersion::V1 => Auth::v1(community),
164                _ => Auth::v2c(community),
165            })
166        }
167    }
168
169    /// Validate V3 arguments and return an error message if invalid.
170    pub fn validate(&self) -> Result<(), String> {
171        if let Some(ref _username) = self.username {
172            // If auth-protocol is specified, auth-password is required
173            if self.auth_protocol.is_some() && self.auth_password.is_none() {
174                return Err(
175                    "authentication password (-A) required when using auth protocol".into(),
176                );
177            }
178
179            // If priv-protocol is specified, priv-password is required
180            if self.priv_protocol.is_some() && self.priv_password.is_none() {
181                return Err("privacy password (-X) required when using priv protocol".into());
182            }
183
184            // Privacy requires authentication
185            if self.priv_protocol.is_some() && self.auth_protocol.is_none() {
186                return Err("authentication protocol (-a) required when using privacy".into());
187            }
188
189            // Check auth/priv compatibility
190            if let (Some(auth), Some(priv_proto)) = (self.auth_protocol, self.priv_protocol)
191                && !auth.is_compatible_with(priv_proto)
192            {
193                return Err(format!(
194                    "{} authentication does not produce enough key material for {} privacy; use {} or stronger",
195                    auth,
196                    priv_proto,
197                    priv_proto.min_auth_protocol()
198                ));
199            }
200        }
201        Ok(())
202    }
203}
204
205/// Output control arguments.
206#[derive(Debug, Parser)]
207pub struct OutputArgs {
208    /// Output format: human, json, or raw.
209    #[arg(short = 'O', long = "output", default_value = "human")]
210    pub format: OutputFormat,
211
212    /// Show PDU structure and wire details.
213    #[arg(long = "verbose")]
214    pub verbose: bool,
215
216    /// Always display OctetString as hex.
217    #[arg(long = "hex")]
218    pub hex: bool,
219
220    /// Show request timing.
221    #[arg(long = "timing")]
222    pub timing: bool,
223
224    /// Disable well-known OID name hints.
225    #[arg(long = "no-hints")]
226    pub no_hints: bool,
227
228    /// Enable debug logging (async_snmp=debug).
229    #[arg(short = 'd', long = "debug")]
230    pub debug: bool,
231
232    /// Enable trace logging (async_snmp=trace).
233    #[arg(short = 'D', long = "trace")]
234    pub trace: bool,
235}
236
237impl OutputArgs {
238    /// Initialize tracing based on debug/trace flags.
239    ///
240    /// Note: --verbose is handled separately and shows structured request/response info.
241    /// Use -d/--debug for library-level tracing.
242    pub fn init_tracing(&self) {
243        use tracing_subscriber::EnvFilter;
244
245        let filter = if self.trace {
246            "async_snmp=trace"
247        } else if self.debug {
248            "async_snmp=debug"
249        } else {
250            "async_snmp=warn"
251        };
252
253        let _ = tracing_subscriber::fmt()
254            .with_env_filter(EnvFilter::new(filter))
255            .with_writer(std::io::stderr)
256            .try_init();
257    }
258}
259
260/// Walk-specific arguments.
261#[derive(Debug, Parser)]
262pub struct WalkArgs {
263    /// Use GETNEXT instead of GETBULK.
264    #[arg(long = "getnext")]
265    pub getnext: bool,
266
267    /// GETBULK max-repetitions.
268    #[arg(long = "max-rep", default_value = "10")]
269    pub max_repetitions: u32,
270}
271
272/// Set-specific type specifier for values.
273#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
274pub enum ValueType {
275    /// INTEGER (i32)
276    #[value(name = "i")]
277    Integer,
278    /// Unsigned32/Gauge32 (u32)
279    #[value(name = "u")]
280    Unsigned,
281    /// STRING (OctetString from UTF-8)
282    #[value(name = "s")]
283    String,
284    /// Hex-STRING (OctetString from hex)
285    #[value(name = "x")]
286    HexString,
287    /// OBJECT IDENTIFIER
288    #[value(name = "o")]
289    Oid,
290    /// IpAddress
291    #[value(name = "a")]
292    IpAddress,
293    /// TimeTicks
294    #[value(name = "t")]
295    TimeTicks,
296    /// Counter32
297    #[value(name = "c")]
298    Counter32,
299    /// Counter64
300    #[value(name = "C")]
301    Counter64,
302}
303
304impl std::str::FromStr for ValueType {
305    type Err = String;
306
307    fn from_str(s: &str) -> Result<Self, Self::Err> {
308        match s {
309            "i" => Ok(ValueType::Integer),
310            "u" => Ok(ValueType::Unsigned),
311            "s" => Ok(ValueType::String),
312            "x" => Ok(ValueType::HexString),
313            "o" => Ok(ValueType::Oid),
314            "a" => Ok(ValueType::IpAddress),
315            "t" => Ok(ValueType::TimeTicks),
316            "c" => Ok(ValueType::Counter32),
317            "C" => Ok(ValueType::Counter64),
318            _ => Err(format!("invalid type specifier: {}", s)),
319        }
320    }
321}
322
323impl ValueType {
324    /// Parse a string value into an SNMP Value according to the type specifier.
325    pub fn parse_value(&self, s: &str) -> Result<crate::Value, String> {
326        use crate::{Oid, Value};
327
328        match self {
329            ValueType::Integer => {
330                let v: i32 = s
331                    .parse()
332                    .map_err(|_| format!("invalid integer value: {}", s))?;
333                Ok(Value::Integer(v))
334            }
335            ValueType::Unsigned => {
336                let v: u32 = s
337                    .parse()
338                    .map_err(|_| format!("invalid unsigned value: {}", s))?;
339                Ok(Value::Gauge32(v))
340            }
341            ValueType::String => Ok(Value::OctetString(s.as_bytes().to_vec().into())),
342            ValueType::HexString => {
343                let bytes = parse_hex_string(s)?;
344                Ok(Value::OctetString(bytes.into()))
345            }
346            ValueType::Oid => {
347                let oid = Oid::parse(s).map_err(|e| format!("invalid OID value: {}", e))?;
348                Ok(Value::ObjectIdentifier(oid))
349            }
350            ValueType::IpAddress => {
351                let parts: Vec<&str> = s.split('.').collect();
352                if parts.len() != 4 {
353                    return Err(format!("invalid IP address: {}", s));
354                }
355                let mut bytes = [0u8; 4];
356                for (i, part) in parts.iter().enumerate() {
357                    bytes[i] = part
358                        .parse()
359                        .map_err(|_| format!("invalid IP address octet: {}", part))?;
360                }
361                Ok(Value::IpAddress(bytes))
362            }
363            ValueType::TimeTicks => {
364                let v: u32 = s
365                    .parse()
366                    .map_err(|_| format!("invalid timeticks value: {}", s))?;
367                Ok(Value::TimeTicks(v))
368            }
369            ValueType::Counter32 => {
370                let v: u32 = s
371                    .parse()
372                    .map_err(|_| format!("invalid counter32 value: {}", s))?;
373                Ok(Value::Counter32(v))
374            }
375            ValueType::Counter64 => {
376                let v: u64 = s
377                    .parse()
378                    .map_err(|_| format!("invalid counter64 value: {}", s))?;
379                Ok(Value::Counter64(v))
380            }
381        }
382    }
383}
384
385/// Parse a hex string (with or without spaces/separators) into bytes.
386fn parse_hex_string(s: &str) -> Result<Vec<u8>, String> {
387    // Remove common separators: spaces, colons, dashes
388    let clean: String = s.chars().filter(|c| c.is_ascii_hexdigit()).collect();
389
390    if !clean.len().is_multiple_of(2) {
391        return Err("hex string must have even number of digits".into());
392    }
393
394    clean
395        .as_bytes()
396        .chunks(2)
397        .map(|chunk| {
398            let hex = std::str::from_utf8(chunk).unwrap();
399            u8::from_str_radix(hex, 16).map_err(|_| format!("invalid hex: {}", hex))
400        })
401        .collect()
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407
408    #[test]
409    fn test_target_addr_with_port() {
410        let args = CommonArgs {
411            target: "192.168.1.1:162".to_string(),
412            snmp_version: SnmpVersion::V2c,
413            community: "public".to_string(),
414            timeout: 5.0,
415            retries: 3,
416        };
417        let addr = args.target_addr().unwrap();
418        assert_eq!(addr.port(), 162);
419    }
420
421    #[test]
422    fn test_target_addr_default_port() {
423        let args = CommonArgs {
424            target: "192.168.1.1".to_string(),
425            snmp_version: SnmpVersion::V2c,
426            community: "public".to_string(),
427            timeout: 5.0,
428            retries: 3,
429        };
430        let addr = args.target_addr().unwrap();
431        assert_eq!(addr.port(), 161);
432    }
433
434    #[test]
435    fn test_v3_args_validation() {
436        // No username - valid (not v3)
437        let args = V3Args {
438            username: None,
439            level: None,
440            auth_protocol: None,
441            auth_password: None,
442            priv_protocol: None,
443            priv_password: None,
444        };
445        assert!(args.validate().is_ok());
446
447        // Username only - valid (noAuthNoPriv)
448        let args = V3Args {
449            username: Some("admin".to_string()),
450            level: None,
451            auth_protocol: None,
452            auth_password: None,
453            priv_protocol: None,
454            priv_password: None,
455        };
456        assert!(args.validate().is_ok());
457
458        // Auth protocol without password - invalid
459        let args = V3Args {
460            username: Some("admin".to_string()),
461            level: None,
462            auth_protocol: Some(AuthProtocol::Sha256),
463            auth_password: None,
464            priv_protocol: None,
465            priv_password: None,
466        };
467        assert!(args.validate().is_err());
468
469        // Privacy without auth - invalid
470        let args = V3Args {
471            username: Some("admin".to_string()),
472            level: None,
473            auth_protocol: None,
474            auth_password: None,
475            priv_protocol: Some(PrivProtocol::Aes128),
476            priv_password: Some("pass".to_string()),
477        };
478        assert!(args.validate().is_err());
479
480        // Incompatible auth/priv - invalid (SHA1 with AES256)
481        let args = V3Args {
482            username: Some("admin".to_string()),
483            level: None,
484            auth_protocol: Some(AuthProtocol::Sha1),
485            auth_password: Some("pass".to_string()),
486            priv_protocol: Some(PrivProtocol::Aes256),
487            priv_password: Some("pass".to_string()),
488        };
489        assert!(args.validate().is_err());
490    }
491
492    #[test]
493    fn test_value_type_parse_integer() {
494        use crate::Value;
495        let v = ValueType::Integer.parse_value("42").unwrap();
496        assert!(matches!(v, Value::Integer(42)));
497
498        let v = ValueType::Integer.parse_value("-100").unwrap();
499        assert!(matches!(v, Value::Integer(-100)));
500
501        assert!(ValueType::Integer.parse_value("not_a_number").is_err());
502    }
503
504    #[test]
505    fn test_value_type_parse_unsigned() {
506        use crate::Value;
507        let v = ValueType::Unsigned.parse_value("42").unwrap();
508        assert!(matches!(v, Value::Gauge32(42)));
509
510        assert!(ValueType::Unsigned.parse_value("-1").is_err());
511    }
512
513    #[test]
514    fn test_value_type_parse_string() {
515        use crate::Value;
516        let v = ValueType::String.parse_value("hello world").unwrap();
517        if let Value::OctetString(bytes) = v {
518            assert_eq!(&*bytes, b"hello world");
519        } else {
520            panic!("expected OctetString");
521        }
522    }
523
524    #[test]
525    fn test_value_type_parse_hex_string() {
526        use crate::Value;
527
528        // Plain hex
529        let v = ValueType::HexString.parse_value("001a2b").unwrap();
530        if let Value::OctetString(bytes) = v {
531            assert_eq!(&*bytes, &[0x00, 0x1a, 0x2b]);
532        } else {
533            panic!("expected OctetString");
534        }
535
536        // With spaces
537        let v = ValueType::HexString.parse_value("00 1A 2B").unwrap();
538        if let Value::OctetString(bytes) = v {
539            assert_eq!(&*bytes, &[0x00, 0x1a, 0x2b]);
540        } else {
541            panic!("expected OctetString");
542        }
543
544        // Odd number of digits
545        assert!(ValueType::HexString.parse_value("001").is_err());
546    }
547
548    #[test]
549    fn test_value_type_parse_ip_address() {
550        use crate::Value;
551        let v = ValueType::IpAddress.parse_value("192.168.1.1").unwrap();
552        assert!(matches!(v, Value::IpAddress([192, 168, 1, 1])));
553
554        assert!(ValueType::IpAddress.parse_value("192.168.1").is_err());
555        assert!(ValueType::IpAddress.parse_value("256.1.1.1").is_err());
556    }
557
558    #[test]
559    fn test_value_type_parse_timeticks() {
560        use crate::Value;
561        let v = ValueType::TimeTicks.parse_value("12345678").unwrap();
562        assert!(matches!(v, Value::TimeTicks(12345678)));
563    }
564
565    #[test]
566    fn test_value_type_parse_counters() {
567        use crate::Value;
568
569        let v = ValueType::Counter32.parse_value("4294967295").unwrap();
570        assert!(matches!(v, Value::Counter32(4294967295)));
571
572        let v = ValueType::Counter64
573            .parse_value("18446744073709551615")
574            .unwrap();
575        assert!(matches!(v, Value::Counter64(18446744073709551615)));
576    }
577}