async_snmp/agent/
mod.rs

1//! SNMP Agent (RFC 3413).
2//!
3//! This module provides SNMP agent functionality for responding to
4//! GET, GETNEXT, GETBULK, and SET requests.
5//!
6//! # Features
7//!
8//! - **Async handlers**: All handler methods are async for database queries, network calls, etc.
9//! - **Atomic SET**: Two-phase commit protocol (test/commit/undo) per RFC 3416
10//! - **VACM support**: Optional View-based Access Control Model (RFC 3415)
11//!
12//! # Example
13//!
14//! ```rust,no_run
15//! use async_snmp::agent::Agent;
16//! use async_snmp::handler::{MibHandler, RequestContext, GetResult, GetNextResult, BoxFuture};
17//! use async_snmp::{Oid, Value, VarBind, oid};
18//! use std::sync::Arc;
19//!
20//! // Define a simple handler for the system MIB subtree
21//! struct SystemMibHandler;
22//!
23//! impl MibHandler for SystemMibHandler {
24//!     fn get<'a>(&'a self, ctx: &'a RequestContext, oid: &'a Oid) -> BoxFuture<'a, GetResult> {
25//!         Box::pin(async move {
26//!             // sysDescr.0
27//!             if oid == &oid!(1, 3, 6, 1, 2, 1, 1, 1, 0) {
28//!                 return GetResult::Value(Value::OctetString("My SNMP Agent".into()));
29//!             }
30//!             // sysObjectID.0
31//!             if oid == &oid!(1, 3, 6, 1, 2, 1, 1, 2, 0) {
32//!                 return GetResult::Value(Value::ObjectIdentifier(oid!(1, 3, 6, 1, 4, 1, 99999)));
33//!             }
34//!             GetResult::NoSuchObject
35//!         })
36//!     }
37//!
38//!     fn get_next<'a>(&'a self, ctx: &'a RequestContext, oid: &'a Oid) -> BoxFuture<'a, GetNextResult> {
39//!         Box::pin(async move {
40//!             // Return the lexicographically next OID after the given one
41//!             let sys_descr = oid!(1, 3, 6, 1, 2, 1, 1, 1, 0);
42//!             let sys_object_id = oid!(1, 3, 6, 1, 2, 1, 1, 2, 0);
43//!
44//!             if oid < &sys_descr {
45//!                 return GetNextResult::Value(VarBind::new(sys_descr, Value::OctetString("My SNMP Agent".into())));
46//!             }
47//!             if oid < &sys_object_id {
48//!                 return GetNextResult::Value(VarBind::new(sys_object_id, Value::ObjectIdentifier(oid!(1, 3, 6, 1, 4, 1, 99999))));
49//!             }
50//!             GetNextResult::EndOfMibView
51//!         })
52//!     }
53//! }
54//!
55//! #[tokio::main]
56//! async fn main() -> Result<(), async_snmp::Error> {
57//!     let agent = Agent::builder()
58//!         .bind("0.0.0.0:161")
59//!         .community(b"public")
60//!         .handler(oid!(1, 3, 6, 1, 2, 1, 1), Arc::new(SystemMibHandler))
61//!         .build()
62//!         .await?;
63//!
64//!     agent.run().await
65//! }
66//! ```
67
68mod request;
69mod response;
70mod set_handler;
71pub mod vacm;
72
73pub use vacm::{SecurityModel, VacmBuilder, VacmConfig, View};
74
75use std::collections::HashMap;
76use std::net::SocketAddr;
77use std::sync::Arc;
78use std::sync::atomic::{AtomicU32, Ordering};
79use std::time::Instant;
80
81use bytes::Bytes;
82use subtle::ConstantTimeEq;
83use tokio::net::UdpSocket;
84use tracing::instrument;
85
86use crate::ber::Decoder;
87use crate::error::{DecodeErrorKind, Error, ErrorStatus, Result};
88use crate::handler::{GetNextResult, GetResult, MibHandler, RequestContext};
89use crate::notification::UsmUserConfig;
90use crate::oid::Oid;
91use crate::pdu::{Pdu, PduType};
92use crate::util::bind_udp_socket;
93use crate::v3::SaltCounter;
94use crate::value::Value;
95use crate::varbind::VarBind;
96use crate::version::Version;
97
98/// Default maximum message size for UDP (RFC 3417 recommendation).
99const DEFAULT_MAX_MESSAGE_SIZE: usize = 1472;
100
101/// Overhead for SNMP message encoding (approximate conservative estimate).
102/// This accounts for version, community/USM, PDU headers, etc.
103const RESPONSE_OVERHEAD: usize = 100;
104
105/// Registered handler with its OID prefix.
106pub(crate) struct RegisteredHandler {
107    pub(crate) prefix: Oid,
108    pub(crate) handler: Arc<dyn MibHandler>,
109}
110
111/// Builder for [`Agent`].
112pub struct AgentBuilder {
113    bind_addr: String,
114    communities: Vec<Vec<u8>>,
115    usm_users: HashMap<Bytes, UsmUserConfig>,
116    handlers: Vec<RegisteredHandler>,
117    engine_id: Option<Vec<u8>>,
118    max_message_size: usize,
119    vacm: Option<VacmConfig>,
120}
121
122impl AgentBuilder {
123    /// Create a new builder with default settings.
124    pub fn new() -> Self {
125        Self {
126            bind_addr: "0.0.0.0:161".to_string(),
127            communities: Vec::new(),
128            usm_users: HashMap::new(),
129            handlers: Vec::new(),
130            engine_id: None,
131            max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
132            vacm: None,
133        }
134    }
135
136    /// Set the bind address.
137    ///
138    /// Default is "0.0.0.0:161".
139    pub fn bind(mut self, addr: impl Into<String>) -> Self {
140        self.bind_addr = addr.into();
141        self
142    }
143
144    /// Add an accepted community string for v1/v2c requests.
145    ///
146    /// Multiple communities can be added. If none are added,
147    /// all community strings are rejected.
148    pub fn community(mut self, community: &[u8]) -> Self {
149        self.communities.push(community.to_vec());
150        self
151    }
152
153    /// Add multiple community strings.
154    pub fn communities<I, C>(mut self, communities: I) -> Self
155    where
156        I: IntoIterator<Item = C>,
157        C: AsRef<[u8]>,
158    {
159        for c in communities {
160            self.communities.push(c.as_ref().to_vec());
161        }
162        self
163    }
164
165    /// Add a USM user for v3 authentication.
166    pub fn usm_user<F>(mut self, username: impl Into<Bytes>, configure: F) -> Self
167    where
168        F: FnOnce(UsmUserConfig) -> UsmUserConfig,
169    {
170        let username_bytes: Bytes = username.into();
171        let config = configure(UsmUserConfig::new(username_bytes.clone()));
172        self.usm_users.insert(username_bytes, config);
173        self
174    }
175
176    /// Set the engine ID for v3.
177    ///
178    /// If not set, a default engine ID will be generated.
179    pub fn engine_id(mut self, engine_id: impl Into<Vec<u8>>) -> Self {
180        self.engine_id = Some(engine_id.into());
181        self
182    }
183
184    /// Set the maximum message size for responses.
185    ///
186    /// Default is 1472 octets (fits Ethernet MTU minus IP/UDP headers).
187    /// GETBULK responses will be truncated to fit within this limit.
188    ///
189    /// For SNMPv3 requests, the agent uses the minimum of this value
190    /// and the msgMaxSize from the request.
191    pub fn max_message_size(mut self, size: usize) -> Self {
192        self.max_message_size = size;
193        self
194    }
195
196    /// Register a MIB handler for an OID subtree.
197    ///
198    /// Handlers are matched by longest prefix. When a request comes in,
199    /// the handler with the longest matching prefix is used.
200    pub fn handler(mut self, prefix: Oid, handler: Arc<dyn MibHandler>) -> Self {
201        self.handlers.push(RegisteredHandler { prefix, handler });
202        self
203    }
204
205    /// Configure VACM (View-based Access Control Model) using a builder function.
206    ///
207    /// When VACM is enabled, all requests are checked against the configured
208    /// access control rules. Requests that don't have proper access are rejected
209    /// with `noAccess` error (v2c/v3) or `noSuchName` (v1).
210    ///
211    /// # Example
212    ///
213    /// ```rust,no_run
214    /// use async_snmp::agent::{Agent, SecurityModel, VacmBuilder};
215    /// use async_snmp::message::SecurityLevel;
216    /// use async_snmp::oid;
217    ///
218    /// # async fn example() -> Result<(), async_snmp::Error> {
219    /// let agent = Agent::builder()
220    ///     .bind("0.0.0.0:161")
221    ///     .community(b"public")
222    ///     .community(b"private")
223    ///     .vacm(|v| v
224    ///         .group("public", SecurityModel::V2c, "readonly_group")
225    ///         .group("private", SecurityModel::V2c, "readwrite_group")
226    ///         .access("readonly_group", |a| a
227    ///             .read_view("full_view"))
228    ///         .access("readwrite_group", |a| a
229    ///             .read_view("full_view")
230    ///             .write_view("write_view"))
231    ///         .view("full_view", |v| v
232    ///             .include(oid!(1, 3, 6, 1)))
233    ///         .view("write_view", |v| v
234    ///             .include(oid!(1, 3, 6, 1, 2, 1, 1))))
235    ///     .build()
236    ///     .await?;
237    /// # Ok(())
238    /// # }
239    /// ```
240    pub fn vacm<F>(mut self, configure: F) -> Self
241    where
242        F: FnOnce(VacmBuilder) -> VacmBuilder,
243    {
244        let builder = VacmBuilder::new();
245        self.vacm = Some(configure(builder).build());
246        self
247    }
248
249    /// Build the agent.
250    ///
251    /// For IPv6 bind addresses, the socket has `IPV6_V6ONLY` set to true.
252    pub async fn build(mut self) -> Result<Agent> {
253        let bind_addr: std::net::SocketAddr = self.bind_addr.parse().map_err(|_| Error::Io {
254            target: None,
255            source: std::io::Error::new(
256                std::io::ErrorKind::InvalidInput,
257                format!("invalid bind address: {}", self.bind_addr),
258            ),
259        })?;
260
261        let socket = bind_udp_socket(bind_addr).await.map_err(|e| Error::Io {
262            target: Some(bind_addr),
263            source: e,
264        })?;
265
266        let local_addr = socket.local_addr().map_err(|e| Error::Io {
267            target: Some(bind_addr),
268            source: e,
269        })?;
270
271        // Generate default engine ID if not provided
272        let engine_id = self.engine_id.unwrap_or_else(|| {
273            // RFC 3411 format: enterprise number + format + local identifier
274            // Use a simple format: 0x80 (local) + timestamp + random
275            let mut id = vec![0x80, 0x00, 0x00, 0x00, 0x01]; // Enterprise format indicator
276            let timestamp = std::time::SystemTime::now()
277                .duration_since(std::time::UNIX_EPOCH)
278                .unwrap_or_default()
279                .as_secs();
280            id.extend_from_slice(&timestamp.to_be_bytes());
281            id
282        });
283
284        // Sort handlers by prefix length (longest first) for matching
285        self.handlers
286            .sort_by(|a, b| b.prefix.len().cmp(&a.prefix.len()));
287
288        Ok(Agent {
289            inner: Arc::new(AgentInner {
290                socket,
291                local_addr,
292                communities: self.communities,
293                usm_users: self.usm_users,
294                handlers: self.handlers,
295                engine_id,
296                engine_boots: AtomicU32::new(1),
297                engine_time: AtomicU32::new(0),
298                engine_start: Instant::now(),
299                salt_counter: SaltCounter::new(),
300                max_message_size: self.max_message_size,
301                vacm: self.vacm,
302                snmp_invalid_msgs: AtomicU32::new(0),
303                snmp_unknown_security_models: AtomicU32::new(0),
304                snmp_silent_drops: AtomicU32::new(0),
305            }),
306        })
307    }
308}
309
310impl Default for AgentBuilder {
311    fn default() -> Self {
312        Self::new()
313    }
314}
315
316/// Inner state shared across agent clones.
317pub(crate) struct AgentInner {
318    pub(crate) socket: UdpSocket,
319    pub(crate) local_addr: SocketAddr,
320    pub(crate) communities: Vec<Vec<u8>>,
321    pub(crate) usm_users: HashMap<Bytes, UsmUserConfig>,
322    pub(crate) handlers: Vec<RegisteredHandler>,
323    pub(crate) engine_id: Vec<u8>,
324    pub(crate) engine_boots: AtomicU32,
325    pub(crate) engine_time: AtomicU32,
326    pub(crate) engine_start: Instant,
327    pub(crate) salt_counter: SaltCounter,
328    pub(crate) max_message_size: usize,
329    pub(crate) vacm: Option<VacmConfig>,
330    // RFC 3412 statistics counters
331    /// snmpInvalidMsgs (1.3.6.1.6.3.11.2.1.2) - messages with invalid msgFlags
332    /// (e.g., privacy without authentication)
333    pub(crate) snmp_invalid_msgs: AtomicU32,
334    /// snmpUnknownSecurityModels (1.3.6.1.6.3.11.2.1.1) - messages with
335    /// unrecognized security model
336    pub(crate) snmp_unknown_security_models: AtomicU32,
337    /// snmpSilentDrops (1.3.6.1.6.3.11.2.1.3) - confirmed-class PDUs silently
338    /// dropped because even an empty response would exceed max message size
339    pub(crate) snmp_silent_drops: AtomicU32,
340}
341
342/// SNMP Agent.
343///
344/// Listens for and responds to SNMP requests (GET, GETNEXT, GETBULK, SET).
345///
346/// # Example
347///
348/// ```rust,no_run
349/// use async_snmp::agent::Agent;
350/// use async_snmp::oid;
351///
352/// # async fn example() -> Result<(), async_snmp::Error> {
353/// let agent = Agent::builder()
354///     .bind("0.0.0.0:161")
355///     .community(b"public")
356///     .build()
357///     .await?;
358///
359/// agent.run().await
360/// # }
361/// ```
362pub struct Agent {
363    pub(crate) inner: Arc<AgentInner>,
364}
365
366impl Agent {
367    /// Create a builder for configuring the agent.
368    pub fn builder() -> AgentBuilder {
369        AgentBuilder::new()
370    }
371
372    /// Get the local address the agent is bound to.
373    pub fn local_addr(&self) -> SocketAddr {
374        self.inner.local_addr
375    }
376
377    /// Get the engine ID.
378    pub fn engine_id(&self) -> &[u8] {
379        &self.inner.engine_id
380    }
381
382    /// Get the snmpInvalidMsgs counter value.
383    ///
384    /// This counter tracks messages with invalid msgFlags, such as
385    /// privacy-without-authentication (RFC 3412 Section 7.2 Step 5d).
386    ///
387    /// OID: 1.3.6.1.6.3.11.2.1.2
388    pub fn snmp_invalid_msgs(&self) -> u32 {
389        self.inner.snmp_invalid_msgs.load(Ordering::Relaxed)
390    }
391
392    /// Get the snmpUnknownSecurityModels counter value.
393    ///
394    /// This counter tracks messages with unrecognized security models
395    /// (RFC 3412 Section 7.2 Step 2).
396    ///
397    /// OID: 1.3.6.1.6.3.11.2.1.1
398    pub fn snmp_unknown_security_models(&self) -> u32 {
399        self.inner
400            .snmp_unknown_security_models
401            .load(Ordering::Relaxed)
402    }
403
404    /// Get the snmpSilentDrops counter value.
405    ///
406    /// This counter tracks confirmed-class PDUs (GetRequest, GetNextRequest,
407    /// GetBulkRequest, SetRequest, InformRequest) that were silently dropped
408    /// because even an empty Response-PDU would exceed the maximum message
409    /// size constraint (RFC 3412 Section 7.1).
410    ///
411    /// OID: 1.3.6.1.6.3.11.2.1.3
412    pub fn snmp_silent_drops(&self) -> u32 {
413        self.inner.snmp_silent_drops.load(Ordering::Relaxed)
414    }
415
416    /// Run the agent, processing requests indefinitely.
417    ///
418    /// This method runs until an error occurs or the task is cancelled.
419    #[instrument(skip(self), err, fields(snmp.local_addr = %self.local_addr()))]
420    pub async fn run(&self) -> Result<()> {
421        let mut buf = vec![0u8; 65535];
422
423        loop {
424            let (len, source) =
425                self.inner
426                    .socket
427                    .recv_from(&mut buf)
428                    .await
429                    .map_err(|e| Error::Io {
430                        target: Some(self.inner.local_addr),
431                        source: e,
432                    })?;
433
434            let data = Bytes::copy_from_slice(&buf[..len]);
435
436            // Update engine time before processing
437            self.update_engine_time();
438
439            match self.handle_request(data, source).await {
440                Ok(Some(response_bytes)) => {
441                    if let Err(e) = self.inner.socket.send_to(&response_bytes, source).await {
442                        tracing::warn!(snmp.source = %source, error = %e, "failed to send response");
443                    }
444                }
445                Ok(None) => {
446                    // No response needed (e.g., invalid message)
447                }
448                Err(e) => {
449                    tracing::warn!(snmp.source = %source, error = %e, "error handling request");
450                }
451            }
452        }
453    }
454
455    /// Process a single request and return the response bytes.
456    ///
457    /// Returns `None` if no response should be sent.
458    async fn handle_request(&self, data: Bytes, source: SocketAddr) -> Result<Option<Bytes>> {
459        // Peek at version
460        let mut decoder = Decoder::new(data.clone());
461        let mut seq = decoder.read_sequence()?;
462        let version_num = seq.read_integer()?;
463        let version = Version::from_i32(version_num).ok_or_else(|| {
464            Error::decode(seq.offset(), DecodeErrorKind::UnknownVersion(version_num))
465        })?;
466        drop(seq);
467        drop(decoder);
468
469        match version {
470            Version::V1 => self.handle_v1(data, source).await,
471            Version::V2c => self.handle_v2c(data, source).await,
472            Version::V3 => self.handle_v3(data, source).await,
473        }
474    }
475
476    /// Update engine time based on elapsed time since start.
477    fn update_engine_time(&self) {
478        let elapsed = self.inner.engine_start.elapsed().as_secs() as u32;
479        self.inner.engine_time.store(elapsed, Ordering::Relaxed);
480    }
481
482    /// Validate community string using constant-time comparison.
483    ///
484    /// Uses constant-time comparison to prevent timing attacks that could
485    /// be used to guess valid community strings character by character.
486    pub(crate) fn validate_community(&self, community: &[u8]) -> bool {
487        if self.inner.communities.is_empty() {
488            // No communities configured = reject all
489            return false;
490        }
491        // Use constant-time comparison for each community string.
492        // We compare against all configured communities regardless of
493        // early matches to maintain constant-time behavior.
494        let mut valid = false;
495        for configured in &self.inner.communities {
496            // ct_eq returns a Choice, which we convert to bool after comparison
497            if configured.len() == community.len()
498                && bool::from(configured.as_slice().ct_eq(community))
499            {
500                valid = true;
501            }
502        }
503        valid
504    }
505
506    /// Dispatch a request to the appropriate handler.
507    async fn dispatch_request(&self, ctx: &RequestContext, pdu: &Pdu) -> Result<Pdu> {
508        match pdu.pdu_type {
509            PduType::GetRequest => self.handle_get(ctx, pdu).await,
510            PduType::GetNextRequest => self.handle_get_next(ctx, pdu).await,
511            PduType::GetBulkRequest => self.handle_get_bulk(ctx, pdu).await,
512            PduType::SetRequest => self.handle_set(ctx, pdu).await,
513            PduType::InformRequest => self.handle_inform(pdu),
514            _ => {
515                // Should not happen - filtered earlier
516                Ok(pdu.to_error_response(ErrorStatus::GenErr, 0))
517            }
518        }
519    }
520
521    /// Handle InformRequest PDU.
522    ///
523    /// Per RFC 3416 Section 4.2.7, an InformRequest is a confirmed-class PDU
524    /// that the receiver acknowledges by returning a Response with the same
525    /// request-id and varbind list.
526    fn handle_inform(&self, pdu: &Pdu) -> Result<Pdu> {
527        // Simply acknowledge by returning the same varbinds in a Response
528        Ok(Pdu {
529            pdu_type: PduType::Response,
530            request_id: pdu.request_id,
531            error_status: 0,
532            error_index: 0,
533            varbinds: pdu.varbinds.clone(),
534        })
535    }
536
537    /// Handle GET request.
538    async fn handle_get(&self, ctx: &RequestContext, pdu: &Pdu) -> Result<Pdu> {
539        let mut response_varbinds = Vec::with_capacity(pdu.varbinds.len());
540
541        for (index, vb) in pdu.varbinds.iter().enumerate() {
542            // VACM read access check
543            if let Some(ref vacm) = self.inner.vacm
544                && !vacm.check_access(ctx.read_view.as_ref(), &vb.oid)
545            {
546                // v1: noSuchName, v2c/v3: noAccess or NoSuchObject
547                if ctx.version == Version::V1 {
548                    return Ok(Pdu {
549                        pdu_type: PduType::Response,
550                        request_id: pdu.request_id,
551                        error_status: ErrorStatus::NoSuchName.as_i32(),
552                        error_index: (index + 1) as i32,
553                        varbinds: pdu.varbinds.clone(),
554                    });
555                } else {
556                    // For GET, return NoSuchObject for inaccessible OIDs per RFC 3415
557                    response_varbinds.push(VarBind::new(vb.oid.clone(), Value::NoSuchObject));
558                    continue;
559                }
560            }
561
562            let result = if let Some(handler) = self.find_handler(&vb.oid) {
563                handler.handler.get(ctx, &vb.oid).await
564            } else {
565                GetResult::NoSuchObject
566            };
567
568            let response_value = match result {
569                GetResult::Value(v) => v,
570                GetResult::NoSuchObject => {
571                    // v1 returns noSuchName error, v2c/v3 returns NoSuchObject exception
572                    if ctx.version == Version::V1 {
573                        return Ok(Pdu {
574                            pdu_type: PduType::Response,
575                            request_id: pdu.request_id,
576                            error_status: ErrorStatus::NoSuchName.as_i32(),
577                            error_index: (index + 1) as i32,
578                            varbinds: pdu.varbinds.clone(),
579                        });
580                    } else {
581                        Value::NoSuchObject
582                    }
583                }
584                GetResult::NoSuchInstance => {
585                    // v1 returns noSuchName error, v2c/v3 returns NoSuchInstance exception
586                    if ctx.version == Version::V1 {
587                        return Ok(Pdu {
588                            pdu_type: PduType::Response,
589                            request_id: pdu.request_id,
590                            error_status: ErrorStatus::NoSuchName.as_i32(),
591                            error_index: (index + 1) as i32,
592                            varbinds: pdu.varbinds.clone(),
593                        });
594                    } else {
595                        Value::NoSuchInstance
596                    }
597                }
598            };
599
600            response_varbinds.push(VarBind::new(vb.oid.clone(), response_value));
601        }
602
603        Ok(Pdu {
604            pdu_type: PduType::Response,
605            request_id: pdu.request_id,
606            error_status: 0,
607            error_index: 0,
608            varbinds: response_varbinds,
609        })
610    }
611
612    /// Handle GETNEXT request.
613    async fn handle_get_next(&self, ctx: &RequestContext, pdu: &Pdu) -> Result<Pdu> {
614        let mut response_varbinds = Vec::with_capacity(pdu.varbinds.len());
615
616        for (index, vb) in pdu.varbinds.iter().enumerate() {
617            // Try to find the next OID from any handler
618            let next = self.get_next_oid(ctx, &vb.oid).await;
619
620            // Check VACM access for the returned OID (if VACM enabled)
621            let next = if let Some(ref next_vb) = next {
622                if let Some(ref vacm) = self.inner.vacm {
623                    if vacm.check_access(ctx.read_view.as_ref(), &next_vb.oid) {
624                        next
625                    } else {
626                        // OID not accessible, continue searching
627                        // For simplicity, we just return EndOfMibView here
628                        // A more complete implementation would continue the search
629                        None
630                    }
631                } else {
632                    next
633                }
634            } else {
635                next
636            };
637
638            match next {
639                Some(next_vb) => {
640                    response_varbinds.push(next_vb);
641                }
642                None => {
643                    // v1 returns noSuchName, v2c/v3 returns endOfMibView
644                    if ctx.version == Version::V1 {
645                        return Ok(Pdu {
646                            pdu_type: PduType::Response,
647                            request_id: pdu.request_id,
648                            error_status: ErrorStatus::NoSuchName.as_i32(),
649                            error_index: (index + 1) as i32,
650                            varbinds: pdu.varbinds.clone(),
651                        });
652                    } else {
653                        response_varbinds.push(VarBind::new(vb.oid.clone(), Value::EndOfMibView));
654                    }
655                }
656            }
657        }
658
659        Ok(Pdu {
660            pdu_type: PduType::Response,
661            request_id: pdu.request_id,
662            error_status: 0,
663            error_index: 0,
664            varbinds: response_varbinds,
665        })
666    }
667
668    /// Handle GETBULK request.
669    ///
670    /// Per RFC 3416 Section 4.2.3, if the response would exceed the message
671    /// size limit, we return fewer variable bindings rather than all of them.
672    async fn handle_get_bulk(&self, ctx: &RequestContext, pdu: &Pdu) -> Result<Pdu> {
673        // For GETBULK, error_status is non_repeaters and error_index is max_repetitions
674        let non_repeaters = pdu.error_status.max(0) as usize;
675        let max_repetitions = pdu.error_index.max(0) as usize;
676
677        let mut response_varbinds = Vec::new();
678        let mut current_size: usize = RESPONSE_OVERHEAD;
679        let max_size = self.inner.max_message_size;
680
681        // Helper to check if we can add a varbind
682        let can_add = |vb: &VarBind, current_size: usize| -> bool {
683            current_size + vb.encoded_size() <= max_size
684        };
685
686        // Handle non-repeaters (first N varbinds get one GETNEXT each)
687        for vb in pdu.varbinds.iter().take(non_repeaters) {
688            let next_vb = match self.get_next_oid(ctx, &vb.oid).await {
689                Some(next_vb) => next_vb,
690                None => VarBind::new(vb.oid.clone(), Value::EndOfMibView),
691            };
692
693            if !can_add(&next_vb, current_size) {
694                // Can't fit even non-repeaters, return tooBig if we have nothing
695                if response_varbinds.is_empty() {
696                    return Ok(Pdu {
697                        pdu_type: PduType::Response,
698                        request_id: pdu.request_id,
699                        error_status: ErrorStatus::TooBig.as_i32(),
700                        error_index: 0,
701                        varbinds: pdu.varbinds.clone(),
702                    });
703                }
704                // Otherwise return what we have
705                break;
706            }
707
708            current_size += next_vb.encoded_size();
709            response_varbinds.push(next_vb);
710        }
711
712        // Handle repeaters
713        if non_repeaters < pdu.varbinds.len() {
714            let repeaters = &pdu.varbinds[non_repeaters..];
715            let mut current_oids: Vec<Oid> = repeaters.iter().map(|vb| vb.oid.clone()).collect();
716            let mut all_done = vec![false; repeaters.len()];
717
718            'outer: for _ in 0..max_repetitions {
719                let mut row_complete = true;
720                for (i, oid) in current_oids.iter_mut().enumerate() {
721                    let next_vb = if all_done[i] {
722                        VarBind::new(oid.clone(), Value::EndOfMibView)
723                    } else {
724                        match self.get_next_oid(ctx, oid).await {
725                            Some(next_vb) => {
726                                *oid = next_vb.oid.clone();
727                                row_complete = false;
728                                next_vb
729                            }
730                            None => {
731                                all_done[i] = true;
732                                VarBind::new(oid.clone(), Value::EndOfMibView)
733                            }
734                        }
735                    };
736
737                    // Check size before adding
738                    if !can_add(&next_vb, current_size) {
739                        // Can't fit more, return what we have
740                        break 'outer;
741                    }
742
743                    current_size += next_vb.encoded_size();
744                    response_varbinds.push(next_vb);
745                }
746
747                if row_complete {
748                    break;
749                }
750            }
751        }
752
753        Ok(Pdu {
754            pdu_type: PduType::Response,
755            request_id: pdu.request_id,
756            error_status: 0,
757            error_index: 0,
758            varbinds: response_varbinds,
759        })
760    }
761
762    /// Find the handler for a given OID.
763    pub(crate) fn find_handler(&self, oid: &Oid) -> Option<&RegisteredHandler> {
764        // Handlers are sorted by prefix length (longest first)
765        self.inner
766            .handlers
767            .iter()
768            .find(|&handler| handler.handler.handles(&handler.prefix, oid))
769            .map(|v| v as _)
770    }
771
772    /// Get the next OID from any handler.
773    async fn get_next_oid(&self, ctx: &RequestContext, oid: &Oid) -> Option<VarBind> {
774        // Find the first handler that can provide a next OID
775        let mut best_result: Option<VarBind> = None;
776
777        for handler in &self.inner.handlers {
778            if let GetNextResult::Value(next) = handler.handler.get_next(ctx, oid).await {
779                // Must be lexicographically greater than the request OID
780                if next.oid > *oid {
781                    match &best_result {
782                        None => best_result = Some(next),
783                        Some(current) if next.oid < current.oid => best_result = Some(next),
784                        _ => {}
785                    }
786                }
787            }
788        }
789
790        best_result
791    }
792}
793
794impl Clone for Agent {
795    fn clone(&self) -> Self {
796        Self {
797            inner: Arc::clone(&self.inner),
798        }
799    }
800}
801
802#[cfg(test)]
803mod tests {
804    use super::*;
805    use crate::handler::{
806        BoxFuture, GetNextResult, GetResult, MibHandler, RequestContext, SecurityModel, SetResult,
807    };
808    use crate::message::SecurityLevel;
809    use crate::oid;
810
811    struct TestHandler;
812
813    impl MibHandler for TestHandler {
814        fn get<'a>(&'a self, _ctx: &'a RequestContext, oid: &'a Oid) -> BoxFuture<'a, GetResult> {
815            Box::pin(async move {
816                if oid == &oid!(1, 3, 6, 1, 4, 1, 99999, 1, 0) {
817                    return GetResult::Value(Value::Integer(42));
818                }
819                if oid == &oid!(1, 3, 6, 1, 4, 1, 99999, 2, 0) {
820                    return GetResult::Value(Value::OctetString(Bytes::from_static(b"test")));
821                }
822                GetResult::NoSuchObject
823            })
824        }
825
826        fn get_next<'a>(
827            &'a self,
828            _ctx: &'a RequestContext,
829            oid: &'a Oid,
830        ) -> BoxFuture<'a, GetNextResult> {
831            Box::pin(async move {
832                let oid1 = oid!(1, 3, 6, 1, 4, 1, 99999, 1, 0);
833                let oid2 = oid!(1, 3, 6, 1, 4, 1, 99999, 2, 0);
834
835                if oid < &oid1 {
836                    return GetNextResult::Value(VarBind::new(oid1, Value::Integer(42)));
837                }
838                if oid < &oid2 {
839                    return GetNextResult::Value(VarBind::new(
840                        oid2,
841                        Value::OctetString(Bytes::from_static(b"test")),
842                    ));
843                }
844                GetNextResult::EndOfMibView
845            })
846        }
847    }
848
849    fn test_ctx() -> RequestContext {
850        RequestContext {
851            source: "127.0.0.1:12345".parse().unwrap(),
852            version: Version::V2c,
853            security_model: SecurityModel::V2c,
854            security_name: Bytes::from_static(b"public"),
855            security_level: SecurityLevel::NoAuthNoPriv,
856            context_name: Bytes::new(),
857            request_id: 1,
858            pdu_type: PduType::GetRequest,
859            group_name: None,
860            read_view: None,
861            write_view: None,
862        }
863    }
864
865    #[test]
866    fn test_agent_builder_defaults() {
867        let builder = AgentBuilder::new();
868        assert_eq!(builder.bind_addr, "0.0.0.0:161");
869        assert!(builder.communities.is_empty());
870        assert!(builder.usm_users.is_empty());
871        assert!(builder.handlers.is_empty());
872    }
873
874    #[test]
875    fn test_agent_builder_community() {
876        let builder = AgentBuilder::new()
877            .community(b"public")
878            .community(b"private");
879        assert_eq!(builder.communities.len(), 2);
880    }
881
882    #[test]
883    fn test_agent_builder_communities() {
884        let builder = AgentBuilder::new().communities(["public", "private"]);
885        assert_eq!(builder.communities.len(), 2);
886    }
887
888    #[test]
889    fn test_agent_builder_handler() {
890        let builder =
891            AgentBuilder::new().handler(oid!(1, 3, 6, 1, 4, 1, 99999), Arc::new(TestHandler));
892        assert_eq!(builder.handlers.len(), 1);
893    }
894
895    #[tokio::test]
896    async fn test_mib_handler_default_set() {
897        let handler = TestHandler;
898        let mut ctx = test_ctx();
899        ctx.pdu_type = PduType::SetRequest;
900
901        let result = handler
902            .test_set(&ctx, &oid!(1, 3, 6, 1), &Value::Integer(1))
903            .await;
904        assert_eq!(result, SetResult::NotWritable);
905    }
906
907    #[test]
908    fn test_mib_handler_handles() {
909        let handler = TestHandler;
910        let prefix = oid!(1, 3, 6, 1, 4, 1, 99999);
911
912        // OID within prefix
913        assert!(handler.handles(&prefix, &oid!(1, 3, 6, 1, 4, 1, 99999, 1, 0)));
914
915        // OID before prefix (GETNEXT should still try)
916        assert!(handler.handles(&prefix, &oid!(1, 3, 6, 1, 4, 1, 99998)));
917
918        // OID after prefix (not handled)
919        assert!(!handler.handles(&prefix, &oid!(1, 3, 6, 1, 4, 1, 100000)));
920    }
921
922    #[tokio::test]
923    async fn test_test_handler_get() {
924        let handler = TestHandler;
925        let ctx = test_ctx();
926
927        // Existing OID
928        let result = handler
929            .get(&ctx, &oid!(1, 3, 6, 1, 4, 1, 99999, 1, 0))
930            .await;
931        assert!(matches!(result, GetResult::Value(Value::Integer(42))));
932
933        // Non-existing OID
934        let result = handler
935            .get(&ctx, &oid!(1, 3, 6, 1, 4, 1, 99999, 99, 0))
936            .await;
937        assert!(matches!(result, GetResult::NoSuchObject));
938    }
939
940    #[tokio::test]
941    async fn test_test_handler_get_next() {
942        let handler = TestHandler;
943        let mut ctx = test_ctx();
944        ctx.pdu_type = PduType::GetNextRequest;
945
946        // Before first OID
947        let next = handler.get_next(&ctx, &oid!(1, 3, 6, 1, 4, 1, 99999)).await;
948        assert!(next.is_value());
949        if let GetNextResult::Value(vb) = next {
950            assert_eq!(vb.oid, oid!(1, 3, 6, 1, 4, 1, 99999, 1, 0));
951        }
952
953        // Between OIDs
954        let next = handler
955            .get_next(&ctx, &oid!(1, 3, 6, 1, 4, 1, 99999, 1, 0))
956            .await;
957        assert!(next.is_value());
958        if let GetNextResult::Value(vb) = next {
959            assert_eq!(vb.oid, oid!(1, 3, 6, 1, 4, 1, 99999, 2, 0));
960        }
961
962        // After last OID
963        let next = handler
964            .get_next(&ctx, &oid!(1, 3, 6, 1, 4, 1, 99999, 2, 0))
965            .await;
966        assert!(next.is_end_of_mib_view());
967    }
968}