1use crate::ber::Decoder;
7use crate::error::{
8 AuthErrorKind, CryptoErrorKind, DecodeErrorKind, EncodeErrorKind, Error, ErrorStatus, Result,
9};
10use crate::format::hex;
11use crate::message::{MsgFlags, MsgGlobalData, ScopedPdu, SecurityLevel, V3Message};
12use crate::pdu::{Pdu, PduType};
13use crate::transport::Transport;
14use crate::v3::{AuthProtocol, PrivProtocol};
15use crate::v3::{
16 LocalizedKey, PrivKey, UsmSecurityParams,
17 auth::{authenticate_message, verify_message},
18 is_not_in_time_window_report, is_unknown_engine_id_report,
19};
20use bytes::Bytes;
21use std::time::Instant;
22use tracing::{Span, instrument};
23
24use super::Client;
25
26#[derive(Clone)]
38pub struct V3SecurityConfig {
39 pub username: Bytes,
41 pub auth: Option<(AuthProtocol, Vec<u8>)>,
43 pub privacy: Option<(PrivProtocol, Vec<u8>)>,
45 pub master_keys: Option<crate::v3::MasterKeys>,
47}
48
49impl V3SecurityConfig {
50 pub fn new(username: impl Into<Bytes>) -> Self {
52 Self {
53 username: username.into(),
54 auth: None,
55 privacy: None,
56 master_keys: None,
57 }
58 }
59
60 pub fn auth(mut self, protocol: AuthProtocol, password: impl Into<Vec<u8>>) -> Self {
62 self.auth = Some((protocol, password.into()));
63 self
64 }
65
66 pub fn privacy(mut self, protocol: PrivProtocol, password: impl Into<Vec<u8>>) -> Self {
68 self.privacy = Some((protocol, password.into()));
69 self
70 }
71
72 pub fn with_master_keys(mut self, master_keys: crate::v3::MasterKeys) -> Self {
78 self.master_keys = Some(master_keys);
79 self
80 }
81
82 pub fn security_level(&self) -> SecurityLevel {
84 if let Some(ref master_keys) = self.master_keys {
86 if master_keys.priv_protocol().is_some() {
87 return SecurityLevel::AuthPriv;
88 }
89 return SecurityLevel::AuthNoPriv;
90 }
91
92 match (&self.auth, &self.privacy) {
93 (None, _) => SecurityLevel::NoAuthNoPriv,
94 (Some(_), None) => SecurityLevel::AuthNoPriv,
95 (Some(_), Some(_)) => SecurityLevel::AuthPriv,
96 }
97 }
98
99 pub fn derive_keys(&self, engine_id: &[u8]) -> V3DerivedKeys {
105 if let Some(ref master_keys) = self.master_keys {
107 tracing::trace!(
108 engine_id_len = engine_id.len(),
109 auth_protocol = ?master_keys.auth_protocol(),
110 priv_protocol = ?master_keys.priv_protocol(),
111 "localizing from cached master keys"
112 );
113 let (auth_key, priv_key) = master_keys.localize(engine_id);
114 tracing::trace!("key localization complete");
115 return V3DerivedKeys {
116 auth_key: Some(auth_key),
117 priv_key,
118 };
119 }
120
121 tracing::trace!(
123 engine_id_len = engine_id.len(),
124 has_auth = self.auth.is_some(),
125 has_priv = self.privacy.is_some(),
126 "deriving localized keys from passwords"
127 );
128
129 let auth_key = self.auth.as_ref().map(|(protocol, password)| {
130 tracing::trace!(auth_protocol = ?protocol, "deriving auth key");
131 LocalizedKey::from_password(*protocol, password, engine_id)
132 });
133
134 let priv_key = match (&self.auth, &self.privacy) {
135 (Some((auth_protocol, _)), Some((priv_protocol, priv_password))) => {
136 tracing::trace!(priv_protocol = ?priv_protocol, "deriving privacy key");
137 Some(PrivKey::from_password(
138 *auth_protocol,
139 *priv_protocol,
140 priv_password,
141 engine_id,
142 ))
143 }
144 _ => None,
145 };
146
147 tracing::trace!("key derivation complete");
148 V3DerivedKeys { auth_key, priv_key }
149 }
150}
151
152impl std::fmt::Debug for V3SecurityConfig {
153 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154 f.debug_struct("V3SecurityConfig")
155 .field("username", &String::from_utf8_lossy(&self.username))
156 .field("auth", &self.auth.as_ref().map(|(p, _)| p))
157 .field("privacy", &self.privacy.as_ref().map(|(p, _)| p))
158 .finish()
159 }
160}
161
162pub struct V3DerivedKeys {
164 pub auth_key: Option<LocalizedKey>,
165 pub priv_key: Option<PrivKey>,
166}
167
168impl<T: Transport> Client<T> {
170 #[instrument(level = "debug", skip(self), fields(snmp.target = %self.peer_addr()))]
172 pub(super) async fn ensure_engine_discovered(&self) -> Result<()> {
173 {
175 let state = self.inner.engine_state.read().unwrap();
176 if state.is_some() {
177 return Ok(());
178 }
179 }
180
181 if let Some(cache) = &self.inner.engine_cache
183 && let Some(cached_state) = cache.get(&self.peer_addr())
184 {
185 tracing::debug!("using cached engine state");
186 let mut state = self.inner.engine_state.write().unwrap();
187 *state = Some(cached_state.clone());
188 if let Some(security) = &self.inner.config.v3_security {
190 let keys = security.derive_keys(&cached_state.engine_id);
191 let mut derived = self.inner.derived_keys.write().unwrap();
192 *derived = Some(keys);
193 }
194 return Ok(());
195 }
196
197 tracing::debug!("performing engine discovery");
199 let msg_id = self.next_request_id();
200 let discovery_msg = V3Message::discovery_request(msg_id);
201 let discovery_data = discovery_msg.encode();
202
203 self.inner
205 .transport
206 .register_request(msg_id, self.inner.config.timeout);
207 self.inner.transport.send(&discovery_data).await?;
208 let (response_data, _source) = self.inner.transport.recv(msg_id).await?;
209
210 let response = V3Message::decode(response_data)?;
212
213 let engine_state = crate::v3::parse_discovery_response(&response.security_params)?;
215 tracing::debug!(
216 snmp.engine_id = %hex::Bytes(&engine_state.engine_id),
217 snmp.engine_boots = engine_state.engine_boots,
218 snmp.engine_time = engine_state.engine_time,
219 "discovered engine"
220 );
221
222 if let Some(security) = &self.inner.config.v3_security {
224 let keys = security.derive_keys(&engine_state.engine_id);
225 let mut derived = self.inner.derived_keys.write().unwrap();
226 *derived = Some(keys);
227 }
228
229 {
231 let mut state = self.inner.engine_state.write().unwrap();
232 *state = Some(engine_state.clone());
233 }
234
235 if let Some(cache) = &self.inner.engine_cache {
237 cache.insert(self.peer_addr(), engine_state);
238 }
239
240 Ok(())
241 }
242
243 pub(super) fn build_v3_message(&self, pdu: &Pdu) -> Result<(Vec<u8>, i32)> {
245 let security = self
246 .inner
247 .config
248 .v3_security
249 .as_ref()
250 .ok_or_else(|| Error::encode(EncodeErrorKind::NoSecurityConfig))?;
251
252 let engine_state = self.inner.engine_state.read().unwrap();
253 let engine_state = engine_state
254 .as_ref()
255 .ok_or_else(|| Error::encode(EncodeErrorKind::EngineNotDiscovered))?;
256
257 let derived = self.inner.derived_keys.read().unwrap();
258
259 let security_level = security.security_level();
260 let msg_id = pdu.request_id; let scoped_pdu = ScopedPdu::new(
264 engine_state.engine_id.clone(),
265 Bytes::new(), pdu.clone(),
267 );
268
269 let engine_boots = engine_state.engine_boots;
271 let engine_time = engine_state.estimated_time();
272
273 let (msg_data, priv_params) = if security_level.requires_priv() {
275 tracing::trace!("encrypting scoped PDU");
276
277 let derived_ref = derived
280 .as_ref()
281 .ok_or_else(|| Error::encode(EncodeErrorKind::KeysNotDerived))?;
282 let mut priv_key = derived_ref
283 .priv_key
284 .as_ref()
285 .ok_or_else(|| Error::encode(EncodeErrorKind::NoPrivKey))?
286 .clone();
287
288 let scoped_pdu_bytes = scoped_pdu.encode_to_bytes();
290
291 let (ciphertext, salt) = priv_key.encrypt(
293 &scoped_pdu_bytes,
294 engine_boots,
295 engine_time,
296 Some(&self.inner.salt_counter),
297 )?;
298
299 tracing::trace!(
300 plaintext_len = scoped_pdu_bytes.len(),
301 ciphertext_len = ciphertext.len(),
302 "encrypted scoped PDU"
303 );
304
305 (crate::message::V3MessageData::Encrypted(ciphertext), salt)
306 } else {
307 (
308 crate::message::V3MessageData::Plaintext(scoped_pdu),
309 Bytes::new(),
310 )
311 };
312
313 let mac_len = if security_level.requires_auth() {
315 derived
316 .as_ref()
317 .and_then(|d| d.auth_key.as_ref())
318 .map(|k| k.mac_len())
319 .unwrap_or(12)
320 } else {
321 0
322 };
323
324 let mut usm_params = UsmSecurityParams::new(
325 engine_state.engine_id.clone(),
326 engine_boots,
327 engine_time,
328 security.username.clone(),
329 );
330
331 if security_level.requires_auth() {
332 usm_params = usm_params.with_auth_placeholder(mac_len);
333 }
334
335 if security_level.requires_priv() {
336 usm_params = usm_params.with_priv_params(priv_params);
337 }
338
339 let usm_encoded = usm_params.encode();
340
341 let msg_flags = MsgFlags::new(security_level, true); let global_data = MsgGlobalData::new(msg_id, 65507, msg_flags);
344
345 let msg = match msg_data {
347 crate::message::V3MessageData::Plaintext(scoped_pdu) => {
348 V3Message::new(global_data, usm_encoded, scoped_pdu)
349 }
350 crate::message::V3MessageData::Encrypted(ciphertext) => {
351 V3Message::new_encrypted(global_data, usm_encoded, ciphertext)
352 }
353 };
354
355 let mut encoded = msg.encode().to_vec();
356
357 if security_level.requires_auth() {
359 tracing::trace!("applying HMAC authentication");
360
361 let auth_key = derived
362 .as_ref()
363 .and_then(|d| d.auth_key.as_ref())
364 .ok_or_else(|| Error::encode(EncodeErrorKind::MissingAuthKey))?;
365
366 if let Some((offset, len)) = UsmSecurityParams::find_auth_params_offset(&encoded) {
368 authenticate_message(auth_key, &mut encoded, offset, len);
369 tracing::trace!(
370 auth_params_offset = offset,
371 auth_params_len = len,
372 "applied HMAC authentication"
373 );
374 } else {
375 return Err(Error::encode(EncodeErrorKind::MissingAuthParams));
376 }
377 }
378
379 Ok((encoded, msg_id))
380 }
381
382 #[instrument(
384 level = "debug",
385 skip(self, pdu),
386 fields(
387 snmp.target = %self.peer_addr(),
388 snmp.request_id = pdu.request_id,
389 snmp.security_level = ?self.inner.config.v3_security.as_ref().map(|s| s.security_level()),
390 snmp.attempt = tracing::field::Empty,
391 snmp.elapsed_ms = tracing::field::Empty,
392 )
393 )]
394 pub(super) async fn send_v3_and_recv(&self, pdu: Pdu) -> Result<Pdu> {
395 let start = Instant::now();
396
397 self.ensure_engine_discovered().await?;
399
400 let security = self
401 .inner
402 .config
403 .v3_security
404 .as_ref()
405 .ok_or_else(|| Error::encode(EncodeErrorKind::NoSecurityConfig))?;
406 let security_level = security.security_level();
407
408 let mut last_error = None;
409 let max_attempts = if self.inner.transport.is_reliable() {
410 0
411 } else {
412 self.inner.config.retry.max_attempts
413 };
414
415 for attempt in 0..=max_attempts {
416 Span::current().record("snmp.attempt", attempt);
417 if attempt > 0 {
418 tracing::debug!("retrying V3 request");
419 }
420
421 let (data, msg_id) = self.build_v3_message(&pdu)?;
423
424 tracing::debug!(
425 snmp.pdu_type = ?pdu.pdu_type,
426 snmp.varbind_count = pdu.varbinds.len(),
427 snmp.msg_id = msg_id,
428 "sending V3 {} request",
429 pdu.pdu_type
430 );
431 tracing::trace!(snmp.bytes = data.len(), "sending V3 request");
432
433 self.inner
435 .transport
436 .register_request(msg_id, self.inner.config.timeout);
437
438 self.inner.transport.send(&data).await?;
440
441 match self.inner.transport.recv(msg_id).await {
443 Ok((response_data, _source)) => {
444 tracing::trace!(snmp.bytes = response_data.len(), "received V3 response");
445
446 if security_level.requires_auth() {
448 tracing::trace!("verifying HMAC authentication on response");
449
450 let derived = self.inner.derived_keys.read().unwrap();
451 let auth_key = derived
452 .as_ref()
453 .and_then(|d| d.auth_key.as_ref())
454 .ok_or_else(|| {
455 Error::auth(Some(self.peer_addr()), AuthErrorKind::NoAuthKey)
456 })?;
457
458 if let Some((offset, len)) =
459 UsmSecurityParams::find_auth_params_offset(&response_data)
460 {
461 if !verify_message(auth_key, &response_data, offset, len) {
462 tracing::trace!("HMAC verification failed");
463 return Err(Error::auth(
464 Some(self.peer_addr()),
465 AuthErrorKind::HmacMismatch,
466 ));
467 }
468 tracing::trace!(
469 auth_params_offset = offset,
470 auth_params_len = len,
471 "HMAC verification successful"
472 );
473 } else {
474 return Err(Error::auth(
475 Some(self.peer_addr()),
476 AuthErrorKind::AuthParamsNotFound,
477 ));
478 }
479 }
480
481 let response = V3Message::decode(response_data.clone())?;
483
484 if let Some(scoped_pdu) = response.scoped_pdu()
486 && scoped_pdu.pdu.pdu_type == PduType::Report
487 {
488 if is_not_in_time_window_report(&scoped_pdu.pdu) {
490 tracing::debug!("not in time window, resyncing");
491 let usm_params =
493 UsmSecurityParams::decode(response.security_params.clone())?;
494 {
495 let mut state = self.inner.engine_state.write().unwrap();
496 if let Some(ref mut s) = *state {
497 s.update_time(usm_params.engine_boots, usm_params.engine_time);
498 }
499 }
500 last_error = Some(Error::NotInTimeWindow {
501 target: Some(self.peer_addr()),
502 });
503 if attempt < max_attempts {
505 let delay = self.inner.config.retry.compute_delay(attempt);
506 if !delay.is_zero() {
507 tracing::debug!(
508 delay_ms = delay.as_millis() as u64,
509 "backing off"
510 );
511 tokio::time::sleep(delay).await;
512 }
513 }
514 continue;
515 }
516
517 if is_unknown_engine_id_report(&scoped_pdu.pdu) {
519 return Err(Error::UnknownEngineId {
520 target: Some(self.peer_addr()),
521 });
522 }
523
524 return Err(Error::Snmp {
526 target: Some(self.peer_addr()),
527 status: ErrorStatus::GenErr,
528 index: 0,
529 oid: scoped_pdu.pdu.varbinds.first().map(|vb| vb.oid.clone()),
530 });
531 }
532
533 let response_security_params = response.security_params.clone();
535
536 let response_pdu = if security_level.requires_priv() {
538 match response.data {
539 crate::message::V3MessageData::Encrypted(ciphertext) => {
540 tracing::trace!(
541 ciphertext_len = ciphertext.len(),
542 "decrypting response"
543 );
544
545 let derived = self.inner.derived_keys.read().unwrap();
547 let priv_key = derived
548 .as_ref()
549 .and_then(|d| d.priv_key.as_ref())
550 .ok_or_else(|| {
551 Error::decrypt(
552 Some(self.peer_addr()),
553 CryptoErrorKind::NoPrivKey,
554 )
555 })?;
556
557 let usm_params =
558 UsmSecurityParams::decode(response_security_params.clone())?;
559 let plaintext = priv_key.decrypt(
560 &ciphertext,
561 usm_params.engine_boots,
562 usm_params.engine_time,
563 &usm_params.priv_params,
564 )?;
565
566 tracing::trace!(
567 plaintext_len = plaintext.len(),
568 "decrypted response"
569 );
570
571 let mut decoder = Decoder::new(plaintext);
573 let scoped_pdu = ScopedPdu::decode(&mut decoder)?;
574 scoped_pdu.pdu
575 }
576 crate::message::V3MessageData::Plaintext(scoped_pdu) => scoped_pdu.pdu,
577 }
578 } else {
579 response
580 .into_pdu()
581 .ok_or_else(|| Error::decode(0, DecodeErrorKind::MissingPdu))?
582 };
583
584 if response_pdu.request_id != pdu.request_id {
586 return Err(Error::RequestIdMismatch {
587 expected: pdu.request_id,
588 actual: response_pdu.request_id,
589 });
590 }
591
592 tracing::debug!(
593 snmp.pdu_type = ?response_pdu.pdu_type,
594 snmp.varbind_count = response_pdu.varbinds.len(),
595 snmp.error_status = response_pdu.error_status,
596 snmp.error_index = response_pdu.error_index,
597 "received V3 {} response",
598 response_pdu.pdu_type
599 );
600
601 {
603 let usm_params = UsmSecurityParams::decode(response_security_params)?;
604 let mut state = self.inner.engine_state.write().unwrap();
605 if let Some(ref mut s) = *state {
606 s.update_time(usm_params.engine_boots, usm_params.engine_time);
607 }
608 }
609
610 if response_pdu.is_error() {
612 let status = response_pdu.error_status_enum();
613 let oid = (response_pdu.error_index as usize)
615 .checked_sub(1)
616 .and_then(|idx| response_pdu.varbinds.get(idx))
617 .map(|vb| vb.oid.clone());
618
619 Span::current()
620 .record("snmp.elapsed_ms", start.elapsed().as_millis() as u64);
621 return Err(Error::Snmp {
622 target: Some(self.peer_addr()),
623 status,
624 index: response_pdu.error_index as u32,
625 oid,
626 });
627 }
628
629 Span::current().record("snmp.elapsed_ms", start.elapsed().as_millis() as u64);
630 return Ok(response_pdu);
631 }
632 Err(e @ Error::Timeout { .. }) => {
633 last_error = Some(e);
634 if attempt < max_attempts {
636 let delay = self.inner.config.retry.compute_delay(attempt);
637 if !delay.is_zero() {
638 tracing::debug!(delay_ms = delay.as_millis() as u64, "backing off");
639 tokio::time::sleep(delay).await;
640 }
641 }
642 continue;
643 }
644 Err(e) => {
645 Span::current().record("snmp.elapsed_ms", start.elapsed().as_millis() as u64);
646 return Err(e);
647 }
648 }
649 }
650
651 Span::current().record("snmp.elapsed_ms", start.elapsed().as_millis() as u64);
653 Err(last_error.unwrap_or(Error::Timeout {
654 target: Some(self.peer_addr()),
655 elapsed: start.elapsed(),
656 request_id: pdu.request_id,
657 retries: max_attempts,
658 }))
659 }
660}