use std::fmt;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use bytes::Bytes;
use crate::client::retry::Retry;
use crate::client::walk::{OidOrdering, WalkMode};
use crate::client::{
Auth, ClientConfig, CommunityVersion, DEFAULT_MAX_OIDS_PER_REQUEST, DEFAULT_MAX_REPETITIONS,
DEFAULT_TIMEOUT, UsmConfig,
};
use crate::error::{Error, Result};
use crate::transport::{TcpTransport, Transport, UdpHandle, UdpTransport};
use crate::v3::EngineCache;
use crate::version::Version;
use super::Client;
#[derive(Debug, Clone)]
pub enum Target {
Address(String),
HostPort(String, u16),
}
impl fmt::Display for Target {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Target::Address(addr) => f.write_str(addr),
Target::HostPort(host, port) => {
if host.contains(':') && !(host.starts_with('[') && host.ends_with(']')) {
write!(f, "[{}]:{}", host, port)
} else {
write!(f, "{}:{}", host, port)
}
}
}
}
}
impl From<&str> for Target {
fn from(s: &str) -> Self {
Target::Address(s.to_string())
}
}
impl From<String> for Target {
fn from(s: String) -> Self {
Target::Address(s)
}
}
impl From<&String> for Target {
fn from(s: &String) -> Self {
Target::Address(s.clone())
}
}
impl From<(&str, u16)> for Target {
fn from((host, port): (&str, u16)) -> Self {
Target::HostPort(host.to_string(), port)
}
}
impl From<(String, u16)> for Target {
fn from((host, port): (String, u16)) -> Self {
Target::HostPort(host, port)
}
}
impl From<SocketAddr> for Target {
fn from(addr: SocketAddr) -> Self {
Target::HostPort(addr.ip().to_string(), addr.port())
}
}
#[derive(Debug)]
pub struct ClientBuilder {
target: Target,
auth: Auth,
timeout: Duration,
retry: Retry,
max_oids_per_request: usize,
max_repetitions: u32,
walk_mode: WalkMode,
oid_ordering: OidOrdering,
max_walk_results: Option<usize>,
engine_cache: Option<Arc<EngineCache>>,
local_engine_id: Option<Vec<u8>>,
local_engine_boots: u32,
}
impl ClientBuilder {
pub fn new(target: impl Into<Target>, auth: impl Into<Auth>) -> Self {
Self {
target: target.into(),
auth: auth.into(),
timeout: DEFAULT_TIMEOUT,
retry: Retry::default(),
max_oids_per_request: DEFAULT_MAX_OIDS_PER_REQUEST,
max_repetitions: DEFAULT_MAX_REPETITIONS,
walk_mode: WalkMode::Auto,
oid_ordering: OidOrdering::Strict,
max_walk_results: None,
engine_cache: None,
local_engine_id: None,
local_engine_boots: 1,
}
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn retry(mut self, retry: impl Into<Retry>) -> Self {
self.retry = retry.into();
self
}
pub fn max_oids_per_request(mut self, max: usize) -> Self {
self.max_oids_per_request = max;
self
}
pub fn max_repetitions(mut self, max: u32) -> Self {
self.max_repetitions = max;
self
}
pub fn walk_mode(mut self, mode: WalkMode) -> Self {
self.walk_mode = mode;
self
}
pub fn oid_ordering(mut self, ordering: OidOrdering) -> Self {
self.oid_ordering = ordering;
self
}
pub fn max_walk_results(mut self, limit: usize) -> Self {
self.max_walk_results = Some(limit);
self
}
pub fn local_engine_id(mut self, engine_id: impl Into<Vec<u8>>) -> Self {
self.local_engine_id = Some(engine_id.into());
self
}
pub fn local_engine_boots(mut self, boots: u32) -> Self {
self.local_engine_boots = boots;
self
}
pub fn engine_cache(mut self, cache: Arc<EngineCache>) -> Self {
self.engine_cache = Some(cache);
self
}
fn validate(&self) -> Result<()> {
if self.max_oids_per_request == 0 {
return Err(
Error::Config("max_oids_per_request must be greater than 0".into()).boxed(),
);
}
if let Auth::Usm(usm) = &self.auth {
if usm.priv_protocol.is_some() && usm.auth_protocol.is_none() {
return Err(Error::Config("privacy requires authentication".into()).boxed());
}
if usm.auth_protocol.is_some()
&& usm.auth_password.is_none()
&& usm.master_keys.is_none()
{
return Err(Error::Config("auth protocol requires password".into()).boxed());
}
if usm.priv_protocol.is_some()
&& usm.priv_password.is_none()
&& usm.master_keys.is_none()
{
return Err(Error::Config("priv protocol requires password".into()).boxed());
}
}
if let Auth::Community {
version: CommunityVersion::V1,
..
} = &self.auth
&& self.walk_mode == WalkMode::GetBulk
{
return Err(Error::Config("GETBULK not supported in SNMPv1".into()).boxed());
}
if self.oid_ordering == OidOrdering::AllowNonIncreasing && self.max_walk_results.is_none() {
return Err(Error::Config(
"AllowNonIncreasing requires max_walk_results to bound memory usage".into(),
)
.boxed());
}
Ok(())
}
async fn resolve_target(&self) -> Result<SocketAddr> {
let (host, port) = match &self.target {
Target::Address(addr) => split_host_port(addr),
Target::HostPort(host, port) => (host.as_str(), *port),
};
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
return Ok(SocketAddr::new(ip, port));
}
let lookup = tokio::net::lookup_host((host, port));
let mut addrs = tokio::time::timeout(self.timeout, lookup)
.await
.map_err(|_| {
Error::Config(format!("DNS lookup timed out for '{}'", self.target).into()).boxed()
})?
.map_err(|e| {
Error::Config(format!("could not resolve address '{}': {}", self.target, e).into())
.boxed()
})?;
addrs.next().ok_or_else(|| {
Error::Config(format!("could not resolve address '{}'", self.target).into()).boxed()
})
}
fn build_config(&self) -> ClientConfig {
match &self.auth {
Auth::Community { version, community } => {
let snmp_version = match version {
CommunityVersion::V1 => Version::V1,
CommunityVersion::V2c => Version::V2c,
};
ClientConfig {
version: snmp_version,
community: Bytes::copy_from_slice(community.as_bytes()),
timeout: self.timeout,
retry: self.retry.clone(),
max_oids_per_request: self.max_oids_per_request,
v3_security: None,
walk_mode: self.walk_mode,
oid_ordering: self.oid_ordering,
max_walk_results: self.max_walk_results,
max_repetitions: self.max_repetitions,
local_engine_id: self
.local_engine_id
.as_ref()
.map(|id| Bytes::copy_from_slice(id)),
local_engine_boots: self.local_engine_boots,
}
}
Auth::Usm(usm) => {
let mut security = UsmConfig::new(Bytes::copy_from_slice(usm.username.as_bytes()));
if let Some(context_name) = &usm.context_name {
security =
security.context_name(Bytes::copy_from_slice(context_name.as_bytes()));
}
if let Some(ref master_keys) = usm.master_keys {
security = security.with_master_keys(master_keys.clone());
} else {
if let (Some(auth_proto), Some(auth_pass)) =
(usm.auth_protocol, &usm.auth_password)
{
security = security.auth(auth_proto, auth_pass.as_bytes());
}
if let (Some(priv_proto), Some(priv_pass)) =
(usm.priv_protocol, &usm.priv_password)
{
security = security.privacy(priv_proto, priv_pass.as_bytes());
}
}
ClientConfig {
version: Version::V3,
community: Bytes::new(),
timeout: self.timeout,
retry: self.retry.clone(),
max_oids_per_request: self.max_oids_per_request,
v3_security: Some(security),
walk_mode: self.walk_mode,
oid_ordering: self.oid_ordering,
max_walk_results: self.max_walk_results,
max_repetitions: self.max_repetitions,
local_engine_id: self
.local_engine_id
.as_ref()
.map(|id| Bytes::copy_from_slice(id)),
local_engine_boots: self.local_engine_boots,
}
}
}
}
fn build_inner<T: Transport>(self, transport: T) -> Client<T> {
let config = self.build_config();
if let Some(cache) = self.engine_cache {
Client::with_engine_cache(transport, config, cache)
} else {
Client::new(transport, config)
}
}
pub async fn connect(self) -> Result<Client<UdpHandle>> {
self.validate()?;
let addr = self.resolve_target().await?;
let bind_addr = if addr.is_ipv6() {
"[::]:0"
} else {
"0.0.0.0:0"
};
let transport = UdpTransport::bind(bind_addr).await?;
let handle = transport.handle(addr);
Ok(self.build_inner(handle))
}
pub async fn build_with(self, transport: &UdpTransport) -> Result<Client<UdpHandle>> {
self.validate()?;
let addr = self.resolve_target().await?;
let handle = transport.handle(addr);
Ok(self.build_inner(handle))
}
pub async fn connect_tcp(self) -> Result<Client<TcpTransport>> {
self.validate()?;
let addr = self.resolve_target().await?;
let transport = TcpTransport::connect(addr).await?;
Ok(self.build_inner(transport))
}
}
const DEFAULT_PORT: u16 = 161;
fn split_host_port(target: &str) -> (&str, u16) {
if let Some(rest) = target.strip_prefix('[') {
if let Some((addr, port)) = rest.rsplit_once("]:")
&& let Ok(p) = port.parse()
{
return (addr, p);
}
return (rest.trim_end_matches(']'), DEFAULT_PORT);
}
if let Some((host, port)) = target.rsplit_once(':')
&& !host.contains(':')
&& let Ok(p) = port.parse::<u16>()
{
return (host, p);
}
(target, DEFAULT_PORT)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::v3::{AuthProtocol, MasterKeys, PrivProtocol};
#[test]
fn test_builder_defaults() {
let builder = ClientBuilder::new("192.168.1.1:161", Auth::default());
assert!(matches!(builder.target, Target::Address(ref s) if s == "192.168.1.1:161"));
assert_eq!(builder.timeout, DEFAULT_TIMEOUT);
assert_eq!(builder.retry.max_attempts, 3);
assert_eq!(builder.max_oids_per_request, DEFAULT_MAX_OIDS_PER_REQUEST);
assert_eq!(builder.max_repetitions, DEFAULT_MAX_REPETITIONS);
assert_eq!(builder.walk_mode, WalkMode::Auto);
assert_eq!(builder.oid_ordering, OidOrdering::Strict);
assert!(builder.max_walk_results.is_none());
assert!(builder.engine_cache.is_none());
}
#[test]
fn test_builder_with_options() {
let cache = Arc::new(EngineCache::new());
let builder = ClientBuilder::new("192.168.1.1:161", Auth::v2c("private"))
.timeout(Duration::from_secs(10))
.retry(Retry::fixed(5, Duration::ZERO))
.max_oids_per_request(20)
.max_repetitions(50)
.walk_mode(WalkMode::GetNext)
.oid_ordering(OidOrdering::AllowNonIncreasing)
.max_walk_results(1000)
.engine_cache(cache.clone());
assert_eq!(builder.timeout, Duration::from_secs(10));
assert_eq!(builder.retry.max_attempts, 5);
assert_eq!(builder.max_oids_per_request, 20);
assert_eq!(builder.max_repetitions, 50);
assert_eq!(builder.walk_mode, WalkMode::GetNext);
assert_eq!(builder.oid_ordering, OidOrdering::AllowNonIncreasing);
assert_eq!(builder.max_walk_results, Some(1000));
assert!(builder.engine_cache.is_some());
}
#[test]
fn test_validate_community_ok() {
let builder = ClientBuilder::new("192.168.1.1:161", Auth::v2c("public"));
assert!(builder.validate().is_ok());
}
#[test]
fn test_validate_zero_max_oids_per_request_error() {
let builder =
ClientBuilder::new("192.168.1.1:161", Auth::v2c("public")).max_oids_per_request(0);
let err = builder.validate().unwrap_err();
assert!(matches!(
*err,
Error::Config(ref msg) if msg.contains("max_oids_per_request must be greater than 0")
));
}
#[test]
fn test_validate_usm_no_auth_no_priv_ok() {
let builder = ClientBuilder::new("192.168.1.1:161", Auth::usm("readonly"));
assert!(builder.validate().is_ok());
}
#[test]
fn test_validate_usm_auth_no_priv_ok() {
let builder = ClientBuilder::new(
"192.168.1.1:161",
Auth::usm("admin").auth(AuthProtocol::Sha256, "authpass"),
);
assert!(builder.validate().is_ok());
}
#[test]
fn test_validate_usm_auth_priv_ok() {
let builder = ClientBuilder::new(
"192.168.1.1:161",
Auth::usm("admin")
.auth(AuthProtocol::Sha256, "authpass")
.privacy(PrivProtocol::Aes128, "privpass"),
);
assert!(builder.validate().is_ok());
}
#[test]
fn test_validate_priv_without_auth_error() {
let usm = crate::client::UsmAuth {
username: "user".to_string(),
auth_protocol: None,
auth_password: None,
priv_protocol: Some(PrivProtocol::Aes128),
priv_password: Some("privpass".to_string()),
context_name: None,
master_keys: None,
};
let builder = ClientBuilder::new("192.168.1.1:161", Auth::Usm(usm));
let err = builder.validate().unwrap_err();
assert!(
matches!(*err, Error::Config(ref msg) if msg.contains("privacy requires authentication"))
);
}
#[test]
fn test_validate_auth_protocol_without_password_error() {
let usm = crate::client::UsmAuth {
username: "user".to_string(),
auth_protocol: Some(AuthProtocol::Sha256),
auth_password: None,
priv_protocol: None,
priv_password: None,
context_name: None,
master_keys: None,
};
let builder = ClientBuilder::new("192.168.1.1:161", Auth::Usm(usm));
let err = builder.validate().unwrap_err();
assert!(
matches!(*err, Error::Config(ref msg) if msg.contains("auth protocol requires password"))
);
}
#[test]
fn test_validate_priv_protocol_without_password_error() {
let usm = crate::client::UsmAuth {
username: "user".to_string(),
auth_protocol: Some(AuthProtocol::Sha256),
auth_password: Some("authpass".to_string()),
priv_protocol: Some(PrivProtocol::Aes128),
priv_password: None,
context_name: None,
master_keys: None,
};
let builder = ClientBuilder::new("192.168.1.1:161", Auth::Usm(usm));
let err = builder.validate().unwrap_err();
assert!(
matches!(*err, Error::Config(ref msg) if msg.contains("priv protocol requires password"))
);
}
#[test]
fn test_builder_with_usm_builder() {
let builder = ClientBuilder::new(
"192.168.1.1:161",
Auth::usm("admin").auth(AuthProtocol::Sha256, "pass"),
);
assert!(builder.validate().is_ok());
}
#[test]
fn test_validate_master_keys_bypass_auth_password() {
let master_keys = MasterKeys::new(AuthProtocol::Sha256, b"authpass").unwrap();
let usm = crate::client::UsmAuth {
username: "user".to_string(),
auth_protocol: Some(AuthProtocol::Sha256),
auth_password: None, priv_protocol: None,
priv_password: None,
context_name: None,
master_keys: Some(master_keys),
};
let builder = ClientBuilder::new("192.168.1.1:161", Auth::Usm(usm));
assert!(builder.validate().is_ok());
}
#[test]
fn test_validate_master_keys_bypass_priv_password() {
let master_keys = MasterKeys::new(AuthProtocol::Sha256, b"authpass")
.unwrap()
.with_privacy(PrivProtocol::Aes128, b"privpass")
.unwrap();
let usm = crate::client::UsmAuth {
username: "user".to_string(),
auth_protocol: Some(AuthProtocol::Sha256),
auth_password: None, priv_protocol: Some(PrivProtocol::Aes128),
priv_password: None, context_name: None,
master_keys: Some(master_keys),
};
let builder = ClientBuilder::new("192.168.1.1:161", Auth::Usm(usm));
assert!(builder.validate().is_ok());
}
#[test]
fn test_build_config_preserves_v3_context_name() {
let builder = ClientBuilder::new(
"192.168.1.1:161",
Auth::usm("admin")
.auth(AuthProtocol::Sha256, "authpass")
.context_name("vlan100"),
);
let config = builder.build_config();
let security = config
.v3_security
.expect("expected v3 security config to be built");
assert_eq!(security.context_name.as_ref(), b"vlan100");
}
#[test]
fn test_builder_with_host_port_tuple() {
let builder = ClientBuilder::new(("fe80::1", 161), Auth::default());
assert!(matches!(
builder.target,
Target::HostPort(ref h, 161) if h == "fe80::1"
));
}
#[test]
fn test_builder_with_string_host_port_tuple() {
let builder = ClientBuilder::new(("switch.local".to_string(), 162), Auth::v2c("public"));
assert!(matches!(
builder.target,
Target::HostPort(ref h, 162) if h == "switch.local"
));
}
#[test]
fn test_target_from_str() {
let t: Target = "192.168.1.1:161".into();
assert!(matches!(t, Target::Address(ref s) if s == "192.168.1.1:161"));
}
#[test]
fn test_target_from_tuple() {
let t: Target = ("fe80::1", 161).into();
assert!(matches!(t, Target::HostPort(ref h, 161) if h == "fe80::1"));
}
#[test]
fn test_target_from_socket_addr() {
let addr: SocketAddr = "192.168.1.1:162".parse().unwrap();
let t: Target = addr.into();
assert!(matches!(t, Target::HostPort(ref h, 162) if h == "192.168.1.1"));
}
#[test]
fn test_target_display() {
let t: Target = "192.168.1.1:161".into();
assert_eq!(t.to_string(), "192.168.1.1:161");
let t: Target = ("fe80::1", 161).into();
assert_eq!(t.to_string(), "[fe80::1]:161");
let addr: SocketAddr = "[::1]:162".parse().unwrap();
let t: Target = addr.into();
assert_eq!(t.to_string(), "[::1]:162");
}
#[tokio::test]
async fn test_resolve_target_socket_addr() {
let addr: SocketAddr = "10.0.0.1:162".parse().unwrap();
let builder = ClientBuilder::new(addr, Auth::default());
let resolved = builder.resolve_target().await.unwrap();
assert_eq!(resolved, addr);
}
#[tokio::test]
async fn test_resolve_target_host_port_ipv4() {
let builder = ClientBuilder::new(("192.168.1.1", 162), Auth::default());
let addr = builder.resolve_target().await.unwrap();
assert_eq!(addr, "192.168.1.1:162".parse().unwrap());
}
#[tokio::test]
async fn test_resolve_target_host_port_ipv6() {
let builder = ClientBuilder::new(("::1", 161), Auth::default());
let addr = builder.resolve_target().await.unwrap();
assert_eq!(addr, "[::1]:161".parse().unwrap());
}
#[tokio::test]
async fn test_resolve_target_string_still_works() {
let builder = ClientBuilder::new("10.0.0.1:162", Auth::default());
let addr = builder.resolve_target().await.unwrap();
assert_eq!(addr, "10.0.0.1:162".parse().unwrap());
}
#[test]
fn test_split_host_port_ipv4_with_port() {
assert_eq!(split_host_port("192.168.1.1:162"), ("192.168.1.1", 162));
}
#[test]
fn test_split_host_port_ipv4_default() {
assert_eq!(split_host_port("192.168.1.1"), ("192.168.1.1", 161));
}
#[test]
fn test_split_host_port_ipv6_bare() {
assert_eq!(split_host_port("fe80::1"), ("fe80::1", 161));
}
#[test]
fn test_split_host_port_ipv6_loopback() {
assert_eq!(split_host_port("::1"), ("::1", 161));
}
#[test]
fn test_split_host_port_ipv6_bracketed_with_port() {
assert_eq!(split_host_port("[fe80::1]:162"), ("fe80::1", 162));
}
#[test]
fn test_split_host_port_ipv6_bracketed_default() {
assert_eq!(split_host_port("[::1]"), ("::1", 161));
}
#[test]
fn test_split_host_port_hostname() {
assert_eq!(split_host_port("switch.local"), ("switch.local", 161));
}
#[test]
fn test_split_host_port_hostname_with_port() {
assert_eq!(split_host_port("switch.local:162"), ("switch.local", 162));
}
}