use std::{
error::Error,
fmt::Display,
io::Write,
net::{IpAddr, Ipv4Addr},
str::Utf8Error,
time::Duration,
};
use log::{debug, trace};
use tokio::net::UdpSocket;
#[derive(Clone, Debug, PartialEq)]
pub struct Console {
pub address: IpAddr,
pub name: String,
pub airtouch_id: u32,
pub console_id: String,
}
#[derive(Debug)]
pub enum DiscoveryError {
IoError(std::io::Error),
EncodingError(Utf8Error),
ProtocolError(String),
AddressError { reported: IpAddr, actual: IpAddr },
NoResponse,
}
impl Display for DiscoveryError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl Error for DiscoveryError {}
impl From<std::io::Error> for DiscoveryError {
fn from(value: std::io::Error) -> Self {
Self::IoError(value)
}
}
impl From<Utf8Error> for DiscoveryError {
fn from(value: Utf8Error) -> Self {
Self::EncodingError(value)
}
}
#[cfg(feature = "timeout")]
impl From<tokio::time::error::Elapsed> for DiscoveryError {
fn from(_: tokio::time::error::Elapsed) -> Self {
Self::NoResponse
}
}
const DISCOVERY_REQUEST: &str = "::REQUEST-POLYAIRE-AIRTOUCH-DEVICE-INFO:;";
const DISCOVERY_PORT: u16 = 49005;
#[cfg(not(test))]
const LISTEN_PORT: u16 = DISCOVERY_PORT;
pub async fn discover() -> Result<Console, DiscoveryError> {
let sock = create_socket(LISTEN_PORT).await?;
debug!("discover listening on {}", sock.local_addr()?);
for a in broadcast_addresses()? {
debug!("discover sending req to {}... ", a);
std::io::stderr().flush()?;
if sock
.send_to(DISCOVERY_REQUEST.as_bytes(), (a, DISCOVERY_PORT))
.await?
!= DISCOVERY_REQUEST.len()
{
return Err(DiscoveryError::IoError(std::io::Error::other("short send")));
}
trace!("discover send done")
}
let mut buf = vec![0u8; 4096];
loop {
trace!("discover waiting for reply...");
let (len, sockaddr) = sock.recv_from(&mut buf).await?;
let reply = std::str::from_utf8(&buf[..len])?;
debug!("discover received {:?}... ", reply);
match reply.splitn(5, ',').collect::<Vec<&str>>()[..] {
[addr, conid, "AirTouch5", atid, name]
if addr.parse::<IpAddr>().is_ok_and(|a| a == sockaddr.ip())
&& atid.parse::<u32>().is_ok() =>
{
return Ok(Console {
address: addr.parse().unwrap(),
console_id: conid.to_string(),
airtouch_id: atid.parse().unwrap(),
name: name.to_string(),
})
}
[addr, _, "AirTouch5", atid, _name]
if addr.parse::<IpAddr>().is_ok() && atid.parse::<u32>().is_ok() =>
{
return Err(DiscoveryError::AddressError {
reported: addr.parse().unwrap(),
actual: sockaddr.ip(),
})
}
[req] if req == DISCOVERY_REQUEST => continue,
_ => return Err(DiscoveryError::ProtocolError(reply.to_string())),
}
}
}
#[cfg(feature = "timeout")]
pub async fn discover_timeout(duration: Option<Duration>) -> Result<Console, DiscoveryError> {
tokio::time::timeout(duration.unwrap_or(Duration::from_secs(1)), discover()).await?
}
async fn create_socket(port: u16) -> std::io::Result<UdpSocket> {
let socket = UdpSocket::bind((Ipv4Addr::UNSPECIFIED, port)).await?;
socket.set_broadcast(true)?;
Ok(socket)
}
fn broadcast_addresses_real() -> Result<Vec<IpAddr>, DiscoveryError> {
use network_interface::{Addr, NetworkInterface, NetworkInterfaceConfig};
let addrs: Vec<IpAddr> = NetworkInterface::show()
.map_err(|x| DiscoveryError::IoError(std::io::Error::other(x.to_string())))?
.iter()
.flat_map(|iface| iface.addr.iter())
.filter_map(|addr| match addr {
Addr::V4(a) if a.ip != Ipv4Addr::LOCALHOST => a.broadcast,
_ => None,
})
.map(IpAddr::V4)
.collect();
if addrs.is_empty() {
Err(DiscoveryError::IoError(std::io::Error::other(
"no broadcast interfaces",
)))
} else {
Ok(addrs)
}
}
#[cfg(not(test))]
fn broadcast_addresses() -> Result<Vec<IpAddr>, DiscoveryError> {
broadcast_addresses_real()
}
#[cfg(test)]
fn broadcast_addresses() -> Result<Vec<IpAddr>, DiscoveryError> {
Ok(vec![IpAddr::V4(Ipv4Addr::LOCALHOST)])
}
#[cfg(test)]
const LISTEN_PORT: u16 = DISCOVERY_PORT - 1;
#[cfg(test)]
#[serial_test::serial(mock_discoverable)]
mod tests {
use tokio::net::UdpSocket;
use super::*;
struct MockDiscoverable {
c: Console,
socket: UdpSocket,
}
impl MockDiscoverable {
async fn new() -> std::io::Result<Self> {
let mock = Self {
c: Console {
address: broadcast_addresses().unwrap()[0],
console_id: "AT5C202502000001".to_owned(),
airtouch_id: 13,
name: "Test unit".to_owned(),
},
socket: create_socket(DISCOVERY_PORT).await?,
};
debug!("mock unit listening on {}", mock.socket.local_addr()?);
Ok(mock)
}
async fn single(&self) -> std::io::Result<()> {
let mut buf = vec![0u8; 128];
trace!("mock unit waiting for request... ");
std::io::stderr().flush()?;
let (len, sockaddr) = self.socket.recv_from(&mut buf).await?;
debug!(
"mock unit received {:?}... ",
std::str::from_utf8(&buf[..len]).unwrap()
);
std::io::stderr().flush()?;
assert_eq!(len, DISCOVERY_REQUEST.len());
assert_eq!(&buf[..len], DISCOVERY_REQUEST.as_bytes());
trace!("mock unit recv ok");
let resp = format!(
"{},{},AirTouch5,{},{}",
self.c.address, self.c.console_id, self.c.airtouch_id, self.c.name
);
debug!("mock unit sending to {}... ", sockaddr);
std::io::stderr().flush()?;
assert_eq!(
self.socket.send_to(resp.as_bytes(), sockaddr).await?,
resp.len()
);
trace!("mock unit send done");
Ok(())
}
}
#[tokio::test]
async fn test_ok() {
let mock = MockDiscoverable::new()
.await
.expect("error constructing mock");
let expected = mock.c.clone();
tokio::spawn(async move {
assert!(mock.single().await.is_ok());
});
assert_matches!(discover().await, Ok(c) => { assert_eq!(c, expected); });
}
#[cfg(feature = "timeout")]
#[tokio::test]
async fn test_timeout_default() {
assert_matches!(
discover_timeout(None).await,
Err(DiscoveryError::NoResponse)
);
}
#[cfg(feature = "timeout")]
#[tokio::test]
async fn test_timeout_specific() {
assert_matches!(
discover_timeout(Some(Duration::from_millis(250))).await,
Err(DiscoveryError::NoResponse)
);
}
#[tokio::test]
async fn test_create_socket() {
let sock = create_socket(LISTEN_PORT).await;
assert_matches!(sock, Ok(udp) => {
assert_matches!(udp.local_addr(), Ok(addr) => {
assert_eq!(addr, std::net::SocketAddr::V4(std::net::SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, LISTEN_PORT)));
});
assert_matches!(udp.broadcast(), Ok(true));
});
}
#[tokio::test]
#[serial_test::parallel]
async fn test_broadcast_addresses() {
let addrs = broadcast_addresses_real().expect("failed to get broadcast addresses");
assert!(!addrs.is_empty(), "broadcast addresses should not be empty");
assert!(
!addrs.contains(&IpAddr::V4(Ipv4Addr::LOCALHOST)),
"broadcast addresses should not include localhost"
);
}
}