Skip to main content

async_snmp/client/
walk.rs

1//! Walk stream implementations.
2
3// Allow complex types for boxed futures in manual Stream implementations.
4// The `pending` fields require `Option<Pin<Box<dyn Future<Output = ...> + Send>>>`
5// which triggers this lint but is the standard pattern for storing futures.
6#![allow(clippy::type_complexity)]
7
8/// Implement `next()` and `collect()` for a Stream type that implements poll_next.
9macro_rules! impl_stream_helpers {
10    ($type:ident < $($gen:tt),+ >) => {
11        impl<$($gen),+> $type<$($gen),+>
12        where
13            $($gen: crate::transport::Transport + 'static,)+
14        {
15            /// Get the next varbind, or None when complete.
16            pub async fn next(&mut self) -> Option<crate::error::Result<crate::varbind::VarBind>> {
17                std::future::poll_fn(|cx| std::pin::Pin::new(&mut *self).poll_next(cx)).await
18            }
19
20            /// Collect all remaining varbinds.
21            pub async fn collect(mut self) -> crate::error::Result<Vec<crate::varbind::VarBind>> {
22                let mut results = Vec::new();
23                while let Some(result) = self.next().await {
24                    results.push(result?);
25                }
26                Ok(results)
27            }
28        }
29    };
30}
31
32use std::collections::{HashSet, VecDeque};
33use std::pin::Pin;
34use std::task::{Context, Poll};
35
36use futures_core::Stream;
37
38use crate::error::{Error, Result, WalkAbortReason};
39use crate::oid::Oid;
40use crate::transport::Transport;
41use crate::value::Value;
42use crate::varbind::VarBind;
43use crate::version::Version;
44
45use super::Client;
46
47/// Walk operation mode.
48#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
49pub enum WalkMode {
50    /// Auto-select based on version (default).
51    /// V1 uses GETNEXT, V2c/V3 uses GETBULK.
52    #[default]
53    Auto,
54    /// Always use GETNEXT (slower but more compatible).
55    GetNext,
56    /// Always use GETBULK (faster, errors on v1).
57    GetBulk,
58}
59
60/// OID ordering behavior during walk operations.
61///
62/// SNMP walks rely on agents returning OIDs in strictly increasing
63/// lexicographic order. However, some buggy agents violate this requirement,
64/// returning OIDs out of order or even repeating OIDs (which would cause
65/// infinite loops).
66///
67/// This enum controls how the library handles ordering violations:
68///
69/// - [`Strict`](Self::Strict) (default): Terminates immediately with
70///   [`Error::WalkAborted`](crate::Error::WalkAborted) on any violation.
71///   Use this unless you know the agent has ordering bugs.
72///
73/// - [`AllowNonIncreasing`](Self::AllowNonIncreasing): Tolerates out-of-order
74///   OIDs but tracks all seen OIDs to detect cycles. Returns
75///   [`Error::WalkAborted`](crate::Error::WalkAborted) if the same OID appears twice.
76#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
77pub enum OidOrdering {
78    /// Require strictly increasing OIDs (default).
79    ///
80    /// Walk terminates with [`Error::WalkAborted`](crate::Error::WalkAborted)
81    /// on first violation. Most efficient: O(1) memory, O(1) per-item check.
82    #[default]
83    Strict,
84
85    /// Allow non-increasing OIDs, with cycle detection.
86    ///
87    /// Some buggy agents return OIDs out of order. This mode tracks all seen
88    /// OIDs in a HashSet to detect cycles, terminating with an error if the
89    /// same OID is returned twice.
90    ///
91    /// **Warning**: This uses O(n) memory where n = number of walk results.
92    /// Always pair with [`ClientBuilder::max_walk_results`] to bound memory
93    /// usage. Cycle detection only catches duplicate OIDs; a pathological
94    /// agent could still return an infinite sequence of unique OIDs within
95    /// the subtree.
96    ///
97    /// [`ClientBuilder::max_walk_results`]: crate::ClientBuilder::max_walk_results
98    AllowNonIncreasing,
99}
100
101enum OidTracker {
102    Strict { last: Option<Oid> },
103    Relaxed { seen: HashSet<Oid> },
104}
105
106/// Outcome of validating a single varbind from a walk response.
107enum VarbindOutcome {
108    /// Varbind is valid and within the subtree; emit it.
109    Yield,
110    /// Walk is complete (end-of-MIB or out-of-subtree).
111    Done,
112    /// Walk should abort with the given error.
113    Abort(Box<Error>),
114}
115
116/// Validate a varbind received during a walk.
117///
118/// Checks end-of-MIB, subtree containment, and OID ordering.
119/// Returns the outcome, updating `oid_tracker` on success.
120fn validate_walk_varbind(
121    vb: &VarBind,
122    base_oid: &Oid,
123    oid_tracker: &mut OidTracker,
124    target: std::net::SocketAddr,
125) -> VarbindOutcome {
126    if matches!(vb.value, Value::EndOfMibView) {
127        return VarbindOutcome::Done;
128    }
129    if !vb.oid.starts_with(base_oid) {
130        return VarbindOutcome::Done;
131    }
132    match oid_tracker.check(&vb.oid, target) {
133        Ok(()) => VarbindOutcome::Yield,
134        Err(e) => VarbindOutcome::Abort(e),
135    }
136}
137
138impl OidTracker {
139    fn new(ordering: OidOrdering) -> Self {
140        match ordering {
141            OidOrdering::Strict => OidTracker::Strict { last: None },
142            OidOrdering::AllowNonIncreasing => OidTracker::Relaxed {
143                seen: HashSet::new(),
144            },
145        }
146    }
147
148    fn check(&mut self, oid: &Oid, target: std::net::SocketAddr) -> Result<()> {
149        match self {
150            OidTracker::Strict { last } => {
151                if let Some(prev) = last
152                    && oid <= prev
153                {
154                    tracing::debug!(target: "async_snmp::walk", { previous_oid = %prev, current_oid = %oid, %target }, "non-increasing OID detected");
155                    return Err(Error::WalkAborted {
156                        target,
157                        reason: WalkAbortReason::NonIncreasing,
158                    }
159                    .boxed());
160                }
161                *last = Some(oid.clone());
162                Ok(())
163            }
164            OidTracker::Relaxed { seen } => {
165                if !seen.insert(oid.clone()) {
166                    tracing::debug!(target: "async_snmp::walk", { %oid, %target }, "duplicate OID detected (cycle)");
167                    return Err(Error::WalkAborted {
168                        target,
169                        reason: WalkAbortReason::Cycle,
170                    }
171                    .boxed());
172                }
173                Ok(())
174            }
175        }
176    }
177}
178
179/// Async stream for walking an OID subtree using GETNEXT.
180///
181/// Created by [`Client::walk_getnext()`].
182pub struct Walk<T: Transport> {
183    client: Client<T>,
184    base_oid: Oid,
185    current_oid: Oid,
186    /// OID tracker for ordering validation.
187    oid_tracker: OidTracker,
188    /// Maximum number of results to return (None = unlimited).
189    max_results: Option<usize>,
190    /// Count of results returned so far.
191    count: usize,
192    done: bool,
193    pending: Option<Pin<Box<dyn std::future::Future<Output = Result<VarBind>> + Send>>>,
194}
195
196impl<T: Transport> Walk<T> {
197    pub(crate) fn new(
198        client: Client<T>,
199        oid: Oid,
200        ordering: OidOrdering,
201        max_results: Option<usize>,
202    ) -> Self {
203        Self {
204            client,
205            base_oid: oid.clone(),
206            current_oid: oid,
207            oid_tracker: OidTracker::new(ordering),
208            max_results,
209            count: 0,
210            done: false,
211            pending: None,
212        }
213    }
214}
215
216impl_stream_helpers!(Walk<T>);
217
218impl<T: Transport + 'static> Stream for Walk<T> {
219    type Item = Result<VarBind>;
220
221    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
222        if self.done {
223            return Poll::Ready(None);
224        }
225
226        // Check max_results limit
227        if let Some(max) = self.max_results
228            && self.count >= max
229        {
230            self.done = true;
231            return Poll::Ready(None);
232        }
233
234        // Check if we have a pending request
235        if self.pending.is_none() {
236            // Start a new GETNEXT request
237            let client = self.client.clone();
238            let oid = self.current_oid.clone();
239
240            let fut = Box::pin(async move { client.get_next(&oid).await });
241            self.pending = Some(fut);
242        }
243
244        // Poll the pending future
245        let pending = self.pending.as_mut().unwrap();
246        match pending.as_mut().poll(cx) {
247            Poll::Pending => Poll::Pending,
248            Poll::Ready(result) => {
249                self.pending = None;
250
251                match result {
252                    Ok(vb) => {
253                        let target = self.client.peer_addr();
254                        let base_oid = self.base_oid.clone();
255                        match validate_walk_varbind(&vb, &base_oid, &mut self.oid_tracker, target) {
256                            VarbindOutcome::Done => {
257                                self.done = true;
258                                return Poll::Ready(None);
259                            }
260                            VarbindOutcome::Abort(e) => {
261                                self.done = true;
262                                return Poll::Ready(Some(Err(e)));
263                            }
264                            VarbindOutcome::Yield => {}
265                        }
266
267                        // Update current OID for next iteration
268                        self.current_oid = vb.oid.clone();
269                        self.count += 1;
270
271                        Poll::Ready(Some(Ok(vb)))
272                    }
273                    Err(e) => {
274                        if self.client.inner.config.version == Version::V1
275                            && matches!(
276                                &*e,
277                                Error::Snmp {
278                                    status: crate::error::ErrorStatus::NoSuchName,
279                                    ..
280                                }
281                            )
282                        {
283                            self.done = true;
284                            return Poll::Ready(None);
285                        }
286
287                        self.done = true;
288                        Poll::Ready(Some(Err(e)))
289                    }
290                }
291            }
292        }
293    }
294}
295
296/// Async stream for walking an OID subtree using GETBULK.
297///
298/// Created by [`Client::bulk_walk()`].
299pub struct BulkWalk<T: Transport> {
300    client: Client<T>,
301    base_oid: Oid,
302    current_oid: Oid,
303    max_repetitions: i32,
304    /// OID tracker for ordering validation.
305    oid_tracker: OidTracker,
306    /// Maximum number of results to return (None = unlimited).
307    max_results: Option<usize>,
308    /// Count of results returned so far.
309    count: usize,
310    done: bool,
311    /// Buffered results from the last GETBULK response
312    buffer: VecDeque<VarBind>,
313    pending: Option<Pin<Box<dyn std::future::Future<Output = Result<Vec<VarBind>>> + Send>>>,
314}
315
316impl<T: Transport> BulkWalk<T> {
317    pub(crate) fn new(
318        client: Client<T>,
319        oid: Oid,
320        max_repetitions: i32,
321        ordering: OidOrdering,
322        max_results: Option<usize>,
323    ) -> Self {
324        Self {
325            client,
326            base_oid: oid.clone(),
327            current_oid: oid,
328            max_repetitions,
329            oid_tracker: OidTracker::new(ordering),
330            max_results,
331            count: 0,
332            done: false,
333            buffer: VecDeque::new(),
334            pending: None,
335        }
336    }
337}
338
339impl_stream_helpers!(BulkWalk<T>);
340
341impl<T: Transport + 'static> Stream for BulkWalk<T> {
342    type Item = Result<VarBind>;
343
344    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
345        loop {
346            if self.done {
347                return Poll::Ready(None);
348            }
349
350            // Check max_results limit
351            if let Some(max) = self.max_results
352                && self.count >= max
353            {
354                self.done = true;
355                return Poll::Ready(None);
356            }
357
358            // Check if we have buffered results to return
359            if let Some(vb) = self.buffer.pop_front() {
360                let target = self.client.peer_addr();
361                let base_oid = self.base_oid.clone();
362                match validate_walk_varbind(&vb, &base_oid, &mut self.oid_tracker, target) {
363                    VarbindOutcome::Done => {
364                        self.done = true;
365                        return Poll::Ready(None);
366                    }
367                    VarbindOutcome::Abort(e) => {
368                        self.done = true;
369                        return Poll::Ready(Some(Err(e)));
370                    }
371                    VarbindOutcome::Yield => {}
372                }
373
374                // Update current OID for next request
375                self.current_oid = vb.oid.clone();
376                self.count += 1;
377
378                return Poll::Ready(Some(Ok(vb)));
379            }
380
381            // Buffer exhausted, need to fetch more
382            if self.pending.is_none() {
383                let client = self.client.clone();
384                let oid = self.current_oid.clone();
385                let max_rep = self.max_repetitions;
386
387                let fut = Box::pin(async move { client.get_bulk(&[oid], 0, max_rep).await });
388                self.pending = Some(fut);
389            }
390
391            // Poll the pending future
392            let pending = self.pending.as_mut().unwrap();
393            match pending.as_mut().poll(cx) {
394                Poll::Pending => return Poll::Pending,
395                Poll::Ready(result) => {
396                    self.pending = None;
397
398                    match result {
399                        Ok(varbinds) => {
400                            if varbinds.is_empty() {
401                                self.done = true;
402                                return Poll::Ready(None);
403                            }
404
405                            self.buffer = varbinds.into();
406                            // Continue loop to process buffer
407                        }
408                        Err(e) => {
409                            self.done = true;
410                            return Poll::Ready(Some(Err(e)));
411                        }
412                    }
413                }
414            }
415        }
416    }
417}
418
419// ============================================================================
420// Unified WalkStream - auto-selects GETNEXT or GETBULK based on WalkMode
421// ============================================================================
422
423/// Unified walk stream that auto-selects between GETNEXT and GETBULK.
424///
425/// Created by [`Client::walk()`] when using `WalkMode::Auto` or explicit mode selection.
426/// This type wraps either a [`Walk`] or [`BulkWalk`] internally based on:
427/// - `WalkMode::Auto`: Uses GETNEXT for V1, GETBULK for V2c/V3
428/// - `WalkMode::GetNext`: Always uses GETNEXT
429/// - `WalkMode::GetBulk`: Always uses GETBULK (fails on V1)
430pub enum WalkStream<T: Transport> {
431    /// GETNEXT-based walk (used for V1 or when explicitly requested)
432    GetNext(Walk<T>),
433    /// GETBULK-based walk (used for V2c/V3 or when explicitly requested)
434    GetBulk(BulkWalk<T>),
435}
436
437impl<T: Transport> WalkStream<T> {
438    /// Create a new walk stream with auto-selection based on version and walk mode.
439    pub(crate) fn new(
440        client: Client<T>,
441        oid: Oid,
442        version: Version,
443        walk_mode: WalkMode,
444        ordering: OidOrdering,
445        max_results: Option<usize>,
446        max_repetitions: i32,
447    ) -> Result<Self> {
448        let use_bulk = match walk_mode {
449            WalkMode::Auto => version != Version::V1,
450            WalkMode::GetNext => false,
451            WalkMode::GetBulk => {
452                if version == Version::V1 {
453                    return Err(Error::Config("GETBULK is not supported in SNMPv1".into()).boxed());
454                }
455                true
456            }
457        };
458
459        Ok(if use_bulk {
460            WalkStream::GetBulk(BulkWalk::new(
461                client,
462                oid,
463                max_repetitions,
464                ordering,
465                max_results,
466            ))
467        } else {
468            WalkStream::GetNext(Walk::new(client, oid, ordering, max_results))
469        })
470    }
471}
472
473impl_stream_helpers!(WalkStream<T>);
474
475impl<T: Transport + 'static> Stream for WalkStream<T> {
476    type Item = Result<VarBind>;
477
478    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
479        // SAFETY: We're just projecting the pin to the inner enum variant
480        match self.get_mut() {
481            WalkStream::GetNext(walk) => Pin::new(walk).poll_next(cx),
482            WalkStream::GetBulk(bulk_walk) => Pin::new(bulk_walk).poll_next(cx),
483        }
484    }
485}