1#![allow(clippy::type_complexity)]
4
5use std::collections::HashSet;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use futures_core::Stream;
10
11use crate::error::{Error, Result, WalkAbortReason};
12use crate::oid::Oid;
13use crate::transport::Transport;
14use crate::value::Value;
15use crate::varbind::VarBind;
16use crate::version::Version;
17
18use super::Client;
19
20#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
22#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
23pub enum WalkMode {
24 #[default]
27 Auto,
28 GetNext,
30 GetBulk,
32}
33
34#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
51#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
52pub enum OidOrdering {
53 #[default]
58 Strict,
59
60 AllowNonIncreasing,
74}
75
76enum OidTracker {
77 Strict { last: Option<Oid> },
78 Relaxed { seen: HashSet<Oid> },
79}
80
81impl OidTracker {
82 fn new(ordering: OidOrdering) -> Self {
83 match ordering {
84 OidOrdering::Strict => OidTracker::Strict { last: None },
85 OidOrdering::AllowNonIncreasing => OidTracker::Relaxed {
86 seen: HashSet::new(),
87 },
88 }
89 }
90
91 fn check(&mut self, oid: &Oid, target: std::net::SocketAddr) -> Result<()> {
92 match self {
93 OidTracker::Strict { last } => {
94 if let Some(prev) = last
95 && oid <= prev
96 {
97 tracing::debug!(target: "async_snmp::walk", { previous_oid = %prev, current_oid = %oid, %target }, "non-increasing OID detected");
98 return Err(Error::WalkAborted {
99 target,
100 reason: WalkAbortReason::NonIncreasing,
101 }
102 .boxed());
103 }
104 *last = Some(oid.clone());
105 Ok(())
106 }
107 OidTracker::Relaxed { seen } => {
108 if !seen.insert(oid.clone()) {
109 tracing::debug!(target: "async_snmp::walk", { %oid, %target }, "duplicate OID detected (cycle)");
110 return Err(Error::WalkAborted {
111 target,
112 reason: WalkAbortReason::Cycle,
113 }
114 .boxed());
115 }
116 Ok(())
117 }
118 }
119 }
120}
121
122pub struct Walk<T: Transport> {
126 client: Client<T>,
127 base_oid: Oid,
128 current_oid: Oid,
129 oid_tracker: OidTracker,
131 max_results: Option<usize>,
133 count: usize,
135 done: bool,
136 pending: Option<Pin<Box<dyn std::future::Future<Output = Result<VarBind>> + Send>>>,
137}
138
139impl<T: Transport> Walk<T> {
140 pub(crate) fn new(
141 client: Client<T>,
142 oid: Oid,
143 ordering: OidOrdering,
144 max_results: Option<usize>,
145 ) -> Self {
146 Self {
147 client,
148 base_oid: oid.clone(),
149 current_oid: oid,
150 oid_tracker: OidTracker::new(ordering),
151 max_results,
152 count: 0,
153 done: false,
154 pending: None,
155 }
156 }
157}
158
159impl<T: Transport + 'static> Walk<T> {
160 pub async fn next(&mut self) -> Option<Result<VarBind>> {
162 std::future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
163 }
164
165 pub async fn collect(mut self) -> Result<Vec<VarBind>> {
167 let mut results = Vec::new();
168 while let Some(result) = self.next().await {
169 results.push(result?);
170 }
171 Ok(results)
172 }
173}
174
175impl<T: Transport + 'static> Stream for Walk<T> {
176 type Item = Result<VarBind>;
177
178 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
179 if self.done {
180 return Poll::Ready(None);
181 }
182
183 if let Some(max) = self.max_results
185 && self.count >= max
186 {
187 self.done = true;
188 return Poll::Ready(None);
189 }
190
191 if self.pending.is_none() {
193 let client = self.client.clone();
195 let oid = self.current_oid.clone();
196
197 let fut = Box::pin(async move { client.get_next(&oid).await });
198 self.pending = Some(fut);
199 }
200
201 let pending = self.pending.as_mut().unwrap();
203 match pending.as_mut().poll(cx) {
204 Poll::Pending => Poll::Pending,
205 Poll::Ready(result) => {
206 self.pending = None;
207
208 match result {
209 Ok(vb) => {
210 if matches!(vb.value, Value::EndOfMibView) {
212 self.done = true;
213 return Poll::Ready(None);
214 }
215
216 if !vb.oid.starts_with(&self.base_oid) {
218 self.done = true;
219 return Poll::Ready(None);
220 }
221
222 let target = self.client.peer_addr();
224 if let Err(e) = self.oid_tracker.check(&vb.oid, target) {
225 self.done = true;
226 return Poll::Ready(Some(Err(e)));
227 }
228
229 self.current_oid = vb.oid.clone();
231 self.count += 1;
232
233 Poll::Ready(Some(Ok(vb)))
234 }
235 Err(e) => {
236 self.done = true;
237 Poll::Ready(Some(Err(e)))
238 }
239 }
240 }
241 }
242 }
243}
244
245pub struct BulkWalk<T: Transport> {
249 client: Client<T>,
250 base_oid: Oid,
251 current_oid: Oid,
252 max_repetitions: i32,
253 oid_tracker: OidTracker,
255 max_results: Option<usize>,
257 count: usize,
259 done: bool,
260 buffer: Vec<VarBind>,
262 buffer_idx: usize,
264 pending: Option<Pin<Box<dyn std::future::Future<Output = Result<Vec<VarBind>>> + Send>>>,
265}
266
267impl<T: Transport> BulkWalk<T> {
268 pub(crate) fn new(
269 client: Client<T>,
270 oid: Oid,
271 max_repetitions: i32,
272 ordering: OidOrdering,
273 max_results: Option<usize>,
274 ) -> Self {
275 Self {
276 client,
277 base_oid: oid.clone(),
278 current_oid: oid,
279 max_repetitions,
280 oid_tracker: OidTracker::new(ordering),
281 max_results,
282 count: 0,
283 done: false,
284 buffer: Vec::new(),
285 buffer_idx: 0,
286 pending: None,
287 }
288 }
289}
290
291impl<T: Transport + 'static> BulkWalk<T> {
292 pub async fn next(&mut self) -> Option<Result<VarBind>> {
294 std::future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
295 }
296
297 pub async fn collect(mut self) -> Result<Vec<VarBind>> {
299 let mut results = Vec::new();
300 while let Some(result) = self.next().await {
301 results.push(result?);
302 }
303 Ok(results)
304 }
305}
306
307impl<T: Transport + 'static> Stream for BulkWalk<T> {
308 type Item = Result<VarBind>;
309
310 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
311 loop {
312 if self.done {
313 return Poll::Ready(None);
314 }
315
316 if let Some(max) = self.max_results
318 && self.count >= max
319 {
320 self.done = true;
321 return Poll::Ready(None);
322 }
323
324 if self.buffer_idx < self.buffer.len() {
326 let vb = self.buffer[self.buffer_idx].clone();
327 self.buffer_idx += 1;
328
329 if matches!(vb.value, Value::EndOfMibView) {
331 self.done = true;
332 return Poll::Ready(None);
333 }
334
335 if !vb.oid.starts_with(&self.base_oid) {
337 self.done = true;
338 return Poll::Ready(None);
339 }
340
341 let target = self.client.peer_addr();
343 if let Err(e) = self.oid_tracker.check(&vb.oid, target) {
344 self.done = true;
345 return Poll::Ready(Some(Err(e)));
346 }
347
348 self.current_oid = vb.oid.clone();
350 self.count += 1;
351
352 return Poll::Ready(Some(Ok(vb)));
353 }
354
355 if self.pending.is_none() {
357 let client = self.client.clone();
358 let oid = self.current_oid.clone();
359 let max_rep = self.max_repetitions;
360
361 let fut = Box::pin(async move { client.get_bulk(&[oid], 0, max_rep).await });
362 self.pending = Some(fut);
363 }
364
365 let pending = self.pending.as_mut().unwrap();
367 match pending.as_mut().poll(cx) {
368 Poll::Pending => return Poll::Pending,
369 Poll::Ready(result) => {
370 self.pending = None;
371
372 match result {
373 Ok(varbinds) => {
374 if varbinds.is_empty() {
375 self.done = true;
376 return Poll::Ready(None);
377 }
378
379 self.buffer = varbinds;
380 self.buffer_idx = 0;
381 }
383 Err(e) => {
384 self.done = true;
385 return Poll::Ready(Some(Err(e)));
386 }
387 }
388 }
389 }
390 }
391 }
392}
393
394pub enum WalkStream<T: Transport> {
406 GetNext(Walk<T>),
408 GetBulk(BulkWalk<T>),
410}
411
412impl<T: Transport> WalkStream<T> {
413 pub(crate) fn new(
415 client: Client<T>,
416 oid: Oid,
417 version: Version,
418 walk_mode: WalkMode,
419 ordering: OidOrdering,
420 max_results: Option<usize>,
421 max_repetitions: i32,
422 ) -> Result<Self> {
423 let use_bulk = match walk_mode {
424 WalkMode::Auto => version != Version::V1,
425 WalkMode::GetNext => false,
426 WalkMode::GetBulk => {
427 if version == Version::V1 {
428 return Err(Error::Config("GETBULK is not supported in SNMPv1".into()).boxed());
429 }
430 true
431 }
432 };
433
434 Ok(if use_bulk {
435 WalkStream::GetBulk(BulkWalk::new(
436 client,
437 oid,
438 max_repetitions,
439 ordering,
440 max_results,
441 ))
442 } else {
443 WalkStream::GetNext(Walk::new(client, oid, ordering, max_results))
444 })
445 }
446}
447
448impl<T: Transport + 'static> WalkStream<T> {
449 pub async fn next(&mut self) -> Option<Result<VarBind>> {
451 std::future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
452 }
453
454 pub async fn collect(mut self) -> Result<Vec<VarBind>> {
456 let mut results = Vec::new();
457 while let Some(result) = self.next().await {
458 results.push(result?);
459 }
460 Ok(results)
461 }
462}
463
464impl<T: Transport + 'static> Stream for WalkStream<T> {
465 type Item = Result<VarBind>;
466
467 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
468 match self.get_mut() {
470 WalkStream::GetNext(walk) => Pin::new(walk).poll_next(cx),
471 WalkStream::GetBulk(bulk_walk) => Pin::new(bulk_walk).poll_next(cx),
472 }
473 }
474}