async_snmp/client/
builder.rs

1//! New unified client builder.
2//!
3//! This module provides the [`ClientBuilder`] type, a single entry point for
4//! constructing SNMP clients with any authentication mode (v1/v2c community
5//! or v3 USM).
6
7use std::net::{SocketAddr, ToSocketAddrs};
8use std::sync::Arc;
9use std::time::Duration;
10
11use bytes::Bytes;
12
13use crate::client::walk::{OidOrdering, WalkMode};
14use crate::client::{Auth, ClientConfig, CommunityVersion, V3SecurityConfig};
15use crate::error::Error;
16use crate::transport::{TcpTransport, Transport, UdpTransport};
17use crate::v3::EngineCache;
18use crate::version::Version;
19
20use super::Client;
21
22/// Builder for constructing SNMP clients.
23///
24/// This is the single entry point for client construction. It supports all
25/// SNMP versions (v1, v2c, v3) through the [`Auth`] enum.
26///
27/// # Example
28///
29/// ```rust,no_run
30/// use async_snmp::{Auth, ClientBuilder};
31/// use std::time::Duration;
32///
33/// # async fn example() -> async_snmp::Result<()> {
34/// // Simple v2c client
35/// let client = ClientBuilder::new("192.168.1.1:161", Auth::v2c("public"))
36///     .connect().await?;
37///
38/// // v3 client with authentication
39/// let client = ClientBuilder::new("192.168.1.1:161",
40///     Auth::usm("admin").auth(async_snmp::AuthProtocol::Sha256, "password"))
41///     .timeout(Duration::from_secs(10))
42///     .retries(5)
43///     .connect().await?;
44/// # Ok(())
45/// # }
46/// ```
47pub struct ClientBuilder {
48    target: String,
49    auth: Auth,
50    timeout: Duration,
51    retries: u32,
52    max_oids_per_request: usize,
53    max_repetitions: u32,
54    walk_mode: WalkMode,
55    oid_ordering: OidOrdering,
56    max_walk_results: Option<usize>,
57    engine_cache: Option<Arc<EngineCache>>,
58    /// Override context engine ID (V3 only, for proxy/routing scenarios).
59    context_engine_id: Option<Vec<u8>>,
60}
61
62impl ClientBuilder {
63    /// Create a new client builder.
64    ///
65    /// # Arguments
66    ///
67    /// * `target` - The target address (e.g., "192.168.1.1:161")
68    /// * `auth` - Authentication configuration (community or USM)
69    ///
70    /// # Example
71    ///
72    /// ```rust,no_run
73    /// use async_snmp::{Auth, ClientBuilder};
74    ///
75    /// // Using Auth::default() for v2c with "public" community
76    /// let builder = ClientBuilder::new("192.168.1.1:161", Auth::default());
77    ///
78    /// // Using Auth::v1() for SNMPv1
79    /// let builder = ClientBuilder::new("192.168.1.1:161", Auth::v1("private"));
80    ///
81    /// // Using Auth::usm() for SNMPv3
82    /// let builder = ClientBuilder::new("192.168.1.1:161",
83    ///     Auth::usm("admin").auth(async_snmp::AuthProtocol::Sha256, "password"));
84    /// ```
85    pub fn new(target: impl Into<String>, auth: impl Into<Auth>) -> Self {
86        Self {
87            target: target.into(),
88            auth: auth.into(),
89            timeout: Duration::from_secs(5),
90            retries: 3,
91            max_oids_per_request: 10,
92            max_repetitions: 25,
93            walk_mode: WalkMode::Auto,
94            oid_ordering: OidOrdering::Strict,
95            max_walk_results: None,
96            engine_cache: None,
97            context_engine_id: None,
98        }
99    }
100
101    /// Set the request timeout (default: 5 seconds).
102    pub fn timeout(mut self, timeout: Duration) -> Self {
103        self.timeout = timeout;
104        self
105    }
106
107    /// Set the number of retries (default: 3).
108    pub fn retries(mut self, retries: u32) -> Self {
109        self.retries = retries;
110        self
111    }
112
113    /// Set the maximum OIDs per request (default: 10).
114    ///
115    /// Requests with more OIDs than this limit are automatically split
116    /// into multiple batches.
117    pub fn max_oids_per_request(mut self, max: usize) -> Self {
118        self.max_oids_per_request = max;
119        self
120    }
121
122    /// Set max-repetitions for GETBULK operations (default: 25).
123    ///
124    /// Controls how many rows are requested per GETBULK PDU during walk
125    /// operations. Higher values reduce round-trips but increase response size.
126    pub fn max_repetitions(mut self, max: u32) -> Self {
127        self.max_repetitions = max;
128        self
129    }
130
131    /// Override walk behavior for devices with buggy GETBULK (default: Auto).
132    ///
133    /// - `WalkMode::Auto`: Use GETNEXT for v1, GETBULK for v2c/v3
134    /// - `WalkMode::GetNext`: Always use GETNEXT (slower but more compatible)
135    /// - `WalkMode::GetBulk`: Always use GETBULK (faster, errors on v1)
136    pub fn walk_mode(mut self, mode: WalkMode) -> Self {
137        self.walk_mode = mode;
138        self
139    }
140
141    /// Set OID ordering behavior for walk operations (default: Strict).
142    ///
143    /// - `OidOrdering::Strict`: Require strictly increasing OIDs. Most efficient.
144    /// - `OidOrdering::AllowNonIncreasing`: Allow non-increasing OIDs with cycle
145    ///   detection. Uses O(n) memory to track seen OIDs.
146    ///
147    /// Use `AllowNonIncreasing` for buggy agents that return OIDs out of order.
148    pub fn oid_ordering(mut self, ordering: OidOrdering) -> Self {
149        self.oid_ordering = ordering;
150        self
151    }
152
153    /// Set maximum results from a single walk operation (default: unlimited).
154    ///
155    /// Safety limit to prevent runaway walks. Walk terminates normally when
156    /// limit is reached.
157    pub fn max_walk_results(mut self, limit: usize) -> Self {
158        self.max_walk_results = Some(limit);
159        self
160    }
161
162    /// Set shared engine cache (V3 only, for high-throughput polling).
163    ///
164    /// Allows multiple clients to share discovered engine state, reducing
165    /// the number of discovery requests.
166    pub fn engine_cache(mut self, cache: Arc<EngineCache>) -> Self {
167        self.engine_cache = Some(cache);
168        self
169    }
170
171    /// Override the context engine ID (V3 only).
172    ///
173    /// By default, the context engine ID is the same as the authoritative
174    /// engine ID discovered during engine discovery. Use this to override for:
175    /// - Proxy scenarios where requests route through an intermediate agent
176    /// - Devices that require a specific context engine ID
177    /// - Pre-configured engine IDs from device documentation
178    ///
179    /// The engine ID should be provided as raw bytes (not hex-encoded).
180    pub fn context_engine_id(mut self, engine_id: impl Into<Vec<u8>>) -> Self {
181        self.context_engine_id = Some(engine_id.into());
182        self
183    }
184
185    /// Validate the configuration.
186    fn validate(&self) -> Result<(), Error> {
187        if let Auth::Usm(usm) = &self.auth {
188            // Privacy requires authentication
189            if usm.priv_protocol.is_some() && usm.auth_protocol.is_none() {
190                return Err(Error::Config("privacy requires authentication".into()));
191            }
192            // Protocol requires password
193            if usm.auth_protocol.is_some() && usm.auth_password.is_none() {
194                return Err(Error::Config("auth protocol requires password".into()));
195            }
196            if usm.priv_protocol.is_some() && usm.priv_password.is_none() {
197                return Err(Error::Config("priv protocol requires password".into()));
198            }
199        }
200
201        // Validate walk mode for v1
202        if let Auth::Community {
203            version: CommunityVersion::V1,
204            ..
205        } = &self.auth
206            && self.walk_mode == WalkMode::GetBulk
207        {
208            return Err(Error::Config("GETBULK not supported in SNMPv1".into()));
209        }
210
211        Ok(())
212    }
213
214    /// Resolve target address to SocketAddr.
215    fn resolve_target(&self) -> Result<SocketAddr, Error> {
216        self.target
217            .to_socket_addrs()
218            .map_err(|e| Error::Io {
219                target: None,
220                source: e,
221            })?
222            .next()
223            .ok_or_else(|| Error::Io {
224                target: None,
225                source: std::io::Error::new(
226                    std::io::ErrorKind::NotFound,
227                    "could not resolve address",
228                ),
229            })
230    }
231
232    /// Build ClientConfig from the builder settings.
233    fn build_config(&self) -> ClientConfig {
234        match &self.auth {
235            Auth::Community { version, community } => {
236                let snmp_version = match version {
237                    CommunityVersion::V1 => Version::V1,
238                    CommunityVersion::V2c => Version::V2c,
239                };
240                ClientConfig {
241                    version: snmp_version,
242                    community: Bytes::copy_from_slice(community.as_bytes()),
243                    timeout: self.timeout,
244                    retries: self.retries,
245                    max_oids_per_request: self.max_oids_per_request,
246                    v3_security: None,
247                    walk_mode: self.walk_mode,
248                    oid_ordering: self.oid_ordering,
249                    max_walk_results: self.max_walk_results,
250                    max_repetitions: self.max_repetitions,
251                }
252            }
253            Auth::Usm(usm) => {
254                let mut security =
255                    V3SecurityConfig::new(Bytes::copy_from_slice(usm.username.as_bytes()));
256
257                // Prefer master_keys over passwords if available
258                if let Some(ref master_keys) = usm.master_keys {
259                    security = security.with_master_keys(master_keys.clone());
260                } else {
261                    if let (Some(auth_proto), Some(auth_pass)) =
262                        (usm.auth_protocol, &usm.auth_password)
263                    {
264                        security = security.auth(auth_proto, auth_pass.as_bytes().to_vec());
265                    }
266
267                    if let (Some(priv_proto), Some(priv_pass)) =
268                        (usm.priv_protocol, &usm.priv_password)
269                    {
270                        security = security.privacy(priv_proto, priv_pass.as_bytes().to_vec());
271                    }
272                }
273
274                ClientConfig {
275                    version: Version::V3,
276                    community: Bytes::new(),
277                    timeout: self.timeout,
278                    retries: self.retries,
279                    max_oids_per_request: self.max_oids_per_request,
280                    v3_security: Some(security),
281                    walk_mode: self.walk_mode,
282                    oid_ordering: self.oid_ordering,
283                    max_walk_results: self.max_walk_results,
284                    max_repetitions: self.max_repetitions,
285                }
286            }
287        }
288    }
289
290    /// Build the client with the given transport.
291    fn build_inner<T: Transport>(self, transport: T) -> Client<T> {
292        let config = self.build_config();
293
294        if let Some(cache) = self.engine_cache {
295            Client::with_engine_cache(transport, config, cache)
296        } else {
297            Client::new(transport, config)
298        }
299    }
300
301    /// Connect via UDP (default).
302    ///
303    /// # Errors
304    ///
305    /// Returns an error if the configuration is invalid or the connection fails.
306    pub async fn connect(self) -> Result<Client<UdpTransport>, Error> {
307        self.validate()?;
308        let addr = self.resolve_target()?;
309        let transport = UdpTransport::connect(addr).await?;
310        Ok(self.build_inner(transport))
311    }
312
313    /// Connect via TCP.
314    ///
315    /// # Errors
316    ///
317    /// Returns an error if the configuration is invalid or the connection fails.
318    pub async fn connect_tcp(self) -> Result<Client<TcpTransport>, Error> {
319        self.validate()?;
320        let addr = self.resolve_target()?;
321        let transport = TcpTransport::connect(addr).await?;
322        Ok(self.build_inner(transport))
323    }
324
325    /// Build with custom transport.
326    ///
327    /// # Errors
328    ///
329    /// Returns an error if the configuration is invalid.
330    pub fn build<T: Transport>(self, transport: T) -> Result<Client<T>, Error> {
331        self.validate()?;
332        Ok(self.build_inner(transport))
333    }
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use crate::v3::{AuthProtocol, PrivProtocol};
340
341    #[test]
342    fn test_builder_defaults() {
343        let builder = ClientBuilder::new("192.168.1.1:161", Auth::default());
344        assert_eq!(builder.target, "192.168.1.1:161");
345        assert_eq!(builder.timeout, Duration::from_secs(5));
346        assert_eq!(builder.retries, 3);
347        assert_eq!(builder.max_oids_per_request, 10);
348        assert_eq!(builder.max_repetitions, 25);
349        assert_eq!(builder.walk_mode, WalkMode::Auto);
350        assert_eq!(builder.oid_ordering, OidOrdering::Strict);
351        assert!(builder.max_walk_results.is_none());
352        assert!(builder.engine_cache.is_none());
353        assert!(builder.context_engine_id.is_none());
354    }
355
356    #[test]
357    fn test_builder_with_options() {
358        let cache = Arc::new(EngineCache::new());
359        let builder = ClientBuilder::new("192.168.1.1:161", Auth::v2c("private"))
360            .timeout(Duration::from_secs(10))
361            .retries(5)
362            .max_oids_per_request(20)
363            .max_repetitions(50)
364            .walk_mode(WalkMode::GetNext)
365            .oid_ordering(OidOrdering::AllowNonIncreasing)
366            .max_walk_results(1000)
367            .engine_cache(cache.clone())
368            .context_engine_id(vec![0x80, 0x00, 0x01]);
369
370        assert_eq!(builder.timeout, Duration::from_secs(10));
371        assert_eq!(builder.retries, 5);
372        assert_eq!(builder.max_oids_per_request, 20);
373        assert_eq!(builder.max_repetitions, 50);
374        assert_eq!(builder.walk_mode, WalkMode::GetNext);
375        assert_eq!(builder.oid_ordering, OidOrdering::AllowNonIncreasing);
376        assert_eq!(builder.max_walk_results, Some(1000));
377        assert!(builder.engine_cache.is_some());
378        assert_eq!(builder.context_engine_id, Some(vec![0x80, 0x00, 0x01]));
379    }
380
381    #[test]
382    fn test_validate_community_ok() {
383        let builder = ClientBuilder::new("192.168.1.1:161", Auth::v2c("public"));
384        assert!(builder.validate().is_ok());
385    }
386
387    #[test]
388    fn test_validate_usm_no_auth_no_priv_ok() {
389        let builder = ClientBuilder::new("192.168.1.1:161", Auth::usm("readonly"));
390        assert!(builder.validate().is_ok());
391    }
392
393    #[test]
394    fn test_validate_usm_auth_no_priv_ok() {
395        let builder = ClientBuilder::new(
396            "192.168.1.1:161",
397            Auth::usm("admin").auth(AuthProtocol::Sha256, "authpass"),
398        );
399        assert!(builder.validate().is_ok());
400    }
401
402    #[test]
403    fn test_validate_usm_auth_priv_ok() {
404        let builder = ClientBuilder::new(
405            "192.168.1.1:161",
406            Auth::usm("admin")
407                .auth(AuthProtocol::Sha256, "authpass")
408                .privacy(PrivProtocol::Aes128, "privpass"),
409        );
410        assert!(builder.validate().is_ok());
411    }
412
413    #[test]
414    fn test_validate_priv_without_auth_error() {
415        // Manually construct UsmAuth with priv but no auth
416        let usm = crate::client::UsmAuth {
417            username: "user".to_string(),
418            auth_protocol: None,
419            auth_password: None,
420            priv_protocol: Some(PrivProtocol::Aes128),
421            priv_password: Some("privpass".to_string()),
422            context_name: None,
423            master_keys: None,
424        };
425        let builder = ClientBuilder::new("192.168.1.1:161", Auth::Usm(usm));
426        let err = builder.validate().unwrap_err();
427        assert!(
428            matches!(err, Error::Config(msg) if msg.contains("privacy requires authentication"))
429        );
430    }
431
432    #[test]
433    fn test_validate_auth_protocol_without_password_error() {
434        // Manually construct UsmAuth with auth protocol but no password
435        let usm = crate::client::UsmAuth {
436            username: "user".to_string(),
437            auth_protocol: Some(AuthProtocol::Sha256),
438            auth_password: None,
439            priv_protocol: None,
440            priv_password: None,
441            context_name: None,
442            master_keys: None,
443        };
444        let builder = ClientBuilder::new("192.168.1.1:161", Auth::Usm(usm));
445        let err = builder.validate().unwrap_err();
446        assert!(
447            matches!(err, Error::Config(msg) if msg.contains("auth protocol requires password"))
448        );
449    }
450
451    #[test]
452    fn test_validate_priv_protocol_without_password_error() {
453        // Manually construct UsmAuth with priv protocol but no password
454        let usm = crate::client::UsmAuth {
455            username: "user".to_string(),
456            auth_protocol: Some(AuthProtocol::Sha256),
457            auth_password: Some("authpass".to_string()),
458            priv_protocol: Some(PrivProtocol::Aes128),
459            priv_password: None,
460            context_name: None,
461            master_keys: None,
462        };
463        let builder = ClientBuilder::new("192.168.1.1:161", Auth::Usm(usm));
464        let err = builder.validate().unwrap_err();
465        assert!(
466            matches!(err, Error::Config(msg) if msg.contains("priv protocol requires password"))
467        );
468    }
469
470    #[test]
471    fn test_builder_with_usm_builder() {
472        // Test that UsmBuilder can be passed directly (via Into<Auth>)
473        let builder = ClientBuilder::new(
474            "192.168.1.1:161",
475            Auth::usm("admin").auth(AuthProtocol::Sha256, "pass"),
476        );
477        assert!(builder.validate().is_ok());
478    }
479}