1#![allow(clippy::type_complexity)]
7
8use std::collections::{HashSet, VecDeque};
9use std::pin::Pin;
10use std::task::{Context, Poll};
11
12use futures_core::Stream;
13
14use crate::error::{Error, Result, WalkAbortReason};
15use crate::oid::Oid;
16use crate::transport::Transport;
17use crate::value::Value;
18use crate::varbind::VarBind;
19use crate::version::Version;
20
21use super::Client;
22
23#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
25pub enum WalkMode {
26 #[default]
29 Auto,
30 GetNext,
32 GetBulk,
34}
35
36#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
53pub enum OidOrdering {
54 #[default]
59 Strict,
60
61 AllowNonIncreasing,
75}
76
77enum OidTracker {
78 Strict { last: Option<Oid> },
79 Relaxed { seen: HashSet<Oid> },
80}
81
82impl OidTracker {
83 fn new(ordering: OidOrdering) -> Self {
84 match ordering {
85 OidOrdering::Strict => OidTracker::Strict { last: None },
86 OidOrdering::AllowNonIncreasing => OidTracker::Relaxed {
87 seen: HashSet::new(),
88 },
89 }
90 }
91
92 fn check(&mut self, oid: &Oid, target: std::net::SocketAddr) -> Result<()> {
93 match self {
94 OidTracker::Strict { last } => {
95 if let Some(prev) = last
96 && oid <= prev
97 {
98 tracing::debug!(target: "async_snmp::walk", { previous_oid = %prev, current_oid = %oid, %target }, "non-increasing OID detected");
99 return Err(Error::WalkAborted {
100 target,
101 reason: WalkAbortReason::NonIncreasing,
102 }
103 .boxed());
104 }
105 *last = Some(oid.clone());
106 Ok(())
107 }
108 OidTracker::Relaxed { seen } => {
109 if !seen.insert(oid.clone()) {
110 tracing::debug!(target: "async_snmp::walk", { %oid, %target }, "duplicate OID detected (cycle)");
111 return Err(Error::WalkAborted {
112 target,
113 reason: WalkAbortReason::Cycle,
114 }
115 .boxed());
116 }
117 Ok(())
118 }
119 }
120 }
121}
122
123pub struct Walk<T: Transport> {
127 client: Client<T>,
128 base_oid: Oid,
129 current_oid: Oid,
130 oid_tracker: OidTracker,
132 max_results: Option<usize>,
134 count: usize,
136 done: bool,
137 pending: Option<Pin<Box<dyn std::future::Future<Output = Result<VarBind>> + Send>>>,
138}
139
140impl<T: Transport> Walk<T> {
141 pub(crate) fn new(
142 client: Client<T>,
143 oid: Oid,
144 ordering: OidOrdering,
145 max_results: Option<usize>,
146 ) -> Self {
147 Self {
148 client,
149 base_oid: oid.clone(),
150 current_oid: oid,
151 oid_tracker: OidTracker::new(ordering),
152 max_results,
153 count: 0,
154 done: false,
155 pending: None,
156 }
157 }
158}
159
160impl<T: Transport + 'static> Walk<T> {
161 pub async fn next(&mut self) -> Option<Result<VarBind>> {
163 std::future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
164 }
165
166 pub async fn collect(mut self) -> Result<Vec<VarBind>> {
168 let mut results = Vec::new();
169 while let Some(result) = self.next().await {
170 results.push(result?);
171 }
172 Ok(results)
173 }
174}
175
176impl<T: Transport + 'static> Stream for Walk<T> {
177 type Item = Result<VarBind>;
178
179 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
180 if self.done {
181 return Poll::Ready(None);
182 }
183
184 if let Some(max) = self.max_results
186 && self.count >= max
187 {
188 self.done = true;
189 return Poll::Ready(None);
190 }
191
192 if self.pending.is_none() {
194 let client = self.client.clone();
196 let oid = self.current_oid.clone();
197
198 let fut = Box::pin(async move { client.get_next(&oid).await });
199 self.pending = Some(fut);
200 }
201
202 let pending = self.pending.as_mut().unwrap();
204 match pending.as_mut().poll(cx) {
205 Poll::Pending => Poll::Pending,
206 Poll::Ready(result) => {
207 self.pending = None;
208
209 match result {
210 Ok(vb) => {
211 if matches!(vb.value, Value::EndOfMibView) {
213 self.done = true;
214 return Poll::Ready(None);
215 }
216
217 if !vb.oid.starts_with(&self.base_oid) {
219 self.done = true;
220 return Poll::Ready(None);
221 }
222
223 let target = self.client.peer_addr();
225 if let Err(e) = self.oid_tracker.check(&vb.oid, target) {
226 self.done = true;
227 return Poll::Ready(Some(Err(e)));
228 }
229
230 self.current_oid = vb.oid.clone();
232 self.count += 1;
233
234 Poll::Ready(Some(Ok(vb)))
235 }
236 Err(e) => {
237 self.done = true;
238 Poll::Ready(Some(Err(e)))
239 }
240 }
241 }
242 }
243 }
244}
245
246pub struct BulkWalk<T: Transport> {
250 client: Client<T>,
251 base_oid: Oid,
252 current_oid: Oid,
253 max_repetitions: i32,
254 oid_tracker: OidTracker,
256 max_results: Option<usize>,
258 count: usize,
260 done: bool,
261 buffer: VecDeque<VarBind>,
263 pending: Option<Pin<Box<dyn std::future::Future<Output = Result<Vec<VarBind>>> + Send>>>,
264}
265
266impl<T: Transport> BulkWalk<T> {
267 pub(crate) fn new(
268 client: Client<T>,
269 oid: Oid,
270 max_repetitions: i32,
271 ordering: OidOrdering,
272 max_results: Option<usize>,
273 ) -> Self {
274 Self {
275 client,
276 base_oid: oid.clone(),
277 current_oid: oid,
278 max_repetitions,
279 oid_tracker: OidTracker::new(ordering),
280 max_results,
281 count: 0,
282 done: false,
283 buffer: VecDeque::new(),
284 pending: None,
285 }
286 }
287}
288
289impl<T: Transport + 'static> BulkWalk<T> {
290 pub async fn next(&mut self) -> Option<Result<VarBind>> {
292 std::future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
293 }
294
295 pub async fn collect(mut self) -> Result<Vec<VarBind>> {
297 let mut results = Vec::new();
298 while let Some(result) = self.next().await {
299 results.push(result?);
300 }
301 Ok(results)
302 }
303}
304
305impl<T: Transport + 'static> Stream for BulkWalk<T> {
306 type Item = Result<VarBind>;
307
308 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
309 loop {
310 if self.done {
311 return Poll::Ready(None);
312 }
313
314 if let Some(max) = self.max_results
316 && self.count >= max
317 {
318 self.done = true;
319 return Poll::Ready(None);
320 }
321
322 if let Some(vb) = self.buffer.pop_front() {
324 if matches!(vb.value, Value::EndOfMibView) {
326 self.done = true;
327 return Poll::Ready(None);
328 }
329
330 if !vb.oid.starts_with(&self.base_oid) {
332 self.done = true;
333 return Poll::Ready(None);
334 }
335
336 let target = self.client.peer_addr();
338 if let Err(e) = self.oid_tracker.check(&vb.oid, target) {
339 self.done = true;
340 return Poll::Ready(Some(Err(e)));
341 }
342
343 self.current_oid = vb.oid.clone();
345 self.count += 1;
346
347 return Poll::Ready(Some(Ok(vb)));
348 }
349
350 if self.pending.is_none() {
352 let client = self.client.clone();
353 let oid = self.current_oid.clone();
354 let max_rep = self.max_repetitions;
355
356 let fut = Box::pin(async move { client.get_bulk(&[oid], 0, max_rep).await });
357 self.pending = Some(fut);
358 }
359
360 let pending = self.pending.as_mut().unwrap();
362 match pending.as_mut().poll(cx) {
363 Poll::Pending => return Poll::Pending,
364 Poll::Ready(result) => {
365 self.pending = None;
366
367 match result {
368 Ok(varbinds) => {
369 if varbinds.is_empty() {
370 self.done = true;
371 return Poll::Ready(None);
372 }
373
374 self.buffer = varbinds.into();
375 }
377 Err(e) => {
378 self.done = true;
379 return Poll::Ready(Some(Err(e)));
380 }
381 }
382 }
383 }
384 }
385 }
386}
387
388pub enum WalkStream<T: Transport> {
400 GetNext(Walk<T>),
402 GetBulk(BulkWalk<T>),
404}
405
406impl<T: Transport> WalkStream<T> {
407 pub(crate) fn new(
409 client: Client<T>,
410 oid: Oid,
411 version: Version,
412 walk_mode: WalkMode,
413 ordering: OidOrdering,
414 max_results: Option<usize>,
415 max_repetitions: i32,
416 ) -> Result<Self> {
417 let use_bulk = match walk_mode {
418 WalkMode::Auto => version != Version::V1,
419 WalkMode::GetNext => false,
420 WalkMode::GetBulk => {
421 if version == Version::V1 {
422 return Err(Error::Config("GETBULK is not supported in SNMPv1".into()).boxed());
423 }
424 true
425 }
426 };
427
428 Ok(if use_bulk {
429 WalkStream::GetBulk(BulkWalk::new(
430 client,
431 oid,
432 max_repetitions,
433 ordering,
434 max_results,
435 ))
436 } else {
437 WalkStream::GetNext(Walk::new(client, oid, ordering, max_results))
438 })
439 }
440}
441
442impl<T: Transport + 'static> WalkStream<T> {
443 pub async fn next(&mut self) -> Option<Result<VarBind>> {
445 std::future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
446 }
447
448 pub async fn collect(mut self) -> Result<Vec<VarBind>> {
450 let mut results = Vec::new();
451 while let Some(result) = self.next().await {
452 results.push(result?);
453 }
454 Ok(results)
455 }
456}
457
458impl<T: Transport + 'static> Stream for WalkStream<T> {
459 type Item = Result<VarBind>;
460
461 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
462 match self.get_mut() {
464 WalkStream::GetNext(walk) => Pin::new(walk).poll_next(cx),
465 WalkStream::GetBulk(bulk_walk) => Pin::new(bulk_walk).poll_next(cx),
466 }
467 }
468}