async_snmp/client/
walk.rs

1//! Walk stream implementations.
2
3#![allow(clippy::type_complexity)]
4
5use std::collections::HashSet;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use futures_core::Stream;
10
11use crate::error::{Error, Result, WalkAbortReason};
12use crate::oid::Oid;
13use crate::transport::Transport;
14use crate::value::Value;
15use crate::varbind::VarBind;
16use crate::version::Version;
17
18use super::Client;
19
20/// Walk operation mode.
21#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
22#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23pub enum WalkMode {
24    /// Auto-select based on version (default).
25    /// V1 uses GETNEXT, V2c/V3 uses GETBULK.
26    #[default]
27    Auto,
28    /// Always use GETNEXT (slower but more compatible).
29    GetNext,
30    /// Always use GETBULK (faster, errors on v1).
31    GetBulk,
32}
33
34/// OID ordering behavior during walk operations.
35///
36/// SNMP walks rely on agents returning OIDs in strictly increasing
37/// lexicographic order. However, some buggy agents violate this requirement,
38/// returning OIDs out of order or even repeating OIDs (which would cause
39/// infinite loops).
40///
41/// This enum controls how the library handles ordering violations:
42///
43/// - [`Strict`](Self::Strict) (default): Terminates immediately with
44///   [`Error::WalkAborted`](crate::Error::WalkAborted) on any violation.
45///   Use this unless you know the agent has ordering bugs.
46///
47/// - [`AllowNonIncreasing`](Self::AllowNonIncreasing): Tolerates out-of-order
48///   OIDs but tracks all seen OIDs to detect cycles. Returns
49///   [`Error::WalkAborted`](crate::Error::WalkAborted) if the same OID appears twice.
50#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
51#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
52pub enum OidOrdering {
53    /// Require strictly increasing OIDs (default).
54    ///
55    /// Walk terminates with [`Error::WalkAborted`](crate::Error::WalkAborted)
56    /// on first violation. Most efficient: O(1) memory, O(1) per-item check.
57    #[default]
58    Strict,
59
60    /// Allow non-increasing OIDs, with cycle detection.
61    ///
62    /// Some buggy agents return OIDs out of order. This mode tracks all seen
63    /// OIDs in a HashSet to detect cycles, terminating with an error if the
64    /// same OID is returned twice.
65    ///
66    /// **Warning**: This uses O(n) memory where n = number of walk results.
67    /// Always pair with [`ClientBuilder::max_walk_results`] to bound memory
68    /// usage. Cycle detection only catches duplicate OIDs; a pathological
69    /// agent could still return an infinite sequence of unique OIDs within
70    /// the subtree.
71    ///
72    /// [`ClientBuilder::max_walk_results`]: crate::ClientBuilder::max_walk_results
73    AllowNonIncreasing,
74}
75
76enum OidTracker {
77    Strict { last: Option<Oid> },
78    Relaxed { seen: HashSet<Oid> },
79}
80
81impl OidTracker {
82    fn new(ordering: OidOrdering) -> Self {
83        match ordering {
84            OidOrdering::Strict => OidTracker::Strict { last: None },
85            OidOrdering::AllowNonIncreasing => OidTracker::Relaxed {
86                seen: HashSet::new(),
87            },
88        }
89    }
90
91    fn check(&mut self, oid: &Oid, target: std::net::SocketAddr) -> Result<()> {
92        match self {
93            OidTracker::Strict { last } => {
94                if let Some(prev) = last
95                    && oid <= prev
96                {
97                    tracing::debug!(target: "async_snmp::walk", { previous_oid = %prev, current_oid = %oid, %target }, "non-increasing OID detected");
98                    return Err(Error::WalkAborted {
99                        target,
100                        reason: WalkAbortReason::NonIncreasing,
101                    }
102                    .boxed());
103                }
104                *last = Some(oid.clone());
105                Ok(())
106            }
107            OidTracker::Relaxed { seen } => {
108                if !seen.insert(oid.clone()) {
109                    tracing::debug!(target: "async_snmp::walk", { %oid, %target }, "duplicate OID detected (cycle)");
110                    return Err(Error::WalkAborted {
111                        target,
112                        reason: WalkAbortReason::Cycle,
113                    }
114                    .boxed());
115                }
116                Ok(())
117            }
118        }
119    }
120}
121
122/// Async stream for walking an OID subtree using GETNEXT.
123///
124/// Created by [`Client::walk_getnext()`].
125pub struct Walk<T: Transport> {
126    client: Client<T>,
127    base_oid: Oid,
128    current_oid: Oid,
129    /// OID tracker for ordering validation.
130    oid_tracker: OidTracker,
131    /// Maximum number of results to return (None = unlimited).
132    max_results: Option<usize>,
133    /// Count of results returned so far.
134    count: usize,
135    done: bool,
136    pending: Option<Pin<Box<dyn std::future::Future<Output = Result<VarBind>> + Send>>>,
137}
138
139impl<T: Transport> Walk<T> {
140    pub(crate) fn new(
141        client: Client<T>,
142        oid: Oid,
143        ordering: OidOrdering,
144        max_results: Option<usize>,
145    ) -> Self {
146        Self {
147            client,
148            base_oid: oid.clone(),
149            current_oid: oid,
150            oid_tracker: OidTracker::new(ordering),
151            max_results,
152            count: 0,
153            done: false,
154            pending: None,
155        }
156    }
157}
158
159impl<T: Transport + 'static> Walk<T> {
160    /// Get the next varbind, or None when complete.
161    pub async fn next(&mut self) -> Option<Result<VarBind>> {
162        std::future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
163    }
164
165    /// Collect all remaining varbinds.
166    pub async fn collect(mut self) -> Result<Vec<VarBind>> {
167        let mut results = Vec::new();
168        while let Some(result) = self.next().await {
169            results.push(result?);
170        }
171        Ok(results)
172    }
173}
174
175impl<T: Transport + 'static> Stream for Walk<T> {
176    type Item = Result<VarBind>;
177
178    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
179        if self.done {
180            return Poll::Ready(None);
181        }
182
183        // Check max_results limit
184        if let Some(max) = self.max_results
185            && self.count >= max
186        {
187            self.done = true;
188            return Poll::Ready(None);
189        }
190
191        // Check if we have a pending request
192        if self.pending.is_none() {
193            // Start a new GETNEXT request
194            let client = self.client.clone();
195            let oid = self.current_oid.clone();
196
197            let fut = Box::pin(async move { client.get_next(&oid).await });
198            self.pending = Some(fut);
199        }
200
201        // Poll the pending future
202        let pending = self.pending.as_mut().unwrap();
203        match pending.as_mut().poll(cx) {
204            Poll::Pending => Poll::Pending,
205            Poll::Ready(result) => {
206                self.pending = None;
207
208                match result {
209                    Ok(vb) => {
210                        // Check for end conditions
211                        if matches!(vb.value, Value::EndOfMibView) {
212                            self.done = true;
213                            return Poll::Ready(None);
214                        }
215
216                        // Check if OID left the subtree
217                        if !vb.oid.starts_with(&self.base_oid) {
218                            self.done = true;
219                            return Poll::Ready(None);
220                        }
221
222                        // Check OID ordering using the tracker
223                        let target = self.client.peer_addr();
224                        if let Err(e) = self.oid_tracker.check(&vb.oid, target) {
225                            self.done = true;
226                            return Poll::Ready(Some(Err(e)));
227                        }
228
229                        // Update current OID for next iteration
230                        self.current_oid = vb.oid.clone();
231                        self.count += 1;
232
233                        Poll::Ready(Some(Ok(vb)))
234                    }
235                    Err(e) => {
236                        self.done = true;
237                        Poll::Ready(Some(Err(e)))
238                    }
239                }
240            }
241        }
242    }
243}
244
245/// Async stream for walking an OID subtree using GETBULK.
246///
247/// Created by [`Client::bulk_walk()`].
248pub struct BulkWalk<T: Transport> {
249    client: Client<T>,
250    base_oid: Oid,
251    current_oid: Oid,
252    max_repetitions: i32,
253    /// OID tracker for ordering validation.
254    oid_tracker: OidTracker,
255    /// Maximum number of results to return (None = unlimited).
256    max_results: Option<usize>,
257    /// Count of results returned so far.
258    count: usize,
259    done: bool,
260    /// Buffered results from the last GETBULK response
261    buffer: Vec<VarBind>,
262    /// Index into the buffer
263    buffer_idx: usize,
264    pending: Option<Pin<Box<dyn std::future::Future<Output = Result<Vec<VarBind>>> + Send>>>,
265}
266
267impl<T: Transport> BulkWalk<T> {
268    pub(crate) fn new(
269        client: Client<T>,
270        oid: Oid,
271        max_repetitions: i32,
272        ordering: OidOrdering,
273        max_results: Option<usize>,
274    ) -> Self {
275        Self {
276            client,
277            base_oid: oid.clone(),
278            current_oid: oid,
279            max_repetitions,
280            oid_tracker: OidTracker::new(ordering),
281            max_results,
282            count: 0,
283            done: false,
284            buffer: Vec::new(),
285            buffer_idx: 0,
286            pending: None,
287        }
288    }
289}
290
291impl<T: Transport + 'static> BulkWalk<T> {
292    /// Get the next varbind, or None when complete.
293    pub async fn next(&mut self) -> Option<Result<VarBind>> {
294        std::future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
295    }
296
297    /// Collect all remaining varbinds.
298    pub async fn collect(mut self) -> Result<Vec<VarBind>> {
299        let mut results = Vec::new();
300        while let Some(result) = self.next().await {
301            results.push(result?);
302        }
303        Ok(results)
304    }
305}
306
307impl<T: Transport + 'static> Stream for BulkWalk<T> {
308    type Item = Result<VarBind>;
309
310    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
311        loop {
312            if self.done {
313                return Poll::Ready(None);
314            }
315
316            // Check max_results limit
317            if let Some(max) = self.max_results
318                && self.count >= max
319            {
320                self.done = true;
321                return Poll::Ready(None);
322            }
323
324            // Check if we have buffered results to return
325            if self.buffer_idx < self.buffer.len() {
326                let vb = self.buffer[self.buffer_idx].clone();
327                self.buffer_idx += 1;
328
329                // Check for end conditions
330                if matches!(vb.value, Value::EndOfMibView) {
331                    self.done = true;
332                    return Poll::Ready(None);
333                }
334
335                // Check if OID left the subtree
336                if !vb.oid.starts_with(&self.base_oid) {
337                    self.done = true;
338                    return Poll::Ready(None);
339                }
340
341                // Check OID ordering using the tracker
342                let target = self.client.peer_addr();
343                if let Err(e) = self.oid_tracker.check(&vb.oid, target) {
344                    self.done = true;
345                    return Poll::Ready(Some(Err(e)));
346                }
347
348                // Update current OID for next request
349                self.current_oid = vb.oid.clone();
350                self.count += 1;
351
352                return Poll::Ready(Some(Ok(vb)));
353            }
354
355            // Buffer exhausted, need to fetch more
356            if self.pending.is_none() {
357                let client = self.client.clone();
358                let oid = self.current_oid.clone();
359                let max_rep = self.max_repetitions;
360
361                let fut = Box::pin(async move { client.get_bulk(&[oid], 0, max_rep).await });
362                self.pending = Some(fut);
363            }
364
365            // Poll the pending future
366            let pending = self.pending.as_mut().unwrap();
367            match pending.as_mut().poll(cx) {
368                Poll::Pending => return Poll::Pending,
369                Poll::Ready(result) => {
370                    self.pending = None;
371
372                    match result {
373                        Ok(varbinds) => {
374                            if varbinds.is_empty() {
375                                self.done = true;
376                                return Poll::Ready(None);
377                            }
378
379                            self.buffer = varbinds;
380                            self.buffer_idx = 0;
381                            // Continue loop to process buffer
382                        }
383                        Err(e) => {
384                            self.done = true;
385                            return Poll::Ready(Some(Err(e)));
386                        }
387                    }
388                }
389            }
390        }
391    }
392}
393
394// ============================================================================
395// Unified WalkStream - auto-selects GETNEXT or GETBULK based on WalkMode
396// ============================================================================
397
398/// Unified walk stream that auto-selects between GETNEXT and GETBULK.
399///
400/// Created by [`Client::walk()`] when using `WalkMode::Auto` or explicit mode selection.
401/// This type wraps either a [`Walk`] or [`BulkWalk`] internally based on:
402/// - `WalkMode::Auto`: Uses GETNEXT for V1, GETBULK for V2c/V3
403/// - `WalkMode::GetNext`: Always uses GETNEXT
404/// - `WalkMode::GetBulk`: Always uses GETBULK (fails on V1)
405pub enum WalkStream<T: Transport> {
406    /// GETNEXT-based walk (used for V1 or when explicitly requested)
407    GetNext(Walk<T>),
408    /// GETBULK-based walk (used for V2c/V3 or when explicitly requested)
409    GetBulk(BulkWalk<T>),
410}
411
412impl<T: Transport> WalkStream<T> {
413    /// Create a new walk stream with auto-selection based on version and walk mode.
414    pub(crate) fn new(
415        client: Client<T>,
416        oid: Oid,
417        version: Version,
418        walk_mode: WalkMode,
419        ordering: OidOrdering,
420        max_results: Option<usize>,
421        max_repetitions: i32,
422    ) -> Result<Self> {
423        let use_bulk = match walk_mode {
424            WalkMode::Auto => version != Version::V1,
425            WalkMode::GetNext => false,
426            WalkMode::GetBulk => {
427                if version == Version::V1 {
428                    return Err(Error::Config("GETBULK is not supported in SNMPv1".into()).boxed());
429                }
430                true
431            }
432        };
433
434        Ok(if use_bulk {
435            WalkStream::GetBulk(BulkWalk::new(
436                client,
437                oid,
438                max_repetitions,
439                ordering,
440                max_results,
441            ))
442        } else {
443            WalkStream::GetNext(Walk::new(client, oid, ordering, max_results))
444        })
445    }
446}
447
448impl<T: Transport + 'static> WalkStream<T> {
449    /// Get the next varbind, or None when complete.
450    pub async fn next(&mut self) -> Option<Result<VarBind>> {
451        std::future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
452    }
453
454    /// Collect all remaining varbinds.
455    pub async fn collect(mut self) -> Result<Vec<VarBind>> {
456        let mut results = Vec::new();
457        while let Some(result) = self.next().await {
458            results.push(result?);
459        }
460        Ok(results)
461    }
462}
463
464impl<T: Transport + 'static> Stream for WalkStream<T> {
465    type Item = Result<VarBind>;
466
467    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
468        // SAFETY: We're just projecting the pin to the inner enum variant
469        match self.get_mut() {
470            WalkStream::GetNext(walk) => Pin::new(walk).poll_next(cx),
471            WalkStream::GetBulk(bulk_walk) => Pin::new(bulk_walk).poll_next(cx),
472        }
473    }
474}