1use std::net::{SocketAddr, ToSocketAddrs};
8use std::sync::Arc;
9use std::time::Duration;
10
11use bytes::Bytes;
12
13use crate::client::walk::{OidOrdering, WalkMode};
14use crate::client::{Auth, ClientConfig, CommunityVersion, V3SecurityConfig};
15use crate::error::Error;
16use crate::transport::{TcpTransport, Transport, UdpTransport};
17use crate::v3::EngineCache;
18use crate::version::Version;
19
20use super::Client;
21
22pub struct ClientBuilder {
48 target: String,
49 auth: Auth,
50 timeout: Duration,
51 retries: u32,
52 max_oids_per_request: usize,
53 max_repetitions: u32,
54 walk_mode: WalkMode,
55 oid_ordering: OidOrdering,
56 max_walk_results: Option<usize>,
57 engine_cache: Option<Arc<EngineCache>>,
58 context_engine_id: Option<Vec<u8>>,
60}
61
62impl ClientBuilder {
63 pub fn new(target: impl Into<String>, auth: impl Into<Auth>) -> Self {
86 Self {
87 target: target.into(),
88 auth: auth.into(),
89 timeout: Duration::from_secs(5),
90 retries: 3,
91 max_oids_per_request: 10,
92 max_repetitions: 25,
93 walk_mode: WalkMode::Auto,
94 oid_ordering: OidOrdering::Strict,
95 max_walk_results: None,
96 engine_cache: None,
97 context_engine_id: None,
98 }
99 }
100
101 pub fn timeout(mut self, timeout: Duration) -> Self {
103 self.timeout = timeout;
104 self
105 }
106
107 pub fn retries(mut self, retries: u32) -> Self {
109 self.retries = retries;
110 self
111 }
112
113 pub fn max_oids_per_request(mut self, max: usize) -> Self {
118 self.max_oids_per_request = max;
119 self
120 }
121
122 pub fn max_repetitions(mut self, max: u32) -> Self {
127 self.max_repetitions = max;
128 self
129 }
130
131 pub fn walk_mode(mut self, mode: WalkMode) -> Self {
137 self.walk_mode = mode;
138 self
139 }
140
141 pub fn oid_ordering(mut self, ordering: OidOrdering) -> Self {
149 self.oid_ordering = ordering;
150 self
151 }
152
153 pub fn max_walk_results(mut self, limit: usize) -> Self {
158 self.max_walk_results = Some(limit);
159 self
160 }
161
162 pub fn engine_cache(mut self, cache: Arc<EngineCache>) -> Self {
167 self.engine_cache = Some(cache);
168 self
169 }
170
171 pub fn context_engine_id(mut self, engine_id: impl Into<Vec<u8>>) -> Self {
181 self.context_engine_id = Some(engine_id.into());
182 self
183 }
184
185 fn validate(&self) -> Result<(), Error> {
187 if let Auth::Usm(usm) = &self.auth {
188 if usm.priv_protocol.is_some() && usm.auth_protocol.is_none() {
190 return Err(Error::Config("privacy requires authentication".into()));
191 }
192 if usm.auth_protocol.is_some() && usm.auth_password.is_none() {
194 return Err(Error::Config("auth protocol requires password".into()));
195 }
196 if usm.priv_protocol.is_some() && usm.priv_password.is_none() {
197 return Err(Error::Config("priv protocol requires password".into()));
198 }
199 }
200
201 if let Auth::Community {
203 version: CommunityVersion::V1,
204 ..
205 } = &self.auth
206 && self.walk_mode == WalkMode::GetBulk
207 {
208 return Err(Error::Config("GETBULK not supported in SNMPv1".into()));
209 }
210
211 Ok(())
212 }
213
214 fn resolve_target(&self) -> Result<SocketAddr, Error> {
216 self.target
217 .to_socket_addrs()
218 .map_err(|e| Error::Io {
219 target: None,
220 source: e,
221 })?
222 .next()
223 .ok_or_else(|| Error::Io {
224 target: None,
225 source: std::io::Error::new(
226 std::io::ErrorKind::NotFound,
227 "could not resolve address",
228 ),
229 })
230 }
231
232 fn build_config(&self) -> ClientConfig {
234 match &self.auth {
235 Auth::Community { version, community } => {
236 let snmp_version = match version {
237 CommunityVersion::V1 => Version::V1,
238 CommunityVersion::V2c => Version::V2c,
239 };
240 ClientConfig {
241 version: snmp_version,
242 community: Bytes::copy_from_slice(community.as_bytes()),
243 timeout: self.timeout,
244 retries: self.retries,
245 max_oids_per_request: self.max_oids_per_request,
246 v3_security: None,
247 walk_mode: self.walk_mode,
248 oid_ordering: self.oid_ordering,
249 max_walk_results: self.max_walk_results,
250 max_repetitions: self.max_repetitions,
251 }
252 }
253 Auth::Usm(usm) => {
254 let mut security =
255 V3SecurityConfig::new(Bytes::copy_from_slice(usm.username.as_bytes()));
256
257 if let Some(ref master_keys) = usm.master_keys {
259 security = security.with_master_keys(master_keys.clone());
260 } else {
261 if let (Some(auth_proto), Some(auth_pass)) =
262 (usm.auth_protocol, &usm.auth_password)
263 {
264 security = security.auth(auth_proto, auth_pass.as_bytes().to_vec());
265 }
266
267 if let (Some(priv_proto), Some(priv_pass)) =
268 (usm.priv_protocol, &usm.priv_password)
269 {
270 security = security.privacy(priv_proto, priv_pass.as_bytes().to_vec());
271 }
272 }
273
274 ClientConfig {
275 version: Version::V3,
276 community: Bytes::new(),
277 timeout: self.timeout,
278 retries: self.retries,
279 max_oids_per_request: self.max_oids_per_request,
280 v3_security: Some(security),
281 walk_mode: self.walk_mode,
282 oid_ordering: self.oid_ordering,
283 max_walk_results: self.max_walk_results,
284 max_repetitions: self.max_repetitions,
285 }
286 }
287 }
288 }
289
290 fn build_inner<T: Transport>(self, transport: T) -> Client<T> {
292 let config = self.build_config();
293
294 if let Some(cache) = self.engine_cache {
295 Client::with_engine_cache(transport, config, cache)
296 } else {
297 Client::new(transport, config)
298 }
299 }
300
301 pub async fn connect(self) -> Result<Client<UdpTransport>, Error> {
307 self.validate()?;
308 let addr = self.resolve_target()?;
309 let transport = UdpTransport::connect(addr).await?;
310 Ok(self.build_inner(transport))
311 }
312
313 pub async fn connect_tcp(self) -> Result<Client<TcpTransport>, Error> {
319 self.validate()?;
320 let addr = self.resolve_target()?;
321 let transport = TcpTransport::connect(addr).await?;
322 Ok(self.build_inner(transport))
323 }
324
325 pub fn build<T: Transport>(self, transport: T) -> Result<Client<T>, Error> {
331 self.validate()?;
332 Ok(self.build_inner(transport))
333 }
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339 use crate::v3::{AuthProtocol, PrivProtocol};
340
341 #[test]
342 fn test_builder_defaults() {
343 let builder = ClientBuilder::new("192.168.1.1:161", Auth::default());
344 assert_eq!(builder.target, "192.168.1.1:161");
345 assert_eq!(builder.timeout, Duration::from_secs(5));
346 assert_eq!(builder.retries, 3);
347 assert_eq!(builder.max_oids_per_request, 10);
348 assert_eq!(builder.max_repetitions, 25);
349 assert_eq!(builder.walk_mode, WalkMode::Auto);
350 assert_eq!(builder.oid_ordering, OidOrdering::Strict);
351 assert!(builder.max_walk_results.is_none());
352 assert!(builder.engine_cache.is_none());
353 assert!(builder.context_engine_id.is_none());
354 }
355
356 #[test]
357 fn test_builder_with_options() {
358 let cache = Arc::new(EngineCache::new());
359 let builder = ClientBuilder::new("192.168.1.1:161", Auth::v2c("private"))
360 .timeout(Duration::from_secs(10))
361 .retries(5)
362 .max_oids_per_request(20)
363 .max_repetitions(50)
364 .walk_mode(WalkMode::GetNext)
365 .oid_ordering(OidOrdering::AllowNonIncreasing)
366 .max_walk_results(1000)
367 .engine_cache(cache.clone())
368 .context_engine_id(vec![0x80, 0x00, 0x01]);
369
370 assert_eq!(builder.timeout, Duration::from_secs(10));
371 assert_eq!(builder.retries, 5);
372 assert_eq!(builder.max_oids_per_request, 20);
373 assert_eq!(builder.max_repetitions, 50);
374 assert_eq!(builder.walk_mode, WalkMode::GetNext);
375 assert_eq!(builder.oid_ordering, OidOrdering::AllowNonIncreasing);
376 assert_eq!(builder.max_walk_results, Some(1000));
377 assert!(builder.engine_cache.is_some());
378 assert_eq!(builder.context_engine_id, Some(vec![0x80, 0x00, 0x01]));
379 }
380
381 #[test]
382 fn test_validate_community_ok() {
383 let builder = ClientBuilder::new("192.168.1.1:161", Auth::v2c("public"));
384 assert!(builder.validate().is_ok());
385 }
386
387 #[test]
388 fn test_validate_usm_no_auth_no_priv_ok() {
389 let builder = ClientBuilder::new("192.168.1.1:161", Auth::usm("readonly"));
390 assert!(builder.validate().is_ok());
391 }
392
393 #[test]
394 fn test_validate_usm_auth_no_priv_ok() {
395 let builder = ClientBuilder::new(
396 "192.168.1.1:161",
397 Auth::usm("admin").auth(AuthProtocol::Sha256, "authpass"),
398 );
399 assert!(builder.validate().is_ok());
400 }
401
402 #[test]
403 fn test_validate_usm_auth_priv_ok() {
404 let builder = ClientBuilder::new(
405 "192.168.1.1:161",
406 Auth::usm("admin")
407 .auth(AuthProtocol::Sha256, "authpass")
408 .privacy(PrivProtocol::Aes128, "privpass"),
409 );
410 assert!(builder.validate().is_ok());
411 }
412
413 #[test]
414 fn test_validate_priv_without_auth_error() {
415 let usm = crate::client::UsmAuth {
417 username: "user".to_string(),
418 auth_protocol: None,
419 auth_password: None,
420 priv_protocol: Some(PrivProtocol::Aes128),
421 priv_password: Some("privpass".to_string()),
422 context_name: None,
423 master_keys: None,
424 };
425 let builder = ClientBuilder::new("192.168.1.1:161", Auth::Usm(usm));
426 let err = builder.validate().unwrap_err();
427 assert!(
428 matches!(err, Error::Config(msg) if msg.contains("privacy requires authentication"))
429 );
430 }
431
432 #[test]
433 fn test_validate_auth_protocol_without_password_error() {
434 let usm = crate::client::UsmAuth {
436 username: "user".to_string(),
437 auth_protocol: Some(AuthProtocol::Sha256),
438 auth_password: None,
439 priv_protocol: None,
440 priv_password: None,
441 context_name: None,
442 master_keys: None,
443 };
444 let builder = ClientBuilder::new("192.168.1.1:161", Auth::Usm(usm));
445 let err = builder.validate().unwrap_err();
446 assert!(
447 matches!(err, Error::Config(msg) if msg.contains("auth protocol requires password"))
448 );
449 }
450
451 #[test]
452 fn test_validate_priv_protocol_without_password_error() {
453 let usm = crate::client::UsmAuth {
455 username: "user".to_string(),
456 auth_protocol: Some(AuthProtocol::Sha256),
457 auth_password: Some("authpass".to_string()),
458 priv_protocol: Some(PrivProtocol::Aes128),
459 priv_password: None,
460 context_name: None,
461 master_keys: None,
462 };
463 let builder = ClientBuilder::new("192.168.1.1:161", Auth::Usm(usm));
464 let err = builder.validate().unwrap_err();
465 assert!(
466 matches!(err, Error::Config(msg) if msg.contains("priv protocol requires password"))
467 );
468 }
469
470 #[test]
471 fn test_builder_with_usm_builder() {
472 let builder = ClientBuilder::new(
474 "192.168.1.1:161",
475 Auth::usm("admin").auth(AuthProtocol::Sha256, "pass"),
476 );
477 assert!(builder.validate().is_ok());
478 }
479}