async_snmp/client/
v3.rs

1//! SNMPv3-specific client functionality.
2//!
3//! This module contains V3 security configuration, key derivation, engine discovery,
4//! and V3 message building/handling.
5
6use crate::ber::Decoder;
7use crate::error::{
8    AuthErrorKind, CryptoErrorKind, DecodeErrorKind, EncodeErrorKind, Error, ErrorStatus, Result,
9};
10use crate::format::hex;
11use crate::message::{MsgFlags, MsgGlobalData, ScopedPdu, SecurityLevel, V3Message};
12use crate::pdu::{Pdu, PduType};
13use crate::transport::Transport;
14use crate::v3::{AuthProtocol, PrivProtocol};
15use crate::v3::{
16    LocalizedKey, PrivKey, UsmSecurityParams,
17    auth::{authenticate_message, verify_message},
18    is_not_in_time_window_report, is_unknown_engine_id_report,
19};
20use bytes::Bytes;
21use std::time::Instant;
22use tracing::{Span, instrument};
23
24use super::Client;
25
26/// SNMPv3 security configuration.
27///
28/// Stores the credentials needed for authenticated and/or encrypted communication.
29/// Keys are derived when the engine ID is discovered.
30///
31/// # Master Key Caching
32///
33/// For high-throughput polling of many engines with shared credentials, use
34/// [`MasterKeys`](crate::MasterKeys) to cache the expensive password-to-key
35/// derivation. When `master_keys` is set, passwords are ignored and keys are
36/// derived from the cached master keys.
37#[derive(Clone)]
38pub struct V3SecurityConfig {
39    /// Username for USM authentication
40    pub username: Bytes,
41    /// Authentication protocol and password
42    pub auth: Option<(AuthProtocol, Vec<u8>)>,
43    /// Privacy protocol and password
44    pub privacy: Option<(PrivProtocol, Vec<u8>)>,
45    /// Pre-computed master keys for efficient key derivation
46    pub master_keys: Option<crate::v3::MasterKeys>,
47}
48
49impl V3SecurityConfig {
50    /// Create a new V3 security config with just a username (noAuthNoPriv).
51    pub fn new(username: impl Into<Bytes>) -> Self {
52        Self {
53            username: username.into(),
54            auth: None,
55            privacy: None,
56            master_keys: None,
57        }
58    }
59
60    /// Add authentication (authNoPriv or authPriv).
61    pub fn auth(mut self, protocol: AuthProtocol, password: impl Into<Vec<u8>>) -> Self {
62        self.auth = Some((protocol, password.into()));
63        self
64    }
65
66    /// Add privacy/encryption (authPriv).
67    pub fn privacy(mut self, protocol: PrivProtocol, password: impl Into<Vec<u8>>) -> Self {
68        self.privacy = Some((protocol, password.into()));
69        self
70    }
71
72    /// Use pre-computed master keys for efficient key derivation.
73    ///
74    /// When set, passwords are ignored and keys are derived from the cached
75    /// master keys. This avoids the expensive ~850μs password expansion for
76    /// each engine.
77    pub fn with_master_keys(mut self, master_keys: crate::v3::MasterKeys) -> Self {
78        self.master_keys = Some(master_keys);
79        self
80    }
81
82    /// Get the security level based on configured auth/privacy.
83    pub fn security_level(&self) -> SecurityLevel {
84        // Check master_keys first, then fall back to auth/privacy
85        if let Some(ref master_keys) = self.master_keys {
86            if master_keys.priv_protocol().is_some() {
87                return SecurityLevel::AuthPriv;
88            }
89            return SecurityLevel::AuthNoPriv;
90        }
91
92        match (&self.auth, &self.privacy) {
93            (None, _) => SecurityLevel::NoAuthNoPriv,
94            (Some(_), None) => SecurityLevel::AuthNoPriv,
95            (Some(_), Some(_)) => SecurityLevel::AuthPriv,
96        }
97    }
98
99    /// Derive localized keys for a specific engine ID.
100    ///
101    /// If master keys are configured, uses the cached master keys for efficient
102    /// localization (~1μs). Otherwise, performs full password-to-key derivation
103    /// (~850μs for SHA-256).
104    pub fn derive_keys(&self, engine_id: &[u8]) -> V3DerivedKeys {
105        // Use master keys if available (efficient path)
106        if let Some(ref master_keys) = self.master_keys {
107            tracing::trace!(
108                engine_id_len = engine_id.len(),
109                auth_protocol = ?master_keys.auth_protocol(),
110                priv_protocol = ?master_keys.priv_protocol(),
111                "localizing from cached master keys"
112            );
113            let (auth_key, priv_key) = master_keys.localize(engine_id);
114            tracing::trace!("key localization complete");
115            return V3DerivedKeys {
116                auth_key: Some(auth_key),
117                priv_key,
118            };
119        }
120
121        // Fall back to password-based derivation
122        tracing::trace!(
123            engine_id_len = engine_id.len(),
124            has_auth = self.auth.is_some(),
125            has_priv = self.privacy.is_some(),
126            "deriving localized keys from passwords"
127        );
128
129        let auth_key = self.auth.as_ref().map(|(protocol, password)| {
130            tracing::trace!(auth_protocol = ?protocol, "deriving auth key");
131            LocalizedKey::from_password(*protocol, password, engine_id)
132        });
133
134        let priv_key = match (&self.auth, &self.privacy) {
135            (Some((auth_protocol, _)), Some((priv_protocol, priv_password))) => {
136                tracing::trace!(priv_protocol = ?priv_protocol, "deriving privacy key");
137                Some(PrivKey::from_password(
138                    *auth_protocol,
139                    *priv_protocol,
140                    priv_password,
141                    engine_id,
142                ))
143            }
144            _ => None,
145        };
146
147        tracing::trace!("key derivation complete");
148        V3DerivedKeys { auth_key, priv_key }
149    }
150}
151
152impl std::fmt::Debug for V3SecurityConfig {
153    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
154        f.debug_struct("V3SecurityConfig")
155            .field("username", &String::from_utf8_lossy(&self.username))
156            .field("auth", &self.auth.as_ref().map(|(p, _)| p))
157            .field("privacy", &self.privacy.as_ref().map(|(p, _)| p))
158            .finish()
159    }
160}
161
162/// Derived keys for a specific engine ID.
163pub struct V3DerivedKeys {
164    pub auth_key: Option<LocalizedKey>,
165    pub priv_key: Option<PrivKey>,
166}
167
168// V3-specific Client implementation
169impl<T: Transport> Client<T> {
170    /// Ensure engine ID is discovered for V3 operations.
171    #[instrument(level = "debug", skip(self), fields(snmp.target = %self.peer_addr()))]
172    pub(super) async fn ensure_engine_discovered(&self) -> Result<()> {
173        // Check if already discovered
174        {
175            let state = self.inner.engine_state.read().unwrap();
176            if state.is_some() {
177                return Ok(());
178            }
179        }
180
181        // Check shared cache first
182        if let Some(cache) = &self.inner.engine_cache
183            && let Some(cached_state) = cache.get(&self.peer_addr())
184        {
185            tracing::debug!("using cached engine state");
186            let mut state = self.inner.engine_state.write().unwrap();
187            *state = Some(cached_state.clone());
188            // Derive keys for this engine
189            if let Some(security) = &self.inner.config.v3_security {
190                let keys = security.derive_keys(&cached_state.engine_id);
191                let mut derived = self.inner.derived_keys.write().unwrap();
192                *derived = Some(keys);
193            }
194            return Ok(());
195        }
196
197        // Perform discovery
198        tracing::debug!("performing engine discovery");
199        let msg_id = self.next_request_id();
200        let discovery_msg = V3Message::discovery_request(msg_id);
201        let discovery_data = discovery_msg.encode();
202
203        // Register request and send discovery
204        self.inner
205            .transport
206            .register_request(msg_id, self.inner.config.timeout);
207        self.inner.transport.send(&discovery_data).await?;
208        let (response_data, _source) = self.inner.transport.recv(msg_id).await?;
209
210        // Parse response
211        let response = V3Message::decode(response_data)?;
212
213        // Extract engine state from USM params
214        let engine_state = crate::v3::parse_discovery_response(&response.security_params)?;
215        tracing::debug!(
216            snmp.engine_id = %hex::Bytes(&engine_state.engine_id),
217            snmp.engine_boots = engine_state.engine_boots,
218            snmp.engine_time = engine_state.engine_time,
219            "discovered engine"
220        );
221
222        // Derive keys for this engine
223        if let Some(security) = &self.inner.config.v3_security {
224            let keys = security.derive_keys(&engine_state.engine_id);
225            let mut derived = self.inner.derived_keys.write().unwrap();
226            *derived = Some(keys);
227        }
228
229        // Store in local cache
230        {
231            let mut state = self.inner.engine_state.write().unwrap();
232            *state = Some(engine_state.clone());
233        }
234
235        // Store in shared cache if present
236        if let Some(cache) = &self.inner.engine_cache {
237            cache.insert(self.peer_addr(), engine_state);
238        }
239
240        Ok(())
241    }
242
243    /// Build and encode a V3 message with authentication and/or encryption.
244    pub(super) fn build_v3_message(&self, pdu: &Pdu) -> Result<(Vec<u8>, i32)> {
245        let security = self
246            .inner
247            .config
248            .v3_security
249            .as_ref()
250            .ok_or_else(|| Error::encode(EncodeErrorKind::NoSecurityConfig))?;
251
252        let engine_state = self.inner.engine_state.read().unwrap();
253        let engine_state = engine_state
254            .as_ref()
255            .ok_or_else(|| Error::encode(EncodeErrorKind::EngineNotDiscovered))?;
256
257        let derived = self.inner.derived_keys.read().unwrap();
258
259        let security_level = security.security_level();
260        let msg_id = pdu.request_id; // Use request_id as msg_id for correlation
261
262        // Build scoped PDU
263        let scoped_pdu = ScopedPdu::new(
264            engine_state.engine_id.clone(),
265            Bytes::new(), // empty context name
266            pdu.clone(),
267        );
268
269        // Get current engine time estimate
270        let engine_boots = engine_state.engine_boots;
271        let engine_time = engine_state.estimated_time();
272
273        // Handle encryption if needed
274        let (msg_data, priv_params) = if security_level.requires_priv() {
275            tracing::trace!("encrypting scoped PDU");
276
277            // Get mutable priv_key - we need interior mutability for salt counter
278            // Since PrivKey uses internal counter, we need to clone and use
279            let derived_ref = derived
280                .as_ref()
281                .ok_or_else(|| Error::encode(EncodeErrorKind::KeysNotDerived))?;
282            let mut priv_key = derived_ref
283                .priv_key
284                .as_ref()
285                .ok_or_else(|| Error::encode(EncodeErrorKind::NoPrivKey))?
286                .clone();
287
288            // Encode scoped PDU
289            let scoped_pdu_bytes = scoped_pdu.encode_to_bytes();
290
291            // Encrypt
292            let (ciphertext, salt) = priv_key.encrypt(
293                &scoped_pdu_bytes,
294                engine_boots,
295                engine_time,
296                Some(&self.inner.salt_counter),
297            )?;
298
299            tracing::trace!(
300                plaintext_len = scoped_pdu_bytes.len(),
301                ciphertext_len = ciphertext.len(),
302                "encrypted scoped PDU"
303            );
304
305            (crate::message::V3MessageData::Encrypted(ciphertext), salt)
306        } else {
307            (
308                crate::message::V3MessageData::Plaintext(scoped_pdu),
309                Bytes::new(),
310            )
311        };
312
313        // Build USM security parameters
314        let mac_len = if security_level.requires_auth() {
315            derived
316                .as_ref()
317                .and_then(|d| d.auth_key.as_ref())
318                .map(|k| k.mac_len())
319                .unwrap_or(12)
320        } else {
321            0
322        };
323
324        let mut usm_params = UsmSecurityParams::new(
325            engine_state.engine_id.clone(),
326            engine_boots,
327            engine_time,
328            security.username.clone(),
329        );
330
331        if security_level.requires_auth() {
332            usm_params = usm_params.with_auth_placeholder(mac_len);
333        }
334
335        if security_level.requires_priv() {
336            usm_params = usm_params.with_priv_params(priv_params);
337        }
338
339        let usm_encoded = usm_params.encode();
340
341        // Build global data
342        let msg_flags = MsgFlags::new(security_level, true); // reportable=true for requests
343        let global_data = MsgGlobalData::new(msg_id, 65507, msg_flags);
344
345        // Build complete message
346        let msg = match msg_data {
347            crate::message::V3MessageData::Plaintext(scoped_pdu) => {
348                V3Message::new(global_data, usm_encoded, scoped_pdu)
349            }
350            crate::message::V3MessageData::Encrypted(ciphertext) => {
351                V3Message::new_encrypted(global_data, usm_encoded, ciphertext)
352            }
353        };
354
355        let mut encoded = msg.encode().to_vec();
356
357        // Apply authentication if needed
358        if security_level.requires_auth() {
359            tracing::trace!("applying HMAC authentication");
360
361            let auth_key = derived
362                .as_ref()
363                .and_then(|d| d.auth_key.as_ref())
364                .ok_or_else(|| Error::encode(EncodeErrorKind::MissingAuthKey))?;
365
366            // Find auth params position and apply HMAC
367            if let Some((offset, len)) = UsmSecurityParams::find_auth_params_offset(&encoded) {
368                authenticate_message(auth_key, &mut encoded, offset, len);
369                tracing::trace!(
370                    auth_params_offset = offset,
371                    auth_params_len = len,
372                    "applied HMAC authentication"
373                );
374            } else {
375                return Err(Error::encode(EncodeErrorKind::MissingAuthParams));
376            }
377        }
378
379        Ok((encoded, msg_id))
380    }
381
382    /// Send a V3 request and handle the response.
383    #[instrument(
384        level = "debug",
385        skip(self, pdu),
386        fields(
387            snmp.target = %self.peer_addr(),
388            snmp.request_id = pdu.request_id,
389            snmp.security_level = ?self.inner.config.v3_security.as_ref().map(|s| s.security_level()),
390            snmp.attempt = tracing::field::Empty,
391            snmp.elapsed_ms = tracing::field::Empty,
392        )
393    )]
394    pub(super) async fn send_v3_and_recv(&self, pdu: Pdu) -> Result<Pdu> {
395        let start = Instant::now();
396
397        // Ensure engine is discovered first
398        self.ensure_engine_discovered().await?;
399
400        let security = self
401            .inner
402            .config
403            .v3_security
404            .as_ref()
405            .ok_or_else(|| Error::encode(EncodeErrorKind::NoSecurityConfig))?;
406        let security_level = security.security_level();
407
408        let mut last_error = None;
409        let max_attempts = if self.inner.transport.is_reliable() {
410            0
411        } else {
412            self.inner.config.retry.max_attempts
413        };
414
415        for attempt in 0..=max_attempts {
416            Span::current().record("snmp.attempt", attempt);
417            if attempt > 0 {
418                tracing::debug!("retrying V3 request");
419            }
420
421            // Build message (may need fresh timestamps on retry)
422            let (data, msg_id) = self.build_v3_message(&pdu)?;
423
424            tracing::debug!(
425                snmp.pdu_type = ?pdu.pdu_type,
426                snmp.varbind_count = pdu.varbinds.len(),
427                snmp.msg_id = msg_id,
428                "sending V3 {} request",
429                pdu.pdu_type
430            );
431            tracing::trace!(snmp.bytes = data.len(), "sending V3 request");
432
433            // Register (or re-register) with fresh deadline before sending
434            self.inner
435                .transport
436                .register_request(msg_id, self.inner.config.timeout);
437
438            // Send request
439            self.inner.transport.send(&data).await?;
440
441            // Wait for response (deadline was set by register_request)
442            match self.inner.transport.recv(msg_id).await {
443                Ok((response_data, _source)) => {
444                    tracing::trace!(snmp.bytes = response_data.len(), "received V3 response");
445
446                    // Verify authentication if required
447                    if security_level.requires_auth() {
448                        tracing::trace!("verifying HMAC authentication on response");
449
450                        let derived = self.inner.derived_keys.read().unwrap();
451                        let auth_key = derived
452                            .as_ref()
453                            .and_then(|d| d.auth_key.as_ref())
454                            .ok_or_else(|| {
455                                Error::auth(Some(self.peer_addr()), AuthErrorKind::NoAuthKey)
456                            })?;
457
458                        if let Some((offset, len)) =
459                            UsmSecurityParams::find_auth_params_offset(&response_data)
460                        {
461                            if !verify_message(auth_key, &response_data, offset, len) {
462                                tracing::trace!("HMAC verification failed");
463                                return Err(Error::auth(
464                                    Some(self.peer_addr()),
465                                    AuthErrorKind::HmacMismatch,
466                                ));
467                            }
468                            tracing::trace!(
469                                auth_params_offset = offset,
470                                auth_params_len = len,
471                                "HMAC verification successful"
472                            );
473                        } else {
474                            return Err(Error::auth(
475                                Some(self.peer_addr()),
476                                AuthErrorKind::AuthParamsNotFound,
477                            ));
478                        }
479                    }
480
481                    // Decode response
482                    let response = V3Message::decode(response_data.clone())?;
483
484                    // Check for Report PDU (error response)
485                    if let Some(scoped_pdu) = response.scoped_pdu()
486                        && scoped_pdu.pdu.pdu_type == PduType::Report
487                    {
488                        // Check for time window error - resync and retry
489                        if is_not_in_time_window_report(&scoped_pdu.pdu) {
490                            tracing::debug!("not in time window, resyncing");
491                            // Update engine time from response
492                            let usm_params =
493                                UsmSecurityParams::decode(response.security_params.clone())?;
494                            {
495                                let mut state = self.inner.engine_state.write().unwrap();
496                                if let Some(ref mut s) = *state {
497                                    s.update_time(usm_params.engine_boots, usm_params.engine_time);
498                                }
499                            }
500                            last_error = Some(Error::NotInTimeWindow {
501                                target: Some(self.peer_addr()),
502                            });
503                            // Apply backoff delay before retry (if not last attempt)
504                            if attempt < max_attempts {
505                                let delay = self.inner.config.retry.compute_delay(attempt);
506                                if !delay.is_zero() {
507                                    tracing::debug!(
508                                        delay_ms = delay.as_millis() as u64,
509                                        "backing off"
510                                    );
511                                    tokio::time::sleep(delay).await;
512                                }
513                            }
514                            continue;
515                        }
516
517                        // Check for unknown engine ID
518                        if is_unknown_engine_id_report(&scoped_pdu.pdu) {
519                            return Err(Error::UnknownEngineId {
520                                target: Some(self.peer_addr()),
521                            });
522                        }
523
524                        // Other Report errors
525                        return Err(Error::Snmp {
526                            target: Some(self.peer_addr()),
527                            status: ErrorStatus::GenErr,
528                            index: 0,
529                            oid: scoped_pdu.pdu.varbinds.first().map(|vb| vb.oid.clone()),
530                        });
531                    }
532
533                    // Extract security params before consuming response
534                    let response_security_params = response.security_params.clone();
535
536                    // Handle encrypted response
537                    let response_pdu = if security_level.requires_priv() {
538                        match response.data {
539                            crate::message::V3MessageData::Encrypted(ciphertext) => {
540                                tracing::trace!(
541                                    ciphertext_len = ciphertext.len(),
542                                    "decrypting response"
543                                );
544
545                                // Decrypt
546                                let derived = self.inner.derived_keys.read().unwrap();
547                                let priv_key = derived
548                                    .as_ref()
549                                    .and_then(|d| d.priv_key.as_ref())
550                                    .ok_or_else(|| {
551                                    Error::decrypt(
552                                        Some(self.peer_addr()),
553                                        CryptoErrorKind::NoPrivKey,
554                                    )
555                                })?;
556
557                                let usm_params =
558                                    UsmSecurityParams::decode(response_security_params.clone())?;
559                                let plaintext = priv_key.decrypt(
560                                    &ciphertext,
561                                    usm_params.engine_boots,
562                                    usm_params.engine_time,
563                                    &usm_params.priv_params,
564                                )?;
565
566                                tracing::trace!(
567                                    plaintext_len = plaintext.len(),
568                                    "decrypted response"
569                                );
570
571                                // Decode scoped PDU
572                                let mut decoder = Decoder::new(plaintext);
573                                let scoped_pdu = ScopedPdu::decode(&mut decoder)?;
574                                scoped_pdu.pdu
575                            }
576                            crate::message::V3MessageData::Plaintext(scoped_pdu) => scoped_pdu.pdu,
577                        }
578                    } else {
579                        response
580                            .into_pdu()
581                            .ok_or_else(|| Error::decode(0, DecodeErrorKind::MissingPdu))?
582                    };
583
584                    // Validate request ID
585                    if response_pdu.request_id != pdu.request_id {
586                        return Err(Error::RequestIdMismatch {
587                            expected: pdu.request_id,
588                            actual: response_pdu.request_id,
589                        });
590                    }
591
592                    tracing::debug!(
593                        snmp.pdu_type = ?response_pdu.pdu_type,
594                        snmp.varbind_count = response_pdu.varbinds.len(),
595                        snmp.error_status = response_pdu.error_status,
596                        snmp.error_index = response_pdu.error_index,
597                        "received V3 {} response",
598                        response_pdu.pdu_type
599                    );
600
601                    // Update engine time from successful response
602                    {
603                        let usm_params = UsmSecurityParams::decode(response_security_params)?;
604                        let mut state = self.inner.engine_state.write().unwrap();
605                        if let Some(ref mut s) = *state {
606                            s.update_time(usm_params.engine_boots, usm_params.engine_time);
607                        }
608                    }
609
610                    // Check for SNMP error
611                    if response_pdu.is_error() {
612                        let status = response_pdu.error_status_enum();
613                        // error_index is 1-based; 0 means error applies to PDU, not a specific varbind
614                        let oid = (response_pdu.error_index as usize)
615                            .checked_sub(1)
616                            .and_then(|idx| response_pdu.varbinds.get(idx))
617                            .map(|vb| vb.oid.clone());
618
619                        Span::current()
620                            .record("snmp.elapsed_ms", start.elapsed().as_millis() as u64);
621                        return Err(Error::Snmp {
622                            target: Some(self.peer_addr()),
623                            status,
624                            index: response_pdu.error_index as u32,
625                            oid,
626                        });
627                    }
628
629                    Span::current().record("snmp.elapsed_ms", start.elapsed().as_millis() as u64);
630                    return Ok(response_pdu);
631                }
632                Err(e @ Error::Timeout { .. }) => {
633                    last_error = Some(e);
634                    // Apply backoff delay before next retry (if not last attempt)
635                    if attempt < max_attempts {
636                        let delay = self.inner.config.retry.compute_delay(attempt);
637                        if !delay.is_zero() {
638                            tracing::debug!(delay_ms = delay.as_millis() as u64, "backing off");
639                            tokio::time::sleep(delay).await;
640                        }
641                    }
642                    continue;
643                }
644                Err(e) => {
645                    Span::current().record("snmp.elapsed_ms", start.elapsed().as_millis() as u64);
646                    return Err(e);
647                }
648            }
649        }
650
651        // All retries exhausted
652        Span::current().record("snmp.elapsed_ms", start.elapsed().as_millis() as u64);
653        Err(last_error.unwrap_or(Error::Timeout {
654            target: Some(self.peer_addr()),
655            elapsed: start.elapsed(),
656            request_id: pdu.request_id,
657            retries: max_attempts,
658        }))
659    }
660}