1use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
10
11use knx_rs_core::cemi::CemiFrame;
12use knx_rs_core::knxip::{KnxIpFrame, ServiceType};
13use tokio::net::UdpSocket;
14use tokio::sync::mpsc;
15use tokio::time::{Duration, Instant};
16
17use crate::error::KnxIpError;
18use crate::{KnxConnection, KnxFuture};
19
20pub const KNX_MULTICAST_ADDR: Ipv4Addr = Ipv4Addr::new(224, 0, 23, 12);
22
23pub const KNX_PORT: u16 = 3671;
25
26const MAX_PACKETS_PER_SEC: u32 = 50;
28
29pub struct RouterConnection {
31 rx: mpsc::Receiver<CemiFrame>,
32 tx_cmd: mpsc::Sender<RouterCmd>,
33}
34
35enum RouterCmd {
36 Send(
37 CemiFrame,
38 tokio::sync::oneshot::Sender<Result<(), KnxIpError>>,
39 ),
40 Close,
41}
42
43impl RouterConnection {
44 pub async fn connect(
53 local_addr: Ipv4Addr,
54 multicast: SocketAddrV4,
55 ) -> Result<Self, KnxIpError> {
56 Self::connect_v4(local_addr, multicast).await
57 }
58
59 pub async fn connect_v4(
65 local_addr: Ipv4Addr,
66 multicast: SocketAddrV4,
67 ) -> Result<Self, KnxIpError> {
68 if !multicast.ip().is_multicast() {
69 return Err(KnxIpError::Protocol(format!(
70 "router target is not multicast: {multicast}"
71 )));
72 }
73 let bind_addr = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, multicast.port());
74 let socket = UdpSocket::bind(bind_addr).await?;
75
76 socket
77 .join_multicast_v4(*multicast.ip(), local_addr)
78 .map_err(|e| KnxIpError::Protocol(format!("join multicast {}: {e}", multicast.ip())))?;
79
80 socket.set_multicast_loop_v4(false).ok();
81 Ok(Self::spawn(socket, SocketAddr::V4(multicast)))
82 }
83
84 pub async fn connect_v6(interface: u32, multicast: SocketAddrV6) -> Result<Self, KnxIpError> {
93 if !multicast.ip().is_multicast() {
94 return Err(KnxIpError::Protocol(format!(
95 "router target is not multicast: {multicast}"
96 )));
97 }
98 let interface = if interface == 0 {
99 multicast.scope_id()
100 } else {
101 interface
102 };
103 let bind_addr = SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, multicast.port(), 0, interface);
104 let socket = UdpSocket::bind(bind_addr).await?;
105
106 socket
107 .join_multicast_v6(multicast.ip(), interface)
108 .map_err(|e| KnxIpError::Protocol(format!("join multicast {}: {e}", multicast.ip())))?;
109
110 socket.set_multicast_loop_v6(false).ok();
111 Ok(Self::spawn(socket, SocketAddr::V6(multicast)))
112 }
113
114 pub async fn connect_multicast(multicast: SocketAddr) -> Result<Self, KnxIpError> {
123 match multicast {
124 SocketAddr::V4(v4) => Self::connect_v4(Ipv4Addr::UNSPECIFIED, v4).await,
125 SocketAddr::V6(v6) => Self::connect_v6(v6.scope_id(), v6).await,
126 }
127 }
128
129 pub async fn connect_default(local_addr: Ipv4Addr) -> Result<Self, KnxIpError> {
135 Self::connect(local_addr, SocketAddrV4::new(KNX_MULTICAST_ADDR, KNX_PORT)).await
136 }
137
138 fn spawn(socket: UdpSocket, target: SocketAddr) -> Self {
139 tracing::info!(%target, "KNXnet/IP router joined multicast");
140
141 let (cemi_tx, cemi_rx) = mpsc::channel(64);
142 let (cmd_tx, cmd_rx) = mpsc::channel(16);
143
144 tokio::spawn(router_task(socket, target, cemi_tx, cmd_rx));
145
146 Self {
147 rx: cemi_rx,
148 tx_cmd: cmd_tx,
149 }
150 }
151}
152
153impl KnxConnection for RouterConnection {
154 fn send(&self, frame: CemiFrame) -> KnxFuture<'_, Result<(), KnxIpError>> {
155 let tx_cmd = self.tx_cmd.clone();
156 Box::pin(async move {
157 let (tx, rx) = tokio::sync::oneshot::channel();
158 tx_cmd
159 .send(RouterCmd::Send(frame, tx))
160 .await
161 .map_err(|_| KnxIpError::Closed)?;
162 rx.await.map_err(|_| KnxIpError::Closed)?
163 })
164 }
165
166 fn recv(&mut self) -> KnxFuture<'_, Option<CemiFrame>> {
167 Box::pin(async move { self.rx.recv().await })
168 }
169
170 fn close(&mut self) -> KnxFuture<'_, ()> {
171 let tx_cmd = self.tx_cmd.clone();
172 Box::pin(async move {
173 let _ = tx_cmd.send(RouterCmd::Close).await;
174 })
175 }
176}
177
178struct RateLimiter {
182 timestamps: std::collections::VecDeque<Instant>,
183 max_per_sec: u32,
184}
185
186impl RateLimiter {
187 fn new(max_per_sec: u32) -> Self {
188 Self {
189 timestamps: std::collections::VecDeque::with_capacity(max_per_sec as usize),
190 max_per_sec,
191 }
192 }
193
194 fn check(&mut self) -> Option<Duration> {
196 let now = Instant::now();
197 let window_start = now - Duration::from_secs(1);
198
199 while self.timestamps.front().is_some_and(|&t| t < window_start) {
201 self.timestamps.pop_front();
202 }
203
204 if self.timestamps.len() < self.max_per_sec as usize {
205 self.timestamps.push_back(now);
206 None } else {
208 self.timestamps
210 .front()
211 .map(|&oldest| (oldest + Duration::from_secs(1)) - now)
212 }
213 }
214
215 fn pause(&mut self, duration: Duration) {
217 let future = Instant::now() + duration;
219 self.timestamps.clear();
220 for _ in 0..self.max_per_sec {
221 self.timestamps.push_back(future);
222 }
223 }
224}
225
226async fn router_task(
229 socket: UdpSocket,
230 target: SocketAddr,
231 cemi_tx: mpsc::Sender<CemiFrame>,
232 mut cmd_rx: mpsc::Receiver<RouterCmd>,
233) {
234 let mut buf = [0u8; 1024];
235 let mut rate_limiter = RateLimiter::new(MAX_PACKETS_PER_SEC);
236
237 loop {
238 tokio::select! {
239 result = socket.recv_from(&mut buf) => {
240 let (n, _src) = match result {
241 Ok(r) => r,
242 Err(e) => {
243 tracing::warn!(error = %e, "router recv error");
244 break;
245 }
246 };
247 handle_routing_indication(&buf[..n], &cemi_tx, &mut rate_limiter).await;
248 }
249
250 cmd = cmd_rx.recv() => {
251 match cmd {
252 Some(RouterCmd::Send(cemi, reply)) => {
253 let result = rate_limited_send(
254 &socket, &target, &cemi, &mut rate_limiter,
255 ).await;
256 let _ = reply.send(result);
257 }
258 Some(RouterCmd::Close) | None => break,
259 }
260 }
261 }
262 }
263
264 tracing::debug!("router task ended");
265}
266
267async fn rate_limited_send(
268 socket: &UdpSocket,
269 target: &SocketAddr,
270 cemi: &CemiFrame,
271 limiter: &mut RateLimiter,
272) -> Result<(), KnxIpError> {
273 if let Some(wait) = limiter.check() {
275 tracing::debug!(wait_ms = wait.as_millis(), "rate limit: waiting");
276 tokio::time::sleep(wait).await;
277 if let Some(extra_wait) = limiter.check() {
279 tokio::time::sleep(extra_wait).await;
280 let _ = limiter.check(); }
282 }
283
284 let frame = KnxIpFrame {
285 service_type: ServiceType::RoutingIndication,
286 body: cemi.as_bytes().to_vec(),
287 };
288 let bytes = frame
289 .try_to_bytes()
290 .map_err(|e| KnxIpError::Protocol(e.to_string()))?;
291 socket.send_to(&bytes, target).await?;
292 Ok(())
293}
294
295async fn handle_routing_indication(
296 data: &[u8],
297 cemi_tx: &mpsc::Sender<CemiFrame>,
298 rate_limiter: &mut RateLimiter,
299) {
300 let frame = match KnxIpFrame::parse(data) {
301 Ok(f) => f,
302 Err(e) => {
303 tracing::trace!(error = %e, "ignoring malformed frame");
304 return;
305 }
306 };
307
308 match frame.service_type {
309 ServiceType::RoutingIndication => {
310 if let Ok(cemi) = CemiFrame::parse(&frame.body) {
311 let _ = cemi_tx.send(cemi).await;
312 }
313 }
314 ServiceType::RoutingBusy => {
315 let wait_ms = if frame.body.len() >= 6 {
317 u16::from_be_bytes([frame.body[4], frame.body[5]])
318 } else {
319 50 };
321 tracing::debug!(wait_ms, "received RoutingBusy, pausing sends");
322 rate_limiter.pause(Duration::from_millis(u64::from(wait_ms)));
324 }
325 _ => {}
326 }
327}
328
329#[cfg(test)]
330#[allow(clippy::unwrap_used, clippy::expect_used)]
331mod tests {
332 use super::*;
333
334 #[test]
335 fn rate_limiter_allows_within_limit() {
336 let mut limiter = RateLimiter::new(3);
337 assert!(limiter.check().is_none());
338 assert!(limiter.check().is_none());
339 assert!(limiter.check().is_none());
340 assert!(limiter.check().is_some());
342 }
343}