1mod request;
69mod response;
70mod set_handler;
71pub mod vacm;
72
73pub use vacm::{SecurityModel, VacmBuilder, VacmConfig, View};
74
75use std::collections::HashMap;
76use std::net::SocketAddr;
77use std::sync::Arc;
78use std::sync::atomic::{AtomicU32, Ordering};
79use std::time::Instant;
80
81use bytes::Bytes;
82use subtle::ConstantTimeEq;
83use tokio::net::UdpSocket;
84use tracing::instrument;
85
86use crate::ber::Decoder;
87use crate::error::{DecodeErrorKind, Error, ErrorStatus, Result};
88use crate::handler::{GetNextResult, GetResult, MibHandler, RequestContext};
89use crate::notification::UsmUserConfig;
90use crate::oid::Oid;
91use crate::pdu::{Pdu, PduType};
92use crate::util::bind_udp_socket;
93use crate::v3::SaltCounter;
94use crate::value::Value;
95use crate::varbind::VarBind;
96use crate::version::Version;
97
98const DEFAULT_MAX_MESSAGE_SIZE: usize = 1472;
100
101const RESPONSE_OVERHEAD: usize = 100;
104
105pub(crate) struct RegisteredHandler {
107 pub(crate) prefix: Oid,
108 pub(crate) handler: Arc<dyn MibHandler>,
109}
110
111pub struct AgentBuilder {
113 bind_addr: String,
114 communities: Vec<Vec<u8>>,
115 usm_users: HashMap<Bytes, UsmUserConfig>,
116 handlers: Vec<RegisteredHandler>,
117 engine_id: Option<Vec<u8>>,
118 max_message_size: usize,
119 vacm: Option<VacmConfig>,
120}
121
122impl AgentBuilder {
123 pub fn new() -> Self {
125 Self {
126 bind_addr: "0.0.0.0:161".to_string(),
127 communities: Vec::new(),
128 usm_users: HashMap::new(),
129 handlers: Vec::new(),
130 engine_id: None,
131 max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
132 vacm: None,
133 }
134 }
135
136 pub fn bind(mut self, addr: impl Into<String>) -> Self {
140 self.bind_addr = addr.into();
141 self
142 }
143
144 pub fn community(mut self, community: &[u8]) -> Self {
149 self.communities.push(community.to_vec());
150 self
151 }
152
153 pub fn communities<I, C>(mut self, communities: I) -> Self
155 where
156 I: IntoIterator<Item = C>,
157 C: AsRef<[u8]>,
158 {
159 for c in communities {
160 self.communities.push(c.as_ref().to_vec());
161 }
162 self
163 }
164
165 pub fn usm_user<F>(mut self, username: impl Into<Bytes>, configure: F) -> Self
167 where
168 F: FnOnce(UsmUserConfig) -> UsmUserConfig,
169 {
170 let username_bytes: Bytes = username.into();
171 let config = configure(UsmUserConfig::new(username_bytes.clone()));
172 self.usm_users.insert(username_bytes, config);
173 self
174 }
175
176 pub fn engine_id(mut self, engine_id: impl Into<Vec<u8>>) -> Self {
180 self.engine_id = Some(engine_id.into());
181 self
182 }
183
184 pub fn max_message_size(mut self, size: usize) -> Self {
192 self.max_message_size = size;
193 self
194 }
195
196 pub fn handler(mut self, prefix: Oid, handler: Arc<dyn MibHandler>) -> Self {
201 self.handlers.push(RegisteredHandler { prefix, handler });
202 self
203 }
204
205 pub fn vacm<F>(mut self, configure: F) -> Self
241 where
242 F: FnOnce(VacmBuilder) -> VacmBuilder,
243 {
244 let builder = VacmBuilder::new();
245 self.vacm = Some(configure(builder).build());
246 self
247 }
248
249 pub async fn build(mut self) -> Result<Agent> {
253 let bind_addr: std::net::SocketAddr = self.bind_addr.parse().map_err(|_| Error::Io {
254 target: None,
255 source: std::io::Error::new(
256 std::io::ErrorKind::InvalidInput,
257 format!("invalid bind address: {}", self.bind_addr),
258 ),
259 })?;
260
261 let socket = bind_udp_socket(bind_addr).await.map_err(|e| Error::Io {
262 target: Some(bind_addr),
263 source: e,
264 })?;
265
266 let local_addr = socket.local_addr().map_err(|e| Error::Io {
267 target: Some(bind_addr),
268 source: e,
269 })?;
270
271 let engine_id = self.engine_id.unwrap_or_else(|| {
273 let mut id = vec![0x80, 0x00, 0x00, 0x00, 0x01]; let timestamp = std::time::SystemTime::now()
277 .duration_since(std::time::UNIX_EPOCH)
278 .unwrap_or_default()
279 .as_secs();
280 id.extend_from_slice(×tamp.to_be_bytes());
281 id
282 });
283
284 self.handlers
286 .sort_by(|a, b| b.prefix.len().cmp(&a.prefix.len()));
287
288 Ok(Agent {
289 inner: Arc::new(AgentInner {
290 socket,
291 local_addr,
292 communities: self.communities,
293 usm_users: self.usm_users,
294 handlers: self.handlers,
295 engine_id,
296 engine_boots: AtomicU32::new(1),
297 engine_time: AtomicU32::new(0),
298 engine_start: Instant::now(),
299 salt_counter: SaltCounter::new(),
300 max_message_size: self.max_message_size,
301 vacm: self.vacm,
302 snmp_invalid_msgs: AtomicU32::new(0),
303 snmp_unknown_security_models: AtomicU32::new(0),
304 snmp_silent_drops: AtomicU32::new(0),
305 }),
306 })
307 }
308}
309
310impl Default for AgentBuilder {
311 fn default() -> Self {
312 Self::new()
313 }
314}
315
316pub(crate) struct AgentInner {
318 pub(crate) socket: UdpSocket,
319 pub(crate) local_addr: SocketAddr,
320 pub(crate) communities: Vec<Vec<u8>>,
321 pub(crate) usm_users: HashMap<Bytes, UsmUserConfig>,
322 pub(crate) handlers: Vec<RegisteredHandler>,
323 pub(crate) engine_id: Vec<u8>,
324 pub(crate) engine_boots: AtomicU32,
325 pub(crate) engine_time: AtomicU32,
326 pub(crate) engine_start: Instant,
327 pub(crate) salt_counter: SaltCounter,
328 pub(crate) max_message_size: usize,
329 pub(crate) vacm: Option<VacmConfig>,
330 pub(crate) snmp_invalid_msgs: AtomicU32,
334 pub(crate) snmp_unknown_security_models: AtomicU32,
337 pub(crate) snmp_silent_drops: AtomicU32,
340}
341
342pub struct Agent {
363 pub(crate) inner: Arc<AgentInner>,
364}
365
366impl Agent {
367 pub fn builder() -> AgentBuilder {
369 AgentBuilder::new()
370 }
371
372 pub fn local_addr(&self) -> SocketAddr {
374 self.inner.local_addr
375 }
376
377 pub fn engine_id(&self) -> &[u8] {
379 &self.inner.engine_id
380 }
381
382 pub fn snmp_invalid_msgs(&self) -> u32 {
389 self.inner.snmp_invalid_msgs.load(Ordering::Relaxed)
390 }
391
392 pub fn snmp_unknown_security_models(&self) -> u32 {
399 self.inner
400 .snmp_unknown_security_models
401 .load(Ordering::Relaxed)
402 }
403
404 pub fn snmp_silent_drops(&self) -> u32 {
413 self.inner.snmp_silent_drops.load(Ordering::Relaxed)
414 }
415
416 #[instrument(skip(self), err, fields(snmp.local_addr = %self.local_addr()))]
420 pub async fn run(&self) -> Result<()> {
421 let mut buf = vec![0u8; 65535];
422
423 loop {
424 let (len, source) =
425 self.inner
426 .socket
427 .recv_from(&mut buf)
428 .await
429 .map_err(|e| Error::Io {
430 target: Some(self.inner.local_addr),
431 source: e,
432 })?;
433
434 let data = Bytes::copy_from_slice(&buf[..len]);
435
436 self.update_engine_time();
438
439 match self.handle_request(data, source).await {
440 Ok(Some(response_bytes)) => {
441 if let Err(e) = self.inner.socket.send_to(&response_bytes, source).await {
442 tracing::warn!(snmp.source = %source, error = %e, "failed to send response");
443 }
444 }
445 Ok(None) => {
446 }
448 Err(e) => {
449 tracing::warn!(snmp.source = %source, error = %e, "error handling request");
450 }
451 }
452 }
453 }
454
455 async fn handle_request(&self, data: Bytes, source: SocketAddr) -> Result<Option<Bytes>> {
459 let mut decoder = Decoder::new(data.clone());
461 let mut seq = decoder.read_sequence()?;
462 let version_num = seq.read_integer()?;
463 let version = Version::from_i32(version_num).ok_or_else(|| {
464 Error::decode(seq.offset(), DecodeErrorKind::UnknownVersion(version_num))
465 })?;
466 drop(seq);
467 drop(decoder);
468
469 match version {
470 Version::V1 => self.handle_v1(data, source).await,
471 Version::V2c => self.handle_v2c(data, source).await,
472 Version::V3 => self.handle_v3(data, source).await,
473 }
474 }
475
476 fn update_engine_time(&self) {
478 let elapsed = self.inner.engine_start.elapsed().as_secs() as u32;
479 self.inner.engine_time.store(elapsed, Ordering::Relaxed);
480 }
481
482 pub(crate) fn validate_community(&self, community: &[u8]) -> bool {
487 if self.inner.communities.is_empty() {
488 return false;
490 }
491 let mut valid = false;
495 for configured in &self.inner.communities {
496 if configured.len() == community.len()
498 && bool::from(configured.as_slice().ct_eq(community))
499 {
500 valid = true;
501 }
502 }
503 valid
504 }
505
506 async fn dispatch_request(&self, ctx: &RequestContext, pdu: &Pdu) -> Result<Pdu> {
508 match pdu.pdu_type {
509 PduType::GetRequest => self.handle_get(ctx, pdu).await,
510 PduType::GetNextRequest => self.handle_get_next(ctx, pdu).await,
511 PduType::GetBulkRequest => self.handle_get_bulk(ctx, pdu).await,
512 PduType::SetRequest => self.handle_set(ctx, pdu).await,
513 PduType::InformRequest => self.handle_inform(pdu),
514 _ => {
515 Ok(pdu.to_error_response(ErrorStatus::GenErr, 0))
517 }
518 }
519 }
520
521 fn handle_inform(&self, pdu: &Pdu) -> Result<Pdu> {
527 Ok(Pdu {
529 pdu_type: PduType::Response,
530 request_id: pdu.request_id,
531 error_status: 0,
532 error_index: 0,
533 varbinds: pdu.varbinds.clone(),
534 })
535 }
536
537 async fn handle_get(&self, ctx: &RequestContext, pdu: &Pdu) -> Result<Pdu> {
539 let mut response_varbinds = Vec::with_capacity(pdu.varbinds.len());
540
541 for (index, vb) in pdu.varbinds.iter().enumerate() {
542 if let Some(ref vacm) = self.inner.vacm
544 && !vacm.check_access(ctx.read_view.as_ref(), &vb.oid)
545 {
546 if ctx.version == Version::V1 {
548 return Ok(Pdu {
549 pdu_type: PduType::Response,
550 request_id: pdu.request_id,
551 error_status: ErrorStatus::NoSuchName.as_i32(),
552 error_index: (index + 1) as i32,
553 varbinds: pdu.varbinds.clone(),
554 });
555 } else {
556 response_varbinds.push(VarBind::new(vb.oid.clone(), Value::NoSuchObject));
558 continue;
559 }
560 }
561
562 let result = if let Some(handler) = self.find_handler(&vb.oid) {
563 handler.handler.get(ctx, &vb.oid).await
564 } else {
565 GetResult::NoSuchObject
566 };
567
568 let response_value = match result {
569 GetResult::Value(v) => v,
570 GetResult::NoSuchObject => {
571 if ctx.version == Version::V1 {
573 return Ok(Pdu {
574 pdu_type: PduType::Response,
575 request_id: pdu.request_id,
576 error_status: ErrorStatus::NoSuchName.as_i32(),
577 error_index: (index + 1) as i32,
578 varbinds: pdu.varbinds.clone(),
579 });
580 } else {
581 Value::NoSuchObject
582 }
583 }
584 GetResult::NoSuchInstance => {
585 if ctx.version == Version::V1 {
587 return Ok(Pdu {
588 pdu_type: PduType::Response,
589 request_id: pdu.request_id,
590 error_status: ErrorStatus::NoSuchName.as_i32(),
591 error_index: (index + 1) as i32,
592 varbinds: pdu.varbinds.clone(),
593 });
594 } else {
595 Value::NoSuchInstance
596 }
597 }
598 };
599
600 response_varbinds.push(VarBind::new(vb.oid.clone(), response_value));
601 }
602
603 Ok(Pdu {
604 pdu_type: PduType::Response,
605 request_id: pdu.request_id,
606 error_status: 0,
607 error_index: 0,
608 varbinds: response_varbinds,
609 })
610 }
611
612 async fn handle_get_next(&self, ctx: &RequestContext, pdu: &Pdu) -> Result<Pdu> {
614 let mut response_varbinds = Vec::with_capacity(pdu.varbinds.len());
615
616 for (index, vb) in pdu.varbinds.iter().enumerate() {
617 let next = self.get_next_oid(ctx, &vb.oid).await;
619
620 let next = if let Some(ref next_vb) = next {
622 if let Some(ref vacm) = self.inner.vacm {
623 if vacm.check_access(ctx.read_view.as_ref(), &next_vb.oid) {
624 next
625 } else {
626 None
630 }
631 } else {
632 next
633 }
634 } else {
635 next
636 };
637
638 match next {
639 Some(next_vb) => {
640 response_varbinds.push(next_vb);
641 }
642 None => {
643 if ctx.version == Version::V1 {
645 return Ok(Pdu {
646 pdu_type: PduType::Response,
647 request_id: pdu.request_id,
648 error_status: ErrorStatus::NoSuchName.as_i32(),
649 error_index: (index + 1) as i32,
650 varbinds: pdu.varbinds.clone(),
651 });
652 } else {
653 response_varbinds.push(VarBind::new(vb.oid.clone(), Value::EndOfMibView));
654 }
655 }
656 }
657 }
658
659 Ok(Pdu {
660 pdu_type: PduType::Response,
661 request_id: pdu.request_id,
662 error_status: 0,
663 error_index: 0,
664 varbinds: response_varbinds,
665 })
666 }
667
668 async fn handle_get_bulk(&self, ctx: &RequestContext, pdu: &Pdu) -> Result<Pdu> {
673 let non_repeaters = pdu.error_status.max(0) as usize;
675 let max_repetitions = pdu.error_index.max(0) as usize;
676
677 let mut response_varbinds = Vec::new();
678 let mut current_size: usize = RESPONSE_OVERHEAD;
679 let max_size = self.inner.max_message_size;
680
681 let can_add = |vb: &VarBind, current_size: usize| -> bool {
683 current_size + vb.encoded_size() <= max_size
684 };
685
686 for vb in pdu.varbinds.iter().take(non_repeaters) {
688 let next_vb = match self.get_next_oid(ctx, &vb.oid).await {
689 Some(next_vb) => next_vb,
690 None => VarBind::new(vb.oid.clone(), Value::EndOfMibView),
691 };
692
693 if !can_add(&next_vb, current_size) {
694 if response_varbinds.is_empty() {
696 return Ok(Pdu {
697 pdu_type: PduType::Response,
698 request_id: pdu.request_id,
699 error_status: ErrorStatus::TooBig.as_i32(),
700 error_index: 0,
701 varbinds: pdu.varbinds.clone(),
702 });
703 }
704 break;
706 }
707
708 current_size += next_vb.encoded_size();
709 response_varbinds.push(next_vb);
710 }
711
712 if non_repeaters < pdu.varbinds.len() {
714 let repeaters = &pdu.varbinds[non_repeaters..];
715 let mut current_oids: Vec<Oid> = repeaters.iter().map(|vb| vb.oid.clone()).collect();
716 let mut all_done = vec![false; repeaters.len()];
717
718 'outer: for _ in 0..max_repetitions {
719 let mut row_complete = true;
720 for (i, oid) in current_oids.iter_mut().enumerate() {
721 let next_vb = if all_done[i] {
722 VarBind::new(oid.clone(), Value::EndOfMibView)
723 } else {
724 match self.get_next_oid(ctx, oid).await {
725 Some(next_vb) => {
726 *oid = next_vb.oid.clone();
727 row_complete = false;
728 next_vb
729 }
730 None => {
731 all_done[i] = true;
732 VarBind::new(oid.clone(), Value::EndOfMibView)
733 }
734 }
735 };
736
737 if !can_add(&next_vb, current_size) {
739 break 'outer;
741 }
742
743 current_size += next_vb.encoded_size();
744 response_varbinds.push(next_vb);
745 }
746
747 if row_complete {
748 break;
749 }
750 }
751 }
752
753 Ok(Pdu {
754 pdu_type: PduType::Response,
755 request_id: pdu.request_id,
756 error_status: 0,
757 error_index: 0,
758 varbinds: response_varbinds,
759 })
760 }
761
762 pub(crate) fn find_handler(&self, oid: &Oid) -> Option<&RegisteredHandler> {
764 self.inner
766 .handlers
767 .iter()
768 .find(|&handler| handler.handler.handles(&handler.prefix, oid))
769 .map(|v| v as _)
770 }
771
772 async fn get_next_oid(&self, ctx: &RequestContext, oid: &Oid) -> Option<VarBind> {
774 let mut best_result: Option<VarBind> = None;
776
777 for handler in &self.inner.handlers {
778 if let GetNextResult::Value(next) = handler.handler.get_next(ctx, oid).await {
779 if next.oid > *oid {
781 match &best_result {
782 None => best_result = Some(next),
783 Some(current) if next.oid < current.oid => best_result = Some(next),
784 _ => {}
785 }
786 }
787 }
788 }
789
790 best_result
791 }
792}
793
794impl Clone for Agent {
795 fn clone(&self) -> Self {
796 Self {
797 inner: Arc::clone(&self.inner),
798 }
799 }
800}
801
802#[cfg(test)]
803mod tests {
804 use super::*;
805 use crate::handler::{
806 BoxFuture, GetNextResult, GetResult, MibHandler, RequestContext, SecurityModel, SetResult,
807 };
808 use crate::message::SecurityLevel;
809 use crate::oid;
810
811 struct TestHandler;
812
813 impl MibHandler for TestHandler {
814 fn get<'a>(&'a self, _ctx: &'a RequestContext, oid: &'a Oid) -> BoxFuture<'a, GetResult> {
815 Box::pin(async move {
816 if oid == &oid!(1, 3, 6, 1, 4, 1, 99999, 1, 0) {
817 return GetResult::Value(Value::Integer(42));
818 }
819 if oid == &oid!(1, 3, 6, 1, 4, 1, 99999, 2, 0) {
820 return GetResult::Value(Value::OctetString(Bytes::from_static(b"test")));
821 }
822 GetResult::NoSuchObject
823 })
824 }
825
826 fn get_next<'a>(
827 &'a self,
828 _ctx: &'a RequestContext,
829 oid: &'a Oid,
830 ) -> BoxFuture<'a, GetNextResult> {
831 Box::pin(async move {
832 let oid1 = oid!(1, 3, 6, 1, 4, 1, 99999, 1, 0);
833 let oid2 = oid!(1, 3, 6, 1, 4, 1, 99999, 2, 0);
834
835 if oid < &oid1 {
836 return GetNextResult::Value(VarBind::new(oid1, Value::Integer(42)));
837 }
838 if oid < &oid2 {
839 return GetNextResult::Value(VarBind::new(
840 oid2,
841 Value::OctetString(Bytes::from_static(b"test")),
842 ));
843 }
844 GetNextResult::EndOfMibView
845 })
846 }
847 }
848
849 fn test_ctx() -> RequestContext {
850 RequestContext {
851 source: "127.0.0.1:12345".parse().unwrap(),
852 version: Version::V2c,
853 security_model: SecurityModel::V2c,
854 security_name: Bytes::from_static(b"public"),
855 security_level: SecurityLevel::NoAuthNoPriv,
856 context_name: Bytes::new(),
857 request_id: 1,
858 pdu_type: PduType::GetRequest,
859 group_name: None,
860 read_view: None,
861 write_view: None,
862 }
863 }
864
865 #[test]
866 fn test_agent_builder_defaults() {
867 let builder = AgentBuilder::new();
868 assert_eq!(builder.bind_addr, "0.0.0.0:161");
869 assert!(builder.communities.is_empty());
870 assert!(builder.usm_users.is_empty());
871 assert!(builder.handlers.is_empty());
872 }
873
874 #[test]
875 fn test_agent_builder_community() {
876 let builder = AgentBuilder::new()
877 .community(b"public")
878 .community(b"private");
879 assert_eq!(builder.communities.len(), 2);
880 }
881
882 #[test]
883 fn test_agent_builder_communities() {
884 let builder = AgentBuilder::new().communities(["public", "private"]);
885 assert_eq!(builder.communities.len(), 2);
886 }
887
888 #[test]
889 fn test_agent_builder_handler() {
890 let builder =
891 AgentBuilder::new().handler(oid!(1, 3, 6, 1, 4, 1, 99999), Arc::new(TestHandler));
892 assert_eq!(builder.handlers.len(), 1);
893 }
894
895 #[tokio::test]
896 async fn test_mib_handler_default_set() {
897 let handler = TestHandler;
898 let mut ctx = test_ctx();
899 ctx.pdu_type = PduType::SetRequest;
900
901 let result = handler
902 .test_set(&ctx, &oid!(1, 3, 6, 1), &Value::Integer(1))
903 .await;
904 assert_eq!(result, SetResult::NotWritable);
905 }
906
907 #[test]
908 fn test_mib_handler_handles() {
909 let handler = TestHandler;
910 let prefix = oid!(1, 3, 6, 1, 4, 1, 99999);
911
912 assert!(handler.handles(&prefix, &oid!(1, 3, 6, 1, 4, 1, 99999, 1, 0)));
914
915 assert!(handler.handles(&prefix, &oid!(1, 3, 6, 1, 4, 1, 99998)));
917
918 assert!(!handler.handles(&prefix, &oid!(1, 3, 6, 1, 4, 1, 100000)));
920 }
921
922 #[tokio::test]
923 async fn test_test_handler_get() {
924 let handler = TestHandler;
925 let ctx = test_ctx();
926
927 let result = handler
929 .get(&ctx, &oid!(1, 3, 6, 1, 4, 1, 99999, 1, 0))
930 .await;
931 assert!(matches!(result, GetResult::Value(Value::Integer(42))));
932
933 let result = handler
935 .get(&ctx, &oid!(1, 3, 6, 1, 4, 1, 99999, 99, 0))
936 .await;
937 assert!(matches!(result, GetResult::NoSuchObject));
938 }
939
940 #[tokio::test]
941 async fn test_test_handler_get_next() {
942 let handler = TestHandler;
943 let mut ctx = test_ctx();
944 ctx.pdu_type = PduType::GetNextRequest;
945
946 let next = handler.get_next(&ctx, &oid!(1, 3, 6, 1, 4, 1, 99999)).await;
948 assert!(next.is_value());
949 if let GetNextResult::Value(vb) = next {
950 assert_eq!(vb.oid, oid!(1, 3, 6, 1, 4, 1, 99999, 1, 0));
951 }
952
953 let next = handler
955 .get_next(&ctx, &oid!(1, 3, 6, 1, 4, 1, 99999, 1, 0))
956 .await;
957 assert!(next.is_value());
958 if let GetNextResult::Value(vb) = next {
959 assert_eq!(vb.oid, oid!(1, 3, 6, 1, 4, 1, 99999, 2, 0));
960 }
961
962 let next = handler
964 .get_next(&ctx, &oid!(1, 3, 6, 1, 4, 1, 99999, 2, 0))
965 .await;
966 assert!(next.is_end_of_mib_view());
967 }
968}