use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use knx_rs_core::knxip::{HostProtocol, Hpai, KnxIpFrame, ServiceType};
use tokio::net::UdpSocket;
use tokio::time::{Duration, timeout};
use crate::error::KnxIpError;
use crate::router::{KNX_MULTICAST_ADDR, KNX_PORT};
const DISCOVERY_TIMEOUT: Duration = Duration::from_secs(3);
#[derive(Debug, Clone)]
pub struct GatewayInfo {
pub address: SocketAddr,
pub name: String,
pub individual_address: u16,
pub raw_body: Vec<u8>,
}
pub async fn discover(local_addr: Ipv4Addr) -> Result<Vec<GatewayInfo>, KnxIpError> {
discover_with_timeout(local_addr, DISCOVERY_TIMEOUT).await
}
pub async fn discover_with_timeout(
local_addr: Ipv4Addr,
duration: Duration,
) -> Result<Vec<GatewayInfo>, KnxIpError> {
let socket = UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0)).await?;
let local_port = socket.local_addr()?.port();
let hpai = Hpai {
protocol: HostProtocol::Ipv4Udp,
ip: if local_addr.is_unspecified() {
[0, 0, 0, 0]
} else {
local_addr.octets()
},
port: local_port,
};
let target = SocketAddr::V4(SocketAddrV4::new(KNX_MULTICAST_ADDR, KNX_PORT));
discover_on(socket, hpai, target, duration).await
}
pub async fn discover_v6(
interface: u32,
multicast: SocketAddrV6,
) -> Result<Vec<GatewayInfo>, KnxIpError> {
discover_v6_with_timeout(interface, multicast, DISCOVERY_TIMEOUT).await
}
pub async fn discover_v6_with_timeout(
interface: u32,
multicast: SocketAddrV6,
duration: Duration,
) -> Result<Vec<GatewayInfo>, KnxIpError> {
if !multicast.ip().is_multicast() {
return Err(KnxIpError::Protocol(format!(
"discovery target is not multicast: {multicast}"
)));
}
let scope_id = if interface == 0 {
multicast.scope_id()
} else {
interface
};
let socket = UdpSocket::bind(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, scope_id)).await?;
let local_port = socket.local_addr()?.port();
let hpai = Hpai::nat_udp(local_port);
let target = SocketAddr::V6(SocketAddrV6::new(
*multicast.ip(),
multicast.port(),
multicast.flowinfo(),
scope_id,
));
discover_on(socket, hpai, target, duration).await
}
async fn discover_on(
socket: UdpSocket,
hpai: Hpai,
target: SocketAddr,
duration: Duration,
) -> Result<Vec<GatewayInfo>, KnxIpError> {
let frame = KnxIpFrame {
service_type: ServiceType::SearchRequest,
body: hpai.to_bytes().to_vec(),
};
let bytes = frame
.try_to_bytes()
.map_err(|e| KnxIpError::Protocol(e.to_string()))?;
socket.send_to(&bytes, target).await?;
tracing::debug!("discovery search request sent");
let mut gateways = Vec::new();
let mut buf = [0u8; 512];
let deadline = tokio::time::Instant::now() + duration;
loop {
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
if remaining.is_zero() {
break;
}
match timeout(remaining, socket.recv_from(&mut buf)).await {
Ok(Ok((n, src))) => {
if let Some(info) = parse_search_response(&buf[..n], src) {
tracing::debug!(name = %info.name, addr = %info.address, "discovered gateway");
gateways.push(info);
}
}
Ok(Err(e)) => {
tracing::trace!(error = %e, "discovery recv error");
}
Err(_) => break, }
}
Ok(gateways)
}
fn parse_search_response(data: &[u8], src: SocketAddr) -> Option<GatewayInfo> {
let frame = KnxIpFrame::parse(data).ok()?;
if frame.service_type != ServiceType::SearchResponse {
return None;
}
let body = &frame.body;
let hpai = Hpai::parse(body)?;
let address = if hpai.is_unspecified() {
socket_addr_with_port(src, hpai.port)
} else {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::from(hpai.ip), hpai.port))
};
let (name, individual_address) = if body.len() >= 62 {
let dib = &body[usize::from(Hpai::LEN)..];
let ia = u16::from_be_bytes([dib[4], dib[5]]);
let name_bytes = &dib[22..52];
let name = core::str::from_utf8(name_bytes)
.unwrap_or("")
.trim_end_matches('\0')
.to_string();
(name, ia)
} else {
(String::new(), 0)
};
Some(GatewayInfo {
address,
name,
individual_address,
raw_body: frame.body.clone(),
})
}
const fn socket_addr_with_port(src: SocketAddr, port: u16) -> SocketAddr {
let port = if port == 0 { src.port() } else { port };
match src {
SocketAddr::V4(v4) => SocketAddr::V4(SocketAddrV4::new(*v4.ip(), port)),
SocketAddr::V6(v6) => SocketAddr::V6(SocketAddrV6::new(
*v6.ip(),
port,
v6.flowinfo(),
v6.scope_id(),
)),
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn parse_search_response_too_short() {
assert!(
parse_search_response(
&[0x06, 0x10, 0x02, 0x02, 0x00, 0x06],
"0.0.0.0:0".parse().unwrap()
)
.is_none()
);
}
}