1mod auth;
4mod builder;
5mod retry;
6mod v3;
7mod walk;
8
9pub use auth::{Auth, CommunityVersion, UsmAuth, UsmBuilder};
10pub use builder::{ClientBuilder, Target};
11pub use retry::{Backoff, Retry, RetryBuilder};
12
13impl Client<UdpHandle> {
15 pub fn builder(target: impl Into<Target>, auth: impl Into<Auth>) -> ClientBuilder {
43 ClientBuilder::new(target, auth)
44 }
45}
46use crate::error::internal::DecodeErrorKind;
47use crate::error::{Error, ErrorStatus, Result};
48use crate::message::{CommunityMessage, Message};
49use crate::oid::Oid;
50use crate::pdu::{GetBulkPdu, Pdu, TrapV1Pdu};
51use crate::transport::Transport;
52use crate::transport::UdpHandle;
53use crate::v3::{EngineCache, EngineState, SaltCounter};
54use crate::value::Value;
55use crate::varbind::VarBind;
56use crate::version::Version;
57use bytes::Bytes;
58use std::net::SocketAddr;
59use std::pin::Pin;
60use std::sync::Arc;
61use std::sync::RwLock;
62use std::time::{Duration, Instant};
63use tokio::sync::Mutex as AsyncMutex;
64use tracing::{Span, instrument};
65
66pub use crate::notification::{DerivedKeys, UsmConfig};
67pub use walk::{BulkWalk, OidOrdering, Walk, WalkMode, WalkStream};
68
69pub(crate) fn pdu_to_snmp_error(pdu: &Pdu, target: SocketAddr) -> Option<Box<Error>> {
78 if !pdu.is_error() {
79 return None;
80 }
81 let status = pdu.error_status_enum();
82 let oid = (pdu.error_index as usize)
83 .checked_sub(1)
84 .and_then(|idx| pdu.varbinds.get(idx))
85 .map(|vb| vb.oid.clone());
86 Some(
87 Error::Snmp {
88 target,
89 status,
90 index: pdu.error_index.max(0) as u32,
91 oid,
92 }
93 .boxed(),
94 )
95}
96
97pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
103
104pub const DEFAULT_MAX_OIDS_PER_REQUEST: usize = 10;
109
110pub const DEFAULT_MAX_REPETITIONS: u32 = 25;
114
115pub struct Client<T: Transport = UdpHandle> {
119 inner: Arc<ClientInner<T>>,
120}
121
122impl<T: Transport> Clone for Client<T> {
123 fn clone(&self) -> Self {
124 Self {
125 inner: Arc::clone(&self.inner),
126 }
127 }
128}
129
130struct ClientInner<T: Transport> {
131 transport: T,
132 config: ClientConfig,
133 engine_state: RwLock<Option<EngineState>>,
135 derived_keys: RwLock<Option<DerivedKeys>>,
137 salt_counter: SaltCounter,
139 engine_cache: Option<Arc<EngineCache>>,
141 discovery_lock: AsyncMutex<()>,
143 local_engine_start: Instant,
145 local_derived_keys: RwLock<Option<DerivedKeys>>,
147}
148
149#[derive(Clone)]
153pub struct ClientConfig {
154 pub version: Version,
156 pub community: Bytes,
158 pub timeout: Duration,
160 pub retry: Retry,
162 pub max_oids_per_request: usize,
164 pub v3_security: Option<UsmConfig>,
166 pub walk_mode: WalkMode,
168 pub oid_ordering: OidOrdering,
170 pub max_walk_results: Option<usize>,
172 pub max_repetitions: u32,
174 pub local_engine_id: Option<Bytes>,
179 pub local_engine_boots: u32,
181}
182
183impl Default for ClientConfig {
184 fn default() -> Self {
188 Self {
189 version: Version::V2c,
190 community: Bytes::from_static(b"public"),
191 timeout: DEFAULT_TIMEOUT,
192 retry: Retry::default(),
193 max_oids_per_request: DEFAULT_MAX_OIDS_PER_REQUEST,
194 v3_security: None,
195 walk_mode: WalkMode::Auto,
196 oid_ordering: OidOrdering::Strict,
197 max_walk_results: None,
198 max_repetitions: DEFAULT_MAX_REPETITIONS,
199 local_engine_id: None,
200 local_engine_boots: 1,
201 }
202 }
203}
204
205impl<T: Transport> Client<T> {
206 pub fn new(transport: T, config: ClientConfig) -> Self {
213 Self {
214 inner: Arc::new(ClientInner {
215 transport,
216 config,
217 engine_state: RwLock::new(None),
218 derived_keys: RwLock::new(None),
219 salt_counter: SaltCounter::new(),
220 engine_cache: None,
221 discovery_lock: AsyncMutex::new(()),
222 local_engine_start: Instant::now(),
223 local_derived_keys: RwLock::new(None),
224 }),
225 }
226 }
227
228 pub fn with_engine_cache(
230 transport: T,
231 config: ClientConfig,
232 engine_cache: Arc<EngineCache>,
233 ) -> Self {
234 Self {
235 inner: Arc::new(ClientInner {
236 transport,
237 config,
238 engine_state: RwLock::new(None),
239 derived_keys: RwLock::new(None),
240 salt_counter: SaltCounter::new(),
241 engine_cache: Some(engine_cache),
242 discovery_lock: AsyncMutex::new(()),
243 local_engine_start: Instant::now(),
244 local_derived_keys: RwLock::new(None),
245 }),
246 }
247 }
248
249 pub fn peer_addr(&self) -> SocketAddr {
254 self.inner.transport.peer_addr()
255 }
256
257 fn next_request_id(&self) -> i32 {
261 self.inner.transport.alloc_request_id()
262 }
263
264 fn is_v3(&self) -> bool {
266 self.inner.config.version == Version::V3 && self.inner.config.v3_security.is_some()
267 }
268
269 #[instrument(
271 level = "debug",
272 skip(self, data),
273 fields(
274 snmp.target = %self.peer_addr(),
275 snmp.request_id = request_id,
276 snmp.attempt = tracing::field::Empty,
277 snmp.elapsed_ms = tracing::field::Empty,
278 )
279 )]
280 async fn send_and_recv(&self, request_id: i32, data: &[u8]) -> Result<Pdu> {
281 let start = Instant::now();
282 let mut last_error: Option<Box<Error>> = None;
283 let max_attempts = if self.inner.transport.is_reliable() {
284 0
285 } else {
286 self.inner.config.retry.max_attempts
287 };
288
289 for attempt in 0..=max_attempts {
290 Span::current().record("snmp.attempt", attempt);
291 if attempt > 0 {
292 tracing::debug!(target: "async_snmp::client", "retrying request");
293 }
294
295 self.inner
297 .transport
298 .register_request(request_id, self.inner.config.timeout);
299
300 tracing::trace!(target: "async_snmp::client", { snmp.bytes = data.len() }, "sending request");
302 self.inner.transport.send(data).await?;
303
304 match self.inner.transport.recv(request_id).await {
306 Ok((response_data, _source)) => {
307 tracing::trace!(target: "async_snmp::client", { snmp.bytes = response_data.len() }, "received response");
308
309 let response = Message::decode(response_data)?;
311
312 let response_version = response.version();
314 let expected_version = self.inner.config.version;
315 if response_version != expected_version {
316 tracing::warn!(target: "async_snmp::client", { ?expected_version, ?response_version, peer = %self.peer_addr() }, "version mismatch in response");
317 return Err(Error::MalformedResponse {
318 target: self.peer_addr(),
319 }
320 .boxed());
321 }
322
323 let response_pdu = match response.into_pdu() {
324 Some(p) => p,
325 None => {
326 tracing::warn!(target: "async_snmp::client", { peer = %self.peer_addr() }, "received TrapV1 in response to request");
327 return Err(Error::MalformedResponse {
328 target: self.peer_addr(),
329 }
330 .boxed());
331 }
332 };
333
334 if response_pdu.request_id != request_id {
336 tracing::warn!(target: "async_snmp::client", { expected_request_id = request_id, actual_request_id = response_pdu.request_id, peer = %self.peer_addr() }, "request ID mismatch in response");
337 return Err(Error::MalformedResponse {
338 target: self.peer_addr(),
339 }
340 .boxed());
341 }
342
343 if let Some(err) = pdu_to_snmp_error(&response_pdu, self.peer_addr()) {
345 Span::current()
346 .record("snmp.elapsed_ms", start.elapsed().as_millis() as u64);
347 return Err(err);
348 }
349
350 Span::current().record("snmp.elapsed_ms", start.elapsed().as_millis() as u64);
351 return Ok(response_pdu);
352 }
353 Err(e) if matches!(*e, Error::Timeout { .. }) => {
354 last_error = Some(e);
355 if attempt < max_attempts {
357 let delay = self.inner.config.retry.compute_delay(attempt);
358 if !delay.is_zero() {
359 tracing::debug!(target: "async_snmp::client", { delay_ms = delay.as_millis() as u64 }, "backing off");
360 tokio::time::sleep(delay).await;
361 }
362 }
363 continue;
364 }
365 Err(e) => {
366 Span::current().record("snmp.elapsed_ms", start.elapsed().as_millis() as u64);
367 return Err(e);
368 }
369 }
370 }
371
372 let elapsed = start.elapsed();
374 Span::current().record("snmp.elapsed_ms", elapsed.as_millis() as u64);
375 tracing::debug!(target: "async_snmp::client", { request_id, peer = %self.peer_addr(), ?elapsed, retries = max_attempts }, "request timed out");
376 Err(last_error.unwrap_or_else(|| {
377 Error::Timeout {
378 target: self.peer_addr(),
379 elapsed,
380 retries: max_attempts,
381 }
382 .boxed()
383 }))
384 }
385
386 async fn send_request(&self, pdu: Pdu) -> Result<Pdu> {
388 if self.is_v3() {
390 return self.send_v3_and_recv(pdu).await;
391 }
392
393 tracing::debug!(target: "async_snmp::client", { snmp.pdu_type = ?pdu.pdu_type, snmp.varbind_count = pdu.varbinds.len() }, "sending {} request", pdu.pdu_type);
394
395 let request_id = pdu.request_id;
396 let message = CommunityMessage::new(
397 self.inner.config.version,
398 self.inner.config.community.clone(),
399 pdu,
400 );
401 let data = message.encode();
402 let response = self.send_and_recv(request_id, &data).await?;
403
404 tracing::debug!(target: "async_snmp::client", { snmp.pdu_type = ?response.pdu_type, snmp.varbind_count = response.varbinds.len(), snmp.error_status = response.error_status, snmp.error_index = response.error_index }, "received {} response", response.pdu_type);
405
406 Ok(response)
407 }
408
409 async fn send_bulk_request(&self, pdu: GetBulkPdu) -> Result<Pdu> {
411 if self.is_v3() {
413 let pdu = Pdu::get_bulk(
415 pdu.request_id,
416 pdu.non_repeaters,
417 pdu.max_repetitions,
418 pdu.varbinds,
419 );
420 return self.send_v3_and_recv(pdu).await;
421 }
422
423 tracing::debug!(target: "async_snmp::client", { snmp.non_repeaters = pdu.non_repeaters, snmp.max_repetitions = pdu.max_repetitions, snmp.varbind_count = pdu.varbinds.len() }, "sending GetBulkRequest");
424
425 let request_id = pdu.request_id;
426 let data = CommunityMessage::encode_bulk(
427 self.inner.config.version,
428 self.inner.config.community.clone(),
429 &pdu,
430 );
431 let response = self.send_and_recv(request_id, &data).await?;
432
433 tracing::debug!(target: "async_snmp::client", { snmp.pdu_type = ?response.pdu_type, snmp.varbind_count = response.varbinds.len(), snmp.error_status = response.error_status, snmp.error_index = response.error_index }, "received {} response", response.pdu_type);
434
435 Ok(response)
436 }
437
438 #[instrument(skip(self), err, fields(snmp.target = %self.peer_addr(), snmp.oid = %oid))]
440 pub async fn get(&self, oid: &Oid) -> Result<VarBind> {
441 let request_id = self.next_request_id();
442 let pdu = Pdu::get_request(request_id, std::slice::from_ref(oid));
443 let response = self.send_request(pdu).await?;
444
445 response.varbinds.into_iter().next().ok_or_else(|| {
446 tracing::debug!(target: "async_snmp::client", { peer = %self.peer_addr(), kind = %DecodeErrorKind::EmptyResponse }, "empty GET response");
447 Error::MalformedResponse {
448 target: self.peer_addr(),
449 }
450 .boxed()
451 })
452 }
453
454 #[instrument(skip(self, oids), err, fields(snmp.target = %self.peer_addr(), snmp.oid_count = oids.len()))]
475 pub async fn get_many(&self, oids: &[Oid]) -> Result<Vec<VarBind>> {
476 self.get_or_getnext_many(oids, "GET", Pdu::get_request)
477 .await
478 }
479
480 #[instrument(skip(self), err, fields(snmp.target = %self.peer_addr(), snmp.oid = %oid))]
482 pub async fn get_next(&self, oid: &Oid) -> Result<VarBind> {
483 let request_id = self.next_request_id();
484 let pdu = Pdu::get_next_request(request_id, std::slice::from_ref(oid));
485 let response = self.send_request(pdu).await?;
486
487 response.varbinds.into_iter().next().ok_or_else(|| {
488 tracing::debug!(target: "async_snmp::client", { peer = %self.peer_addr(), kind = %DecodeErrorKind::EmptyResponse }, "empty GETNEXT response");
489 Error::MalformedResponse {
490 target: self.peer_addr(),
491 }
492 .boxed()
493 })
494 }
495
496 #[instrument(skip(self, oids), err, fields(snmp.target = %self.peer_addr(), snmp.oid_count = oids.len()))]
516 pub async fn get_next_many(&self, oids: &[Oid]) -> Result<Vec<VarBind>> {
517 self.get_or_getnext_many(oids, "GETNEXT", Pdu::get_next_request)
518 .await
519 }
520
521 async fn get_or_getnext_many(
526 &self,
527 oids: &[Oid],
528 op_name: &'static str,
529 op: fn(i32, &[Oid]) -> Pdu,
530 ) -> Result<Vec<VarBind>> {
531 if oids.is_empty() {
532 return Ok(Vec::new());
533 }
534
535 let max_per_request = self.inner.config.max_oids_per_request;
536 let mut all_results = Vec::with_capacity(oids.len());
537
538 for chunk in oids.chunks(max_per_request) {
539 self.send_batch_with_bisect(chunk, op_name, op, &mut all_results)
540 .await?;
541 }
542
543 Ok(all_results)
544 }
545
546 fn send_batch_with_bisect<'a>(
552 &'a self,
553 oids: &'a [Oid],
554 op_name: &'static str,
555 op: fn(i32, &[Oid]) -> Pdu,
556 results: &'a mut Vec<VarBind>,
557 ) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>> {
558 Box::pin(async move {
559 let request_id = self.next_request_id();
560 let pdu = op(request_id, oids);
561 match self.send_request(pdu).await {
562 Ok(response) => {
563 if response.varbinds.len() > oids.len() {
564 tracing::warn!(target: "async_snmp::client", { peer = %self.peer_addr(), expected = oids.len(), actual = response.varbinds.len(), snmp.op = op_name }, "response has more varbinds than requested");
565 return Err(Error::MalformedResponse {
566 target: self.peer_addr(),
567 }
568 .boxed());
569 } else if response.varbinds.len() < oids.len() {
570 tracing::warn!(target: "async_snmp::client", { peer = %self.peer_addr(), expected = oids.len(), actual = response.varbinds.len(), snmp.op = op_name }, "response has fewer varbinds than requested");
571 }
572 results.extend(response.varbinds);
573 Ok(())
574 }
575 Err(e)
576 if oids.len() > 1
577 && matches!(
578 &*e,
579 Error::Snmp {
580 status: ErrorStatus::TooBig,
581 ..
582 }
583 ) =>
584 {
585 let mid = oids.len() / 2;
586 tracing::debug!(target: "async_snmp::client", { peer = %self.peer_addr(), snmp.batch_size = oids.len(), snmp.split_at = mid, snmp.op = op_name }, "tooBig response, bisecting batch");
587 self.send_batch_with_bisect(&oids[..mid], op_name, op, results)
588 .await?;
589 self.send_batch_with_bisect(&oids[mid..], op_name, op, results)
590 .await?;
591 Ok(())
592 }
593 Err(e) => Err(e),
594 }
595 })
596 }
597
598 #[instrument(skip(self, value), err, fields(snmp.target = %self.peer_addr(), snmp.oid = %oid))]
600 pub async fn set(&self, oid: &Oid, value: Value) -> Result<VarBind> {
601 let request_id = self.next_request_id();
602 let varbind = VarBind::new(oid.clone(), value);
603 let pdu = Pdu::set_request(request_id, vec![varbind]);
604 let response = self.send_request(pdu).await?;
605
606 response.varbinds.into_iter().next().ok_or_else(|| {
607 tracing::debug!(target: "async_snmp::client", { peer = %self.peer_addr(), kind = %DecodeErrorKind::EmptyResponse }, "empty SET response");
608 Error::MalformedResponse {
609 target: self.peer_addr(),
610 }
611 .boxed()
612 })
613 }
614
615 #[instrument(skip(self, varbinds), err, fields(snmp.target = %self.peer_addr(), snmp.oid_count = varbinds.len()))]
641 pub async fn set_many(&self, varbinds: &[(Oid, Value)]) -> Result<Vec<VarBind>> {
642 if varbinds.is_empty() {
643 return Ok(Vec::new());
644 }
645
646 let max_per_request = self.inner.config.max_oids_per_request;
647
648 if varbinds.len() > max_per_request {
649 return Err(Error::Config(
650 format!(
651 "set_many: {} varbinds exceeds max_oids_per_request ({}); \
652 SET must be atomic and cannot be split across PDUs",
653 varbinds.len(),
654 max_per_request,
655 )
656 .into(),
657 )
658 .boxed());
659 }
660
661 let request_id = self.next_request_id();
662 let vbs: Vec<VarBind> = varbinds
663 .iter()
664 .map(|(oid, value)| VarBind::new(oid.clone(), value.clone()))
665 .collect();
666 let expected_count = vbs.len();
667 let pdu = Pdu::set_request(request_id, vbs);
668 let response = self.send_request(pdu).await?;
669 if response.varbinds.len() > expected_count {
670 tracing::warn!(target: "async_snmp::client", { peer = %self.peer_addr(), expected = expected_count, actual = response.varbinds.len() }, "SET response has more varbinds than requested");
671 return Err(Error::MalformedResponse {
672 target: self.peer_addr(),
673 }
674 .boxed());
675 } else if response.varbinds.len() < expected_count {
676 tracing::warn!(target: "async_snmp::client", { peer = %self.peer_addr(), expected = expected_count, actual = response.varbinds.len() }, "SET response has fewer varbinds than requested");
677 }
678 Ok(response.varbinds)
679 }
680
681 #[instrument(skip(self, varbinds), err, fields(snmp.target = %self.peer_addr(), snmp.trap_oid = %trap_oid))]
700 pub async fn send_trap(
701 &self,
702 trap_oid: &Oid,
703 uptime: u32,
704 varbinds: Vec<VarBind>,
705 ) -> Result<()> {
706 if self.inner.config.version == Version::V1 {
707 let local_ip = match self.inner.transport.local_addr().ip() {
710 std::net::IpAddr::V4(v4) => v4.octets(),
711 std::net::IpAddr::V6(_) => [0, 0, 0, 0],
712 };
713 let pdu = Pdu::trap_v2(0, uptime, trap_oid, varbinds);
716 let trap = pdu.to_v1_trap(local_ip).ok_or_else(|| {
717 Error::Config("cannot convert trap to v1 (Counter64 varbind?)".into()).boxed()
718 })?;
719 return self.send_v1_trap(trap).await;
720 }
721
722 let request_id = self.next_request_id();
723 let pdu = Pdu::trap_v2(request_id, uptime, trap_oid, varbinds);
724
725 if self.is_v3() {
726 self.ensure_local_keys_derived()?;
727 let msg_id = self.next_request_id();
728 let data = self.build_v3_trap_message(&pdu, msg_id)?;
729 tracing::debug!(target: "async_snmp::client", { snmp.pdu_type = "TrapV2", snmp.varbind_count = pdu.varbinds.len(), snmp.bytes = data.len() }, "sending V3 trap");
730 self.inner.transport.send(&data).await?;
731 } else {
732 let message = CommunityMessage::new(
733 self.inner.config.version,
734 self.inner.config.community.clone(),
735 pdu,
736 );
737 let data = message.encode();
738 tracing::debug!(target: "async_snmp::client", { snmp.pdu_type = "TrapV2", snmp.bytes = data.len() }, "sending v2c trap");
739 self.inner.transport.send(&data).await?;
740 }
741
742 Ok(())
743 }
744
745 #[instrument(skip(self, trap), err, fields(snmp.target = %self.peer_addr(), snmp.generic_trap = %trap.generic_trap))]
775 pub async fn send_v1_trap(&self, trap: TrapV1Pdu) -> Result<()> {
776 if self.inner.config.version != Version::V1 {
777 return Err(Error::Config("send_v1_trap requires a V1 client".into()).boxed());
778 }
779
780 let message = CommunityMessage::v1_trap(self.inner.config.community.clone(), trap);
781 let data = message.encode();
782 tracing::debug!(target: "async_snmp::client", { snmp.pdu_type = "TrapV1", snmp.bytes = data.len() }, "sending v1 trap");
783 self.inner.transport.send(&data).await?;
784
785 Ok(())
786 }
787
788 #[instrument(skip(self, varbinds), err, fields(snmp.target = %self.peer_addr(), snmp.trap_oid = %trap_oid))]
803 pub async fn send_inform(
804 &self,
805 trap_oid: &Oid,
806 uptime: u32,
807 varbinds: Vec<VarBind>,
808 ) -> Result<()> {
809 if self.inner.config.version == Version::V1 {
810 return Err(Error::Config("v1 inform sending not supported".into()).boxed());
811 }
812
813 let request_id = self.next_request_id();
814 let pdu = Pdu::inform_request(request_id, uptime, trap_oid, varbinds);
815 let _response = self.send_request(pdu).await?;
816 Ok(())
817 }
818
819 #[instrument(skip(self, oids), err, fields(
852 snmp.target = %self.peer_addr(),
853 snmp.oid_count = oids.len(),
854 snmp.non_repeaters = non_repeaters,
855 snmp.max_repetitions = max_repetitions
856 ))]
857 pub async fn get_bulk(
858 &self,
859 oids: &[Oid],
860 non_repeaters: i32,
861 max_repetitions: i32,
862 ) -> Result<Vec<VarBind>> {
863 let request_id = self.next_request_id();
864 let pdu = GetBulkPdu::new(request_id, non_repeaters, max_repetitions, oids);
865 let response = self.send_bulk_request(pdu).await?;
866 Ok(response.varbinds)
867 }
868
869 #[instrument(skip(self), fields(snmp.target = %self.peer_addr(), snmp.oid = %oid))]
895 pub fn walk(&self, oid: Oid) -> Result<WalkStream<T>>
896 where
897 T: 'static,
898 {
899 let ordering = self.inner.config.oid_ordering;
900 let max_results = self.inner.config.max_walk_results;
901 let walk_mode = self.inner.config.walk_mode;
902 let max_repetitions = self.inner.config.max_repetitions as i32;
903 let version = self.inner.config.version;
904
905 WalkStream::new(
906 self.clone(),
907 oid,
908 version,
909 walk_mode,
910 ordering,
911 max_results,
912 max_repetitions,
913 )
914 }
915
916 #[instrument(skip(self), fields(snmp.target = %self.peer_addr(), snmp.oid = %oid))]
939 pub fn walk_getnext(&self, oid: Oid) -> Walk<T>
940 where
941 T: 'static,
942 {
943 let ordering = self.inner.config.oid_ordering;
944 let max_results = self.inner.config.max_walk_results;
945 Walk::new(self.clone(), oid, ordering, max_results)
946 }
947
948 #[instrument(skip(self), fields(snmp.target = %self.peer_addr(), snmp.oid = %oid, snmp.max_repetitions = max_repetitions))]
974 pub fn bulk_walk(&self, oid: Oid, max_repetitions: i32) -> BulkWalk<T>
975 where
976 T: 'static,
977 {
978 let ordering = self.inner.config.oid_ordering;
979 let max_results = self.inner.config.max_walk_results;
980 BulkWalk::new(self.clone(), oid, max_repetitions, ordering, max_results)
981 }
982
983 #[instrument(skip(self), fields(snmp.target = %self.peer_addr(), snmp.oid = %oid))]
1001 pub fn bulk_walk_default(&self, oid: Oid) -> BulkWalk<T>
1002 where
1003 T: 'static,
1004 {
1005 let ordering = self.inner.config.oid_ordering;
1006 let max_results = self.inner.config.max_walk_results;
1007 let max_repetitions = self.inner.config.max_repetitions as i32;
1008 BulkWalk::new(self.clone(), oid, max_repetitions, ordering, max_results)
1009 }
1010}
1011
1012#[cfg(test)]
1013mod tests {
1014 use super::*;
1015 use crate::message::CommunityMessage;
1016 use crate::oid::Oid;
1017 use crate::pdu::{Pdu, PduType};
1018 use crate::varbind::VarBind;
1019 use crate::version::Version;
1020 use bytes::Bytes;
1021 use std::collections::VecDeque;
1022 use std::net::SocketAddr;
1023 use std::sync::{Arc, Mutex};
1024
1025 #[derive(Clone)]
1031 struct TruncatingTransport {
1032 response_varbind_count: usize,
1034 pending: Arc<Mutex<VecDeque<i32>>>,
1037 }
1038
1039 impl TruncatingTransport {
1040 fn new(response_varbind_count: usize) -> Self {
1041 Self {
1042 response_varbind_count,
1043 pending: Arc::new(Mutex::new(VecDeque::new())),
1044 }
1045 }
1046 }
1047
1048 impl Transport for TruncatingTransport {
1049 fn send(&self, data: &[u8]) -> impl std::future::Future<Output = Result<()>> + Send {
1050 let request_id = crate::transport::extract_request_id(data).unwrap_or(1);
1052 {
1053 let mut q = self.pending.lock().unwrap();
1054 q.push_back(request_id);
1055 }
1056 async { Ok(()) }
1057 }
1058
1059 fn recv(
1060 &self,
1061 _request_id: i32,
1062 ) -> impl std::future::Future<Output = Result<(Bytes, SocketAddr)>> + Send {
1063 let request_id = {
1064 let mut q = self.pending.lock().unwrap();
1065 q.pop_front().unwrap_or(1)
1066 };
1067 let n = self.response_varbind_count;
1068 let peer: SocketAddr = "127.0.0.1:161".parse().unwrap();
1069
1070 async move {
1071 let varbinds: Vec<VarBind> = (0..n)
1073 .map(|i| {
1074 VarBind::new(
1075 Oid::from_slice(&[1, 3, 6, 1, i as u32]),
1076 crate::value::Value::Null,
1077 )
1078 })
1079 .collect();
1080
1081 let pdu = Pdu {
1082 pdu_type: PduType::Response,
1083 request_id,
1084 error_status: 0,
1085 error_index: 0,
1086 varbinds,
1087 };
1088
1089 let msg = CommunityMessage::v2c(Bytes::from_static(b"public"), pdu);
1090 let encoded = msg.encode();
1091 Ok((encoded, peer))
1092 }
1093 }
1094
1095 fn peer_addr(&self) -> SocketAddr {
1096 "127.0.0.1:161".parse().unwrap()
1097 }
1098
1099 fn local_addr(&self) -> SocketAddr {
1100 "127.0.0.1:0".parse().unwrap()
1101 }
1102
1103 fn is_reliable(&self) -> bool {
1104 true
1105 }
1106 }
1107
1108 fn make_client(response_varbind_count: usize) -> Client<TruncatingTransport> {
1109 let transport = TruncatingTransport::new(response_varbind_count);
1110 let config = ClientConfig {
1111 version: Version::V2c,
1112 max_oids_per_request: 10,
1113 retry: crate::client::retry::Retry::none(),
1114 ..Default::default()
1115 };
1116 Client::new(transport, config)
1117 }
1118
1119 #[tokio::test]
1120 async fn get_many_warns_on_truncated_response() {
1121 let client = make_client(1);
1123 let oids = [
1124 Oid::from_slice(&[1, 3, 6, 1, 1]),
1125 Oid::from_slice(&[1, 3, 6, 1, 2]),
1126 Oid::from_slice(&[1, 3, 6, 1, 3]),
1127 ];
1128
1129 let result = client.get_many(&oids).await;
1130 assert!(result.is_ok(), "expected Ok, got: {:?}", result.err());
1131 assert_eq!(result.unwrap().len(), 1);
1132 }
1133
1134 #[tokio::test]
1135 async fn get_many_rejects_inflated_response() {
1136 let client = make_client(5);
1138 let oids = [
1139 Oid::from_slice(&[1, 3, 6, 1, 1]),
1140 Oid::from_slice(&[1, 3, 6, 1, 2]),
1141 Oid::from_slice(&[1, 3, 6, 1, 3]),
1142 ];
1143
1144 let err = client.get_many(&oids).await.unwrap_err();
1145 assert!(
1146 matches!(*err, Error::MalformedResponse { .. }),
1147 "expected MalformedResponse, got: {err}"
1148 );
1149 }
1150
1151 #[tokio::test]
1152 async fn get_many_accepts_correct_response_count() {
1153 let client = make_client(3);
1155 let oids = [
1156 Oid::from_slice(&[1, 3, 6, 1, 1]),
1157 Oid::from_slice(&[1, 3, 6, 1, 2]),
1158 Oid::from_slice(&[1, 3, 6, 1, 3]),
1159 ];
1160
1161 let result = client.get_many(&oids).await;
1162 assert!(result.is_ok(), "expected Ok, got: {:?}", result.err());
1163 assert_eq!(result.unwrap().len(), 3);
1164 }
1165
1166 #[tokio::test]
1167 async fn get_next_many_warns_on_truncated_response() {
1168 let client = make_client(1);
1170 let oids = [
1171 Oid::from_slice(&[1, 3, 6, 1, 1]),
1172 Oid::from_slice(&[1, 3, 6, 1, 2]),
1173 Oid::from_slice(&[1, 3, 6, 1, 3]),
1174 ];
1175
1176 let result = client.get_next_many(&oids).await;
1177 assert!(result.is_ok(), "expected Ok, got: {:?}", result.err());
1178 assert_eq!(result.unwrap().len(), 1);
1179 }
1180
1181 #[tokio::test]
1182 async fn get_next_many_rejects_inflated_response() {
1183 let client = make_client(5);
1185 let oids = [
1186 Oid::from_slice(&[1, 3, 6, 1, 1]),
1187 Oid::from_slice(&[1, 3, 6, 1, 2]),
1188 Oid::from_slice(&[1, 3, 6, 1, 3]),
1189 ];
1190
1191 let err = client.get_next_many(&oids).await.unwrap_err();
1192 assert!(
1193 matches!(*err, Error::MalformedResponse { .. }),
1194 "expected MalformedResponse, got: {err}"
1195 );
1196 }
1197
1198 #[tokio::test]
1199 async fn get_next_many_accepts_correct_response_count() {
1200 let client = make_client(3);
1202 let oids = [
1203 Oid::from_slice(&[1, 3, 6, 1, 1]),
1204 Oid::from_slice(&[1, 3, 6, 1, 2]),
1205 Oid::from_slice(&[1, 3, 6, 1, 3]),
1206 ];
1207
1208 let result = client.get_next_many(&oids).await;
1209 assert!(result.is_ok(), "expected Ok, got: {:?}", result.err());
1210 assert_eq!(result.unwrap().len(), 3);
1211 }
1212
1213 #[tokio::test]
1214 async fn set_many_warns_on_truncated_response() {
1215 let client = make_client(1);
1217 let varbinds = [
1218 (
1219 Oid::from_slice(&[1, 3, 6, 1, 1]),
1220 crate::value::Value::Integer(1),
1221 ),
1222 (
1223 Oid::from_slice(&[1, 3, 6, 1, 2]),
1224 crate::value::Value::Integer(2),
1225 ),
1226 (
1227 Oid::from_slice(&[1, 3, 6, 1, 3]),
1228 crate::value::Value::Integer(3),
1229 ),
1230 ];
1231
1232 let result = client.set_many(&varbinds).await;
1233 assert!(result.is_ok(), "expected Ok, got: {:?}", result.err());
1234 assert_eq!(result.unwrap().len(), 1);
1235 }
1236
1237 #[tokio::test]
1238 async fn set_many_rejects_inflated_response() {
1239 let client = make_client(5);
1241 let varbinds = [
1242 (
1243 Oid::from_slice(&[1, 3, 6, 1, 1]),
1244 crate::value::Value::Integer(1),
1245 ),
1246 (
1247 Oid::from_slice(&[1, 3, 6, 1, 2]),
1248 crate::value::Value::Integer(2),
1249 ),
1250 (
1251 Oid::from_slice(&[1, 3, 6, 1, 3]),
1252 crate::value::Value::Integer(3),
1253 ),
1254 ];
1255
1256 let err = client.set_many(&varbinds).await.unwrap_err();
1257 assert!(
1258 matches!(*err, Error::MalformedResponse { .. }),
1259 "expected MalformedResponse, got: {err}"
1260 );
1261 }
1262
1263 #[tokio::test]
1264 async fn set_many_accepts_correct_response_count() {
1265 let client = make_client(3);
1267 let varbinds = [
1268 (
1269 Oid::from_slice(&[1, 3, 6, 1, 1]),
1270 crate::value::Value::Integer(1),
1271 ),
1272 (
1273 Oid::from_slice(&[1, 3, 6, 1, 2]),
1274 crate::value::Value::Integer(2),
1275 ),
1276 (
1277 Oid::from_slice(&[1, 3, 6, 1, 3]),
1278 crate::value::Value::Integer(3),
1279 ),
1280 ];
1281
1282 let result = client.set_many(&varbinds).await;
1283 assert!(result.is_ok(), "expected Ok, got: {:?}", result.err());
1284 assert_eq!(result.unwrap().len(), 3);
1285 }
1286
1287 #[derive(Clone)]
1292 struct TooBigTransport {
1293 max_varbinds: usize,
1295 pending: Arc<Mutex<VecDeque<(i32, usize)>>>,
1296 }
1297
1298 impl TooBigTransport {
1299 fn new(max_varbinds: usize) -> Self {
1300 Self {
1301 max_varbinds,
1302 pending: Arc::new(Mutex::new(VecDeque::new())),
1303 }
1304 }
1305 }
1306
1307 impl Transport for TooBigTransport {
1308 fn send(&self, data: &[u8]) -> impl std::future::Future<Output = Result<()>> + Send {
1309 let request_id = crate::transport::extract_request_id(data).unwrap_or(1);
1310 let msg = CommunityMessage::decode(Bytes::copy_from_slice(data)).unwrap();
1312 let varbind_count = msg.pdu.standard().unwrap().varbinds.len();
1313 {
1314 let mut q = self.pending.lock().unwrap();
1315 q.push_back((request_id, varbind_count));
1316 }
1317 async { Ok(()) }
1318 }
1319
1320 fn recv(
1321 &self,
1322 _request_id: i32,
1323 ) -> impl std::future::Future<Output = Result<(Bytes, SocketAddr)>> + Send {
1324 let (request_id, varbind_count) = {
1325 let mut q = self.pending.lock().unwrap();
1326 q.pop_front().unwrap_or((1, 0))
1327 };
1328 let max = self.max_varbinds;
1329 let peer: SocketAddr = "127.0.0.1:161".parse().unwrap();
1330
1331 async move {
1332 let pdu = if varbind_count > max {
1333 Pdu {
1335 pdu_type: PduType::Response,
1336 request_id,
1337 error_status: ErrorStatus::TooBig.as_i32(),
1338 error_index: 0,
1339 varbinds: vec![],
1340 }
1341 } else {
1342 let varbinds: Vec<VarBind> = (0..varbind_count)
1344 .map(|i| {
1345 VarBind::new(
1346 Oid::from_slice(&[1, 3, 6, 1, i as u32]),
1347 crate::value::Value::Integer(i as i32),
1348 )
1349 })
1350 .collect();
1351 Pdu {
1352 pdu_type: PduType::Response,
1353 request_id,
1354 error_status: 0,
1355 error_index: 0,
1356 varbinds,
1357 }
1358 };
1359
1360 let msg = CommunityMessage::v2c(Bytes::from_static(b"public"), pdu);
1361 Ok((msg.encode(), peer))
1362 }
1363 }
1364
1365 fn peer_addr(&self) -> SocketAddr {
1366 "127.0.0.1:161".parse().unwrap()
1367 }
1368
1369 fn local_addr(&self) -> SocketAddr {
1370 "127.0.0.1:0".parse().unwrap()
1371 }
1372
1373 fn is_reliable(&self) -> bool {
1374 true
1375 }
1376 }
1377
1378 #[tokio::test]
1379 async fn get_many_bisects_on_too_big() {
1380 let transport = TooBigTransport::new(3);
1385 let config = ClientConfig {
1386 version: Version::V2c,
1387 max_oids_per_request: 10,
1388 retry: crate::client::retry::Retry::none(),
1389 ..Default::default()
1390 };
1391 let client = Client::new(transport, config);
1392
1393 let oids: Vec<Oid> = (0..8u32)
1394 .map(|i| Oid::from_slice(&[1, 3, 6, 1, i]))
1395 .collect();
1396
1397 let result = client.get_many(&oids).await.unwrap();
1398 assert_eq!(result.len(), 8);
1399 }
1400
1401 #[tokio::test]
1402 async fn get_many_single_oid_too_big_is_unrecoverable() {
1403 let transport = TooBigTransport::new(0);
1405 let config = ClientConfig {
1406 version: Version::V2c,
1407 max_oids_per_request: 10,
1408 retry: crate::client::retry::Retry::none(),
1409 ..Default::default()
1410 };
1411 let client = Client::new(transport, config);
1412
1413 let oids = [Oid::from_slice(&[1, 3, 6, 1, 1])];
1414 let err = client.get_many(&oids).await.unwrap_err();
1415 assert!(
1416 matches!(
1417 &*err,
1418 Error::Snmp {
1419 status: ErrorStatus::TooBig,
1420 ..
1421 }
1422 ),
1423 "expected TooBig, got: {err}"
1424 );
1425 }
1426
1427 #[tokio::test]
1428 async fn get_next_many_bisects_on_too_big() {
1429 let transport = TooBigTransport::new(3);
1431 let config = ClientConfig {
1432 version: Version::V2c,
1433 max_oids_per_request: 10,
1434 retry: crate::client::retry::Retry::none(),
1435 ..Default::default()
1436 };
1437 let client = Client::new(transport, config);
1438
1439 let oids: Vec<Oid> = (0..8u32)
1440 .map(|i| Oid::from_slice(&[1, 3, 6, 1, i]))
1441 .collect();
1442
1443 let result = client.get_next_many(&oids).await.unwrap();
1444 assert_eq!(result.len(), 8);
1445 }
1446
1447 #[tokio::test]
1449 async fn get_many_batched_warns_on_truncated_response() {
1450 let transport = TruncatingTransport::new(1);
1453 let config = ClientConfig {
1454 version: Version::V2c,
1455 max_oids_per_request: 10,
1456 retry: crate::client::retry::Retry::none(),
1457 ..Default::default()
1458 };
1459 let client = Client::new(transport, config);
1460
1461 let oids: Vec<Oid> = (0..12u32)
1462 .map(|i| Oid::from_slice(&[1, 3, 6, 1, i]))
1463 .collect();
1464
1465 let result = client.get_many(&oids).await;
1466 assert!(result.is_ok(), "expected Ok, got: {:?}", result.err());
1467 assert_eq!(result.unwrap().len(), 2); }
1469
1470 #[tokio::test]
1471 async fn get_many_batched_rejects_inflated_response() {
1472 let transport = TruncatingTransport::new(12);
1474 let config = ClientConfig {
1475 version: Version::V2c,
1476 max_oids_per_request: 10,
1477 retry: crate::client::retry::Retry::none(),
1478 ..Default::default()
1479 };
1480 let client = Client::new(transport, config);
1481
1482 let oids: Vec<Oid> = (0..12u32)
1483 .map(|i| Oid::from_slice(&[1, 3, 6, 1, i]))
1484 .collect();
1485
1486 let err = client.get_many(&oids).await.unwrap_err();
1487 assert!(
1488 matches!(*err, Error::MalformedResponse { .. }),
1489 "expected MalformedResponse, got: {err}"
1490 );
1491 }
1492}