1use clap::{Parser, ValueEnum};
6use std::net::SocketAddr;
7use std::time::Duration;
8
9use crate::Version;
10use crate::client::Auth;
11use crate::v3::{AuthProtocol, PrivProtocol};
12
13#[derive(Debug, Clone, Copy, Default, ValueEnum)]
15pub enum SnmpVersion {
16 #[value(name = "1")]
18 V1,
19 #[default]
21 #[value(name = "2c")]
22 V2c,
23 #[value(name = "3")]
25 V3,
26}
27
28impl From<SnmpVersion> for Version {
29 fn from(v: SnmpVersion) -> Self {
30 match v {
31 SnmpVersion::V1 => Version::V1,
32 SnmpVersion::V2c => Version::V2c,
33 SnmpVersion::V3 => Version::V3,
34 }
35 }
36}
37
38#[derive(Debug, Clone, Copy, Default, ValueEnum)]
40pub enum OutputFormat {
41 #[default]
43 Human,
44 Json,
46 Raw,
48}
49
50#[derive(Debug, Parser)]
52pub struct CommonArgs {
53 #[arg(value_name = "TARGET")]
55 pub target: String,
56
57 #[arg(short = 'v', long = "snmp-version", default_value = "2c")]
59 pub snmp_version: SnmpVersion,
60
61 #[arg(short = 'c', long = "community", default_value = "public")]
63 pub community: String,
64
65 #[arg(short = 't', long = "timeout", default_value = "5")]
67 pub timeout: f64,
68
69 #[arg(short = 'r', long = "retries", default_value = "3")]
71 pub retries: u32,
72}
73
74impl CommonArgs {
75 pub fn target_addr(&self) -> Result<SocketAddr, String> {
77 let addr_str = if self.target.contains(':') {
79 self.target.clone()
80 } else {
81 format!("{}:161", self.target)
82 };
83
84 addr_str
85 .parse()
86 .or_else(|_| {
87 use std::net::ToSocketAddrs;
89 addr_str
90 .to_socket_addrs()
91 .map_err(|e| e.to_string())?
92 .next()
93 .ok_or_else(|| format!("could not resolve hostname: {}", self.target))
94 })
95 .map_err(|e| format!("invalid target '{}': {}", self.target, e))
96 }
97
98 pub fn timeout_duration(&self) -> Duration {
100 Duration::from_secs_f64(self.timeout)
101 }
102}
103
104#[derive(Debug, Parser)]
106pub struct V3Args {
107 #[arg(short = 'u', long = "username")]
109 pub username: Option<String>,
110
111 #[arg(short = 'l', long = "level")]
113 pub level: Option<String>,
114
115 #[arg(short = 'a', long = "auth-protocol")]
117 pub auth_protocol: Option<AuthProtocol>,
118
119 #[arg(short = 'A', long = "auth-password")]
121 pub auth_password: Option<String>,
122
123 #[arg(short = 'x', long = "priv-protocol")]
125 pub priv_protocol: Option<PrivProtocol>,
126
127 #[arg(short = 'X', long = "priv-password")]
129 pub priv_password: Option<String>,
130}
131
132impl V3Args {
133 pub fn is_v3(&self) -> bool {
135 self.username.is_some()
136 }
137
138 pub fn auth(&self, common: &CommonArgs) -> Result<Auth, String> {
143 if let Some(ref username) = self.username {
144 let mut builder = Auth::usm(username);
145 if let Some(proto) = self.auth_protocol {
146 let pass = self
147 .auth_password
148 .as_ref()
149 .ok_or("auth password required")?;
150 builder = builder.auth(proto, pass);
151 }
152 if let Some(proto) = self.priv_protocol {
153 let pass = self
154 .priv_password
155 .as_ref()
156 .ok_or("priv password required")?;
157 builder = builder.privacy(proto, pass);
158 }
159 Ok(builder.into())
160 } else {
161 let community = &common.community;
162 Ok(match common.snmp_version {
163 SnmpVersion::V1 => Auth::v1(community),
164 _ => Auth::v2c(community),
165 })
166 }
167 }
168
169 pub fn validate(&self) -> Result<(), String> {
171 if let Some(ref _username) = self.username {
172 if self.auth_protocol.is_some() && self.auth_password.is_none() {
174 return Err(
175 "authentication password (-A) required when using auth protocol".into(),
176 );
177 }
178
179 if self.priv_protocol.is_some() && self.priv_password.is_none() {
181 return Err("privacy password (-X) required when using priv protocol".into());
182 }
183
184 if self.priv_protocol.is_some() && self.auth_protocol.is_none() {
186 return Err("authentication protocol (-a) required when using privacy".into());
187 }
188
189 if let (Some(auth), Some(priv_proto)) = (self.auth_protocol, self.priv_protocol)
191 && !auth.is_compatible_with(priv_proto)
192 {
193 return Err(format!(
194 "{} authentication does not produce enough key material for {} privacy; use {} or stronger",
195 auth,
196 priv_proto,
197 priv_proto.min_auth_protocol()
198 ));
199 }
200 }
201 Ok(())
202 }
203}
204
205#[derive(Debug, Parser)]
207pub struct OutputArgs {
208 #[arg(short = 'O', long = "output", default_value = "human")]
210 pub format: OutputFormat,
211
212 #[arg(long = "verbose")]
214 pub verbose: bool,
215
216 #[arg(long = "hex")]
218 pub hex: bool,
219
220 #[arg(long = "timing")]
222 pub timing: bool,
223
224 #[arg(long = "no-hints")]
226 pub no_hints: bool,
227
228 #[arg(short = 'd', long = "debug")]
230 pub debug: bool,
231
232 #[arg(short = 'D', long = "trace")]
234 pub trace: bool,
235}
236
237impl OutputArgs {
238 pub fn init_tracing(&self) {
243 use tracing_subscriber::EnvFilter;
244
245 let filter = if self.trace {
246 "async_snmp=trace"
247 } else if self.debug {
248 "async_snmp=debug"
249 } else {
250 "async_snmp=warn"
251 };
252
253 let _ = tracing_subscriber::fmt()
254 .with_env_filter(EnvFilter::new(filter))
255 .with_writer(std::io::stderr)
256 .try_init();
257 }
258}
259
260#[derive(Debug, Parser)]
262pub struct WalkArgs {
263 #[arg(long = "getnext")]
265 pub getnext: bool,
266
267 #[arg(long = "max-rep", default_value = "10")]
269 pub max_repetitions: u32,
270}
271
272#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
274pub enum ValueType {
275 #[value(name = "i")]
277 Integer,
278 #[value(name = "u")]
280 Unsigned,
281 #[value(name = "s")]
283 String,
284 #[value(name = "x")]
286 HexString,
287 #[value(name = "o")]
289 Oid,
290 #[value(name = "a")]
292 IpAddress,
293 #[value(name = "t")]
295 TimeTicks,
296 #[value(name = "c")]
298 Counter32,
299 #[value(name = "C")]
301 Counter64,
302}
303
304impl std::str::FromStr for ValueType {
305 type Err = String;
306
307 fn from_str(s: &str) -> Result<Self, Self::Err> {
308 match s {
309 "i" => Ok(ValueType::Integer),
310 "u" => Ok(ValueType::Unsigned),
311 "s" => Ok(ValueType::String),
312 "x" => Ok(ValueType::HexString),
313 "o" => Ok(ValueType::Oid),
314 "a" => Ok(ValueType::IpAddress),
315 "t" => Ok(ValueType::TimeTicks),
316 "c" => Ok(ValueType::Counter32),
317 "C" => Ok(ValueType::Counter64),
318 _ => Err(format!("invalid type specifier: {}", s)),
319 }
320 }
321}
322
323impl ValueType {
324 pub fn parse_value(&self, s: &str) -> Result<crate::Value, String> {
326 use crate::{Oid, Value};
327
328 match self {
329 ValueType::Integer => {
330 let v: i32 = s
331 .parse()
332 .map_err(|_| format!("invalid integer value: {}", s))?;
333 Ok(Value::Integer(v))
334 }
335 ValueType::Unsigned => {
336 let v: u32 = s
337 .parse()
338 .map_err(|_| format!("invalid unsigned value: {}", s))?;
339 Ok(Value::Gauge32(v))
340 }
341 ValueType::String => Ok(Value::OctetString(s.as_bytes().to_vec().into())),
342 ValueType::HexString => {
343 let bytes = parse_hex_string(s)?;
344 Ok(Value::OctetString(bytes.into()))
345 }
346 ValueType::Oid => {
347 let oid = Oid::parse(s).map_err(|e| format!("invalid OID value: {}", e))?;
348 Ok(Value::ObjectIdentifier(oid))
349 }
350 ValueType::IpAddress => {
351 let parts: Vec<&str> = s.split('.').collect();
352 if parts.len() != 4 {
353 return Err(format!("invalid IP address: {}", s));
354 }
355 let mut bytes = [0u8; 4];
356 for (i, part) in parts.iter().enumerate() {
357 bytes[i] = part
358 .parse()
359 .map_err(|_| format!("invalid IP address octet: {}", part))?;
360 }
361 Ok(Value::IpAddress(bytes))
362 }
363 ValueType::TimeTicks => {
364 let v: u32 = s
365 .parse()
366 .map_err(|_| format!("invalid timeticks value: {}", s))?;
367 Ok(Value::TimeTicks(v))
368 }
369 ValueType::Counter32 => {
370 let v: u32 = s
371 .parse()
372 .map_err(|_| format!("invalid counter32 value: {}", s))?;
373 Ok(Value::Counter32(v))
374 }
375 ValueType::Counter64 => {
376 let v: u64 = s
377 .parse()
378 .map_err(|_| format!("invalid counter64 value: {}", s))?;
379 Ok(Value::Counter64(v))
380 }
381 }
382 }
383}
384
385fn parse_hex_string(s: &str) -> Result<Vec<u8>, String> {
387 let clean: String = s.chars().filter(|c| c.is_ascii_hexdigit()).collect();
389
390 if !clean.len().is_multiple_of(2) {
391 return Err("hex string must have even number of digits".into());
392 }
393
394 clean
395 .as_bytes()
396 .chunks(2)
397 .map(|chunk| {
398 let hex = std::str::from_utf8(chunk).unwrap();
399 u8::from_str_radix(hex, 16).map_err(|_| format!("invalid hex: {}", hex))
400 })
401 .collect()
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407
408 #[test]
409 fn test_target_addr_with_port() {
410 let args = CommonArgs {
411 target: "192.168.1.1:162".to_string(),
412 snmp_version: SnmpVersion::V2c,
413 community: "public".to_string(),
414 timeout: 5.0,
415 retries: 3,
416 };
417 let addr = args.target_addr().unwrap();
418 assert_eq!(addr.port(), 162);
419 }
420
421 #[test]
422 fn test_target_addr_default_port() {
423 let args = CommonArgs {
424 target: "192.168.1.1".to_string(),
425 snmp_version: SnmpVersion::V2c,
426 community: "public".to_string(),
427 timeout: 5.0,
428 retries: 3,
429 };
430 let addr = args.target_addr().unwrap();
431 assert_eq!(addr.port(), 161);
432 }
433
434 #[test]
435 fn test_v3_args_validation() {
436 let args = V3Args {
438 username: None,
439 level: None,
440 auth_protocol: None,
441 auth_password: None,
442 priv_protocol: None,
443 priv_password: None,
444 };
445 assert!(args.validate().is_ok());
446
447 let args = V3Args {
449 username: Some("admin".to_string()),
450 level: None,
451 auth_protocol: None,
452 auth_password: None,
453 priv_protocol: None,
454 priv_password: None,
455 };
456 assert!(args.validate().is_ok());
457
458 let args = V3Args {
460 username: Some("admin".to_string()),
461 level: None,
462 auth_protocol: Some(AuthProtocol::Sha256),
463 auth_password: None,
464 priv_protocol: None,
465 priv_password: None,
466 };
467 assert!(args.validate().is_err());
468
469 let args = V3Args {
471 username: Some("admin".to_string()),
472 level: None,
473 auth_protocol: None,
474 auth_password: None,
475 priv_protocol: Some(PrivProtocol::Aes128),
476 priv_password: Some("pass".to_string()),
477 };
478 assert!(args.validate().is_err());
479
480 let args = V3Args {
482 username: Some("admin".to_string()),
483 level: None,
484 auth_protocol: Some(AuthProtocol::Sha1),
485 auth_password: Some("pass".to_string()),
486 priv_protocol: Some(PrivProtocol::Aes256),
487 priv_password: Some("pass".to_string()),
488 };
489 assert!(args.validate().is_err());
490 }
491
492 #[test]
493 fn test_value_type_parse_integer() {
494 use crate::Value;
495 let v = ValueType::Integer.parse_value("42").unwrap();
496 assert!(matches!(v, Value::Integer(42)));
497
498 let v = ValueType::Integer.parse_value("-100").unwrap();
499 assert!(matches!(v, Value::Integer(-100)));
500
501 assert!(ValueType::Integer.parse_value("not_a_number").is_err());
502 }
503
504 #[test]
505 fn test_value_type_parse_unsigned() {
506 use crate::Value;
507 let v = ValueType::Unsigned.parse_value("42").unwrap();
508 assert!(matches!(v, Value::Gauge32(42)));
509
510 assert!(ValueType::Unsigned.parse_value("-1").is_err());
511 }
512
513 #[test]
514 fn test_value_type_parse_string() {
515 use crate::Value;
516 let v = ValueType::String.parse_value("hello world").unwrap();
517 if let Value::OctetString(bytes) = v {
518 assert_eq!(&*bytes, b"hello world");
519 } else {
520 panic!("expected OctetString");
521 }
522 }
523
524 #[test]
525 fn test_value_type_parse_hex_string() {
526 use crate::Value;
527
528 let v = ValueType::HexString.parse_value("001a2b").unwrap();
530 if let Value::OctetString(bytes) = v {
531 assert_eq!(&*bytes, &[0x00, 0x1a, 0x2b]);
532 } else {
533 panic!("expected OctetString");
534 }
535
536 let v = ValueType::HexString.parse_value("00 1A 2B").unwrap();
538 if let Value::OctetString(bytes) = v {
539 assert_eq!(&*bytes, &[0x00, 0x1a, 0x2b]);
540 } else {
541 panic!("expected OctetString");
542 }
543
544 assert!(ValueType::HexString.parse_value("001").is_err());
546 }
547
548 #[test]
549 fn test_value_type_parse_ip_address() {
550 use crate::Value;
551 let v = ValueType::IpAddress.parse_value("192.168.1.1").unwrap();
552 assert!(matches!(v, Value::IpAddress([192, 168, 1, 1])));
553
554 assert!(ValueType::IpAddress.parse_value("192.168.1").is_err());
555 assert!(ValueType::IpAddress.parse_value("256.1.1.1").is_err());
556 }
557
558 #[test]
559 fn test_value_type_parse_timeticks() {
560 use crate::Value;
561 let v = ValueType::TimeTicks.parse_value("12345678").unwrap();
562 assert!(matches!(v, Value::TimeTicks(12345678)));
563 }
564
565 #[test]
566 fn test_value_type_parse_counters() {
567 use crate::Value;
568
569 let v = ValueType::Counter32.parse_value("4294967295").unwrap();
570 assert!(matches!(v, Value::Counter32(4294967295)));
571
572 let v = ValueType::Counter64
573 .parse_value("18446744073709551615")
574 .unwrap();
575 assert!(matches!(v, Value::Counter64(18446744073709551615)));
576 }
577}