1#![allow(clippy::type_complexity)]
7
8macro_rules! impl_stream_helpers {
10 ($type:ident < $($gen:tt),+ >) => {
11 impl<$($gen),+> $type<$($gen),+>
12 where
13 $($gen: crate::transport::Transport + 'static,)+
14 {
15 pub async fn next(&mut self) -> Option<crate::error::Result<crate::varbind::VarBind>> {
17 std::future::poll_fn(|cx| std::pin::Pin::new(&mut *self).poll_next(cx)).await
18 }
19
20 pub async fn collect(mut self) -> crate::error::Result<Vec<crate::varbind::VarBind>> {
22 let mut results = Vec::new();
23 while let Some(result) = self.next().await {
24 results.push(result?);
25 }
26 Ok(results)
27 }
28 }
29 };
30}
31
32use std::collections::{HashSet, VecDeque};
33use std::pin::Pin;
34use std::task::{Context, Poll};
35
36use futures_core::Stream;
37
38use crate::error::{Error, Result, WalkAbortReason};
39use crate::oid::Oid;
40use crate::transport::Transport;
41use crate::value::Value;
42use crate::varbind::VarBind;
43use crate::version::Version;
44
45use super::Client;
46
47#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
49pub enum WalkMode {
50 #[default]
53 Auto,
54 GetNext,
56 GetBulk,
58}
59
60#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
77pub enum OidOrdering {
78 #[default]
83 Strict,
84
85 AllowNonIncreasing,
99}
100
101enum OidTracker {
102 Strict { last: Option<Oid> },
103 Relaxed { seen: HashSet<Oid> },
104}
105
106enum VarbindOutcome {
108 Yield,
110 Done,
112 Abort(Box<Error>),
114}
115
116fn validate_walk_varbind(
121 vb: &VarBind,
122 base_oid: &Oid,
123 oid_tracker: &mut OidTracker,
124 target: std::net::SocketAddr,
125) -> VarbindOutcome {
126 if matches!(
127 vb.value,
128 Value::EndOfMibView | Value::NoSuchObject | Value::NoSuchInstance
129 ) {
130 return VarbindOutcome::Done;
131 }
132 if !vb.oid.starts_with(base_oid) {
133 return VarbindOutcome::Done;
134 }
135 match oid_tracker.check(&vb.oid, target) {
136 Ok(()) => VarbindOutcome::Yield,
137 Err(e) => VarbindOutcome::Abort(e),
138 }
139}
140
141impl OidTracker {
142 fn new(ordering: OidOrdering) -> Self {
143 match ordering {
144 OidOrdering::Strict => OidTracker::Strict { last: None },
145 OidOrdering::AllowNonIncreasing => OidTracker::Relaxed {
146 seen: HashSet::new(),
147 },
148 }
149 }
150
151 fn check(&mut self, oid: &Oid, target: std::net::SocketAddr) -> Result<()> {
152 match self {
153 OidTracker::Strict { last } => {
154 if let Some(prev) = last
155 && oid <= prev
156 {
157 tracing::debug!(target: "async_snmp::walk", { previous_oid = %prev, current_oid = %oid, %target }, "non-increasing OID detected");
158 return Err(Error::WalkAborted {
159 target,
160 reason: WalkAbortReason::NonIncreasing,
161 }
162 .boxed());
163 }
164 *last = Some(oid.clone());
165 Ok(())
166 }
167 OidTracker::Relaxed { seen } => {
168 if !seen.insert(oid.clone()) {
169 tracing::debug!(target: "async_snmp::walk", { %oid, %target }, "duplicate OID detected (cycle)");
170 return Err(Error::WalkAborted {
171 target,
172 reason: WalkAbortReason::Cycle,
173 }
174 .boxed());
175 }
176 Ok(())
177 }
178 }
179 }
180}
181
182pub struct Walk<T: Transport> {
186 client: Client<T>,
187 base_oid: Oid,
188 current_oid: Oid,
189 oid_tracker: OidTracker,
191 max_results: Option<usize>,
193 count: usize,
195 done: bool,
196 pending: Option<Pin<Box<dyn std::future::Future<Output = Result<VarBind>> + Send>>>,
197}
198
199impl<T: Transport> Walk<T> {
200 pub(crate) fn new(
201 client: Client<T>,
202 oid: Oid,
203 ordering: OidOrdering,
204 max_results: Option<usize>,
205 ) -> Self {
206 Self {
207 client,
208 base_oid: oid.clone(),
209 current_oid: oid,
210 oid_tracker: OidTracker::new(ordering),
211 max_results,
212 count: 0,
213 done: false,
214 pending: None,
215 }
216 }
217}
218
219impl_stream_helpers!(Walk<T>);
220
221impl<T: Transport + 'static> Stream for Walk<T> {
222 type Item = Result<VarBind>;
223
224 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
225 if self.done {
226 return Poll::Ready(None);
227 }
228
229 if let Some(max) = self.max_results
231 && self.count >= max
232 {
233 self.done = true;
234 return Poll::Ready(None);
235 }
236
237 if self.pending.is_none() {
239 let client = self.client.clone();
241 let oid = self.current_oid.clone();
242
243 let fut = Box::pin(async move { client.get_next(&oid).await });
244 self.pending = Some(fut);
245 }
246
247 let pending = self.pending.as_mut().unwrap();
249 match pending.as_mut().poll(cx) {
250 Poll::Pending => Poll::Pending,
251 Poll::Ready(result) => {
252 self.pending = None;
253
254 match result {
255 Ok(vb) => {
256 let target = self.client.peer_addr();
257 let base_oid = self.base_oid.clone();
258 match validate_walk_varbind(&vb, &base_oid, &mut self.oid_tracker, target) {
259 VarbindOutcome::Done => {
260 self.done = true;
261 return Poll::Ready(None);
262 }
263 VarbindOutcome::Abort(e) => {
264 self.done = true;
265 return Poll::Ready(Some(Err(e)));
266 }
267 VarbindOutcome::Yield => {}
268 }
269
270 self.current_oid = vb.oid.clone();
272 self.count += 1;
273
274 Poll::Ready(Some(Ok(vb)))
275 }
276 Err(e) => {
277 if self.client.inner.config.version == Version::V1
278 && matches!(
279 &*e,
280 Error::Snmp {
281 status: crate::error::ErrorStatus::NoSuchName,
282 ..
283 }
284 )
285 {
286 self.done = true;
287 return Poll::Ready(None);
288 }
289
290 self.done = true;
291 Poll::Ready(Some(Err(e)))
292 }
293 }
294 }
295 }
296 }
297}
298
299pub struct BulkWalk<T: Transport> {
303 client: Client<T>,
304 base_oid: Oid,
305 current_oid: Oid,
306 max_repetitions: i32,
307 oid_tracker: OidTracker,
309 max_results: Option<usize>,
311 count: usize,
313 done: bool,
314 buffer: VecDeque<VarBind>,
316 pending: Option<Pin<Box<dyn std::future::Future<Output = Result<Vec<VarBind>>> + Send>>>,
317}
318
319impl<T: Transport> BulkWalk<T> {
320 pub(crate) fn new(
321 client: Client<T>,
322 oid: Oid,
323 max_repetitions: i32,
324 ordering: OidOrdering,
325 max_results: Option<usize>,
326 ) -> Self {
327 Self {
328 client,
329 base_oid: oid.clone(),
330 current_oid: oid,
331 max_repetitions,
332 oid_tracker: OidTracker::new(ordering),
333 max_results,
334 count: 0,
335 done: false,
336 buffer: VecDeque::new(),
337 pending: None,
338 }
339 }
340}
341
342impl_stream_helpers!(BulkWalk<T>);
343
344impl<T: Transport + 'static> Stream for BulkWalk<T> {
345 type Item = Result<VarBind>;
346
347 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
348 loop {
349 if self.done {
350 return Poll::Ready(None);
351 }
352
353 if let Some(max) = self.max_results
355 && self.count >= max
356 {
357 self.done = true;
358 return Poll::Ready(None);
359 }
360
361 if let Some(vb) = self.buffer.pop_front() {
363 let target = self.client.peer_addr();
364 let base_oid = self.base_oid.clone();
365 match validate_walk_varbind(&vb, &base_oid, &mut self.oid_tracker, target) {
366 VarbindOutcome::Done => {
367 self.done = true;
368 return Poll::Ready(None);
369 }
370 VarbindOutcome::Abort(e) => {
371 self.done = true;
372 return Poll::Ready(Some(Err(e)));
373 }
374 VarbindOutcome::Yield => {}
375 }
376
377 self.current_oid = vb.oid.clone();
379 self.count += 1;
380
381 return Poll::Ready(Some(Ok(vb)));
382 }
383
384 if self.pending.is_none() {
386 let client = self.client.clone();
387 let oid = self.current_oid.clone();
388 let max_rep = self.max_repetitions;
389
390 let fut = Box::pin(async move { client.get_bulk(&[oid], 0, max_rep).await });
391 self.pending = Some(fut);
392 }
393
394 let pending = self.pending.as_mut().unwrap();
396 match pending.as_mut().poll(cx) {
397 Poll::Pending => return Poll::Pending,
398 Poll::Ready(result) => {
399 self.pending = None;
400
401 match result {
402 Ok(varbinds) => {
403 if varbinds.is_empty() {
404 self.done = true;
405 return Poll::Ready(None);
406 }
407
408 self.buffer = varbinds.into();
409 }
411 Err(e) => {
412 self.done = true;
413 return Poll::Ready(Some(Err(e)));
414 }
415 }
416 }
417 }
418 }
419 }
420}
421
422pub enum WalkStream<T: Transport> {
434 GetNext(Walk<T>),
436 GetBulk(BulkWalk<T>),
438}
439
440impl<T: Transport> WalkStream<T> {
441 pub(crate) fn new(
443 client: Client<T>,
444 oid: Oid,
445 version: Version,
446 walk_mode: WalkMode,
447 ordering: OidOrdering,
448 max_results: Option<usize>,
449 max_repetitions: i32,
450 ) -> Result<Self> {
451 let use_bulk = match walk_mode {
452 WalkMode::Auto => version != Version::V1,
453 WalkMode::GetNext => false,
454 WalkMode::GetBulk => {
455 if version == Version::V1 {
456 return Err(Error::Config("GETBULK is not supported in SNMPv1".into()).boxed());
457 }
458 true
459 }
460 };
461
462 Ok(if use_bulk {
463 WalkStream::GetBulk(BulkWalk::new(
464 client,
465 oid,
466 max_repetitions,
467 ordering,
468 max_results,
469 ))
470 } else {
471 WalkStream::GetNext(Walk::new(client, oid, ordering, max_results))
472 })
473 }
474}
475
476impl<T: Transport + 'static> WalkStream<T> {
477 pub async fn next(&mut self) -> Option<Result<VarBind>> {
479 std::future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
480 }
481
482 pub async fn collect(mut self) -> Result<Vec<VarBind>> {
489 let mut results = Vec::new();
490 while let Some(result) = self.next().await {
491 results.push(result?);
492 }
493 if results.is_empty() {
494 let (client, base_oid) = match &self {
495 WalkStream::GetNext(w) => (&w.client, &w.base_oid),
496 WalkStream::GetBulk(bw) => (&bw.client, &bw.base_oid),
497 };
498 match client.get(base_oid).await {
499 Ok(vb)
500 if !matches!(
501 vb.value,
502 Value::NoSuchObject | Value::NoSuchInstance | Value::EndOfMibView
503 ) =>
504 {
505 results.push(vb);
506 }
507 _ => {}
508 }
509 }
510 Ok(results)
511 }
512}
513
514impl<T: Transport + 'static> Stream for WalkStream<T> {
515 type Item = Result<VarBind>;
516
517 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
518 match self.get_mut() {
520 WalkStream::GetNext(walk) => Pin::new(walk).poll_next(cx),
521 WalkStream::GetBulk(bulk_walk) => Pin::new(bulk_walk).poll_next(cx),
522 }
523 }
524}
525
526#[cfg(test)]
527mod tests {
528 use super::*;
529 use crate::oid;
530
531 fn target_addr() -> std::net::SocketAddr {
532 "127.0.0.1:161".parse().unwrap()
533 }
534
535 #[test]
536 fn test_walk_terminates_on_no_such_object() {
537 let base = oid!(1, 3, 6, 1, 2, 1, 1);
538 let mut tracker = OidTracker::new(OidOrdering::Strict);
539 let vb = VarBind::new(oid!(1, 3, 6, 1, 2, 1, 1, 1, 0), Value::NoSuchObject);
540 assert!(matches!(
541 validate_walk_varbind(&vb, &base, &mut tracker, target_addr()),
542 VarbindOutcome::Done
543 ));
544 }
545
546 #[test]
547 fn test_walk_terminates_on_no_such_instance() {
548 let base = oid!(1, 3, 6, 1, 2, 1, 1);
549 let mut tracker = OidTracker::new(OidOrdering::Strict);
550 let vb = VarBind::new(oid!(1, 3, 6, 1, 2, 1, 1, 1, 0), Value::NoSuchInstance);
551 assert!(matches!(
552 validate_walk_varbind(&vb, &base, &mut tracker, target_addr()),
553 VarbindOutcome::Done
554 ));
555 }
556
557 #[test]
558 fn test_walk_terminates_on_end_of_mib_view() {
559 let base = oid!(1, 3, 6, 1, 2, 1, 1);
560 let mut tracker = OidTracker::new(OidOrdering::Strict);
561 let vb = VarBind::new(oid!(1, 3, 6, 1, 2, 1, 1, 1, 0), Value::EndOfMibView);
562 assert!(matches!(
563 validate_walk_varbind(&vb, &base, &mut tracker, target_addr()),
564 VarbindOutcome::Done
565 ));
566 }
567
568 #[test]
569 fn test_walk_yields_normal_value() {
570 let base = oid!(1, 3, 6, 1, 2, 1, 1);
571 let mut tracker = OidTracker::new(OidOrdering::Strict);
572 let vb = VarBind::new(oid!(1, 3, 6, 1, 2, 1, 1, 1, 0), Value::Integer(42));
573 assert!(matches!(
574 validate_walk_varbind(&vb, &base, &mut tracker, target_addr()),
575 VarbindOutcome::Yield
576 ));
577 }
578}