async_snmp/client/
walk.rs

1//! Walk stream implementations.
2
3#![allow(clippy::type_complexity)]
4
5use std::collections::HashSet;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use futures_core::Stream;
10
11use crate::error::{Error, Result};
12use crate::oid::Oid;
13use crate::transport::Transport;
14use crate::value::Value;
15use crate::varbind::VarBind;
16use crate::version::Version;
17
18use super::Client;
19
20/// Walk operation mode.
21#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
22#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23pub enum WalkMode {
24    /// Auto-select based on version (default).
25    /// V1 uses GETNEXT, V2c/V3 uses GETBULK.
26    #[default]
27    Auto,
28    /// Always use GETNEXT (slower but more compatible).
29    GetNext,
30    /// Always use GETBULK (faster, errors on v1).
31    GetBulk,
32}
33
34/// OID ordering behavior during walk operations.
35///
36/// SNMP walks rely on agents returning OIDs in strictly increasing
37/// lexicographic order. However, some buggy agents violate this requirement,
38/// returning OIDs out of order or even repeating OIDs (which would cause
39/// infinite loops).
40///
41/// This enum controls how the library handles ordering violations:
42///
43/// - [`Strict`](Self::Strict) (default): Terminates immediately with
44///   [`Error::NonIncreasingOid`](crate::Error::NonIncreasingOid) on any violation.
45///   Use this unless you know the agent has ordering bugs.
46///
47/// - [`AllowNonIncreasing`](Self::AllowNonIncreasing): Tolerates out-of-order
48///   OIDs but tracks all seen OIDs to detect cycles. Returns
49///   [`Error::DuplicateOid`](crate::Error::DuplicateOid) if the same OID appears twice.
50#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
51#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
52pub enum OidOrdering {
53    /// Require strictly increasing OIDs (default).
54    ///
55    /// Walk terminates with [`Error::NonIncreasingOid`](crate::Error::NonIncreasingOid)
56    /// on first violation. Most efficient: O(1) memory, O(1) per-item check.
57    #[default]
58    Strict,
59
60    /// Allow non-increasing OIDs, with cycle detection.
61    ///
62    /// Some buggy agents return OIDs out of order. This mode tracks all seen
63    /// OIDs in a HashSet to detect cycles, terminating with an error if the
64    /// same OID is returned twice.
65    ///
66    /// **Warning**: This uses O(n) memory where n = number of walk results.
67    /// Always pair with [`ClientBuilder::max_walk_results`] to bound memory
68    /// usage. Cycle detection only catches duplicate OIDs; a pathological
69    /// agent could still return an infinite sequence of unique OIDs within
70    /// the subtree.
71    ///
72    /// [`ClientBuilder::max_walk_results`]: crate::ClientBuilder::max_walk_results
73    AllowNonIncreasing,
74}
75
76enum OidTracker {
77    Strict { last: Option<Oid> },
78    Relaxed { seen: HashSet<Oid> },
79}
80
81impl OidTracker {
82    fn new(ordering: OidOrdering) -> Self {
83        match ordering {
84            OidOrdering::Strict => OidTracker::Strict { last: None },
85            OidOrdering::AllowNonIncreasing => OidTracker::Relaxed {
86                seen: HashSet::new(),
87            },
88        }
89    }
90
91    fn check(&mut self, oid: &Oid) -> Result<()> {
92        match self {
93            OidTracker::Strict { last } => {
94                if let Some(prev) = last
95                    && oid <= prev
96                {
97                    return Err(Error::NonIncreasingOid {
98                        previous: prev.clone(),
99                        current: oid.clone(),
100                    });
101                }
102                *last = Some(oid.clone());
103                Ok(())
104            }
105            OidTracker::Relaxed { seen } => {
106                if !seen.insert(oid.clone()) {
107                    return Err(Error::DuplicateOid { oid: oid.clone() });
108                }
109                Ok(())
110            }
111        }
112    }
113}
114
115/// Async stream for walking an OID subtree using GETNEXT.
116///
117/// Created by [`Client::walk_getnext()`].
118pub struct Walk<T: Transport> {
119    client: Client<T>,
120    base_oid: Oid,
121    current_oid: Oid,
122    /// OID tracker for ordering validation.
123    oid_tracker: OidTracker,
124    /// Maximum number of results to return (None = unlimited).
125    max_results: Option<usize>,
126    /// Count of results returned so far.
127    count: usize,
128    done: bool,
129    pending: Option<Pin<Box<dyn std::future::Future<Output = Result<VarBind>> + Send>>>,
130}
131
132impl<T: Transport> Walk<T> {
133    pub(crate) fn new(
134        client: Client<T>,
135        oid: Oid,
136        ordering: OidOrdering,
137        max_results: Option<usize>,
138    ) -> Self {
139        Self {
140            client,
141            base_oid: oid.clone(),
142            current_oid: oid,
143            oid_tracker: OidTracker::new(ordering),
144            max_results,
145            count: 0,
146            done: false,
147            pending: None,
148        }
149    }
150}
151
152impl<T: Transport + 'static> Walk<T> {
153    /// Get the next varbind, or None when complete.
154    pub async fn next(&mut self) -> Option<Result<VarBind>> {
155        std::future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
156    }
157
158    /// Collect all remaining varbinds.
159    pub async fn collect(mut self) -> Result<Vec<VarBind>> {
160        let mut results = Vec::new();
161        while let Some(result) = self.next().await {
162            results.push(result?);
163        }
164        Ok(results)
165    }
166}
167
168impl<T: Transport + 'static> Stream for Walk<T> {
169    type Item = Result<VarBind>;
170
171    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
172        if self.done {
173            return Poll::Ready(None);
174        }
175
176        // Check max_results limit
177        if let Some(max) = self.max_results
178            && self.count >= max
179        {
180            self.done = true;
181            return Poll::Ready(None);
182        }
183
184        // Check if we have a pending request
185        if self.pending.is_none() {
186            // Start a new GETNEXT request
187            let client = self.client.clone();
188            let oid = self.current_oid.clone();
189
190            let fut = Box::pin(async move { client.get_next(&oid).await });
191            self.pending = Some(fut);
192        }
193
194        // Poll the pending future
195        let pending = self.pending.as_mut().unwrap();
196        match pending.as_mut().poll(cx) {
197            Poll::Pending => Poll::Pending,
198            Poll::Ready(result) => {
199                self.pending = None;
200
201                match result {
202                    Ok(vb) => {
203                        // Check for end conditions
204                        if matches!(vb.value, Value::EndOfMibView) {
205                            self.done = true;
206                            return Poll::Ready(None);
207                        }
208
209                        // Check if OID left the subtree
210                        if !vb.oid.starts_with(&self.base_oid) {
211                            self.done = true;
212                            return Poll::Ready(None);
213                        }
214
215                        // Check OID ordering using the tracker
216                        if let Err(e) = self.oid_tracker.check(&vb.oid) {
217                            self.done = true;
218                            return Poll::Ready(Some(Err(e)));
219                        }
220
221                        // Update current OID for next iteration
222                        self.current_oid = vb.oid.clone();
223                        self.count += 1;
224
225                        Poll::Ready(Some(Ok(vb)))
226                    }
227                    Err(e) => {
228                        self.done = true;
229                        Poll::Ready(Some(Err(e)))
230                    }
231                }
232            }
233        }
234    }
235}
236
237/// Async stream for walking an OID subtree using GETBULK.
238///
239/// Created by [`Client::bulk_walk()`].
240pub struct BulkWalk<T: Transport> {
241    client: Client<T>,
242    base_oid: Oid,
243    current_oid: Oid,
244    max_repetitions: i32,
245    /// OID tracker for ordering validation.
246    oid_tracker: OidTracker,
247    /// Maximum number of results to return (None = unlimited).
248    max_results: Option<usize>,
249    /// Count of results returned so far.
250    count: usize,
251    done: bool,
252    /// Buffered results from the last GETBULK response
253    buffer: Vec<VarBind>,
254    /// Index into the buffer
255    buffer_idx: usize,
256    pending: Option<Pin<Box<dyn std::future::Future<Output = Result<Vec<VarBind>>> + Send>>>,
257}
258
259impl<T: Transport> BulkWalk<T> {
260    pub(crate) fn new(
261        client: Client<T>,
262        oid: Oid,
263        max_repetitions: i32,
264        ordering: OidOrdering,
265        max_results: Option<usize>,
266    ) -> Self {
267        Self {
268            client,
269            base_oid: oid.clone(),
270            current_oid: oid,
271            max_repetitions,
272            oid_tracker: OidTracker::new(ordering),
273            max_results,
274            count: 0,
275            done: false,
276            buffer: Vec::new(),
277            buffer_idx: 0,
278            pending: None,
279        }
280    }
281}
282
283impl<T: Transport + 'static> BulkWalk<T> {
284    /// Get the next varbind, or None when complete.
285    pub async fn next(&mut self) -> Option<Result<VarBind>> {
286        std::future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
287    }
288
289    /// Collect all remaining varbinds.
290    pub async fn collect(mut self) -> Result<Vec<VarBind>> {
291        let mut results = Vec::new();
292        while let Some(result) = self.next().await {
293            results.push(result?);
294        }
295        Ok(results)
296    }
297}
298
299impl<T: Transport + 'static> Stream for BulkWalk<T> {
300    type Item = Result<VarBind>;
301
302    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
303        loop {
304            if self.done {
305                return Poll::Ready(None);
306            }
307
308            // Check max_results limit
309            if let Some(max) = self.max_results
310                && self.count >= max
311            {
312                self.done = true;
313                return Poll::Ready(None);
314            }
315
316            // Check if we have buffered results to return
317            if self.buffer_idx < self.buffer.len() {
318                let vb = self.buffer[self.buffer_idx].clone();
319                self.buffer_idx += 1;
320
321                // Check for end conditions
322                if matches!(vb.value, Value::EndOfMibView) {
323                    self.done = true;
324                    return Poll::Ready(None);
325                }
326
327                // Check if OID left the subtree
328                if !vb.oid.starts_with(&self.base_oid) {
329                    self.done = true;
330                    return Poll::Ready(None);
331                }
332
333                // Check OID ordering using the tracker
334                if let Err(e) = self.oid_tracker.check(&vb.oid) {
335                    self.done = true;
336                    return Poll::Ready(Some(Err(e)));
337                }
338
339                // Update current OID for next request
340                self.current_oid = vb.oid.clone();
341                self.count += 1;
342
343                return Poll::Ready(Some(Ok(vb)));
344            }
345
346            // Buffer exhausted, need to fetch more
347            if self.pending.is_none() {
348                let client = self.client.clone();
349                let oid = self.current_oid.clone();
350                let max_rep = self.max_repetitions;
351
352                let fut = Box::pin(async move { client.get_bulk(&[oid], 0, max_rep).await });
353                self.pending = Some(fut);
354            }
355
356            // Poll the pending future
357            let pending = self.pending.as_mut().unwrap();
358            match pending.as_mut().poll(cx) {
359                Poll::Pending => return Poll::Pending,
360                Poll::Ready(result) => {
361                    self.pending = None;
362
363                    match result {
364                        Ok(varbinds) => {
365                            if varbinds.is_empty() {
366                                self.done = true;
367                                return Poll::Ready(None);
368                            }
369
370                            self.buffer = varbinds;
371                            self.buffer_idx = 0;
372                            // Continue loop to process buffer
373                        }
374                        Err(e) => {
375                            self.done = true;
376                            return Poll::Ready(Some(Err(e)));
377                        }
378                    }
379                }
380            }
381        }
382    }
383}
384
385// ============================================================================
386// Unified WalkStream - auto-selects GETNEXT or GETBULK based on WalkMode
387// ============================================================================
388
389/// Unified walk stream that auto-selects between GETNEXT and GETBULK.
390///
391/// Created by [`Client::walk()`] when using `WalkMode::Auto` or explicit mode selection.
392/// This type wraps either a [`Walk`] or [`BulkWalk`] internally based on:
393/// - `WalkMode::Auto`: Uses GETNEXT for V1, GETBULK for V2c/V3
394/// - `WalkMode::GetNext`: Always uses GETNEXT
395/// - `WalkMode::GetBulk`: Always uses GETBULK (fails on V1)
396pub enum WalkStream<T: Transport> {
397    /// GETNEXT-based walk (used for V1 or when explicitly requested)
398    GetNext(Walk<T>),
399    /// GETBULK-based walk (used for V2c/V3 or when explicitly requested)
400    GetBulk(BulkWalk<T>),
401}
402
403impl<T: Transport> WalkStream<T> {
404    /// Create a new walk stream with auto-selection based on version and walk mode.
405    pub(crate) fn new(
406        client: Client<T>,
407        oid: Oid,
408        version: Version,
409        walk_mode: WalkMode,
410        ordering: OidOrdering,
411        max_results: Option<usize>,
412        max_repetitions: i32,
413    ) -> Result<Self> {
414        let use_bulk = match walk_mode {
415            WalkMode::Auto => version != Version::V1,
416            WalkMode::GetNext => false,
417            WalkMode::GetBulk => {
418                if version == Version::V1 {
419                    return Err(Error::GetBulkNotSupportedInV1);
420                }
421                true
422            }
423        };
424
425        Ok(if use_bulk {
426            WalkStream::GetBulk(BulkWalk::new(
427                client,
428                oid,
429                max_repetitions,
430                ordering,
431                max_results,
432            ))
433        } else {
434            WalkStream::GetNext(Walk::new(client, oid, ordering, max_results))
435        })
436    }
437}
438
439impl<T: Transport + 'static> WalkStream<T> {
440    /// Get the next varbind, or None when complete.
441    pub async fn next(&mut self) -> Option<Result<VarBind>> {
442        std::future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
443    }
444
445    /// Collect all remaining varbinds.
446    pub async fn collect(mut self) -> Result<Vec<VarBind>> {
447        let mut results = Vec::new();
448        while let Some(result) = self.next().await {
449            results.push(result?);
450        }
451        Ok(results)
452    }
453}
454
455impl<T: Transport + 'static> Stream for WalkStream<T> {
456    type Item = Result<VarBind>;
457
458    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
459        // SAFETY: We're just projecting the pin to the inner enum variant
460        match self.get_mut() {
461            WalkStream::GetNext(walk) => Pin::new(walk).poll_next(cx),
462            WalkStream::GetBulk(bulk_walk) => Pin::new(bulk_walk).poll_next(cx),
463        }
464    }
465}