async_snmp/client/
walk.rs

1//! Walk stream implementations.
2
3// Allow complex types for boxed futures in manual Stream implementations.
4// The `pending` fields require `Option<Pin<Box<dyn Future<Output = ...> + Send>>>`
5// which triggers this lint but is the standard pattern for storing futures.
6#![allow(clippy::type_complexity)]
7
8use std::collections::{HashSet, VecDeque};
9use std::pin::Pin;
10use std::task::{Context, Poll};
11
12use futures_core::Stream;
13
14use crate::error::{Error, Result, WalkAbortReason};
15use crate::oid::Oid;
16use crate::transport::Transport;
17use crate::value::Value;
18use crate::varbind::VarBind;
19use crate::version::Version;
20
21use super::Client;
22
23/// Walk operation mode.
24#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
25pub enum WalkMode {
26    /// Auto-select based on version (default).
27    /// V1 uses GETNEXT, V2c/V3 uses GETBULK.
28    #[default]
29    Auto,
30    /// Always use GETNEXT (slower but more compatible).
31    GetNext,
32    /// Always use GETBULK (faster, errors on v1).
33    GetBulk,
34}
35
36/// OID ordering behavior during walk operations.
37///
38/// SNMP walks rely on agents returning OIDs in strictly increasing
39/// lexicographic order. However, some buggy agents violate this requirement,
40/// returning OIDs out of order or even repeating OIDs (which would cause
41/// infinite loops).
42///
43/// This enum controls how the library handles ordering violations:
44///
45/// - [`Strict`](Self::Strict) (default): Terminates immediately with
46///   [`Error::WalkAborted`](crate::Error::WalkAborted) on any violation.
47///   Use this unless you know the agent has ordering bugs.
48///
49/// - [`AllowNonIncreasing`](Self::AllowNonIncreasing): Tolerates out-of-order
50///   OIDs but tracks all seen OIDs to detect cycles. Returns
51///   [`Error::WalkAborted`](crate::Error::WalkAborted) if the same OID appears twice.
52#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
53pub enum OidOrdering {
54    /// Require strictly increasing OIDs (default).
55    ///
56    /// Walk terminates with [`Error::WalkAborted`](crate::Error::WalkAborted)
57    /// on first violation. Most efficient: O(1) memory, O(1) per-item check.
58    #[default]
59    Strict,
60
61    /// Allow non-increasing OIDs, with cycle detection.
62    ///
63    /// Some buggy agents return OIDs out of order. This mode tracks all seen
64    /// OIDs in a HashSet to detect cycles, terminating with an error if the
65    /// same OID is returned twice.
66    ///
67    /// **Warning**: This uses O(n) memory where n = number of walk results.
68    /// Always pair with [`ClientBuilder::max_walk_results`] to bound memory
69    /// usage. Cycle detection only catches duplicate OIDs; a pathological
70    /// agent could still return an infinite sequence of unique OIDs within
71    /// the subtree.
72    ///
73    /// [`ClientBuilder::max_walk_results`]: crate::ClientBuilder::max_walk_results
74    AllowNonIncreasing,
75}
76
77enum OidTracker {
78    Strict { last: Option<Oid> },
79    Relaxed { seen: HashSet<Oid> },
80}
81
82impl OidTracker {
83    fn new(ordering: OidOrdering) -> Self {
84        match ordering {
85            OidOrdering::Strict => OidTracker::Strict { last: None },
86            OidOrdering::AllowNonIncreasing => OidTracker::Relaxed {
87                seen: HashSet::new(),
88            },
89        }
90    }
91
92    fn check(&mut self, oid: &Oid, target: std::net::SocketAddr) -> Result<()> {
93        match self {
94            OidTracker::Strict { last } => {
95                if let Some(prev) = last
96                    && oid <= prev
97                {
98                    tracing::debug!(target: "async_snmp::walk", { previous_oid = %prev, current_oid = %oid, %target }, "non-increasing OID detected");
99                    return Err(Error::WalkAborted {
100                        target,
101                        reason: WalkAbortReason::NonIncreasing,
102                    }
103                    .boxed());
104                }
105                *last = Some(oid.clone());
106                Ok(())
107            }
108            OidTracker::Relaxed { seen } => {
109                if !seen.insert(oid.clone()) {
110                    tracing::debug!(target: "async_snmp::walk", { %oid, %target }, "duplicate OID detected (cycle)");
111                    return Err(Error::WalkAborted {
112                        target,
113                        reason: WalkAbortReason::Cycle,
114                    }
115                    .boxed());
116                }
117                Ok(())
118            }
119        }
120    }
121}
122
123/// Async stream for walking an OID subtree using GETNEXT.
124///
125/// Created by [`Client::walk_getnext()`].
126pub struct Walk<T: Transport> {
127    client: Client<T>,
128    base_oid: Oid,
129    current_oid: Oid,
130    /// OID tracker for ordering validation.
131    oid_tracker: OidTracker,
132    /// Maximum number of results to return (None = unlimited).
133    max_results: Option<usize>,
134    /// Count of results returned so far.
135    count: usize,
136    done: bool,
137    pending: Option<Pin<Box<dyn std::future::Future<Output = Result<VarBind>> + Send>>>,
138}
139
140impl<T: Transport> Walk<T> {
141    pub(crate) fn new(
142        client: Client<T>,
143        oid: Oid,
144        ordering: OidOrdering,
145        max_results: Option<usize>,
146    ) -> Self {
147        Self {
148            client,
149            base_oid: oid.clone(),
150            current_oid: oid,
151            oid_tracker: OidTracker::new(ordering),
152            max_results,
153            count: 0,
154            done: false,
155            pending: None,
156        }
157    }
158}
159
160impl<T: Transport + 'static> Walk<T> {
161    /// Get the next varbind, or None when complete.
162    pub async fn next(&mut self) -> Option<Result<VarBind>> {
163        std::future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
164    }
165
166    /// Collect all remaining varbinds.
167    pub async fn collect(mut self) -> Result<Vec<VarBind>> {
168        let mut results = Vec::new();
169        while let Some(result) = self.next().await {
170            results.push(result?);
171        }
172        Ok(results)
173    }
174}
175
176impl<T: Transport + 'static> Stream for Walk<T> {
177    type Item = Result<VarBind>;
178
179    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
180        if self.done {
181            return Poll::Ready(None);
182        }
183
184        // Check max_results limit
185        if let Some(max) = self.max_results
186            && self.count >= max
187        {
188            self.done = true;
189            return Poll::Ready(None);
190        }
191
192        // Check if we have a pending request
193        if self.pending.is_none() {
194            // Start a new GETNEXT request
195            let client = self.client.clone();
196            let oid = self.current_oid.clone();
197
198            let fut = Box::pin(async move { client.get_next(&oid).await });
199            self.pending = Some(fut);
200        }
201
202        // Poll the pending future
203        let pending = self.pending.as_mut().unwrap();
204        match pending.as_mut().poll(cx) {
205            Poll::Pending => Poll::Pending,
206            Poll::Ready(result) => {
207                self.pending = None;
208
209                match result {
210                    Ok(vb) => {
211                        // Check for end conditions
212                        if matches!(vb.value, Value::EndOfMibView) {
213                            self.done = true;
214                            return Poll::Ready(None);
215                        }
216
217                        // Check if OID left the subtree
218                        if !vb.oid.starts_with(&self.base_oid) {
219                            self.done = true;
220                            return Poll::Ready(None);
221                        }
222
223                        // Check OID ordering using the tracker
224                        let target = self.client.peer_addr();
225                        if let Err(e) = self.oid_tracker.check(&vb.oid, target) {
226                            self.done = true;
227                            return Poll::Ready(Some(Err(e)));
228                        }
229
230                        // Update current OID for next iteration
231                        self.current_oid = vb.oid.clone();
232                        self.count += 1;
233
234                        Poll::Ready(Some(Ok(vb)))
235                    }
236                    Err(e) => {
237                        self.done = true;
238                        Poll::Ready(Some(Err(e)))
239                    }
240                }
241            }
242        }
243    }
244}
245
246/// Async stream for walking an OID subtree using GETBULK.
247///
248/// Created by [`Client::bulk_walk()`].
249pub struct BulkWalk<T: Transport> {
250    client: Client<T>,
251    base_oid: Oid,
252    current_oid: Oid,
253    max_repetitions: i32,
254    /// OID tracker for ordering validation.
255    oid_tracker: OidTracker,
256    /// Maximum number of results to return (None = unlimited).
257    max_results: Option<usize>,
258    /// Count of results returned so far.
259    count: usize,
260    done: bool,
261    /// Buffered results from the last GETBULK response
262    buffer: VecDeque<VarBind>,
263    pending: Option<Pin<Box<dyn std::future::Future<Output = Result<Vec<VarBind>>> + Send>>>,
264}
265
266impl<T: Transport> BulkWalk<T> {
267    pub(crate) fn new(
268        client: Client<T>,
269        oid: Oid,
270        max_repetitions: i32,
271        ordering: OidOrdering,
272        max_results: Option<usize>,
273    ) -> Self {
274        Self {
275            client,
276            base_oid: oid.clone(),
277            current_oid: oid,
278            max_repetitions,
279            oid_tracker: OidTracker::new(ordering),
280            max_results,
281            count: 0,
282            done: false,
283            buffer: VecDeque::new(),
284            pending: None,
285        }
286    }
287}
288
289impl<T: Transport + 'static> BulkWalk<T> {
290    /// Get the next varbind, or None when complete.
291    pub async fn next(&mut self) -> Option<Result<VarBind>> {
292        std::future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
293    }
294
295    /// Collect all remaining varbinds.
296    pub async fn collect(mut self) -> Result<Vec<VarBind>> {
297        let mut results = Vec::new();
298        while let Some(result) = self.next().await {
299            results.push(result?);
300        }
301        Ok(results)
302    }
303}
304
305impl<T: Transport + 'static> Stream for BulkWalk<T> {
306    type Item = Result<VarBind>;
307
308    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
309        loop {
310            if self.done {
311                return Poll::Ready(None);
312            }
313
314            // Check max_results limit
315            if let Some(max) = self.max_results
316                && self.count >= max
317            {
318                self.done = true;
319                return Poll::Ready(None);
320            }
321
322            // Check if we have buffered results to return
323            if let Some(vb) = self.buffer.pop_front() {
324                // Check for end conditions
325                if matches!(vb.value, Value::EndOfMibView) {
326                    self.done = true;
327                    return Poll::Ready(None);
328                }
329
330                // Check if OID left the subtree
331                if !vb.oid.starts_with(&self.base_oid) {
332                    self.done = true;
333                    return Poll::Ready(None);
334                }
335
336                // Check OID ordering using the tracker
337                let target = self.client.peer_addr();
338                if let Err(e) = self.oid_tracker.check(&vb.oid, target) {
339                    self.done = true;
340                    return Poll::Ready(Some(Err(e)));
341                }
342
343                // Update current OID for next request
344                self.current_oid = vb.oid.clone();
345                self.count += 1;
346
347                return Poll::Ready(Some(Ok(vb)));
348            }
349
350            // Buffer exhausted, need to fetch more
351            if self.pending.is_none() {
352                let client = self.client.clone();
353                let oid = self.current_oid.clone();
354                let max_rep = self.max_repetitions;
355
356                let fut = Box::pin(async move { client.get_bulk(&[oid], 0, max_rep).await });
357                self.pending = Some(fut);
358            }
359
360            // Poll the pending future
361            let pending = self.pending.as_mut().unwrap();
362            match pending.as_mut().poll(cx) {
363                Poll::Pending => return Poll::Pending,
364                Poll::Ready(result) => {
365                    self.pending = None;
366
367                    match result {
368                        Ok(varbinds) => {
369                            if varbinds.is_empty() {
370                                self.done = true;
371                                return Poll::Ready(None);
372                            }
373
374                            self.buffer = varbinds.into();
375                            // Continue loop to process buffer
376                        }
377                        Err(e) => {
378                            self.done = true;
379                            return Poll::Ready(Some(Err(e)));
380                        }
381                    }
382                }
383            }
384        }
385    }
386}
387
388// ============================================================================
389// Unified WalkStream - auto-selects GETNEXT or GETBULK based on WalkMode
390// ============================================================================
391
392/// Unified walk stream that auto-selects between GETNEXT and GETBULK.
393///
394/// Created by [`Client::walk()`] when using `WalkMode::Auto` or explicit mode selection.
395/// This type wraps either a [`Walk`] or [`BulkWalk`] internally based on:
396/// - `WalkMode::Auto`: Uses GETNEXT for V1, GETBULK for V2c/V3
397/// - `WalkMode::GetNext`: Always uses GETNEXT
398/// - `WalkMode::GetBulk`: Always uses GETBULK (fails on V1)
399pub enum WalkStream<T: Transport> {
400    /// GETNEXT-based walk (used for V1 or when explicitly requested)
401    GetNext(Walk<T>),
402    /// GETBULK-based walk (used for V2c/V3 or when explicitly requested)
403    GetBulk(BulkWalk<T>),
404}
405
406impl<T: Transport> WalkStream<T> {
407    /// Create a new walk stream with auto-selection based on version and walk mode.
408    pub(crate) fn new(
409        client: Client<T>,
410        oid: Oid,
411        version: Version,
412        walk_mode: WalkMode,
413        ordering: OidOrdering,
414        max_results: Option<usize>,
415        max_repetitions: i32,
416    ) -> Result<Self> {
417        let use_bulk = match walk_mode {
418            WalkMode::Auto => version != Version::V1,
419            WalkMode::GetNext => false,
420            WalkMode::GetBulk => {
421                if version == Version::V1 {
422                    return Err(Error::Config("GETBULK is not supported in SNMPv1".into()).boxed());
423                }
424                true
425            }
426        };
427
428        Ok(if use_bulk {
429            WalkStream::GetBulk(BulkWalk::new(
430                client,
431                oid,
432                max_repetitions,
433                ordering,
434                max_results,
435            ))
436        } else {
437            WalkStream::GetNext(Walk::new(client, oid, ordering, max_results))
438        })
439    }
440}
441
442impl<T: Transport + 'static> WalkStream<T> {
443    /// Get the next varbind, or None when complete.
444    pub async fn next(&mut self) -> Option<Result<VarBind>> {
445        std::future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
446    }
447
448    /// Collect all remaining varbinds.
449    pub async fn collect(mut self) -> Result<Vec<VarBind>> {
450        let mut results = Vec::new();
451        while let Some(result) = self.next().await {
452            results.push(result?);
453        }
454        Ok(results)
455    }
456}
457
458impl<T: Transport + 'static> Stream for WalkStream<T> {
459    type Item = Result<VarBind>;
460
461    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
462        // SAFETY: We're just projecting the pin to the inner enum variant
463        match self.get_mut() {
464            WalkStream::GetNext(walk) => Pin::new(walk).poll_next(cx),
465            WalkStream::GetBulk(bulk_walk) => Pin::new(bulk_walk).poll_next(cx),
466        }
467    }
468}