#![allow(clippy::type_complexity)]
macro_rules! impl_stream_helpers {
($type:ident < $($gen:tt),+ >) => {
impl<$($gen),+> $type<$($gen),+>
where
$($gen: crate::transport::Transport + 'static,)+
{
pub async fn next(&mut self) -> Option<crate::error::Result<crate::varbind::VarBind>> {
std::future::poll_fn(|cx| std::pin::Pin::new(&mut *self).poll_next(cx)).await
}
pub async fn collect(mut self) -> crate::error::Result<Vec<crate::varbind::VarBind>> {
let mut results = Vec::new();
while let Some(result) = self.next().await {
results.push(result?);
}
Ok(results)
}
}
};
}
use std::collections::{HashSet, VecDeque};
use std::pin::Pin;
use std::task::{Context, Poll};
use futures_core::Stream;
use crate::error::{Error, Result, WalkAbortReason};
use crate::oid::Oid;
use crate::transport::Transport;
use crate::value::Value;
use crate::varbind::VarBind;
use crate::version::Version;
use super::Client;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum WalkMode {
#[default]
Auto,
GetNext,
GetBulk,
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum OidOrdering {
#[default]
Strict,
AllowNonIncreasing,
}
enum OidTracker {
Strict { last: Option<Oid> },
Relaxed { seen: HashSet<Oid> },
}
enum VarbindOutcome {
Yield,
Done,
Abort(Box<Error>),
}
fn validate_walk_varbind(
vb: &VarBind,
base_oid: &Oid,
oid_tracker: &mut OidTracker,
target: std::net::SocketAddr,
) -> VarbindOutcome {
if matches!(
vb.value,
Value::EndOfMibView | Value::NoSuchObject | Value::NoSuchInstance
) {
return VarbindOutcome::Done;
}
if !vb.oid.starts_with(base_oid) {
return VarbindOutcome::Done;
}
match oid_tracker.check(&vb.oid, target) {
Ok(()) => VarbindOutcome::Yield,
Err(e) => VarbindOutcome::Abort(e),
}
}
impl OidTracker {
fn new(ordering: OidOrdering) -> Self {
match ordering {
OidOrdering::Strict => OidTracker::Strict { last: None },
OidOrdering::AllowNonIncreasing => OidTracker::Relaxed {
seen: HashSet::new(),
},
}
}
fn check(&mut self, oid: &Oid, target: std::net::SocketAddr) -> Result<()> {
match self {
OidTracker::Strict { last } => {
if let Some(prev) = last
&& oid <= prev
{
tracing::debug!(target: "async_snmp::walk", { previous_oid = %prev, current_oid = %oid, %target }, "non-increasing OID detected");
return Err(Error::WalkAborted {
target,
reason: WalkAbortReason::NonIncreasing,
}
.boxed());
}
*last = Some(oid.clone());
Ok(())
}
OidTracker::Relaxed { seen } => {
if !seen.insert(oid.clone()) {
tracing::debug!(target: "async_snmp::walk", { %oid, %target }, "duplicate OID detected (cycle)");
return Err(Error::WalkAborted {
target,
reason: WalkAbortReason::Cycle,
}
.boxed());
}
Ok(())
}
}
}
}
pub struct Walk<T: Transport> {
client: Client<T>,
base_oid: Oid,
current_oid: Oid,
oid_tracker: OidTracker,
max_results: Option<usize>,
count: usize,
done: bool,
pending: Option<Pin<Box<dyn std::future::Future<Output = Result<VarBind>> + Send>>>,
}
impl<T: Transport> Walk<T> {
pub(crate) fn new(
client: Client<T>,
oid: Oid,
ordering: OidOrdering,
max_results: Option<usize>,
) -> Self {
Self {
client,
base_oid: oid.clone(),
current_oid: oid,
oid_tracker: OidTracker::new(ordering),
max_results,
count: 0,
done: false,
pending: None,
}
}
}
impl_stream_helpers!(Walk<T>);
impl<T: Transport + 'static> Stream for Walk<T> {
type Item = Result<VarBind>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.done {
return Poll::Ready(None);
}
if let Some(max) = self.max_results
&& self.count >= max
{
self.done = true;
return Poll::Ready(None);
}
if self.pending.is_none() {
let client = self.client.clone();
let oid = self.current_oid.clone();
let fut = Box::pin(async move { client.get_next(&oid).await });
self.pending = Some(fut);
}
let pending = self.pending.as_mut().unwrap();
match pending.as_mut().poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(result) => {
self.pending = None;
match result {
Ok(vb) => {
let target = self.client.peer_addr();
let base_oid = self.base_oid.clone();
match validate_walk_varbind(&vb, &base_oid, &mut self.oid_tracker, target) {
VarbindOutcome::Done => {
self.done = true;
return Poll::Ready(None);
}
VarbindOutcome::Abort(e) => {
self.done = true;
return Poll::Ready(Some(Err(e)));
}
VarbindOutcome::Yield => {}
}
self.current_oid = vb.oid.clone();
self.count += 1;
Poll::Ready(Some(Ok(vb)))
}
Err(e) => {
if self.client.inner.config.version == Version::V1
&& matches!(
&*e,
Error::Snmp {
status: crate::error::ErrorStatus::NoSuchName,
..
}
)
{
self.done = true;
return Poll::Ready(None);
}
self.done = true;
Poll::Ready(Some(Err(e)))
}
}
}
}
}
}
pub struct BulkWalk<T: Transport> {
client: Client<T>,
base_oid: Oid,
current_oid: Oid,
max_repetitions: i32,
oid_tracker: OidTracker,
max_results: Option<usize>,
count: usize,
done: bool,
buffer: VecDeque<VarBind>,
pending: Option<Pin<Box<dyn std::future::Future<Output = Result<Vec<VarBind>>> + Send>>>,
}
impl<T: Transport> BulkWalk<T> {
pub(crate) fn new(
client: Client<T>,
oid: Oid,
max_repetitions: i32,
ordering: OidOrdering,
max_results: Option<usize>,
) -> Self {
Self {
client,
base_oid: oid.clone(),
current_oid: oid,
max_repetitions,
oid_tracker: OidTracker::new(ordering),
max_results,
count: 0,
done: false,
buffer: VecDeque::new(),
pending: None,
}
}
}
impl_stream_helpers!(BulkWalk<T>);
impl<T: Transport + 'static> Stream for BulkWalk<T> {
type Item = Result<VarBind>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
if self.done {
return Poll::Ready(None);
}
if let Some(max) = self.max_results
&& self.count >= max
{
self.done = true;
return Poll::Ready(None);
}
if let Some(vb) = self.buffer.pop_front() {
let target = self.client.peer_addr();
let base_oid = self.base_oid.clone();
match validate_walk_varbind(&vb, &base_oid, &mut self.oid_tracker, target) {
VarbindOutcome::Done => {
self.done = true;
return Poll::Ready(None);
}
VarbindOutcome::Abort(e) => {
self.done = true;
return Poll::Ready(Some(Err(e)));
}
VarbindOutcome::Yield => {}
}
self.current_oid = vb.oid.clone();
self.count += 1;
return Poll::Ready(Some(Ok(vb)));
}
if self.pending.is_none() {
let client = self.client.clone();
let oid = self.current_oid.clone();
let max_rep = self.max_repetitions;
let fut = Box::pin(async move { client.get_bulk(&[oid], 0, max_rep).await });
self.pending = Some(fut);
}
let pending = self.pending.as_mut().unwrap();
match pending.as_mut().poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(result) => {
self.pending = None;
match result {
Ok(varbinds) => {
if varbinds.is_empty() {
self.done = true;
return Poll::Ready(None);
}
self.buffer = varbinds.into();
}
Err(e) => {
self.done = true;
return Poll::Ready(Some(Err(e)));
}
}
}
}
}
}
}
pub enum WalkStream<T: Transport> {
GetNext(Walk<T>),
GetBulk(BulkWalk<T>),
}
impl<T: Transport> WalkStream<T> {
pub(crate) fn new(
client: Client<T>,
oid: Oid,
version: Version,
walk_mode: WalkMode,
ordering: OidOrdering,
max_results: Option<usize>,
max_repetitions: i32,
) -> Result<Self> {
let use_bulk = match walk_mode {
WalkMode::Auto => version != Version::V1,
WalkMode::GetNext => false,
WalkMode::GetBulk => {
if version == Version::V1 {
return Err(Error::Config("GETBULK is not supported in SNMPv1".into()).boxed());
}
true
}
};
Ok(if use_bulk {
WalkStream::GetBulk(BulkWalk::new(
client,
oid,
max_repetitions,
ordering,
max_results,
))
} else {
WalkStream::GetNext(Walk::new(client, oid, ordering, max_results))
})
}
}
impl<T: Transport + 'static> WalkStream<T> {
pub async fn next(&mut self) -> Option<Result<VarBind>> {
std::future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
}
pub async fn collect(mut self) -> Result<Vec<VarBind>> {
let mut results = Vec::new();
while let Some(result) = self.next().await {
results.push(result?);
}
if results.is_empty() {
let (client, base_oid) = match &self {
WalkStream::GetNext(w) => (&w.client, &w.base_oid),
WalkStream::GetBulk(bw) => (&bw.client, &bw.base_oid),
};
match client.get(base_oid).await {
Ok(vb)
if !matches!(
vb.value,
Value::NoSuchObject | Value::NoSuchInstance | Value::EndOfMibView
) =>
{
results.push(vb);
}
_ => {}
}
}
Ok(results)
}
}
impl<T: Transport + 'static> Stream for WalkStream<T> {
type Item = Result<VarBind>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.get_mut() {
WalkStream::GetNext(walk) => Pin::new(walk).poll_next(cx),
WalkStream::GetBulk(bulk_walk) => Pin::new(bulk_walk).poll_next(cx),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::oid;
fn target_addr() -> std::net::SocketAddr {
"127.0.0.1:161".parse().unwrap()
}
#[test]
fn test_walk_terminates_on_no_such_object() {
let base = oid!(1, 3, 6, 1, 2, 1, 1);
let mut tracker = OidTracker::new(OidOrdering::Strict);
let vb = VarBind::new(oid!(1, 3, 6, 1, 2, 1, 1, 1, 0), Value::NoSuchObject);
assert!(matches!(
validate_walk_varbind(&vb, &base, &mut tracker, target_addr()),
VarbindOutcome::Done
));
}
#[test]
fn test_walk_terminates_on_no_such_instance() {
let base = oid!(1, 3, 6, 1, 2, 1, 1);
let mut tracker = OidTracker::new(OidOrdering::Strict);
let vb = VarBind::new(oid!(1, 3, 6, 1, 2, 1, 1, 1, 0), Value::NoSuchInstance);
assert!(matches!(
validate_walk_varbind(&vb, &base, &mut tracker, target_addr()),
VarbindOutcome::Done
));
}
#[test]
fn test_walk_terminates_on_end_of_mib_view() {
let base = oid!(1, 3, 6, 1, 2, 1, 1);
let mut tracker = OidTracker::new(OidOrdering::Strict);
let vb = VarBind::new(oid!(1, 3, 6, 1, 2, 1, 1, 1, 0), Value::EndOfMibView);
assert!(matches!(
validate_walk_varbind(&vb, &base, &mut tracker, target_addr()),
VarbindOutcome::Done
));
}
#[test]
fn test_walk_yields_normal_value() {
let base = oid!(1, 3, 6, 1, 2, 1, 1);
let mut tracker = OidTracker::new(OidOrdering::Strict);
let vb = VarBind::new(oid!(1, 3, 6, 1, 2, 1, 1, 1, 0), Value::Integer(42));
assert!(matches!(
validate_walk_varbind(&vb, &base, &mut tracker, target_addr()),
VarbindOutcome::Yield
));
}
}