Skip to main content

async_snmp/client/
walk.rs

1//! Walk stream implementations.
2
3// Allow complex types for boxed futures in manual Stream implementations.
4// The `pending` fields require `Option<Pin<Box<dyn Future<Output = ...> + Send>>>`
5// which triggers this lint but is the standard pattern for storing futures.
6#![allow(clippy::type_complexity)]
7
8/// Implement `next()` and `collect()` for a Stream type that implements poll_next.
9macro_rules! impl_stream_helpers {
10    ($type:ident < $($gen:tt),+ >) => {
11        impl<$($gen),+> $type<$($gen),+>
12        where
13            $($gen: crate::transport::Transport + 'static,)+
14        {
15            /// Get the next varbind, or None when complete.
16            pub async fn next(&mut self) -> Option<crate::error::Result<crate::varbind::VarBind>> {
17                std::future::poll_fn(|cx| std::pin::Pin::new(&mut *self).poll_next(cx)).await
18            }
19
20            /// Collect all remaining varbinds.
21            pub async fn collect(mut self) -> crate::error::Result<Vec<crate::varbind::VarBind>> {
22                let mut results = Vec::new();
23                while let Some(result) = self.next().await {
24                    results.push(result?);
25                }
26                Ok(results)
27            }
28        }
29    };
30}
31
32use std::collections::{HashSet, VecDeque};
33use std::pin::Pin;
34use std::task::{Context, Poll};
35
36use futures_core::Stream;
37
38use crate::error::{Error, Result, WalkAbortReason};
39use crate::oid::Oid;
40use crate::transport::Transport;
41use crate::value::Value;
42use crate::varbind::VarBind;
43use crate::version::Version;
44
45use super::Client;
46
47/// Walk operation mode.
48#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
49pub enum WalkMode {
50    /// Auto-select based on version (default).
51    /// V1 uses GETNEXT, V2c/V3 uses GETBULK.
52    #[default]
53    Auto,
54    /// Always use GETNEXT (slower but more compatible).
55    GetNext,
56    /// Always use GETBULK (faster, errors on v1).
57    GetBulk,
58}
59
60/// OID ordering behavior during walk operations.
61///
62/// SNMP walks rely on agents returning OIDs in strictly increasing
63/// lexicographic order. However, some buggy agents violate this requirement,
64/// returning OIDs out of order or even repeating OIDs (which would cause
65/// infinite loops).
66///
67/// This enum controls how the library handles ordering violations:
68///
69/// - [`Strict`](Self::Strict) (default): Terminates immediately with
70///   [`Error::WalkAborted`](crate::Error::WalkAborted) on any violation.
71///   Use this unless you know the agent has ordering bugs.
72///
73/// - [`AllowNonIncreasing`](Self::AllowNonIncreasing): Tolerates out-of-order
74///   OIDs but tracks all seen OIDs to detect cycles. Returns
75///   [`Error::WalkAborted`](crate::Error::WalkAborted) if the same OID appears twice.
76#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
77pub enum OidOrdering {
78    /// Require strictly increasing OIDs (default).
79    ///
80    /// Walk terminates with [`Error::WalkAborted`](crate::Error::WalkAborted)
81    /// on first violation. Most efficient: O(1) memory, O(1) per-item check.
82    #[default]
83    Strict,
84
85    /// Allow non-increasing OIDs, with cycle detection.
86    ///
87    /// Some buggy agents return OIDs out of order. This mode tracks all seen
88    /// OIDs in a HashSet to detect cycles, terminating with an error if the
89    /// same OID is returned twice.
90    ///
91    /// **Warning**: This uses O(n) memory where n = number of walk results.
92    /// Always pair with [`ClientBuilder::max_walk_results`] to bound memory
93    /// usage. Cycle detection only catches duplicate OIDs; a pathological
94    /// agent could still return an infinite sequence of unique OIDs within
95    /// the subtree.
96    ///
97    /// [`ClientBuilder::max_walk_results`]: crate::ClientBuilder::max_walk_results
98    AllowNonIncreasing,
99}
100
101enum OidTracker {
102    Strict { last: Option<Oid> },
103    Relaxed { seen: HashSet<Oid> },
104}
105
106/// Outcome of validating a single varbind from a walk response.
107enum VarbindOutcome {
108    /// Varbind is valid and within the subtree; emit it.
109    Yield,
110    /// Walk is complete (end-of-MIB or out-of-subtree).
111    Done,
112    /// Walk should abort with the given error.
113    Abort(Box<Error>),
114}
115
116/// Validate a varbind received during a walk.
117///
118/// Checks end-of-MIB, subtree containment, and OID ordering.
119/// Returns the outcome, updating `oid_tracker` on success.
120fn validate_walk_varbind(
121    vb: &VarBind,
122    base_oid: &Oid,
123    oid_tracker: &mut OidTracker,
124    target: std::net::SocketAddr,
125) -> VarbindOutcome {
126    if matches!(
127        vb.value,
128        Value::EndOfMibView | Value::NoSuchObject | Value::NoSuchInstance
129    ) {
130        return VarbindOutcome::Done;
131    }
132    if !vb.oid.starts_with(base_oid) {
133        return VarbindOutcome::Done;
134    }
135    match oid_tracker.check(&vb.oid, target) {
136        Ok(()) => VarbindOutcome::Yield,
137        Err(e) => VarbindOutcome::Abort(e),
138    }
139}
140
141impl OidTracker {
142    fn new(ordering: OidOrdering) -> Self {
143        match ordering {
144            OidOrdering::Strict => OidTracker::Strict { last: None },
145            OidOrdering::AllowNonIncreasing => OidTracker::Relaxed {
146                seen: HashSet::new(),
147            },
148        }
149    }
150
151    fn check(&mut self, oid: &Oid, target: std::net::SocketAddr) -> Result<()> {
152        match self {
153            OidTracker::Strict { last } => {
154                if let Some(prev) = last
155                    && oid <= prev
156                {
157                    tracing::debug!(target: "async_snmp::walk", { previous_oid = %prev, current_oid = %oid, %target }, "non-increasing OID detected");
158                    return Err(Error::WalkAborted {
159                        target,
160                        reason: WalkAbortReason::NonIncreasing,
161                    }
162                    .boxed());
163                }
164                *last = Some(oid.clone());
165                Ok(())
166            }
167            OidTracker::Relaxed { seen } => {
168                if !seen.insert(oid.clone()) {
169                    tracing::debug!(target: "async_snmp::walk", { %oid, %target }, "duplicate OID detected (cycle)");
170                    return Err(Error::WalkAborted {
171                        target,
172                        reason: WalkAbortReason::Cycle,
173                    }
174                    .boxed());
175                }
176                Ok(())
177            }
178        }
179    }
180}
181
182/// Async stream for walking an OID subtree using GETNEXT.
183///
184/// Created by [`Client::walk_getnext()`].
185pub struct Walk<T: Transport> {
186    client: Client<T>,
187    base_oid: Oid,
188    current_oid: Oid,
189    /// OID tracker for ordering validation.
190    oid_tracker: OidTracker,
191    /// Maximum number of results to return (None = unlimited).
192    max_results: Option<usize>,
193    /// Count of results returned so far.
194    count: usize,
195    done: bool,
196    pending: Option<Pin<Box<dyn std::future::Future<Output = Result<VarBind>> + Send>>>,
197}
198
199impl<T: Transport> Walk<T> {
200    pub(crate) fn new(
201        client: Client<T>,
202        oid: Oid,
203        ordering: OidOrdering,
204        max_results: Option<usize>,
205    ) -> Self {
206        Self {
207            client,
208            base_oid: oid.clone(),
209            current_oid: oid,
210            oid_tracker: OidTracker::new(ordering),
211            max_results,
212            count: 0,
213            done: false,
214            pending: None,
215        }
216    }
217}
218
219impl_stream_helpers!(Walk<T>);
220
221impl<T: Transport + 'static> Stream for Walk<T> {
222    type Item = Result<VarBind>;
223
224    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
225        if self.done {
226            return Poll::Ready(None);
227        }
228
229        // Check max_results limit
230        if let Some(max) = self.max_results
231            && self.count >= max
232        {
233            self.done = true;
234            return Poll::Ready(None);
235        }
236
237        // Check if we have a pending request
238        if self.pending.is_none() {
239            // Start a new GETNEXT request
240            let client = self.client.clone();
241            let oid = self.current_oid.clone();
242
243            let fut = Box::pin(async move { client.get_next(&oid).await });
244            self.pending = Some(fut);
245        }
246
247        // Poll the pending future
248        let pending = self.pending.as_mut().unwrap();
249        match pending.as_mut().poll(cx) {
250            Poll::Pending => Poll::Pending,
251            Poll::Ready(result) => {
252                self.pending = None;
253
254                match result {
255                    Ok(vb) => {
256                        let target = self.client.peer_addr();
257                        let base_oid = self.base_oid.clone();
258                        match validate_walk_varbind(&vb, &base_oid, &mut self.oid_tracker, target) {
259                            VarbindOutcome::Done => {
260                                self.done = true;
261                                return Poll::Ready(None);
262                            }
263                            VarbindOutcome::Abort(e) => {
264                                self.done = true;
265                                return Poll::Ready(Some(Err(e)));
266                            }
267                            VarbindOutcome::Yield => {}
268                        }
269
270                        // Update current OID for next iteration
271                        self.current_oid = vb.oid.clone();
272                        self.count += 1;
273
274                        Poll::Ready(Some(Ok(vb)))
275                    }
276                    Err(e) => {
277                        if self.client.inner.config.version == Version::V1
278                            && matches!(
279                                &*e,
280                                Error::Snmp {
281                                    status: crate::error::ErrorStatus::NoSuchName,
282                                    ..
283                                }
284                            )
285                        {
286                            self.done = true;
287                            return Poll::Ready(None);
288                        }
289
290                        self.done = true;
291                        Poll::Ready(Some(Err(e)))
292                    }
293                }
294            }
295        }
296    }
297}
298
299/// Async stream for walking an OID subtree using GETBULK.
300///
301/// Created by [`Client::bulk_walk()`].
302pub struct BulkWalk<T: Transport> {
303    client: Client<T>,
304    base_oid: Oid,
305    current_oid: Oid,
306    max_repetitions: i32,
307    /// OID tracker for ordering validation.
308    oid_tracker: OidTracker,
309    /// Maximum number of results to return (None = unlimited).
310    max_results: Option<usize>,
311    /// Count of results returned so far.
312    count: usize,
313    done: bool,
314    /// Buffered results from the last GETBULK response
315    buffer: VecDeque<VarBind>,
316    pending: Option<Pin<Box<dyn std::future::Future<Output = Result<Vec<VarBind>>> + Send>>>,
317}
318
319impl<T: Transport> BulkWalk<T> {
320    pub(crate) fn new(
321        client: Client<T>,
322        oid: Oid,
323        max_repetitions: i32,
324        ordering: OidOrdering,
325        max_results: Option<usize>,
326    ) -> Self {
327        Self {
328            client,
329            base_oid: oid.clone(),
330            current_oid: oid,
331            max_repetitions,
332            oid_tracker: OidTracker::new(ordering),
333            max_results,
334            count: 0,
335            done: false,
336            buffer: VecDeque::new(),
337            pending: None,
338        }
339    }
340}
341
342impl_stream_helpers!(BulkWalk<T>);
343
344impl<T: Transport + 'static> Stream for BulkWalk<T> {
345    type Item = Result<VarBind>;
346
347    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
348        loop {
349            if self.done {
350                return Poll::Ready(None);
351            }
352
353            // Check max_results limit
354            if let Some(max) = self.max_results
355                && self.count >= max
356            {
357                self.done = true;
358                return Poll::Ready(None);
359            }
360
361            // Check if we have buffered results to return
362            if let Some(vb) = self.buffer.pop_front() {
363                let target = self.client.peer_addr();
364                let base_oid = self.base_oid.clone();
365                match validate_walk_varbind(&vb, &base_oid, &mut self.oid_tracker, target) {
366                    VarbindOutcome::Done => {
367                        self.done = true;
368                        return Poll::Ready(None);
369                    }
370                    VarbindOutcome::Abort(e) => {
371                        self.done = true;
372                        return Poll::Ready(Some(Err(e)));
373                    }
374                    VarbindOutcome::Yield => {}
375                }
376
377                // Update current OID for next request
378                self.current_oid = vb.oid.clone();
379                self.count += 1;
380
381                return Poll::Ready(Some(Ok(vb)));
382            }
383
384            // Buffer exhausted, need to fetch more
385            if self.pending.is_none() {
386                let client = self.client.clone();
387                let oid = self.current_oid.clone();
388                let max_rep = self.max_repetitions;
389
390                let fut = Box::pin(async move { client.get_bulk(&[oid], 0, max_rep).await });
391                self.pending = Some(fut);
392            }
393
394            // Poll the pending future
395            let pending = self.pending.as_mut().unwrap();
396            match pending.as_mut().poll(cx) {
397                Poll::Pending => return Poll::Pending,
398                Poll::Ready(result) => {
399                    self.pending = None;
400
401                    match result {
402                        Ok(varbinds) => {
403                            if varbinds.is_empty() {
404                                self.done = true;
405                                return Poll::Ready(None);
406                            }
407
408                            self.buffer = varbinds.into();
409                            // Continue loop to process buffer
410                        }
411                        Err(e) => {
412                            self.done = true;
413                            return Poll::Ready(Some(Err(e)));
414                        }
415                    }
416                }
417            }
418        }
419    }
420}
421
422// ============================================================================
423// Unified WalkStream - auto-selects GETNEXT or GETBULK based on WalkMode
424// ============================================================================
425
426/// Unified walk stream that auto-selects between GETNEXT and GETBULK.
427///
428/// Created by [`Client::walk()`] when using `WalkMode::Auto` or explicit mode selection.
429/// This type wraps either a [`Walk`] or [`BulkWalk`] internally based on:
430/// - `WalkMode::Auto`: Uses GETNEXT for V1, GETBULK for V2c/V3
431/// - `WalkMode::GetNext`: Always uses GETNEXT
432/// - `WalkMode::GetBulk`: Always uses GETBULK (fails on V1)
433pub enum WalkStream<T: Transport> {
434    /// GETNEXT-based walk (used for V1 or when explicitly requested)
435    GetNext(Walk<T>),
436    /// GETBULK-based walk (used for V2c/V3 or when explicitly requested)
437    GetBulk(BulkWalk<T>),
438}
439
440impl<T: Transport> WalkStream<T> {
441    /// Create a new walk stream with auto-selection based on version and walk mode.
442    pub(crate) fn new(
443        client: Client<T>,
444        oid: Oid,
445        version: Version,
446        walk_mode: WalkMode,
447        ordering: OidOrdering,
448        max_results: Option<usize>,
449        max_repetitions: i32,
450    ) -> Result<Self> {
451        let use_bulk = match walk_mode {
452            WalkMode::Auto => version != Version::V1,
453            WalkMode::GetNext => false,
454            WalkMode::GetBulk => {
455                if version == Version::V1 {
456                    return Err(Error::Config("GETBULK is not supported in SNMPv1".into()).boxed());
457                }
458                true
459            }
460        };
461
462        Ok(if use_bulk {
463            WalkStream::GetBulk(BulkWalk::new(
464                client,
465                oid,
466                max_repetitions,
467                ordering,
468                max_results,
469            ))
470        } else {
471            WalkStream::GetNext(Walk::new(client, oid, ordering, max_results))
472        })
473    }
474}
475
476impl<T: Transport + 'static> WalkStream<T> {
477    /// Get the next varbind, or None when complete.
478    pub async fn next(&mut self) -> Option<Result<VarBind>> {
479        std::future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
480    }
481
482    /// Collect all remaining varbinds.
483    ///
484    /// If the walk completes with no results, a fallback GET is attempted on the
485    /// base OID. This handles scalar OIDs (e.g. `sysDescr.0`) where GETNEXT would
486    /// walk past the value. The GET result is only returned if it contains a real
487    /// value (not `NoSuchObject`, `NoSuchInstance`, or `EndOfMibView`).
488    pub async fn collect(mut self) -> Result<Vec<VarBind>> {
489        let mut results = Vec::new();
490        while let Some(result) = self.next().await {
491            results.push(result?);
492        }
493        if results.is_empty() {
494            let (client, base_oid) = match &self {
495                WalkStream::GetNext(w) => (&w.client, &w.base_oid),
496                WalkStream::GetBulk(bw) => (&bw.client, &bw.base_oid),
497            };
498            match client.get(base_oid).await {
499                Ok(vb)
500                    if !matches!(
501                        vb.value,
502                        Value::NoSuchObject | Value::NoSuchInstance | Value::EndOfMibView
503                    ) =>
504                {
505                    results.push(vb);
506                }
507                _ => {}
508            }
509        }
510        Ok(results)
511    }
512}
513
514impl<T: Transport + 'static> Stream for WalkStream<T> {
515    type Item = Result<VarBind>;
516
517    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
518        // SAFETY: We're just projecting the pin to the inner enum variant
519        match self.get_mut() {
520            WalkStream::GetNext(walk) => Pin::new(walk).poll_next(cx),
521            WalkStream::GetBulk(bulk_walk) => Pin::new(bulk_walk).poll_next(cx),
522        }
523    }
524}
525
526#[cfg(test)]
527mod tests {
528    use super::*;
529    use crate::oid;
530
531    fn target_addr() -> std::net::SocketAddr {
532        "127.0.0.1:161".parse().unwrap()
533    }
534
535    #[test]
536    fn test_walk_terminates_on_no_such_object() {
537        let base = oid!(1, 3, 6, 1, 2, 1, 1);
538        let mut tracker = OidTracker::new(OidOrdering::Strict);
539        let vb = VarBind::new(oid!(1, 3, 6, 1, 2, 1, 1, 1, 0), Value::NoSuchObject);
540        assert!(matches!(
541            validate_walk_varbind(&vb, &base, &mut tracker, target_addr()),
542            VarbindOutcome::Done
543        ));
544    }
545
546    #[test]
547    fn test_walk_terminates_on_no_such_instance() {
548        let base = oid!(1, 3, 6, 1, 2, 1, 1);
549        let mut tracker = OidTracker::new(OidOrdering::Strict);
550        let vb = VarBind::new(oid!(1, 3, 6, 1, 2, 1, 1, 1, 0), Value::NoSuchInstance);
551        assert!(matches!(
552            validate_walk_varbind(&vb, &base, &mut tracker, target_addr()),
553            VarbindOutcome::Done
554        ));
555    }
556
557    #[test]
558    fn test_walk_terminates_on_end_of_mib_view() {
559        let base = oid!(1, 3, 6, 1, 2, 1, 1);
560        let mut tracker = OidTracker::new(OidOrdering::Strict);
561        let vb = VarBind::new(oid!(1, 3, 6, 1, 2, 1, 1, 1, 0), Value::EndOfMibView);
562        assert!(matches!(
563            validate_walk_varbind(&vb, &base, &mut tracker, target_addr()),
564            VarbindOutcome::Done
565        ));
566    }
567
568    #[test]
569    fn test_walk_yields_normal_value() {
570        let base = oid!(1, 3, 6, 1, 2, 1, 1);
571        let mut tracker = OidTracker::new(OidOrdering::Strict);
572        let vb = VarBind::new(oid!(1, 3, 6, 1, 2, 1, 1, 1, 0), Value::Integer(42));
573        assert!(matches!(
574            validate_walk_varbind(&vb, &base, &mut tracker, target_addr()),
575            VarbindOutcome::Yield
576        ));
577    }
578}