async_snmp/client/
v3.rs

1//! SNMPv3-specific client functionality.
2//!
3//! This module contains V3 security configuration, key derivation, engine discovery,
4//! and V3 message building/handling.
5
6use crate::ber::Decoder;
7use crate::error::{
8    AuthErrorKind, CryptoErrorKind, DecodeErrorKind, EncodeErrorKind, Error, ErrorStatus, Result,
9};
10use crate::format::hex;
11use crate::message::{MsgFlags, MsgGlobalData, ScopedPdu, SecurityLevel, V3Message};
12use crate::pdu::{Pdu, PduType};
13use crate::transport::Transport;
14use crate::v3::{AuthProtocol, PrivProtocol};
15use crate::v3::{
16    LocalizedKey, PrivKey, UsmSecurityParams,
17    auth::{authenticate_message, verify_message},
18    is_not_in_time_window_report, is_unknown_engine_id_report,
19};
20use bytes::Bytes;
21use std::time::Instant;
22use tracing::{Span, instrument};
23
24use super::Client;
25
26/// SNMPv3 security configuration.
27///
28/// Stores the credentials needed for authenticated and/or encrypted communication.
29/// Keys are derived when the engine ID is discovered.
30///
31/// # Master Key Caching
32///
33/// When polling many engines with shared credentials, use
34/// [`MasterKeys`](crate::MasterKeys) to cache the expensive password-to-key
35/// derivation. When `master_keys` is set, passwords are ignored and keys are
36/// derived from the cached master keys.
37#[derive(Clone)]
38pub struct V3SecurityConfig {
39    /// Username for USM authentication
40    pub username: Bytes,
41    /// Authentication protocol and password
42    pub auth: Option<(AuthProtocol, Vec<u8>)>,
43    /// Privacy protocol and password
44    pub privacy: Option<(PrivProtocol, Vec<u8>)>,
45    /// Pre-computed master keys for efficient key derivation
46    pub master_keys: Option<crate::v3::MasterKeys>,
47}
48
49impl V3SecurityConfig {
50    /// Create a new V3 security config with just a username (noAuthNoPriv).
51    pub fn new(username: impl Into<Bytes>) -> Self {
52        Self {
53            username: username.into(),
54            auth: None,
55            privacy: None,
56            master_keys: None,
57        }
58    }
59
60    /// Add authentication (authNoPriv or authPriv).
61    pub fn auth(mut self, protocol: AuthProtocol, password: impl Into<Vec<u8>>) -> Self {
62        self.auth = Some((protocol, password.into()));
63        self
64    }
65
66    /// Add privacy/encryption (authPriv).
67    pub fn privacy(mut self, protocol: PrivProtocol, password: impl Into<Vec<u8>>) -> Self {
68        self.privacy = Some((protocol, password.into()));
69        self
70    }
71
72    /// Use pre-computed master keys for efficient key derivation.
73    ///
74    /// When set, passwords are ignored and keys are derived from the cached
75    /// master keys. This avoids the expensive ~850μs password expansion for
76    /// each engine.
77    pub fn with_master_keys(mut self, master_keys: crate::v3::MasterKeys) -> Self {
78        self.master_keys = Some(master_keys);
79        self
80    }
81
82    /// Get the security level based on configured auth/privacy.
83    pub fn security_level(&self) -> SecurityLevel {
84        // Check master_keys first, then fall back to auth/privacy
85        if let Some(ref master_keys) = self.master_keys {
86            if master_keys.priv_protocol().is_some() {
87                return SecurityLevel::AuthPriv;
88            }
89            return SecurityLevel::AuthNoPriv;
90        }
91
92        match (&self.auth, &self.privacy) {
93            (None, _) => SecurityLevel::NoAuthNoPriv,
94            (Some(_), None) => SecurityLevel::AuthNoPriv,
95            (Some(_), Some(_)) => SecurityLevel::AuthPriv,
96        }
97    }
98
99    /// Derive localized keys for a specific engine ID.
100    ///
101    /// If master keys are configured, uses the cached master keys for efficient
102    /// localization (~1μs). Otherwise, performs full password-to-key derivation
103    /// (~850μs for SHA-256).
104    pub fn derive_keys(&self, engine_id: &[u8]) -> V3DerivedKeys {
105        // Use master keys if available (efficient path)
106        if let Some(ref master_keys) = self.master_keys {
107            tracing::trace!(
108                engine_id_len = engine_id.len(),
109                auth_protocol = ?master_keys.auth_protocol(),
110                priv_protocol = ?master_keys.priv_protocol(),
111                "localizing from cached master keys"
112            );
113            let (auth_key, priv_key) = master_keys.localize(engine_id);
114            tracing::trace!("key localization complete");
115            return V3DerivedKeys {
116                auth_key: Some(auth_key),
117                priv_key,
118            };
119        }
120
121        // Fall back to password-based derivation
122        tracing::trace!(
123            engine_id_len = engine_id.len(),
124            has_auth = self.auth.is_some(),
125            has_priv = self.privacy.is_some(),
126            "deriving localized keys from passwords"
127        );
128
129        let auth_key = self.auth.as_ref().map(|(protocol, password)| {
130            tracing::trace!(auth_protocol = ?protocol, "deriving auth key");
131            LocalizedKey::from_password(*protocol, password, engine_id)
132        });
133
134        let priv_key = match (&self.auth, &self.privacy) {
135            (Some((auth_protocol, _)), Some((priv_protocol, priv_password))) => {
136                tracing::trace!(priv_protocol = ?priv_protocol, "deriving privacy key");
137                Some(PrivKey::from_password(
138                    *auth_protocol,
139                    *priv_protocol,
140                    priv_password,
141                    engine_id,
142                ))
143            }
144            _ => None,
145        };
146
147        tracing::trace!("key derivation complete");
148        V3DerivedKeys { auth_key, priv_key }
149    }
150}
151
152impl std::fmt::Debug for V3SecurityConfig {
153    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154        f.debug_struct("V3SecurityConfig")
155            .field("username", &String::from_utf8_lossy(&self.username))
156            .field("auth", &self.auth.as_ref().map(|(p, _)| p))
157            .field("privacy", &self.privacy.as_ref().map(|(p, _)| p))
158            .finish()
159    }
160}
161
162/// Derived keys for a specific engine ID.
163pub struct V3DerivedKeys {
164    pub auth_key: Option<LocalizedKey>,
165    pub priv_key: Option<PrivKey>,
166}
167
168// V3-specific Client implementation
169impl<T: Transport> Client<T> {
170    /// Ensure engine ID is discovered for V3 operations.
171    #[instrument(level = "debug", skip(self), fields(snmp.target = %self.peer_addr()))]
172    pub(super) async fn ensure_engine_discovered(&self) -> Result<()> {
173        // Check if already discovered
174        {
175            let state = self.inner.engine_state.read().unwrap();
176            if state.is_some() {
177                return Ok(());
178            }
179        }
180
181        // Check shared cache first
182        if let Some(cache) = &self.inner.engine_cache
183            && let Some(cached_state) = cache.get(&self.peer_addr())
184        {
185            tracing::debug!("using cached engine state");
186            let mut state = self.inner.engine_state.write().unwrap();
187            *state = Some(cached_state.clone());
188            // Derive keys for this engine
189            if let Some(security) = &self.inner.config.v3_security {
190                let keys = security.derive_keys(&cached_state.engine_id);
191                let mut derived = self.inner.derived_keys.write().unwrap();
192                *derived = Some(keys);
193            }
194            return Ok(());
195        }
196
197        // Perform discovery
198        tracing::debug!("performing engine discovery");
199        let msg_id = self.next_request_id();
200        let discovery_msg = V3Message::discovery_request(msg_id);
201        let discovery_data = discovery_msg.encode();
202
203        // Register request and send discovery
204        self.inner
205            .transport
206            .register_request(msg_id, self.inner.config.timeout);
207        self.inner.transport.send(&discovery_data).await?;
208        let (response_data, _source) = self.inner.transport.recv(msg_id).await?;
209
210        // Parse response
211        let response = V3Message::decode(response_data)?;
212
213        let reported_msg_max_size = response.global_data.msg_max_size as u32;
214        let session_max = self.inner.transport.max_message_size();
215        let engine_state = crate::v3::parse_discovery_response_with_limits(
216            &response.security_params,
217            reported_msg_max_size,
218            session_max,
219        )?;
220        tracing::debug!(
221            snmp.engine_id = %hex::Bytes(&engine_state.engine_id),
222            snmp.engine_boots = engine_state.engine_boots,
223            snmp.engine_time = engine_state.engine_time,
224            snmp.msg_max_size = engine_state.msg_max_size,
225            "discovered engine"
226        );
227
228        // Derive keys for this engine
229        if let Some(security) = &self.inner.config.v3_security {
230            let keys = security.derive_keys(&engine_state.engine_id);
231            let mut derived = self.inner.derived_keys.write().unwrap();
232            *derived = Some(keys);
233        }
234
235        // Store in local cache
236        {
237            let mut state = self.inner.engine_state.write().unwrap();
238            *state = Some(engine_state.clone());
239        }
240
241        // Store in shared cache if present
242        if let Some(cache) = &self.inner.engine_cache {
243            cache.insert(self.peer_addr(), engine_state);
244        }
245
246        Ok(())
247    }
248
249    /// Build and encode a V3 message with authentication and/or encryption.
250    ///
251    /// The `msg_id` parameter is separate from `pdu.request_id` per RFC 3412
252    /// Section 6.2: retransmissions SHOULD use a new msgID for each attempt.
253    pub(super) fn build_v3_message(&self, pdu: &Pdu, msg_id: i32) -> Result<Vec<u8>> {
254        let security = self
255            .inner
256            .config
257            .v3_security
258            .as_ref()
259            .ok_or_else(|| Error::encode(EncodeErrorKind::NoSecurityConfig))?;
260
261        let engine_state = self.inner.engine_state.read().unwrap();
262        let engine_state = engine_state
263            .as_ref()
264            .ok_or_else(|| Error::encode(EncodeErrorKind::EngineNotDiscovered))?;
265
266        let derived = self.inner.derived_keys.read().unwrap();
267
268        let security_level = security.security_level();
269
270        // Build scoped PDU
271        let scoped_pdu = ScopedPdu::new(
272            engine_state.engine_id.clone(),
273            Bytes::new(), // empty context name
274            pdu.clone(),
275        );
276
277        // Get current engine time estimate
278        let engine_boots = engine_state.engine_boots;
279        let engine_time = engine_state.estimated_time();
280
281        // Handle encryption if needed
282        let (msg_data, priv_params) = if security_level.requires_priv() {
283            tracing::trace!("encrypting scoped PDU");
284
285            // Get mutable priv_key - we need interior mutability for salt counter
286            // Since PrivKey uses internal counter, we need to clone and use
287            let derived_ref = derived
288                .as_ref()
289                .ok_or_else(|| Error::encode(EncodeErrorKind::KeysNotDerived))?;
290            let mut priv_key = derived_ref
291                .priv_key
292                .as_ref()
293                .ok_or_else(|| Error::encode(EncodeErrorKind::NoPrivKey))?
294                .clone();
295
296            // Encode scoped PDU
297            let scoped_pdu_bytes = scoped_pdu.encode_to_bytes();
298
299            // Encrypt
300            let (ciphertext, salt) = priv_key.encrypt(
301                &scoped_pdu_bytes,
302                engine_boots,
303                engine_time,
304                Some(&self.inner.salt_counter),
305            )?;
306
307            tracing::trace!(
308                plaintext_len = scoped_pdu_bytes.len(),
309                ciphertext_len = ciphertext.len(),
310                "encrypted scoped PDU"
311            );
312
313            (crate::message::V3MessageData::Encrypted(ciphertext), salt)
314        } else {
315            (
316                crate::message::V3MessageData::Plaintext(scoped_pdu),
317                Bytes::new(),
318            )
319        };
320
321        // Build USM security parameters
322        let mac_len = if security_level.requires_auth() {
323            derived
324                .as_ref()
325                .and_then(|d| d.auth_key.as_ref())
326                .map(|k| k.mac_len())
327                .unwrap_or(12)
328        } else {
329            0
330        };
331
332        let mut usm_params = UsmSecurityParams::new(
333            engine_state.engine_id.clone(),
334            engine_boots,
335            engine_time,
336            security.username.clone(),
337        );
338
339        if security_level.requires_auth() {
340            usm_params = usm_params.with_auth_placeholder(mac_len);
341        }
342
343        if security_level.requires_priv() {
344            usm_params = usm_params.with_priv_params(priv_params);
345        }
346
347        let usm_encoded = usm_params.encode();
348
349        // Build global data
350        let msg_flags = MsgFlags::new(security_level, true); // reportable=true for requests
351        let global_data = MsgGlobalData::new(msg_id, 65507, msg_flags);
352
353        // Build complete message
354        let msg = match msg_data {
355            crate::message::V3MessageData::Plaintext(scoped_pdu) => {
356                V3Message::new(global_data, usm_encoded, scoped_pdu)
357            }
358            crate::message::V3MessageData::Encrypted(ciphertext) => {
359                V3Message::new_encrypted(global_data, usm_encoded, ciphertext)
360            }
361        };
362
363        let mut encoded = msg.encode().to_vec();
364
365        // Apply authentication if needed
366        if security_level.requires_auth() {
367            tracing::trace!("applying HMAC authentication");
368
369            let auth_key = derived
370                .as_ref()
371                .and_then(|d| d.auth_key.as_ref())
372                .ok_or_else(|| Error::encode(EncodeErrorKind::MissingAuthKey))?;
373
374            // Find auth params position and apply HMAC
375            if let Some((offset, len)) = UsmSecurityParams::find_auth_params_offset(&encoded) {
376                authenticate_message(auth_key, &mut encoded, offset, len);
377                tracing::trace!(
378                    auth_params_offset = offset,
379                    auth_params_len = len,
380                    "applied HMAC authentication"
381                );
382            } else {
383                return Err(Error::encode(EncodeErrorKind::MissingAuthParams));
384            }
385        }
386
387        Ok(encoded)
388    }
389
390    /// Send a V3 request and handle the response.
391    #[instrument(
392        level = "debug",
393        skip(self, pdu),
394        fields(
395            snmp.target = %self.peer_addr(),
396            snmp.request_id = pdu.request_id,
397            snmp.security_level = ?self.inner.config.v3_security.as_ref().map(|s| s.security_level()),
398            snmp.attempt = tracing::field::Empty,
399            snmp.elapsed_ms = tracing::field::Empty,
400        )
401    )]
402    pub(super) async fn send_v3_and_recv(&self, pdu: Pdu) -> Result<Pdu> {
403        let start = Instant::now();
404
405        // Ensure engine is discovered first
406        self.ensure_engine_discovered().await?;
407
408        let security = self
409            .inner
410            .config
411            .v3_security
412            .as_ref()
413            .ok_or_else(|| Error::encode(EncodeErrorKind::NoSecurityConfig))?;
414        let security_level = security.security_level();
415
416        let mut last_error = None;
417        let max_attempts = if self.inner.transport.is_reliable() {
418            0
419        } else {
420            self.inner.config.retry.max_attempts
421        };
422
423        for attempt in 0..=max_attempts {
424            Span::current().record("snmp.attempt", attempt);
425            if attempt > 0 {
426                tracing::debug!("retrying V3 request");
427            }
428
429            // RFC 3412 Section 6.2: use fresh msgID for each transmission attempt
430            let msg_id = self.next_request_id();
431            let data = self.build_v3_message(&pdu, msg_id)?;
432
433            tracing::debug!(
434                snmp.pdu_type = ?pdu.pdu_type,
435                snmp.varbind_count = pdu.varbinds.len(),
436                snmp.msg_id = msg_id,
437                "sending V3 {} request",
438                pdu.pdu_type
439            );
440            tracing::trace!(snmp.bytes = data.len(), "sending V3 request");
441
442            // Register (or re-register) with fresh deadline before sending
443            self.inner
444                .transport
445                .register_request(msg_id, self.inner.config.timeout);
446
447            // Send request
448            self.inner.transport.send(&data).await?;
449
450            // Wait for response (deadline was set by register_request)
451            match self.inner.transport.recv(msg_id).await {
452                Ok((response_data, _source)) => {
453                    tracing::trace!(snmp.bytes = response_data.len(), "received V3 response");
454
455                    // Verify authentication if required
456                    if security_level.requires_auth() {
457                        tracing::trace!("verifying HMAC authentication on response");
458
459                        let derived = self.inner.derived_keys.read().unwrap();
460                        let auth_key = derived
461                            .as_ref()
462                            .and_then(|d| d.auth_key.as_ref())
463                            .ok_or_else(|| {
464                                Error::auth(Some(self.peer_addr()), AuthErrorKind::NoAuthKey)
465                            })?;
466
467                        if let Some((offset, len)) =
468                            UsmSecurityParams::find_auth_params_offset(&response_data)
469                        {
470                            if !verify_message(auth_key, &response_data, offset, len) {
471                                tracing::trace!("HMAC verification failed");
472                                return Err(Error::auth(
473                                    Some(self.peer_addr()),
474                                    AuthErrorKind::HmacMismatch,
475                                ));
476                            }
477                            tracing::trace!(
478                                auth_params_offset = offset,
479                                auth_params_len = len,
480                                "HMAC verification successful"
481                            );
482                        } else {
483                            return Err(Error::auth(
484                                Some(self.peer_addr()),
485                                AuthErrorKind::AuthParamsNotFound,
486                            ));
487                        }
488                    }
489
490                    // Decode response
491                    let response = V3Message::decode(response_data.clone())?;
492
493                    // Check for Report PDU (error response)
494                    if let Some(scoped_pdu) = response.scoped_pdu()
495                        && scoped_pdu.pdu.pdu_type == PduType::Report
496                    {
497                        // Check for time window error - resync and retry
498                        if is_not_in_time_window_report(&scoped_pdu.pdu) {
499                            tracing::debug!("not in time window, resyncing");
500                            // Update engine time from response
501                            let usm_params =
502                                UsmSecurityParams::decode(response.security_params.clone())?;
503                            {
504                                let mut state = self.inner.engine_state.write().unwrap();
505                                if let Some(ref mut s) = *state {
506                                    s.update_time(usm_params.engine_boots, usm_params.engine_time);
507                                }
508                            }
509                            last_error = Some(Error::NotInTimeWindow {
510                                target: Some(self.peer_addr()),
511                            });
512                            // Apply backoff delay before retry (if not last attempt)
513                            if attempt < max_attempts {
514                                let delay = self.inner.config.retry.compute_delay(attempt);
515                                if !delay.is_zero() {
516                                    tracing::debug!(
517                                        delay_ms = delay.as_millis() as u64,
518                                        "backing off"
519                                    );
520                                    tokio::time::sleep(delay).await;
521                                }
522                            }
523                            continue;
524                        }
525
526                        // Check for unknown engine ID
527                        if is_unknown_engine_id_report(&scoped_pdu.pdu) {
528                            return Err(Error::UnknownEngineId {
529                                target: Some(self.peer_addr()),
530                            });
531                        }
532
533                        // Other Report errors
534                        return Err(Error::Snmp {
535                            target: Some(self.peer_addr()),
536                            status: ErrorStatus::GenErr,
537                            index: 0,
538                            oid: scoped_pdu.pdu.varbinds.first().map(|vb| vb.oid.clone()),
539                        });
540                    }
541
542                    // Extract security params before consuming response
543                    let response_security_params = response.security_params.clone();
544
545                    // Handle encrypted response
546                    let response_pdu = if security_level.requires_priv() {
547                        match response.data {
548                            crate::message::V3MessageData::Encrypted(ciphertext) => {
549                                tracing::trace!(
550                                    ciphertext_len = ciphertext.len(),
551                                    "decrypting response"
552                                );
553
554                                // Decrypt
555                                let derived = self.inner.derived_keys.read().unwrap();
556                                let priv_key = derived
557                                    .as_ref()
558                                    .and_then(|d| d.priv_key.as_ref())
559                                    .ok_or_else(|| {
560                                    Error::decrypt(
561                                        Some(self.peer_addr()),
562                                        CryptoErrorKind::NoPrivKey,
563                                    )
564                                })?;
565
566                                let usm_params =
567                                    UsmSecurityParams::decode(response_security_params.clone())?;
568                                let plaintext = priv_key.decrypt(
569                                    &ciphertext,
570                                    usm_params.engine_boots,
571                                    usm_params.engine_time,
572                                    &usm_params.priv_params,
573                                )?;
574
575                                tracing::trace!(
576                                    plaintext_len = plaintext.len(),
577                                    "decrypted response"
578                                );
579
580                                // Decode scoped PDU
581                                let mut decoder = Decoder::new(plaintext);
582                                let scoped_pdu = ScopedPdu::decode(&mut decoder)?;
583                                scoped_pdu.pdu
584                            }
585                            crate::message::V3MessageData::Plaintext(scoped_pdu) => scoped_pdu.pdu,
586                        }
587                    } else {
588                        response
589                            .into_pdu()
590                            .ok_or_else(|| Error::decode(0, DecodeErrorKind::MissingPdu))?
591                    };
592
593                    // Validate request ID
594                    if response_pdu.request_id != pdu.request_id {
595                        return Err(Error::RequestIdMismatch {
596                            expected: pdu.request_id,
597                            actual: response_pdu.request_id,
598                        });
599                    }
600
601                    tracing::debug!(
602                        snmp.pdu_type = ?response_pdu.pdu_type,
603                        snmp.varbind_count = response_pdu.varbinds.len(),
604                        snmp.error_status = response_pdu.error_status,
605                        snmp.error_index = response_pdu.error_index,
606                        "received V3 {} response",
607                        response_pdu.pdu_type
608                    );
609
610                    // Update engine time from successful response
611                    {
612                        let usm_params = UsmSecurityParams::decode(response_security_params)?;
613                        let mut state = self.inner.engine_state.write().unwrap();
614                        if let Some(ref mut s) = *state {
615                            s.update_time(usm_params.engine_boots, usm_params.engine_time);
616                        }
617                    }
618
619                    // Check for SNMP error
620                    if response_pdu.is_error() {
621                        let status = response_pdu.error_status_enum();
622                        // error_index is 1-based; 0 means error applies to PDU, not a specific varbind
623                        let oid = (response_pdu.error_index as usize)
624                            .checked_sub(1)
625                            .and_then(|idx| response_pdu.varbinds.get(idx))
626                            .map(|vb| vb.oid.clone());
627
628                        Span::current()
629                            .record("snmp.elapsed_ms", start.elapsed().as_millis() as u64);
630                        return Err(Error::Snmp {
631                            target: Some(self.peer_addr()),
632                            status,
633                            index: response_pdu.error_index as u32,
634                            oid,
635                        });
636                    }
637
638                    Span::current().record("snmp.elapsed_ms", start.elapsed().as_millis() as u64);
639                    return Ok(response_pdu);
640                }
641                Err(e @ Error::Timeout { .. }) => {
642                    last_error = Some(e);
643                    // Apply backoff delay before next retry (if not last attempt)
644                    if attempt < max_attempts {
645                        let delay = self.inner.config.retry.compute_delay(attempt);
646                        if !delay.is_zero() {
647                            tracing::debug!(delay_ms = delay.as_millis() as u64, "backing off");
648                            tokio::time::sleep(delay).await;
649                        }
650                    }
651                    continue;
652                }
653                Err(e) => {
654                    Span::current().record("snmp.elapsed_ms", start.elapsed().as_millis() as u64);
655                    return Err(e);
656                }
657            }
658        }
659
660        // All retries exhausted
661        Span::current().record("snmp.elapsed_ms", start.elapsed().as_millis() as u64);
662        Err(last_error.unwrap_or(Error::Timeout {
663            target: Some(self.peer_addr()),
664            elapsed: start.elapsed(),
665            request_id: pdu.request_id,
666            retries: max_attempts,
667        }))
668    }
669}