async_snmp/client/
v3.rs

1//! SNMPv3-specific client functionality.
2//!
3//! This module contains V3 security configuration, key derivation, engine discovery,
4//! and V3 message building/handling.
5
6use crate::ber::Decoder;
7use crate::error::{
8    AuthErrorKind, CryptoErrorKind, DecodeErrorKind, EncodeErrorKind, Error, ErrorStatus, Result,
9};
10use crate::format::hex;
11use crate::message::{MsgFlags, MsgGlobalData, ScopedPdu, SecurityLevel, V3Message};
12use crate::pdu::{Pdu, PduType};
13use crate::transport::Transport;
14use crate::v3::{AuthProtocol, PrivProtocol};
15use crate::v3::{
16    LocalizedKey, PrivKey, UsmSecurityParams,
17    auth::{authenticate_message, verify_message},
18    is_not_in_time_window_report, is_unknown_engine_id_report,
19};
20use bytes::Bytes;
21use std::time::Instant;
22use tracing::{Span, instrument};
23
24use super::Client;
25
26/// SNMPv3 security configuration.
27///
28/// Stores the credentials needed for authenticated and/or encrypted communication.
29/// Keys are derived when the engine ID is discovered.
30///
31/// # Master Key Caching
32///
33/// For high-throughput polling of many engines with shared credentials, use
34/// [`MasterKeys`](crate::MasterKeys) to cache the expensive password-to-key
35/// derivation. When `master_keys` is set, passwords are ignored and keys are
36/// derived from the cached master keys.
37#[derive(Clone)]
38pub struct V3SecurityConfig {
39    /// Username for USM authentication
40    pub username: Bytes,
41    /// Authentication protocol and password
42    pub auth: Option<(AuthProtocol, Vec<u8>)>,
43    /// Privacy protocol and password
44    pub privacy: Option<(PrivProtocol, Vec<u8>)>,
45    /// Pre-computed master keys for efficient key derivation
46    pub master_keys: Option<crate::v3::MasterKeys>,
47}
48
49impl V3SecurityConfig {
50    /// Create a new V3 security config with just a username (noAuthNoPriv).
51    pub fn new(username: impl Into<Bytes>) -> Self {
52        Self {
53            username: username.into(),
54            auth: None,
55            privacy: None,
56            master_keys: None,
57        }
58    }
59
60    /// Add authentication (authNoPriv or authPriv).
61    pub fn auth(mut self, protocol: AuthProtocol, password: impl Into<Vec<u8>>) -> Self {
62        self.auth = Some((protocol, password.into()));
63        self
64    }
65
66    /// Add privacy/encryption (authPriv).
67    pub fn privacy(mut self, protocol: PrivProtocol, password: impl Into<Vec<u8>>) -> Self {
68        self.privacy = Some((protocol, password.into()));
69        self
70    }
71
72    /// Use pre-computed master keys for efficient key derivation.
73    ///
74    /// When set, passwords are ignored and keys are derived from the cached
75    /// master keys. This avoids the expensive ~850μs password expansion for
76    /// each engine.
77    pub fn with_master_keys(mut self, master_keys: crate::v3::MasterKeys) -> Self {
78        self.master_keys = Some(master_keys);
79        self
80    }
81
82    /// Get the security level based on configured auth/privacy.
83    pub fn security_level(&self) -> SecurityLevel {
84        // Check master_keys first, then fall back to auth/privacy
85        if let Some(ref master_keys) = self.master_keys {
86            if master_keys.priv_protocol().is_some() {
87                return SecurityLevel::AuthPriv;
88            }
89            return SecurityLevel::AuthNoPriv;
90        }
91
92        match (&self.auth, &self.privacy) {
93            (None, _) => SecurityLevel::NoAuthNoPriv,
94            (Some(_), None) => SecurityLevel::AuthNoPriv,
95            (Some(_), Some(_)) => SecurityLevel::AuthPriv,
96        }
97    }
98
99    /// Derive localized keys for a specific engine ID.
100    ///
101    /// If master keys are configured, uses the cached master keys for efficient
102    /// localization (~1μs). Otherwise, performs full password-to-key derivation
103    /// (~850μs for SHA-256).
104    pub fn derive_keys(&self, engine_id: &[u8]) -> V3DerivedKeys {
105        // Use master keys if available (efficient path)
106        if let Some(ref master_keys) = self.master_keys {
107            tracing::trace!(
108                engine_id_len = engine_id.len(),
109                auth_protocol = ?master_keys.auth_protocol(),
110                priv_protocol = ?master_keys.priv_protocol(),
111                "localizing from cached master keys"
112            );
113            let (auth_key, priv_key) = master_keys.localize(engine_id);
114            tracing::trace!("key localization complete");
115            return V3DerivedKeys {
116                auth_key: Some(auth_key),
117                priv_key,
118            };
119        }
120
121        // Fall back to password-based derivation
122        tracing::trace!(
123            engine_id_len = engine_id.len(),
124            has_auth = self.auth.is_some(),
125            has_priv = self.privacy.is_some(),
126            "deriving localized keys from passwords"
127        );
128
129        let auth_key = self.auth.as_ref().map(|(protocol, password)| {
130            tracing::trace!(auth_protocol = ?protocol, "deriving auth key");
131            LocalizedKey::from_password(*protocol, password, engine_id)
132        });
133
134        let priv_key = match (&self.auth, &self.privacy) {
135            (Some((auth_protocol, _)), Some((priv_protocol, priv_password))) => {
136                tracing::trace!(priv_protocol = ?priv_protocol, "deriving privacy key");
137                Some(PrivKey::from_password(
138                    *auth_protocol,
139                    *priv_protocol,
140                    priv_password,
141                    engine_id,
142                ))
143            }
144            _ => None,
145        };
146
147        tracing::trace!("key derivation complete");
148        V3DerivedKeys { auth_key, priv_key }
149    }
150}
151
152impl std::fmt::Debug for V3SecurityConfig {
153    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154        f.debug_struct("V3SecurityConfig")
155            .field("username", &String::from_utf8_lossy(&self.username))
156            .field("auth", &self.auth.as_ref().map(|(p, _)| p))
157            .field("privacy", &self.privacy.as_ref().map(|(p, _)| p))
158            .finish()
159    }
160}
161
162/// Derived keys for a specific engine ID.
163pub struct V3DerivedKeys {
164    pub auth_key: Option<LocalizedKey>,
165    pub priv_key: Option<PrivKey>,
166}
167
168// V3-specific Client implementation
169impl<T: Transport> Client<T> {
170    /// Ensure engine ID is discovered for V3 operations.
171    #[instrument(level = "debug", skip(self), fields(snmp.target = %self.peer_addr()))]
172    pub(super) async fn ensure_engine_discovered(&self) -> Result<()> {
173        // Check if already discovered
174        {
175            let state = self.inner.engine_state.read().unwrap();
176            if state.is_some() {
177                return Ok(());
178            }
179        }
180
181        // Check shared cache first
182        if let Some(cache) = &self.inner.engine_cache
183            && let Some(cached_state) = cache.get(&self.peer_addr())
184        {
185            tracing::debug!("using cached engine state");
186            let mut state = self.inner.engine_state.write().unwrap();
187            *state = Some(cached_state.clone());
188            // Derive keys for this engine
189            if let Some(security) = &self.inner.config.v3_security {
190                let keys = security.derive_keys(&cached_state.engine_id);
191                let mut derived = self.inner.derived_keys.write().unwrap();
192                *derived = Some(keys);
193            }
194            return Ok(());
195        }
196
197        // Perform discovery
198        tracing::debug!("performing engine discovery");
199        let msg_id = self.next_request_id();
200        let discovery_msg = V3Message::discovery_request(msg_id);
201        let discovery_data = discovery_msg.encode();
202
203        // Send discovery and wait for response
204        self.inner.transport.send(&discovery_data).await?;
205        let (response_data, _source) = self
206            .inner
207            .transport
208            .recv(msg_id, self.inner.config.timeout)
209            .await?;
210
211        // Parse response
212        let response = V3Message::decode(response_data)?;
213
214        // Extract engine state from USM params
215        let engine_state = crate::v3::parse_discovery_response(&response.security_params)?;
216        tracing::debug!(
217            snmp.engine_id = %hex::Bytes(&engine_state.engine_id),
218            snmp.engine_boots = engine_state.engine_boots,
219            snmp.engine_time = engine_state.engine_time,
220            "discovered engine"
221        );
222
223        // Derive keys for this engine
224        if let Some(security) = &self.inner.config.v3_security {
225            let keys = security.derive_keys(&engine_state.engine_id);
226            let mut derived = self.inner.derived_keys.write().unwrap();
227            *derived = Some(keys);
228        }
229
230        // Store in local cache
231        {
232            let mut state = self.inner.engine_state.write().unwrap();
233            *state = Some(engine_state.clone());
234        }
235
236        // Store in shared cache if present
237        if let Some(cache) = &self.inner.engine_cache {
238            cache.insert(self.peer_addr(), engine_state);
239        }
240
241        Ok(())
242    }
243
244    /// Build and encode a V3 message with authentication and/or encryption.
245    pub(super) fn build_v3_message(&self, pdu: &Pdu) -> Result<(Vec<u8>, i32)> {
246        let security = self
247            .inner
248            .config
249            .v3_security
250            .as_ref()
251            .ok_or_else(|| Error::encode(EncodeErrorKind::NoSecurityConfig))?;
252
253        let engine_state = self.inner.engine_state.read().unwrap();
254        let engine_state = engine_state
255            .as_ref()
256            .ok_or_else(|| Error::encode(EncodeErrorKind::EngineNotDiscovered))?;
257
258        let derived = self.inner.derived_keys.read().unwrap();
259
260        let security_level = security.security_level();
261        let msg_id = pdu.request_id; // Use request_id as msg_id for correlation
262
263        // Build scoped PDU
264        let scoped_pdu = ScopedPdu::new(
265            engine_state.engine_id.clone(),
266            Bytes::new(), // empty context name
267            pdu.clone(),
268        );
269
270        // Get current engine time estimate
271        let engine_boots = engine_state.engine_boots;
272        let engine_time = engine_state.estimated_time();
273
274        // Handle encryption if needed
275        let (msg_data, priv_params) = if security_level.requires_priv() {
276            tracing::trace!("encrypting scoped PDU");
277
278            // Get mutable priv_key - we need interior mutability for salt counter
279            // Since PrivKey uses internal counter, we need to clone and use
280            let derived_ref = derived
281                .as_ref()
282                .ok_or_else(|| Error::encode(EncodeErrorKind::KeysNotDerived))?;
283            let mut priv_key = derived_ref
284                .priv_key
285                .as_ref()
286                .ok_or_else(|| Error::encode(EncodeErrorKind::NoPrivKey))?
287                .clone();
288
289            // Encode scoped PDU
290            let scoped_pdu_bytes = scoped_pdu.encode_to_bytes();
291
292            // Encrypt
293            let (ciphertext, salt) = priv_key.encrypt(
294                &scoped_pdu_bytes,
295                engine_boots,
296                engine_time,
297                Some(&self.inner.salt_counter),
298            )?;
299
300            tracing::trace!(
301                plaintext_len = scoped_pdu_bytes.len(),
302                ciphertext_len = ciphertext.len(),
303                "encrypted scoped PDU"
304            );
305
306            (crate::message::V3MessageData::Encrypted(ciphertext), salt)
307        } else {
308            (
309                crate::message::V3MessageData::Plaintext(scoped_pdu),
310                Bytes::new(),
311            )
312        };
313
314        // Build USM security parameters
315        let mac_len = if security_level.requires_auth() {
316            derived
317                .as_ref()
318                .and_then(|d| d.auth_key.as_ref())
319                .map(|k| k.mac_len())
320                .unwrap_or(12)
321        } else {
322            0
323        };
324
325        let mut usm_params = UsmSecurityParams::new(
326            engine_state.engine_id.clone(),
327            engine_boots,
328            engine_time,
329            security.username.clone(),
330        );
331
332        if security_level.requires_auth() {
333            usm_params = usm_params.with_auth_placeholder(mac_len);
334        }
335
336        if security_level.requires_priv() {
337            usm_params = usm_params.with_priv_params(priv_params);
338        }
339
340        let usm_encoded = usm_params.encode();
341
342        // Build global data
343        let msg_flags = MsgFlags::new(security_level, true); // reportable=true for requests
344        let global_data = MsgGlobalData::new(msg_id, 65507, msg_flags);
345
346        // Build complete message
347        let msg = match msg_data {
348            crate::message::V3MessageData::Plaintext(scoped_pdu) => {
349                V3Message::new(global_data, usm_encoded, scoped_pdu)
350            }
351            crate::message::V3MessageData::Encrypted(ciphertext) => {
352                V3Message::new_encrypted(global_data, usm_encoded, ciphertext)
353            }
354        };
355
356        let mut encoded = msg.encode().to_vec();
357
358        // Apply authentication if needed
359        if security_level.requires_auth() {
360            tracing::trace!("applying HMAC authentication");
361
362            let auth_key = derived
363                .as_ref()
364                .and_then(|d| d.auth_key.as_ref())
365                .ok_or_else(|| Error::encode(EncodeErrorKind::MissingAuthKey))?;
366
367            // Find auth params position and apply HMAC
368            if let Some((offset, len)) = UsmSecurityParams::find_auth_params_offset(&encoded) {
369                authenticate_message(auth_key, &mut encoded, offset, len);
370                tracing::trace!(
371                    auth_params_offset = offset,
372                    auth_params_len = len,
373                    "applied HMAC authentication"
374                );
375            } else {
376                return Err(Error::encode(EncodeErrorKind::MissingAuthParams));
377            }
378        }
379
380        Ok((encoded, msg_id))
381    }
382
383    /// Send a V3 request and handle the response.
384    #[instrument(
385        level = "debug",
386        skip(self, pdu),
387        fields(
388            snmp.target = %self.peer_addr(),
389            snmp.request_id = pdu.request_id,
390            snmp.security_level = ?self.inner.config.v3_security.as_ref().map(|s| s.security_level()),
391            snmp.retries = tracing::field::Empty,
392            snmp.elapsed_ms = tracing::field::Empty,
393        )
394    )]
395    pub(super) async fn send_v3_and_recv(&self, pdu: Pdu) -> Result<Pdu> {
396        let start = Instant::now();
397
398        // Ensure engine is discovered first
399        self.ensure_engine_discovered().await?;
400
401        let security = self
402            .inner
403            .config
404            .v3_security
405            .as_ref()
406            .ok_or_else(|| Error::encode(EncodeErrorKind::NoSecurityConfig))?;
407        let security_level = security.security_level();
408
409        let mut last_error = None;
410        let retries = if self.inner.transport.is_stream() {
411            0
412        } else {
413            self.inner.config.retries
414        };
415
416        for attempt in 0..=retries {
417            Span::current().record("snmp.retries", attempt);
418            if attempt > 0 {
419                tracing::debug!("retrying V3 request");
420            }
421
422            // Build message (may need fresh timestamps on retry)
423            let (data, msg_id) = self.build_v3_message(&pdu)?;
424
425            tracing::debug!(
426                snmp.pdu_type = ?pdu.pdu_type,
427                snmp.varbind_count = pdu.varbinds.len(),
428                snmp.msg_id = msg_id,
429                "sending V3 {} request",
430                pdu.pdu_type
431            );
432            tracing::trace!(snmp.bytes = data.len(), "sending V3 request");
433
434            // Send request
435            self.inner.transport.send(&data).await?;
436
437            // Wait for response
438            match self
439                .inner
440                .transport
441                .recv(msg_id, self.inner.config.timeout)
442                .await
443            {
444                Ok((response_data, _source)) => {
445                    tracing::trace!(snmp.bytes = response_data.len(), "received V3 response");
446
447                    // Verify authentication if required
448                    if security_level.requires_auth() {
449                        tracing::trace!("verifying HMAC authentication on response");
450
451                        let derived = self.inner.derived_keys.read().unwrap();
452                        let auth_key = derived
453                            .as_ref()
454                            .and_then(|d| d.auth_key.as_ref())
455                            .ok_or_else(|| {
456                                Error::auth(Some(self.peer_addr()), AuthErrorKind::NoAuthKey)
457                            })?;
458
459                        if let Some((offset, len)) =
460                            UsmSecurityParams::find_auth_params_offset(&response_data)
461                        {
462                            if !verify_message(auth_key, &response_data, offset, len) {
463                                tracing::trace!("HMAC verification failed");
464                                return Err(Error::auth(
465                                    Some(self.peer_addr()),
466                                    AuthErrorKind::HmacMismatch,
467                                ));
468                            }
469                            tracing::trace!(
470                                auth_params_offset = offset,
471                                auth_params_len = len,
472                                "HMAC verification successful"
473                            );
474                        } else {
475                            return Err(Error::auth(
476                                Some(self.peer_addr()),
477                                AuthErrorKind::AuthParamsNotFound,
478                            ));
479                        }
480                    }
481
482                    // Decode response
483                    let response = V3Message::decode(response_data.clone())?;
484
485                    // Check for Report PDU (error response)
486                    if let Some(scoped_pdu) = response.scoped_pdu()
487                        && scoped_pdu.pdu.pdu_type == PduType::Report
488                    {
489                        // Check for time window error - resync and retry
490                        if is_not_in_time_window_report(&scoped_pdu.pdu) {
491                            tracing::debug!("not in time window, resyncing");
492                            // Update engine time from response
493                            let usm_params =
494                                UsmSecurityParams::decode(response.security_params.clone())?;
495                            {
496                                let mut state = self.inner.engine_state.write().unwrap();
497                                if let Some(ref mut s) = *state {
498                                    s.update_time(usm_params.engine_boots, usm_params.engine_time);
499                                }
500                            }
501                            last_error = Some(Error::NotInTimeWindow {
502                                target: Some(self.peer_addr()),
503                            });
504                            continue;
505                        }
506
507                        // Check for unknown engine ID
508                        if is_unknown_engine_id_report(&scoped_pdu.pdu) {
509                            return Err(Error::UnknownEngineId {
510                                target: Some(self.peer_addr()),
511                            });
512                        }
513
514                        // Other Report errors
515                        return Err(Error::Snmp {
516                            target: Some(self.peer_addr()),
517                            status: ErrorStatus::GenErr,
518                            index: 0,
519                            oid: scoped_pdu.pdu.varbinds.first().map(|vb| vb.oid.clone()),
520                        });
521                    }
522
523                    // Extract security params before consuming response
524                    let response_security_params = response.security_params.clone();
525
526                    // Handle encrypted response
527                    let response_pdu = if security_level.requires_priv() {
528                        match response.data {
529                            crate::message::V3MessageData::Encrypted(ciphertext) => {
530                                tracing::trace!(
531                                    ciphertext_len = ciphertext.len(),
532                                    "decrypting response"
533                                );
534
535                                // Decrypt
536                                let derived = self.inner.derived_keys.read().unwrap();
537                                let priv_key = derived
538                                    .as_ref()
539                                    .and_then(|d| d.priv_key.as_ref())
540                                    .ok_or_else(|| {
541                                    Error::decrypt(
542                                        Some(self.peer_addr()),
543                                        CryptoErrorKind::NoPrivKey,
544                                    )
545                                })?;
546
547                                let usm_params =
548                                    UsmSecurityParams::decode(response_security_params.clone())?;
549                                let plaintext = priv_key.decrypt(
550                                    &ciphertext,
551                                    usm_params.engine_boots,
552                                    usm_params.engine_time,
553                                    &usm_params.priv_params,
554                                )?;
555
556                                tracing::trace!(
557                                    plaintext_len = plaintext.len(),
558                                    "decrypted response"
559                                );
560
561                                // Decode scoped PDU
562                                let mut decoder = Decoder::new(plaintext);
563                                let scoped_pdu = ScopedPdu::decode(&mut decoder)?;
564                                scoped_pdu.pdu
565                            }
566                            crate::message::V3MessageData::Plaintext(scoped_pdu) => scoped_pdu.pdu,
567                        }
568                    } else {
569                        response
570                            .into_pdu()
571                            .ok_or_else(|| Error::decode(0, DecodeErrorKind::MissingPdu))?
572                    };
573
574                    // Validate request ID
575                    if response_pdu.request_id != pdu.request_id {
576                        return Err(Error::RequestIdMismatch {
577                            expected: pdu.request_id,
578                            actual: response_pdu.request_id,
579                        });
580                    }
581
582                    tracing::debug!(
583                        snmp.pdu_type = ?response_pdu.pdu_type,
584                        snmp.varbind_count = response_pdu.varbinds.len(),
585                        snmp.error_status = response_pdu.error_status,
586                        snmp.error_index = response_pdu.error_index,
587                        "received V3 {} response",
588                        response_pdu.pdu_type
589                    );
590
591                    // Update engine time from successful response
592                    {
593                        let usm_params = UsmSecurityParams::decode(response_security_params)?;
594                        let mut state = self.inner.engine_state.write().unwrap();
595                        if let Some(ref mut s) = *state {
596                            s.update_time(usm_params.engine_boots, usm_params.engine_time);
597                        }
598                    }
599
600                    // Check for SNMP error
601                    if response_pdu.is_error() {
602                        let status = response_pdu.error_status_enum();
603                        // error_index is 1-based; 0 means error applies to PDU, not a specific varbind
604                        let oid = (response_pdu.error_index as usize)
605                            .checked_sub(1)
606                            .and_then(|idx| response_pdu.varbinds.get(idx))
607                            .map(|vb| vb.oid.clone());
608
609                        Span::current()
610                            .record("snmp.elapsed_ms", start.elapsed().as_millis() as u64);
611                        return Err(Error::Snmp {
612                            target: Some(self.peer_addr()),
613                            status,
614                            index: response_pdu.error_index as u32,
615                            oid,
616                        });
617                    }
618
619                    Span::current().record("snmp.elapsed_ms", start.elapsed().as_millis() as u64);
620                    return Ok(response_pdu);
621                }
622                Err(e @ Error::Timeout { .. }) => {
623                    last_error = Some(e);
624                    continue;
625                }
626                Err(e) => {
627                    Span::current().record("snmp.elapsed_ms", start.elapsed().as_millis() as u64);
628                    return Err(e);
629                }
630            }
631        }
632
633        // All retries exhausted
634        Span::current().record("snmp.elapsed_ms", start.elapsed().as_millis() as u64);
635        Err(last_error.unwrap_or(Error::Timeout {
636            target: Some(self.peer_addr()),
637            elapsed: self.inner.config.timeout * (retries + 1),
638            request_id: pdu.request_id,
639            retries,
640        }))
641    }
642}