1use std::net::SocketAddr;
7use std::sync::RwLock;
8use std::time::Duration;
9
10use bytes::Bytes;
11use tokio::sync::Mutex as AsyncMutex;
12
13use crate::client::{Auth, Client, ClientConfig, CommunityVersion, Retry, UsmAuth};
14use crate::error::{Error, Result};
15use crate::message::CommunityMessage;
16use crate::notification::{DerivedKeys, UsmConfig};
17use crate::oid::Oid;
18use crate::pdu::Pdu;
19use crate::transport::{UdpHandle, UdpTransport};
20use crate::v3::compute_engine_boots_time;
21use crate::varbind::VarBind;
22use crate::version::Version;
23
24pub(crate) struct TrapSink {
29 pub(crate) dest: SocketAddr,
30 pub(crate) version: Version,
31 pub(crate) community: Bytes,
32 pub(crate) v3_security: Option<UsmConfig>,
33 pub(crate) derived_keys: RwLock<Option<DerivedKeys>>,
36 inform_timeout: Duration,
38 inform_retry: Retry,
39 inform_client: AsyncMutex<Option<(UdpTransport, Client<UdpHandle>)>>,
42}
43
44impl TrapSink {
45 pub(crate) fn new(
47 dest: SocketAddr,
48 auth: Auth,
49 inform_timeout: Duration,
50 inform_retry: Retry,
51 ) -> Self {
52 match auth {
53 Auth::Community { version, community } => {
54 let snmp_version = match version {
55 CommunityVersion::V1 => Version::V1,
56 CommunityVersion::V2c => Version::V2c,
57 };
58 TrapSink {
59 dest,
60 version: snmp_version,
61 community: Bytes::copy_from_slice(community.as_bytes()),
62 v3_security: None,
63 derived_keys: RwLock::new(None),
64 inform_timeout,
65 inform_retry,
66 inform_client: AsyncMutex::new(None),
67 }
68 }
69 Auth::Usm(usm) => {
70 let security = resolve_usm_config(&usm);
71 TrapSink {
72 dest,
73 version: Version::V3,
74 community: Bytes::new(),
75 v3_security: Some(security),
76 derived_keys: RwLock::new(None),
77 inform_timeout,
78 inform_retry,
79 inform_client: AsyncMutex::new(None),
80 }
81 }
82 }
83 }
84
85 fn ensure_keys_derived(&self, engine_id: &[u8]) -> Result<()> {
87 {
88 let keys = self.derived_keys.read().map_err(|_| {
89 Error::Config("trap sink derived_keys lock poisoned".into()).boxed()
90 })?;
91 if keys.is_some() {
92 return Ok(());
93 }
94 }
95
96 let security = self.v3_security.as_ref().ok_or_else(|| {
97 Error::Config("V3 security not configured for trap sink".into()).boxed()
98 })?;
99
100 let keys = security
101 .derive_keys(engine_id)
102 .map_err(|e| Error::Config(e.to_string().into()).boxed())?;
103
104 let mut derived = self
105 .derived_keys
106 .write()
107 .map_err(|_| Error::Config("trap sink derived_keys lock poisoned".into()).boxed())?;
108 *derived = Some(keys);
109
110 Ok(())
111 }
112
113 async fn get_or_create_inform_client(&self) -> Result<Client<UdpHandle>> {
115 let mut guard = self.inform_client.lock().await;
116 if let Some((_, ref client)) = *guard {
117 return Ok(client.clone());
118 }
119
120 let config = match self.version {
121 Version::V1 => unreachable!("v1 does not support informs"),
122 Version::V2c => ClientConfig {
123 version: Version::V2c,
124 community: self.community.clone(),
125 timeout: self.inform_timeout,
126 retry: self.inform_retry.clone(),
127 v3_security: None,
128 ..ClientConfig::default()
129 },
130 Version::V3 => ClientConfig {
131 version: Version::V3,
132 community: Bytes::new(),
133 timeout: self.inform_timeout,
134 retry: self.inform_retry.clone(),
135 v3_security: self.v3_security.clone(),
136 ..ClientConfig::default()
137 },
138 };
139
140 let bind_addr = if self.dest.is_ipv6() {
141 "[::]:0"
142 } else {
143 "0.0.0.0:0"
144 };
145 let transport = UdpTransport::bind(bind_addr).await?;
146 let handle = transport.handle(self.dest);
147 let client = Client::new(handle, config);
148 *guard = Some((transport, client.clone()));
149 Ok(client)
150 }
151}
152
153fn resolve_usm_config(usm: &UsmAuth) -> UsmConfig {
155 let mut security = UsmConfig::new(Bytes::copy_from_slice(usm.username.as_bytes()));
156 if let Some(context_name) = &usm.context_name {
157 security = security.context_name(Bytes::copy_from_slice(context_name.as_bytes()));
158 }
159
160 if let Some(ref master_keys) = usm.master_keys {
161 security = security.with_master_keys(master_keys.clone());
162 } else {
163 if let (Some(auth_proto), Some(auth_pass)) = (usm.auth_protocol, &usm.auth_password) {
164 security = security.auth(auth_proto, auth_pass.as_bytes());
165 }
166 if let (Some(priv_proto), Some(priv_pass)) = (usm.priv_protocol, &usm.priv_password) {
167 security = security.privacy(priv_proto, priv_pass.as_bytes());
168 }
169 }
170
171 security
172}
173
174impl super::Agent {
175 pub async fn send_trap(
202 &self,
203 trap_oid: &Oid,
204 uptime: u32,
205 varbinds: Vec<VarBind>,
206 ) -> Result<()> {
207 let sinks = &self.inner.trap_sinks;
208 if sinks.is_empty() {
209 return Ok(());
210 }
211
212 let request_id = self.next_notification_id();
213 let pdu = Pdu::trap_v2(request_id, uptime, trap_oid, varbinds);
214
215 for sink in sinks {
216 if let Err(e) = self.send_trap_to_sink(sink, &pdu).await {
217 tracing::warn!(target: "async_snmp::agent", { snmp.dest = %sink.dest, error = %e }, "failed to send trap");
218 }
219 }
220
221 Ok(())
222 }
223
224 pub async fn send_inform(
251 &self,
252 trap_oid: &Oid,
253 uptime: u32,
254 varbinds: Vec<VarBind>,
255 ) -> Result<()> {
256 let sinks = &self.inner.trap_sinks;
257 if sinks.is_empty() {
258 return Ok(());
259 }
260
261 for sink in sinks {
262 if sink.version == Version::V1 {
263 continue;
264 }
265
266 if let Err(e) = self
267 .send_inform_to_sink(sink, trap_oid, uptime, &varbinds)
268 .await
269 {
270 tracing::warn!(target: "async_snmp::agent", { snmp.dest = %sink.dest, error = %e }, "failed to send inform");
271 }
272 }
273
274 Ok(())
275 }
276
277 async fn send_trap_to_sink(&self, sink: &TrapSink, pdu: &Pdu) -> Result<()> {
279 let data = match sink.version {
280 Version::V1 => {
281 let local_ip = match self.inner.socket.local_addr() {
284 Ok(addr) => match addr.ip() {
285 std::net::IpAddr::V4(v4) => v4.octets(),
286 std::net::IpAddr::V6(_) => [0, 0, 0, 0],
287 },
288 Err(_) => [0, 0, 0, 0],
289 };
290 let trap = pdu.to_v1_trap(local_ip).ok_or_else(|| {
291 Error::Config("cannot convert trap to v1 for sink (Counter64 varbind?)".into())
292 .boxed()
293 })?;
294 let msg = CommunityMessage::v1_trap(sink.community.clone(), trap);
295 msg.encode()
296 }
297 Version::V2c => {
298 let msg = CommunityMessage::new(Version::V2c, sink.community.clone(), pdu.clone());
299 msg.encode()
300 }
301 Version::V3 => {
302 let security = sink.v3_security.as_ref().ok_or_else(|| {
303 Error::Config("V3 security not configured for trap sink".into()).boxed()
304 })?;
305
306 sink.ensure_keys_derived(&self.inner.state.engine_id)?;
307 let derived = sink.derived_keys.read().map_err(|_| {
308 Error::Config("trap sink derived_keys lock poisoned".into()).boxed()
309 })?;
310
311 let elapsed_secs = self.inner.state.engine_start.elapsed().as_secs();
312 let (engine_boots, engine_time) =
313 compute_engine_boots_time(self.inner.state.engine_boots_base, elapsed_secs);
314
315 let msg_id = self.next_notification_id();
316 let encoded = crate::v3::encode::encode_v3_message(
317 pdu,
318 msg_id,
319 &self.inner.state.engine_id,
320 engine_boots,
321 engine_time,
322 security,
323 derived.as_ref(),
324 &self.inner.salt_counter,
325 false, crate::v3::DEFAULT_MSG_MAX_SIZE,
327 )?;
328 Bytes::from(encoded)
329 }
330 };
331
332 tracing::debug!(target: "async_snmp::agent", { snmp.dest = %sink.dest, snmp.bytes = data.len() }, "sending trap");
333 self.inner
334 .socket
335 .send_to(&data, sink.dest)
336 .await
337 .map_err(|e| Error::Network {
338 target: sink.dest,
339 source: e,
340 })?;
341
342 Ok(())
343 }
344
345 async fn send_inform_to_sink(
347 &self,
348 sink: &TrapSink,
349 trap_oid: &Oid,
350 uptime: u32,
351 varbinds: &[VarBind],
352 ) -> Result<()> {
353 let client = sink.get_or_create_inform_client().await?;
354 client
355 .send_inform(trap_oid, uptime, varbinds.to_vec())
356 .await?;
357
358 Ok(())
359 }
360
361 fn next_notification_id(&self) -> i32 {
363 use std::sync::atomic::Ordering;
364 static COUNTER: std::sync::atomic::AtomicI32 = std::sync::atomic::AtomicI32::new(1);
365 COUNTER
366 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |v| {
367 Some(if v == i32::MAX { 1 } else { v + 1 })
368 })
369 .unwrap_or(1)
370 }
371}