Skip to main content

async_snmp/client/
mod.rs

1//! SNMP client implementation.
2
3mod auth;
4mod builder;
5mod retry;
6mod v3;
7mod walk;
8
9pub use auth::{Auth, CommunityVersion, UsmAuth, UsmBuilder};
10pub use builder::{ClientBuilder, Target};
11pub use retry::{Backoff, Retry, RetryBuilder};
12
13// New unified entry point
14impl Client<UdpHandle> {
15    /// Create a new SNMP client builder.
16    ///
17    /// This is the single entry point for client construction, supporting all
18    /// SNMP versions (v1, v2c, v3) through the [`Auth`] enum.
19    ///
20    /// # Example
21    ///
22    /// ```rust,no_run
23    /// use async_snmp::{Auth, Client, Retry};
24    /// use std::time::Duration;
25    ///
26    /// # async fn example() -> async_snmp::Result<()> {
27    /// // (host, port) tuple - convenient when host and port are separate
28    /// let client = Client::builder(("192.168.1.1", 161), Auth::v2c("public"))
29    ///     .connect().await?;
30    ///
31    /// // Combined address string (port defaults to 161 if omitted)
32    /// let client = Client::builder("switch.local", Auth::v2c("public"))
33    ///     .connect().await?;
34    ///
35    /// // SocketAddr works too
36    /// let addr: std::net::SocketAddr = "192.168.1.1:161".parse().unwrap();
37    /// let client = Client::builder(addr, Auth::v2c("public"))
38    ///     .connect().await?;
39    /// # Ok(())
40    /// # }
41    /// ```
42    pub fn builder(target: impl Into<Target>, auth: impl Into<Auth>) -> ClientBuilder {
43        ClientBuilder::new(target, auth)
44    }
45}
46use crate::error::internal::DecodeErrorKind;
47use crate::error::{Error, Result};
48use crate::message::{CommunityMessage, Message};
49use crate::oid::Oid;
50use crate::pdu::{GetBulkPdu, Pdu};
51use crate::transport::Transport;
52use crate::transport::UdpHandle;
53use crate::v3::{EngineCache, EngineState, SaltCounter};
54use crate::value::Value;
55use crate::varbind::VarBind;
56use crate::version::Version;
57use bytes::Bytes;
58use std::net::SocketAddr;
59use std::sync::Arc;
60use std::sync::RwLock;
61use std::time::{Duration, Instant};
62use tokio::sync::Mutex as AsyncMutex;
63use tracing::{Span, instrument};
64
65pub use crate::notification::{DerivedKeys, UsmConfig};
66pub use walk::{BulkWalk, OidOrdering, Walk, WalkMode, WalkStream};
67
68// ============================================================================
69// Shared helpers
70// ============================================================================
71
72/// Extract an SNMP-level error from a PDU and convert it to an `Error::Snmp`.
73///
74/// Returns `Some(err)` if the PDU carries an SNMP error status, `None` otherwise.
75/// The `error_index` field is 1-based; 0 means the error applies to the whole PDU.
76pub(crate) fn pdu_to_snmp_error(pdu: &Pdu, target: SocketAddr) -> Option<Box<Error>> {
77    if !pdu.is_error() {
78        return None;
79    }
80    let status = pdu.error_status_enum();
81    let oid = (pdu.error_index as usize)
82        .checked_sub(1)
83        .and_then(|idx| pdu.varbinds.get(idx))
84        .map(|vb| vb.oid.clone());
85    Some(
86        Error::Snmp {
87            target,
88            status,
89            index: pdu.error_index.max(0) as u32,
90            oid,
91        }
92        .boxed(),
93    )
94}
95
96// ============================================================================
97// Default configuration constants
98// ============================================================================
99
100/// Default timeout for SNMP requests.
101pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
102
103/// Default maximum OIDs per request.
104///
105/// Requests with more OIDs than this limit are automatically split into
106/// multiple batches.
107pub const DEFAULT_MAX_OIDS_PER_REQUEST: usize = 10;
108
109/// Default max-repetitions for GETBULK operations.
110///
111/// Controls how many values are requested per GETBULK PDU during walks.
112pub const DEFAULT_MAX_REPETITIONS: u32 = 25;
113
114/// SNMP client.
115///
116/// Generic over transport type, with `UdpHandle` as default.
117#[derive(Clone)]
118pub struct Client<T: Transport = UdpHandle> {
119    inner: Arc<ClientInner<T>>,
120}
121
122struct ClientInner<T: Transport> {
123    transport: T,
124    config: ClientConfig,
125    /// Cached engine state (V3)
126    engine_state: RwLock<Option<EngineState>>,
127    /// Derived keys for this engine (V3)
128    derived_keys: RwLock<Option<DerivedKeys>>,
129    /// Salt counter for privacy (V3)
130    salt_counter: SaltCounter,
131    /// Shared engine cache (V3, optional)
132    engine_cache: Option<Arc<EngineCache>>,
133    /// Serializes concurrent discovery attempts so only one runs at a time.
134    discovery_lock: AsyncMutex<()>,
135}
136
137/// Client configuration.
138///
139/// Most users should use [`ClientBuilder`] rather than constructing this directly.
140#[derive(Clone)]
141pub struct ClientConfig {
142    /// SNMP version (default: V2c)
143    pub version: Version,
144    /// Community string for v1/v2c (default: "public")
145    pub community: Bytes,
146    /// Request timeout (default: 5 seconds)
147    pub timeout: Duration,
148    /// Retry configuration (default: 3 retries, 1-second delay)
149    pub retry: Retry,
150    /// Maximum OIDs per request (default: 10)
151    pub max_oids_per_request: usize,
152    /// SNMPv3 security configuration (default: None)
153    pub v3_security: Option<UsmConfig>,
154    /// Walk operation mode (default: Auto)
155    pub walk_mode: WalkMode,
156    /// OID ordering behavior during walk operations (default: Strict)
157    pub oid_ordering: OidOrdering,
158    /// Maximum results from a single walk operation (default: None/unlimited)
159    pub max_walk_results: Option<usize>,
160    /// Max-repetitions for GETBULK operations (default: 25)
161    pub max_repetitions: u32,
162}
163
164impl Default for ClientConfig {
165    /// Returns configuration for SNMPv2c with community "public".
166    ///
167    /// See field documentation for all default values.
168    fn default() -> Self {
169        Self {
170            version: Version::V2c,
171            community: Bytes::from_static(b"public"),
172            timeout: DEFAULT_TIMEOUT,
173            retry: Retry::default(),
174            max_oids_per_request: DEFAULT_MAX_OIDS_PER_REQUEST,
175            v3_security: None,
176            walk_mode: WalkMode::Auto,
177            oid_ordering: OidOrdering::Strict,
178            max_walk_results: None,
179            max_repetitions: DEFAULT_MAX_REPETITIONS,
180        }
181    }
182}
183
184impl<T: Transport> Client<T> {
185    /// Create a new client with the given transport and config.
186    ///
187    /// For most use cases, prefer [`Client::builder()`] which provides a more
188    /// ergonomic API. Use this constructor when you need fine-grained control
189    /// over transport configuration (e.g., TCP connection timeout, keepalive
190    /// settings) or when using a custom [`Transport`] implementation.
191    pub fn new(transport: T, config: ClientConfig) -> Self {
192        Self {
193            inner: Arc::new(ClientInner {
194                transport,
195                config,
196                engine_state: RwLock::new(None),
197                derived_keys: RwLock::new(None),
198                salt_counter: SaltCounter::new(),
199                engine_cache: None,
200                discovery_lock: AsyncMutex::new(()),
201            }),
202        }
203    }
204
205    /// Create a new V3 client with a shared engine cache.
206    pub fn with_engine_cache(
207        transport: T,
208        config: ClientConfig,
209        engine_cache: Arc<EngineCache>,
210    ) -> Self {
211        Self {
212            inner: Arc::new(ClientInner {
213                transport,
214                config,
215                engine_state: RwLock::new(None),
216                derived_keys: RwLock::new(None),
217                salt_counter: SaltCounter::new(),
218                engine_cache: Some(engine_cache),
219                discovery_lock: AsyncMutex::new(()),
220            }),
221        }
222    }
223
224    /// Get the peer (target) address.
225    ///
226    /// Returns the remote address that this client sends requests to.
227    /// Named to match [`std::net::TcpStream::peer_addr()`].
228    pub fn peer_addr(&self) -> SocketAddr {
229        self.inner.transport.peer_addr()
230    }
231
232    /// Generate next request ID.
233    ///
234    /// Uses the transport's allocator (backed by a global counter).
235    fn next_request_id(&self) -> i32 {
236        self.inner.transport.alloc_request_id()
237    }
238
239    /// Check if using V3 with authentication/encryption configured.
240    fn is_v3(&self) -> bool {
241        self.inner.config.version == Version::V3 && self.inner.config.v3_security.is_some()
242    }
243
244    /// Send a request and wait for response (internal helper with pre-encoded data).
245    #[instrument(
246        level = "debug",
247        skip(self, data),
248        fields(
249            snmp.target = %self.peer_addr(),
250            snmp.request_id = request_id,
251            snmp.attempt = tracing::field::Empty,
252            snmp.elapsed_ms = tracing::field::Empty,
253        )
254    )]
255    async fn send_and_recv(&self, request_id: i32, data: &[u8]) -> Result<Pdu> {
256        let start = Instant::now();
257        let mut last_error: Option<Box<Error>> = None;
258        let max_attempts = if self.inner.transport.is_reliable() {
259            0
260        } else {
261            self.inner.config.retry.max_attempts
262        };
263
264        for attempt in 0..=max_attempts {
265            Span::current().record("snmp.attempt", attempt);
266            if attempt > 0 {
267                tracing::debug!(target: "async_snmp::client", "retrying request");
268            }
269
270            // Register (or re-register) with fresh deadline before sending
271            self.inner
272                .transport
273                .register_request(request_id, self.inner.config.timeout);
274
275            // Send request
276            tracing::trace!(target: "async_snmp::client", { snmp.bytes = data.len() }, "sending request");
277            self.inner.transport.send(data).await?;
278
279            // Wait for response (deadline was set by register_request)
280            match self.inner.transport.recv(request_id).await {
281                Ok((response_data, _source)) => {
282                    tracing::trace!(target: "async_snmp::client", { snmp.bytes = response_data.len() }, "received response");
283
284                    // Decode response and extract PDU
285                    let response = Message::decode(response_data)?;
286
287                    // Validate response version matches request version
288                    let response_version = response.version();
289                    let expected_version = self.inner.config.version;
290                    if response_version != expected_version {
291                        tracing::warn!(target: "async_snmp::client", { ?expected_version, ?response_version, peer = %self.peer_addr() }, "version mismatch in response");
292                        return Err(Error::MalformedResponse {
293                            target: self.peer_addr(),
294                        }
295                        .boxed());
296                    }
297
298                    let response_pdu = match response.try_into_pdu() {
299                        Some(p) => p,
300                        None => {
301                            tracing::warn!(target: "async_snmp::client", { peer = %self.peer_addr() }, "received TrapV1 in response to request");
302                            return Err(Error::MalformedResponse {
303                                target: self.peer_addr(),
304                            }
305                            .boxed());
306                        }
307                    };
308
309                    // Validate request ID
310                    if response_pdu.request_id != request_id {
311                        tracing::warn!(target: "async_snmp::client", { expected_request_id = request_id, actual_request_id = response_pdu.request_id, peer = %self.peer_addr() }, "request ID mismatch in response");
312                        return Err(Error::MalformedResponse {
313                            target: self.peer_addr(),
314                        }
315                        .boxed());
316                    }
317
318                    // Check for SNMP error
319                    if let Some(err) = pdu_to_snmp_error(&response_pdu, self.peer_addr()) {
320                        Span::current()
321                            .record("snmp.elapsed_ms", start.elapsed().as_millis() as u64);
322                        return Err(err);
323                    }
324
325                    Span::current().record("snmp.elapsed_ms", start.elapsed().as_millis() as u64);
326                    return Ok(response_pdu);
327                }
328                Err(e) if matches!(*e, Error::Timeout { .. }) => {
329                    last_error = Some(e);
330                    // Apply backoff delay before next retry (if not last attempt)
331                    if attempt < max_attempts {
332                        let delay = self.inner.config.retry.compute_delay(attempt);
333                        if !delay.is_zero() {
334                            tracing::debug!(target: "async_snmp::client", { delay_ms = delay.as_millis() as u64 }, "backing off");
335                            tokio::time::sleep(delay).await;
336                        }
337                    }
338                    continue;
339                }
340                Err(e) => {
341                    Span::current().record("snmp.elapsed_ms", start.elapsed().as_millis() as u64);
342                    return Err(e);
343                }
344            }
345        }
346
347        // All retries exhausted
348        let elapsed = start.elapsed();
349        Span::current().record("snmp.elapsed_ms", elapsed.as_millis() as u64);
350        tracing::debug!(target: "async_snmp::client", { request_id, peer = %self.peer_addr(), ?elapsed, retries = max_attempts }, "request timed out");
351        Err(last_error.unwrap_or_else(|| {
352            Error::Timeout {
353                target: self.peer_addr(),
354                elapsed,
355                retries: max_attempts,
356            }
357            .boxed()
358        }))
359    }
360
361    /// Send a standard request (GET, GETNEXT, SET) and wait for response.
362    async fn send_request(&self, pdu: Pdu) -> Result<Pdu> {
363        // Dispatch to V3 handler if configured
364        if self.is_v3() {
365            return self.send_v3_and_recv(pdu).await;
366        }
367
368        tracing::debug!(target: "async_snmp::client", { snmp.pdu_type = ?pdu.pdu_type, snmp.varbind_count = pdu.varbinds.len() }, "sending {} request", pdu.pdu_type);
369
370        let request_id = pdu.request_id;
371        let message = CommunityMessage::new(
372            self.inner.config.version,
373            self.inner.config.community.clone(),
374            pdu,
375        );
376        let data = message.encode();
377        let response = self.send_and_recv(request_id, &data).await?;
378
379        tracing::debug!(target: "async_snmp::client", { snmp.pdu_type = ?response.pdu_type, snmp.varbind_count = response.varbinds.len(), snmp.error_status = response.error_status, snmp.error_index = response.error_index }, "received {} response", response.pdu_type);
380
381        Ok(response)
382    }
383
384    /// Send a GETBULK request and wait for response.
385    async fn send_bulk_request(&self, pdu: GetBulkPdu) -> Result<Pdu> {
386        // Dispatch to V3 handler if configured
387        if self.is_v3() {
388            // Convert GetBulkPdu to Pdu for V3 encoding
389            let pdu = Pdu::get_bulk(
390                pdu.request_id,
391                pdu.non_repeaters,
392                pdu.max_repetitions,
393                pdu.varbinds,
394            );
395            return self.send_v3_and_recv(pdu).await;
396        }
397
398        tracing::debug!(target: "async_snmp::client", { snmp.non_repeaters = pdu.non_repeaters, snmp.max_repetitions = pdu.max_repetitions, snmp.varbind_count = pdu.varbinds.len() }, "sending GetBulkRequest");
399
400        let request_id = pdu.request_id;
401        let data = CommunityMessage::encode_bulk(
402            self.inner.config.version,
403            self.inner.config.community.clone(),
404            &pdu,
405        );
406        let response = self.send_and_recv(request_id, &data).await?;
407
408        tracing::debug!(target: "async_snmp::client", { snmp.pdu_type = ?response.pdu_type, snmp.varbind_count = response.varbinds.len(), snmp.error_status = response.error_status, snmp.error_index = response.error_index }, "received {} response", response.pdu_type);
409
410        Ok(response)
411    }
412
413    /// GET a single OID.
414    #[instrument(skip(self), err, fields(snmp.target = %self.peer_addr(), snmp.oid = %oid))]
415    pub async fn get(&self, oid: &Oid) -> Result<VarBind> {
416        let request_id = self.next_request_id();
417        let pdu = Pdu::get_request(request_id, std::slice::from_ref(oid));
418        let response = self.send_request(pdu).await?;
419
420        response.varbinds.into_iter().next().ok_or_else(|| {
421            tracing::debug!(target: "async_snmp::client", { peer = %self.peer_addr(), kind = %DecodeErrorKind::EmptyResponse }, "empty GET response");
422            Error::MalformedResponse {
423                target: self.peer_addr(),
424            }
425            .boxed()
426        })
427    }
428
429    /// GET multiple OIDs.
430    ///
431    /// If the OID list exceeds `max_oids_per_request`, the request is
432    /// automatically split into multiple batches. Results are returned
433    /// in the same order as the input OIDs.
434    ///
435    /// # Example
436    ///
437    /// ```rust,no_run
438    /// # use async_snmp::{Auth, Client, oid};
439    /// # async fn example() -> async_snmp::Result<()> {
440    /// # let client = Client::builder("127.0.0.1:161", Auth::v2c("public")).connect().await?;
441    /// let results = client.get_many(&[
442    ///     oid!(1, 3, 6, 1, 2, 1, 1, 1, 0),  // sysDescr
443    ///     oid!(1, 3, 6, 1, 2, 1, 1, 3, 0),  // sysUpTime
444    ///     oid!(1, 3, 6, 1, 2, 1, 1, 5, 0),  // sysName
445    /// ]).await?;
446    /// # Ok(())
447    /// # }
448    /// ```
449    #[instrument(skip(self, oids), err, fields(snmp.target = %self.peer_addr(), snmp.oid_count = oids.len()))]
450    pub async fn get_many(&self, oids: &[Oid]) -> Result<Vec<VarBind>> {
451        self.get_or_getnext_many(oids, "GET", Pdu::get_request)
452            .await
453    }
454
455    /// GETNEXT for a single OID.
456    #[instrument(skip(self), err, fields(snmp.target = %self.peer_addr(), snmp.oid = %oid))]
457    pub async fn get_next(&self, oid: &Oid) -> Result<VarBind> {
458        let request_id = self.next_request_id();
459        let pdu = Pdu::get_next_request(request_id, std::slice::from_ref(oid));
460        let response = self.send_request(pdu).await?;
461
462        response.varbinds.into_iter().next().ok_or_else(|| {
463            tracing::debug!(target: "async_snmp::client", { peer = %self.peer_addr(), kind = %DecodeErrorKind::EmptyResponse }, "empty GETNEXT response");
464            Error::MalformedResponse {
465                target: self.peer_addr(),
466            }
467            .boxed()
468        })
469    }
470
471    /// GETNEXT for multiple OIDs.
472    ///
473    /// If the OID list exceeds `max_oids_per_request`, the request is
474    /// automatically split into multiple batches. Results are returned
475    /// in the same order as the input OIDs.
476    ///
477    /// # Example
478    ///
479    /// ```rust,no_run
480    /// # use async_snmp::{Auth, Client, oid};
481    /// # async fn example() -> async_snmp::Result<()> {
482    /// # let client = Client::builder("127.0.0.1:161", Auth::v2c("public")).connect().await?;
483    /// let results = client.get_next_many(&[
484    ///     oid!(1, 3, 6, 1, 2, 1, 2, 2, 1, 2),  // ifDescr
485    ///     oid!(1, 3, 6, 1, 2, 1, 2, 2, 1, 3),  // ifType
486    /// ]).await?;
487    /// # Ok(())
488    /// # }
489    /// ```
490    #[instrument(skip(self, oids), err, fields(snmp.target = %self.peer_addr(), snmp.oid_count = oids.len()))]
491    pub async fn get_next_many(&self, oids: &[Oid]) -> Result<Vec<VarBind>> {
492        self.get_or_getnext_many(oids, "GETNEXT", Pdu::get_next_request)
493            .await
494    }
495
496    /// Shared implementation for GET-many and GETNEXT-many.
497    ///
498    /// `op` is the PDU constructor (`Pdu::get_request` or `Pdu::get_next_request`).
499    /// `op_name` is used only for log messages.
500    async fn get_or_getnext_many(
501        &self,
502        oids: &[Oid],
503        op_name: &'static str,
504        op: fn(i32, &[Oid]) -> Pdu,
505    ) -> Result<Vec<VarBind>> {
506        if oids.is_empty() {
507            return Ok(Vec::new());
508        }
509
510        let max_per_request = self.inner.config.max_oids_per_request;
511
512        // Fast path: single request if within limit
513        if oids.len() <= max_per_request {
514            let request_id = self.next_request_id();
515            let pdu = op(request_id, oids);
516            let response = self.send_request(pdu).await?;
517            if response.varbinds.len() > oids.len() {
518                tracing::warn!(target: "async_snmp::client", { peer = %self.peer_addr(), expected = oids.len(), actual = response.varbinds.len(), snmp.op = op_name }, "response has more varbinds than requested");
519                return Err(Error::MalformedResponse {
520                    target: self.peer_addr(),
521                }
522                .boxed());
523            } else if response.varbinds.len() < oids.len() {
524                tracing::warn!(target: "async_snmp::client", { peer = %self.peer_addr(), expected = oids.len(), actual = response.varbinds.len(), snmp.op = op_name }, "response has fewer varbinds than requested");
525            }
526            return Ok(response.varbinds);
527        }
528
529        // Batched path: split into chunks
530        let num_batches = oids.len().div_ceil(max_per_request);
531        tracing::debug!(target: "async_snmp::client", { snmp.oid_count = oids.len(), snmp.max_per_request = max_per_request, snmp.batch_count = num_batches, snmp.op = op_name }, "splitting request into batches");
532
533        let mut all_results = Vec::with_capacity(oids.len());
534
535        for (batch_idx, chunk) in oids.chunks(max_per_request).enumerate() {
536            tracing::debug!(target: "async_snmp::client", { snmp.batch = batch_idx + 1, snmp.batch_total = num_batches, snmp.batch_oid_count = chunk.len(), snmp.op = op_name }, "sending batch");
537            let request_id = self.next_request_id();
538            let pdu = op(request_id, chunk);
539            let response = self.send_request(pdu).await?;
540            if response.varbinds.len() > chunk.len() {
541                tracing::warn!(target: "async_snmp::client", { peer = %self.peer_addr(), expected = chunk.len(), actual = response.varbinds.len(), snmp.batch = batch_idx + 1, snmp.op = op_name }, "response has more varbinds than requested in batch");
542                return Err(Error::MalformedResponse {
543                    target: self.peer_addr(),
544                }
545                .boxed());
546            } else if response.varbinds.len() < chunk.len() {
547                tracing::warn!(target: "async_snmp::client", { peer = %self.peer_addr(), expected = chunk.len(), actual = response.varbinds.len(), snmp.batch = batch_idx + 1, snmp.op = op_name }, "response has fewer varbinds than requested in batch");
548            }
549            all_results.extend(response.varbinds);
550        }
551
552        Ok(all_results)
553    }
554
555    /// SET a single OID.
556    #[instrument(skip(self, value), err, fields(snmp.target = %self.peer_addr(), snmp.oid = %oid))]
557    pub async fn set(&self, oid: &Oid, value: Value) -> Result<VarBind> {
558        let request_id = self.next_request_id();
559        let varbind = VarBind::new(oid.clone(), value);
560        let pdu = Pdu::set_request(request_id, vec![varbind]);
561        let response = self.send_request(pdu).await?;
562
563        response.varbinds.into_iter().next().ok_or_else(|| {
564            tracing::debug!(target: "async_snmp::client", { peer = %self.peer_addr(), kind = %DecodeErrorKind::EmptyResponse }, "empty SET response");
565            Error::MalformedResponse {
566                target: self.peer_addr(),
567            }
568            .boxed()
569        })
570    }
571
572    /// SET multiple OIDs in a single atomic PDU.
573    ///
574    /// RFC 3416 requires that a SET request be atomic: either all variables
575    /// in the request are set, or none are. To preserve this guarantee,
576    /// `set_many` refuses to split the varbind list across multiple PDUs.
577    ///
578    /// If `varbinds.len()` exceeds `max_oids_per_request`, this method
579    /// returns `Error::Config` rather than silently batching the request.
580    /// Callers that need to set more variables than the per-request limit
581    /// must issue multiple explicit `set_many` calls and handle partial
582    /// failure themselves.
583    ///
584    /// # Example
585    ///
586    /// ```rust,no_run
587    /// # use async_snmp::{Auth, Client, oid, Value};
588    /// # async fn example() -> async_snmp::Result<()> {
589    /// # let client = Client::builder("127.0.0.1:161", Auth::v2c("private")).connect().await?;
590    /// let results = client.set_many(&[
591    ///     (oid!(1, 3, 6, 1, 2, 1, 1, 5, 0), Value::from("new-hostname")),
592    ///     (oid!(1, 3, 6, 1, 2, 1, 1, 6, 0), Value::from("new-location")),
593    /// ]).await?;
594    /// # Ok(())
595    /// # }
596    /// ```
597    #[instrument(skip(self, varbinds), err, fields(snmp.target = %self.peer_addr(), snmp.oid_count = varbinds.len()))]
598    pub async fn set_many(&self, varbinds: &[(Oid, Value)]) -> Result<Vec<VarBind>> {
599        if varbinds.is_empty() {
600            return Ok(Vec::new());
601        }
602
603        let max_per_request = self.inner.config.max_oids_per_request;
604
605        if varbinds.len() > max_per_request {
606            return Err(Error::Config(
607                format!(
608                    "set_many: {} varbinds exceeds max_oids_per_request ({}); \
609                     SET must be atomic and cannot be split across PDUs",
610                    varbinds.len(),
611                    max_per_request,
612                )
613                .into(),
614            )
615            .boxed());
616        }
617
618        let request_id = self.next_request_id();
619        let vbs: Vec<VarBind> = varbinds
620            .iter()
621            .map(|(oid, value)| VarBind::new(oid.clone(), value.clone()))
622            .collect();
623        let expected_count = vbs.len();
624        let pdu = Pdu::set_request(request_id, vbs);
625        let response = self.send_request(pdu).await?;
626        if response.varbinds.len() > expected_count {
627            tracing::warn!(target: "async_snmp::client", { peer = %self.peer_addr(), expected = expected_count, actual = response.varbinds.len() }, "SET response has more varbinds than requested");
628            return Err(Error::MalformedResponse {
629                target: self.peer_addr(),
630            }
631            .boxed());
632        } else if response.varbinds.len() < expected_count {
633            tracing::warn!(target: "async_snmp::client", { peer = %self.peer_addr(), expected = expected_count, actual = response.varbinds.len() }, "SET response has fewer varbinds than requested");
634        }
635        Ok(response.varbinds)
636    }
637
638    /// GETBULK request (SNMPv2c/v3 only).
639    ///
640    /// Efficiently retrieves multiple variable bindings in a single request.
641    /// GETBULK splits the requested OIDs into two groups:
642    ///
643    /// - **Non-repeaters** (first N OIDs): Each gets a single GETNEXT, returning
644    ///   one value per OID. Use for scalar values like `sysUpTime.0`.
645    /// - **Repeaters** (remaining OIDs): Each gets up to `max_repetitions` GETNEXTs,
646    ///   returning multiple values per OID. Use for walking table columns.
647    ///
648    /// # Arguments
649    ///
650    /// * `oids` - OIDs to retrieve
651    /// * `non_repeaters` - How many OIDs (from the start) are non-repeating
652    /// * `max_repetitions` - Maximum rows to return for each repeating OID
653    ///
654    /// # Example
655    ///
656    /// ```rust,no_run
657    /// # use async_snmp::{Auth, Client, oid};
658    /// # async fn example() -> async_snmp::Result<()> {
659    /// # let client = Client::builder("127.0.0.1:161", Auth::v2c("public")).connect().await?;
660    /// // Get sysUpTime (non-repeater) plus 10 interface descriptions (repeater)
661    /// let results = client.get_bulk(
662    ///     &[oid!(1, 3, 6, 1, 2, 1, 1, 3, 0), oid!(1, 3, 6, 1, 2, 1, 2, 2, 1, 2)],
663    ///     1,  // first OID is non-repeating
664    ///     10, // get up to 10 values for the second OID
665    /// ).await?;
666    /// // Results: [sysUpTime value, ifDescr.1, ifDescr.2, ..., ifDescr.10]
667    /// # Ok(())
668    /// # }
669    /// ```
670    #[instrument(skip(self, oids), err, fields(
671        snmp.target = %self.peer_addr(),
672        snmp.oid_count = oids.len(),
673        snmp.non_repeaters = non_repeaters,
674        snmp.max_repetitions = max_repetitions
675    ))]
676    pub async fn get_bulk(
677        &self,
678        oids: &[Oid],
679        non_repeaters: i32,
680        max_repetitions: i32,
681    ) -> Result<Vec<VarBind>> {
682        let request_id = self.next_request_id();
683        let pdu = GetBulkPdu::new(request_id, non_repeaters, max_repetitions, oids);
684        let response = self.send_bulk_request(pdu).await?;
685        Ok(response.varbinds)
686    }
687
688    /// Walk an OID subtree.
689    ///
690    /// Auto-selects the optimal walk method based on SNMP version and `WalkMode`:
691    /// - `WalkMode::Auto` (default): Uses GETNEXT for V1, GETBULK for V2c/V3
692    /// - `WalkMode::GetNext`: Always uses GETNEXT
693    /// - `WalkMode::GetBulk`: Always uses GETBULK (fails on V1)
694    ///
695    /// Returns an async stream that yields each variable binding in the subtree.
696    /// The walk terminates when an OID outside the subtree is encountered or
697    /// when `EndOfMibView` is returned.
698    ///
699    /// Uses the client's configured `oid_ordering`, `max_walk_results`, and
700    /// `max_repetitions` (for GETBULK) settings.
701    ///
702    /// # Example
703    ///
704    /// ```rust,no_run
705    /// # use async_snmp::{Auth, Client, oid};
706    /// # async fn example() -> async_snmp::Result<()> {
707    /// # let client = Client::builder("127.0.0.1:161", Auth::v2c("public")).connect().await?;
708    /// // Auto-selects GETBULK for V2c/V3, GETNEXT for V1
709    /// let results = client.walk(oid!(1, 3, 6, 1, 2, 1, 1))?.collect().await?;
710    /// # Ok(())
711    /// # }
712    /// ```
713    #[instrument(skip(self), fields(snmp.target = %self.peer_addr(), snmp.oid = %oid))]
714    pub fn walk(&self, oid: Oid) -> Result<WalkStream<T>>
715    where
716        T: 'static,
717    {
718        let ordering = self.inner.config.oid_ordering;
719        let max_results = self.inner.config.max_walk_results;
720        let walk_mode = self.inner.config.walk_mode;
721        let max_repetitions = self.inner.config.max_repetitions as i32;
722        let version = self.inner.config.version;
723
724        WalkStream::new(
725            self.clone(),
726            oid,
727            version,
728            walk_mode,
729            ordering,
730            max_results,
731            max_repetitions,
732        )
733    }
734
735    /// Walk an OID subtree using GETNEXT.
736    ///
737    /// This method always uses GETNEXT regardless of the client's `WalkMode` configuration.
738    /// For auto-selection based on version and mode, use [`walk()`](Self::walk) instead.
739    ///
740    /// Returns an async stream that yields each variable binding in the subtree.
741    /// The walk terminates when an OID outside the subtree is encountered or
742    /// when `EndOfMibView` is returned.
743    ///
744    /// Uses the client's configured `oid_ordering` and `max_walk_results` settings.
745    ///
746    /// # Example
747    ///
748    /// ```rust,no_run
749    /// # use async_snmp::{Auth, Client, oid};
750    /// # async fn example() -> async_snmp::Result<()> {
751    /// # let client = Client::builder("127.0.0.1:161", Auth::v2c("public")).connect().await?;
752    /// // Force GETNEXT even for V2c/V3 clients
753    /// let results = client.walk_getnext(oid!(1, 3, 6, 1, 2, 1, 1)).collect().await?;
754    /// # Ok(())
755    /// # }
756    /// ```
757    #[instrument(skip(self), fields(snmp.target = %self.peer_addr(), snmp.oid = %oid))]
758    pub fn walk_getnext(&self, oid: Oid) -> Walk<T>
759    where
760        T: 'static,
761    {
762        let ordering = self.inner.config.oid_ordering;
763        let max_results = self.inner.config.max_walk_results;
764        Walk::new(self.clone(), oid, ordering, max_results)
765    }
766
767    /// Walk an OID subtree using GETBULK (more efficient than GETNEXT).
768    ///
769    /// Returns an async stream that yields each variable binding in the subtree.
770    /// Uses GETBULK internally with `non_repeaters=0`, fetching `max_repetitions`
771    /// values per request for efficient table traversal.
772    ///
773    /// Uses the client's configured `oid_ordering` and `max_walk_results` settings.
774    ///
775    /// # Arguments
776    ///
777    /// * `oid` - The base OID of the subtree to walk
778    /// * `max_repetitions` - How many OIDs to fetch per request
779    ///
780    /// # Example
781    ///
782    /// ```rust,no_run
783    /// # use async_snmp::{Auth, Client, oid};
784    /// # async fn example() -> async_snmp::Result<()> {
785    /// # let client = Client::builder("127.0.0.1:161", Auth::v2c("public")).connect().await?;
786    /// // Walk the interfaces table efficiently
787    /// let walk = client.bulk_walk(oid!(1, 3, 6, 1, 2, 1, 2, 2), 25);
788    /// // Process with futures StreamExt
789    /// # Ok(())
790    /// # }
791    /// ```
792    #[instrument(skip(self), fields(snmp.target = %self.peer_addr(), snmp.oid = %oid, snmp.max_repetitions = max_repetitions))]
793    pub fn bulk_walk(&self, oid: Oid, max_repetitions: i32) -> BulkWalk<T>
794    where
795        T: 'static,
796    {
797        let ordering = self.inner.config.oid_ordering;
798        let max_results = self.inner.config.max_walk_results;
799        BulkWalk::new(self.clone(), oid, max_repetitions, ordering, max_results)
800    }
801
802    /// Walk an OID subtree using the client's configured `max_repetitions`.
803    ///
804    /// This is a convenience method that uses the client's `max_repetitions` setting
805    /// (default: 25) instead of requiring it as a parameter.
806    ///
807    /// # Example
808    ///
809    /// ```rust,no_run
810    /// # use async_snmp::{Auth, Client, oid};
811    /// # async fn example() -> async_snmp::Result<()> {
812    /// # let client = Client::builder("127.0.0.1:161", Auth::v2c("public")).connect().await?;
813    /// // Walk using configured max_repetitions
814    /// let walk = client.bulk_walk_default(oid!(1, 3, 6, 1, 2, 1, 2, 2));
815    /// // Process with futures StreamExt
816    /// # Ok(())
817    /// # }
818    /// ```
819    #[instrument(skip(self), fields(snmp.target = %self.peer_addr(), snmp.oid = %oid))]
820    pub fn bulk_walk_default(&self, oid: Oid) -> BulkWalk<T>
821    where
822        T: 'static,
823    {
824        let ordering = self.inner.config.oid_ordering;
825        let max_results = self.inner.config.max_walk_results;
826        let max_repetitions = self.inner.config.max_repetitions as i32;
827        BulkWalk::new(self.clone(), oid, max_repetitions, ordering, max_results)
828    }
829}
830
831#[cfg(test)]
832mod tests {
833    use super::*;
834    use crate::message::CommunityMessage;
835    use crate::oid::Oid;
836    use crate::pdu::{Pdu, PduType};
837    use crate::varbind::VarBind;
838    use crate::version::Version;
839    use bytes::Bytes;
840    use std::collections::VecDeque;
841    use std::net::SocketAddr;
842    use std::sync::{Arc, Mutex};
843
844    // -------------------------------------------------------------------------
845    // Mock transport that returns a response with a configurable number of
846    // varbinds, regardless of how many were requested.
847    // -------------------------------------------------------------------------
848
849    #[derive(Clone)]
850    struct TruncatingTransport {
851        /// Number of varbinds to include in each response.
852        response_varbind_count: usize,
853        /// Captured (request_id) values from sent requests, stored for building
854        /// responses.
855        pending: Arc<Mutex<VecDeque<i32>>>,
856    }
857
858    impl TruncatingTransport {
859        fn new(response_varbind_count: usize) -> Self {
860            Self {
861                response_varbind_count,
862                pending: Arc::new(Mutex::new(VecDeque::new())),
863            }
864        }
865    }
866
867    impl Transport for TruncatingTransport {
868        fn send(&self, data: &[u8]) -> impl std::future::Future<Output = Result<()>> + Send {
869            // Decode the sent request to extract the request_id.
870            let request_id = crate::transport::extract_request_id(data).unwrap_or(1);
871            {
872                let mut q = self.pending.lock().unwrap();
873                q.push_back(request_id);
874            }
875            async { Ok(()) }
876        }
877
878        fn recv(
879            &self,
880            _request_id: i32,
881        ) -> impl std::future::Future<Output = Result<(Bytes, SocketAddr)>> + Send {
882            let request_id = {
883                let mut q = self.pending.lock().unwrap();
884                q.pop_front().unwrap_or(1)
885            };
886            let n = self.response_varbind_count;
887            let peer: SocketAddr = "127.0.0.1:161".parse().unwrap();
888
889            async move {
890                // Build a response PDU with n varbinds (NULL values).
891                let varbinds: Vec<VarBind> = (0..n)
892                    .map(|i| {
893                        VarBind::new(
894                            Oid::from_slice(&[1, 3, 6, 1, i as u32]),
895                            crate::value::Value::Null,
896                        )
897                    })
898                    .collect();
899
900                let pdu = Pdu {
901                    pdu_type: PduType::Response,
902                    request_id,
903                    error_status: 0,
904                    error_index: 0,
905                    varbinds,
906                };
907
908                let msg = CommunityMessage::v2c(Bytes::from_static(b"public"), pdu);
909                let encoded = msg.encode();
910                Ok((encoded, peer))
911            }
912        }
913
914        fn peer_addr(&self) -> SocketAddr {
915            "127.0.0.1:161".parse().unwrap()
916        }
917
918        fn local_addr(&self) -> SocketAddr {
919            "127.0.0.1:0".parse().unwrap()
920        }
921
922        fn is_reliable(&self) -> bool {
923            true
924        }
925    }
926
927    fn make_client(response_varbind_count: usize) -> Client<TruncatingTransport> {
928        let transport = TruncatingTransport::new(response_varbind_count);
929        let config = ClientConfig {
930            version: Version::V2c,
931            max_oids_per_request: 10,
932            retry: crate::client::retry::Retry::none(),
933            ..Default::default()
934        };
935        Client::new(transport, config)
936    }
937
938    #[tokio::test]
939    async fn get_many_warns_on_truncated_response() {
940        // Request 3 OIDs but the mock returns only 1 varbind - should warn and return what we got.
941        let client = make_client(1);
942        let oids = [
943            Oid::from_slice(&[1, 3, 6, 1, 1]),
944            Oid::from_slice(&[1, 3, 6, 1, 2]),
945            Oid::from_slice(&[1, 3, 6, 1, 3]),
946        ];
947
948        let result = client.get_many(&oids).await;
949        assert!(result.is_ok(), "expected Ok, got: {:?}", result.err());
950        assert_eq!(result.unwrap().len(), 1);
951    }
952
953    #[tokio::test]
954    async fn get_many_rejects_inflated_response() {
955        // Request 3 OIDs but the mock returns 5 varbinds.
956        let client = make_client(5);
957        let oids = [
958            Oid::from_slice(&[1, 3, 6, 1, 1]),
959            Oid::from_slice(&[1, 3, 6, 1, 2]),
960            Oid::from_slice(&[1, 3, 6, 1, 3]),
961        ];
962
963        let err = client.get_many(&oids).await.unwrap_err();
964        assert!(
965            matches!(*err, Error::MalformedResponse { .. }),
966            "expected MalformedResponse, got: {err}"
967        );
968    }
969
970    #[tokio::test]
971    async fn get_many_accepts_correct_response_count() {
972        // Request 3 OIDs and the mock returns exactly 3 varbinds.
973        let client = make_client(3);
974        let oids = [
975            Oid::from_slice(&[1, 3, 6, 1, 1]),
976            Oid::from_slice(&[1, 3, 6, 1, 2]),
977            Oid::from_slice(&[1, 3, 6, 1, 3]),
978        ];
979
980        let result = client.get_many(&oids).await;
981        assert!(result.is_ok(), "expected Ok, got: {:?}", result.err());
982        assert_eq!(result.unwrap().len(), 3);
983    }
984
985    #[tokio::test]
986    async fn get_next_many_warns_on_truncated_response() {
987        // Request 3 OIDs but the mock returns only 1 varbind - should warn and return what we got.
988        let client = make_client(1);
989        let oids = [
990            Oid::from_slice(&[1, 3, 6, 1, 1]),
991            Oid::from_slice(&[1, 3, 6, 1, 2]),
992            Oid::from_slice(&[1, 3, 6, 1, 3]),
993        ];
994
995        let result = client.get_next_many(&oids).await;
996        assert!(result.is_ok(), "expected Ok, got: {:?}", result.err());
997        assert_eq!(result.unwrap().len(), 1);
998    }
999
1000    #[tokio::test]
1001    async fn get_next_many_rejects_inflated_response() {
1002        // Request 3 OIDs but the mock returns 5 varbinds.
1003        let client = make_client(5);
1004        let oids = [
1005            Oid::from_slice(&[1, 3, 6, 1, 1]),
1006            Oid::from_slice(&[1, 3, 6, 1, 2]),
1007            Oid::from_slice(&[1, 3, 6, 1, 3]),
1008        ];
1009
1010        let err = client.get_next_many(&oids).await.unwrap_err();
1011        assert!(
1012            matches!(*err, Error::MalformedResponse { .. }),
1013            "expected MalformedResponse, got: {err}"
1014        );
1015    }
1016
1017    #[tokio::test]
1018    async fn get_next_many_accepts_correct_response_count() {
1019        // Request 3 OIDs and the mock returns exactly 3 varbinds.
1020        let client = make_client(3);
1021        let oids = [
1022            Oid::from_slice(&[1, 3, 6, 1, 1]),
1023            Oid::from_slice(&[1, 3, 6, 1, 2]),
1024            Oid::from_slice(&[1, 3, 6, 1, 3]),
1025        ];
1026
1027        let result = client.get_next_many(&oids).await;
1028        assert!(result.is_ok(), "expected Ok, got: {:?}", result.err());
1029        assert_eq!(result.unwrap().len(), 3);
1030    }
1031
1032    #[tokio::test]
1033    async fn set_many_warns_on_truncated_response() {
1034        // Request 3 varbinds but the mock returns only 1 - should warn and return what we got.
1035        let client = make_client(1);
1036        let varbinds = [
1037            (
1038                Oid::from_slice(&[1, 3, 6, 1, 1]),
1039                crate::value::Value::Integer(1),
1040            ),
1041            (
1042                Oid::from_slice(&[1, 3, 6, 1, 2]),
1043                crate::value::Value::Integer(2),
1044            ),
1045            (
1046                Oid::from_slice(&[1, 3, 6, 1, 3]),
1047                crate::value::Value::Integer(3),
1048            ),
1049        ];
1050
1051        let result = client.set_many(&varbinds).await;
1052        assert!(result.is_ok(), "expected Ok, got: {:?}", result.err());
1053        assert_eq!(result.unwrap().len(), 1);
1054    }
1055
1056    #[tokio::test]
1057    async fn set_many_rejects_inflated_response() {
1058        // Request 3 varbinds but the mock returns 5.
1059        let client = make_client(5);
1060        let varbinds = [
1061            (
1062                Oid::from_slice(&[1, 3, 6, 1, 1]),
1063                crate::value::Value::Integer(1),
1064            ),
1065            (
1066                Oid::from_slice(&[1, 3, 6, 1, 2]),
1067                crate::value::Value::Integer(2),
1068            ),
1069            (
1070                Oid::from_slice(&[1, 3, 6, 1, 3]),
1071                crate::value::Value::Integer(3),
1072            ),
1073        ];
1074
1075        let err = client.set_many(&varbinds).await.unwrap_err();
1076        assert!(
1077            matches!(*err, Error::MalformedResponse { .. }),
1078            "expected MalformedResponse, got: {err}"
1079        );
1080    }
1081
1082    #[tokio::test]
1083    async fn set_many_accepts_correct_response_count() {
1084        // Request 3 varbinds and the mock returns exactly 3.
1085        let client = make_client(3);
1086        let varbinds = [
1087            (
1088                Oid::from_slice(&[1, 3, 6, 1, 1]),
1089                crate::value::Value::Integer(1),
1090            ),
1091            (
1092                Oid::from_slice(&[1, 3, 6, 1, 2]),
1093                crate::value::Value::Integer(2),
1094            ),
1095            (
1096                Oid::from_slice(&[1, 3, 6, 1, 3]),
1097                crate::value::Value::Integer(3),
1098            ),
1099        ];
1100
1101        let result = client.set_many(&varbinds).await;
1102        assert!(result.is_ok(), "expected Ok, got: {:?}", result.err());
1103        assert_eq!(result.unwrap().len(), 3);
1104    }
1105
1106    // Batched path: get_many with more OIDs than max_per_request.
1107    #[tokio::test]
1108    async fn get_many_batched_warns_on_truncated_response() {
1109        // max_oids_per_request = 10, request 12 OIDs, mock returns 1 per batch.
1110        // Should warn and return 2 varbinds (1 per batch).
1111        let transport = TruncatingTransport::new(1);
1112        let config = ClientConfig {
1113            version: Version::V2c,
1114            max_oids_per_request: 10,
1115            retry: crate::client::retry::Retry::none(),
1116            ..Default::default()
1117        };
1118        let client = Client::new(transport, config);
1119
1120        let oids: Vec<Oid> = (0..12u32)
1121            .map(|i| Oid::from_slice(&[1, 3, 6, 1, i]))
1122            .collect();
1123
1124        let result = client.get_many(&oids).await;
1125        assert!(result.is_ok(), "expected Ok, got: {:?}", result.err());
1126        assert_eq!(result.unwrap().len(), 2); // 1 varbind per batch, 2 batches
1127    }
1128
1129    #[tokio::test]
1130    async fn get_many_batched_rejects_inflated_response() {
1131        // max_oids_per_request = 10, request 12 OIDs, mock returns 12 per batch.
1132        let transport = TruncatingTransport::new(12);
1133        let config = ClientConfig {
1134            version: Version::V2c,
1135            max_oids_per_request: 10,
1136            retry: crate::client::retry::Retry::none(),
1137            ..Default::default()
1138        };
1139        let client = Client::new(transport, config);
1140
1141        let oids: Vec<Oid> = (0..12u32)
1142            .map(|i| Oid::from_slice(&[1, 3, 6, 1, i]))
1143            .collect();
1144
1145        let err = client.get_many(&oids).await.unwrap_err();
1146        assert!(
1147            matches!(*err, Error::MalformedResponse { .. }),
1148            "expected MalformedResponse, got: {err}"
1149        );
1150    }
1151}