1#![allow(clippy::type_complexity)]
7
8macro_rules! impl_stream_helpers {
10 ($type:ident < $($gen:tt),+ >) => {
11 impl<$($gen),+> $type<$($gen),+>
12 where
13 $($gen: crate::transport::Transport + 'static,)+
14 {
15 pub async fn next(&mut self) -> Option<crate::error::Result<crate::varbind::VarBind>> {
17 std::future::poll_fn(|cx| std::pin::Pin::new(&mut *self).poll_next(cx)).await
18 }
19
20 pub async fn collect(mut self) -> crate::error::Result<Vec<crate::varbind::VarBind>> {
22 let mut results = Vec::new();
23 while let Some(result) = self.next().await {
24 results.push(result?);
25 }
26 Ok(results)
27 }
28 }
29 };
30}
31
32use std::collections::{HashSet, VecDeque};
33use std::pin::Pin;
34use std::task::{Context, Poll};
35
36use futures_core::Stream;
37
38use crate::error::{Error, Result, WalkAbortReason};
39use crate::oid::Oid;
40use crate::transport::Transport;
41use crate::value::Value;
42use crate::varbind::VarBind;
43use crate::version::Version;
44
45use super::Client;
46
47#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
49pub enum WalkMode {
50 #[default]
53 Auto,
54 GetNext,
56 GetBulk,
58}
59
60#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
77pub enum OidOrdering {
78 #[default]
83 Strict,
84
85 AllowNonIncreasing,
99}
100
101enum OidTracker {
102 Strict { last: Option<Oid> },
103 Relaxed { seen: HashSet<Oid> },
104}
105
106enum VarbindOutcome {
108 Yield,
110 Done,
112 Abort(Box<Error>),
114}
115
116fn validate_walk_varbind(
121 vb: &VarBind,
122 base_oid: &Oid,
123 oid_tracker: &mut OidTracker,
124 target: std::net::SocketAddr,
125) -> VarbindOutcome {
126 if matches!(vb.value, Value::EndOfMibView) {
127 return VarbindOutcome::Done;
128 }
129 if !vb.oid.starts_with(base_oid) {
130 return VarbindOutcome::Done;
131 }
132 match oid_tracker.check(&vb.oid, target) {
133 Ok(()) => VarbindOutcome::Yield,
134 Err(e) => VarbindOutcome::Abort(e),
135 }
136}
137
138impl OidTracker {
139 fn new(ordering: OidOrdering) -> Self {
140 match ordering {
141 OidOrdering::Strict => OidTracker::Strict { last: None },
142 OidOrdering::AllowNonIncreasing => OidTracker::Relaxed {
143 seen: HashSet::new(),
144 },
145 }
146 }
147
148 fn check(&mut self, oid: &Oid, target: std::net::SocketAddr) -> Result<()> {
149 match self {
150 OidTracker::Strict { last } => {
151 if let Some(prev) = last
152 && oid <= prev
153 {
154 tracing::debug!(target: "async_snmp::walk", { previous_oid = %prev, current_oid = %oid, %target }, "non-increasing OID detected");
155 return Err(Error::WalkAborted {
156 target,
157 reason: WalkAbortReason::NonIncreasing,
158 }
159 .boxed());
160 }
161 *last = Some(oid.clone());
162 Ok(())
163 }
164 OidTracker::Relaxed { seen } => {
165 if !seen.insert(oid.clone()) {
166 tracing::debug!(target: "async_snmp::walk", { %oid, %target }, "duplicate OID detected (cycle)");
167 return Err(Error::WalkAborted {
168 target,
169 reason: WalkAbortReason::Cycle,
170 }
171 .boxed());
172 }
173 Ok(())
174 }
175 }
176 }
177}
178
179pub struct Walk<T: Transport> {
183 client: Client<T>,
184 base_oid: Oid,
185 current_oid: Oid,
186 oid_tracker: OidTracker,
188 max_results: Option<usize>,
190 count: usize,
192 done: bool,
193 pending: Option<Pin<Box<dyn std::future::Future<Output = Result<VarBind>> + Send>>>,
194}
195
196impl<T: Transport> Walk<T> {
197 pub(crate) fn new(
198 client: Client<T>,
199 oid: Oid,
200 ordering: OidOrdering,
201 max_results: Option<usize>,
202 ) -> Self {
203 Self {
204 client,
205 base_oid: oid.clone(),
206 current_oid: oid,
207 oid_tracker: OidTracker::new(ordering),
208 max_results,
209 count: 0,
210 done: false,
211 pending: None,
212 }
213 }
214}
215
216impl_stream_helpers!(Walk<T>);
217
218impl<T: Transport + 'static> Stream for Walk<T> {
219 type Item = Result<VarBind>;
220
221 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
222 if self.done {
223 return Poll::Ready(None);
224 }
225
226 if let Some(max) = self.max_results
228 && self.count >= max
229 {
230 self.done = true;
231 return Poll::Ready(None);
232 }
233
234 if self.pending.is_none() {
236 let client = self.client.clone();
238 let oid = self.current_oid.clone();
239
240 let fut = Box::pin(async move { client.get_next(&oid).await });
241 self.pending = Some(fut);
242 }
243
244 let pending = self.pending.as_mut().unwrap();
246 match pending.as_mut().poll(cx) {
247 Poll::Pending => Poll::Pending,
248 Poll::Ready(result) => {
249 self.pending = None;
250
251 match result {
252 Ok(vb) => {
253 let target = self.client.peer_addr();
254 let base_oid = self.base_oid.clone();
255 match validate_walk_varbind(&vb, &base_oid, &mut self.oid_tracker, target) {
256 VarbindOutcome::Done => {
257 self.done = true;
258 return Poll::Ready(None);
259 }
260 VarbindOutcome::Abort(e) => {
261 self.done = true;
262 return Poll::Ready(Some(Err(e)));
263 }
264 VarbindOutcome::Yield => {}
265 }
266
267 self.current_oid = vb.oid.clone();
269 self.count += 1;
270
271 Poll::Ready(Some(Ok(vb)))
272 }
273 Err(e) => {
274 if self.client.inner.config.version == Version::V1
275 && matches!(
276 &*e,
277 Error::Snmp {
278 status: crate::error::ErrorStatus::NoSuchName,
279 ..
280 }
281 )
282 {
283 self.done = true;
284 return Poll::Ready(None);
285 }
286
287 self.done = true;
288 Poll::Ready(Some(Err(e)))
289 }
290 }
291 }
292 }
293 }
294}
295
296pub struct BulkWalk<T: Transport> {
300 client: Client<T>,
301 base_oid: Oid,
302 current_oid: Oid,
303 max_repetitions: i32,
304 oid_tracker: OidTracker,
306 max_results: Option<usize>,
308 count: usize,
310 done: bool,
311 buffer: VecDeque<VarBind>,
313 pending: Option<Pin<Box<dyn std::future::Future<Output = Result<Vec<VarBind>>> + Send>>>,
314}
315
316impl<T: Transport> BulkWalk<T> {
317 pub(crate) fn new(
318 client: Client<T>,
319 oid: Oid,
320 max_repetitions: i32,
321 ordering: OidOrdering,
322 max_results: Option<usize>,
323 ) -> Self {
324 Self {
325 client,
326 base_oid: oid.clone(),
327 current_oid: oid,
328 max_repetitions,
329 oid_tracker: OidTracker::new(ordering),
330 max_results,
331 count: 0,
332 done: false,
333 buffer: VecDeque::new(),
334 pending: None,
335 }
336 }
337}
338
339impl_stream_helpers!(BulkWalk<T>);
340
341impl<T: Transport + 'static> Stream for BulkWalk<T> {
342 type Item = Result<VarBind>;
343
344 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
345 loop {
346 if self.done {
347 return Poll::Ready(None);
348 }
349
350 if let Some(max) = self.max_results
352 && self.count >= max
353 {
354 self.done = true;
355 return Poll::Ready(None);
356 }
357
358 if let Some(vb) = self.buffer.pop_front() {
360 let target = self.client.peer_addr();
361 let base_oid = self.base_oid.clone();
362 match validate_walk_varbind(&vb, &base_oid, &mut self.oid_tracker, target) {
363 VarbindOutcome::Done => {
364 self.done = true;
365 return Poll::Ready(None);
366 }
367 VarbindOutcome::Abort(e) => {
368 self.done = true;
369 return Poll::Ready(Some(Err(e)));
370 }
371 VarbindOutcome::Yield => {}
372 }
373
374 self.current_oid = vb.oid.clone();
376 self.count += 1;
377
378 return Poll::Ready(Some(Ok(vb)));
379 }
380
381 if self.pending.is_none() {
383 let client = self.client.clone();
384 let oid = self.current_oid.clone();
385 let max_rep = self.max_repetitions;
386
387 let fut = Box::pin(async move { client.get_bulk(&[oid], 0, max_rep).await });
388 self.pending = Some(fut);
389 }
390
391 let pending = self.pending.as_mut().unwrap();
393 match pending.as_mut().poll(cx) {
394 Poll::Pending => return Poll::Pending,
395 Poll::Ready(result) => {
396 self.pending = None;
397
398 match result {
399 Ok(varbinds) => {
400 if varbinds.is_empty() {
401 self.done = true;
402 return Poll::Ready(None);
403 }
404
405 self.buffer = varbinds.into();
406 }
408 Err(e) => {
409 self.done = true;
410 return Poll::Ready(Some(Err(e)));
411 }
412 }
413 }
414 }
415 }
416 }
417}
418
419pub enum WalkStream<T: Transport> {
431 GetNext(Walk<T>),
433 GetBulk(BulkWalk<T>),
435}
436
437impl<T: Transport> WalkStream<T> {
438 pub(crate) fn new(
440 client: Client<T>,
441 oid: Oid,
442 version: Version,
443 walk_mode: WalkMode,
444 ordering: OidOrdering,
445 max_results: Option<usize>,
446 max_repetitions: i32,
447 ) -> Result<Self> {
448 let use_bulk = match walk_mode {
449 WalkMode::Auto => version != Version::V1,
450 WalkMode::GetNext => false,
451 WalkMode::GetBulk => {
452 if version == Version::V1 {
453 return Err(Error::Config("GETBULK is not supported in SNMPv1".into()).boxed());
454 }
455 true
456 }
457 };
458
459 Ok(if use_bulk {
460 WalkStream::GetBulk(BulkWalk::new(
461 client,
462 oid,
463 max_repetitions,
464 ordering,
465 max_results,
466 ))
467 } else {
468 WalkStream::GetNext(Walk::new(client, oid, ordering, max_results))
469 })
470 }
471}
472
473impl_stream_helpers!(WalkStream<T>);
474
475impl<T: Transport + 'static> Stream for WalkStream<T> {
476 type Item = Result<VarBind>;
477
478 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
479 match self.get_mut() {
481 WalkStream::GetNext(walk) => Pin::new(walk).poll_next(cx),
482 WalkStream::GetBulk(bulk_walk) => Pin::new(bulk_walk).poll_next(cx),
483 }
484 }
485}