use core::pin::pin;
use std::collections::{HashMap, HashSet};
use std::net::IpAddr;
use domain::base::Name;
use embassy_futures::select::select3;
use embassy_time::{Duration, Timer};
use zbus::zvariant::{ObjectPath, OwnedObjectPath};
use zbus::Connection;
use crate::error::Error;
use crate::transport::network::mdns::{DottedName, MdnsRemoteService};
use crate::transport::network::MatterLocalService;
use crate::utils::select::Coalesce;
use crate::utils::zbus_proxies::resolve::manager::ManagerProxy;
use crate::Matter;
const QUERY_POLL_INTERVAL_MS: u64 = 250;
const IF_INDEX_ANY: i32 = 0;
const AF_UNSPEC: i32 = 0;
const DNS_CLASS_IN: u16 = 1;
const DNS_TYPE_PTR: u16 = 12;
pub struct ResolveMdns {
services: HashMap<MatterLocalService, OwnedObjectPath>,
connection: Connection,
}
impl ResolveMdns {
pub fn new(connection: Connection) -> Self {
Self {
services: HashMap::new(),
connection,
}
}
pub async fn run(&mut self, matter: &Matter<'_>) -> Result<(), Error> {
let connection = self.connection.clone();
let mut respond = pin!(self.run_respond(matter));
let mut resolve = pin!(Self::run_resolve(matter, &connection));
let mut browse = pin!(Self::run_browse(matter, &connection));
select3(&mut respond, &mut resolve, &mut browse)
.coalesce()
.await
}
async fn run_respond(&mut self, matter: &Matter<'_>) -> Result<(), Error> {
loop {
matter.transport().wait_mdns().await;
let mut services = HashSet::new();
matter.mdns_services(|service| {
services.insert(service);
Ok(())
})?;
info!("mDNS services changed, updating...");
self.update_services(matter, &services).await?;
info!("mDNS services updated");
}
}
async fn run_browse(matter: &Matter<'_>, connection: &Connection) -> Result<(), Error> {
loop {
let _filter = matter.transport().wait_mdns_browse_request().await;
let resolve = ManagerProxy::new(connection).await?;
while matter.transport().mdns_browse_in_flight() {
if let Ok((records, _flags)) = resolve
.resolve_record(
IF_INDEX_ANY,
"_matterc._udp.local",
DNS_CLASS_IN,
DNS_TYPE_PTR,
0,
)
.await
{
for instance in records
.into_iter()
.filter_map(|(_if, _rt, _rc, rdata)| parse_dns_name(&rdata))
{
let Some((name, type_, domain)) = parse_service_instance(&instance) else {
continue;
};
if let Ok((srv_data, txt_data, _cn, _ct, _cd, _fl)) = resolve
.resolve_service(IF_INDEX_ANY, &name, &type_, &domain, AF_UNSPEC, 0)
.await
{
Self::deposit_browse(matter, &instance, &srv_data, &txt_data);
}
if !matter.transport().mdns_browse_in_flight() {
break;
}
}
}
Timer::after(Duration::from_millis(QUERY_POLL_INTERVAL_MS)).await;
}
}
}
async fn run_resolve(matter: &Matter<'_>, connection: &Connection) -> Result<(), Error> {
loop {
let service = matter.transport().wait_mdns_resolve_request().await;
let mut name_buf: heapless::String<128> = heapless::String::new();
service.instance_name(&mut name_buf);
let label = name_buf.split('.').next().unwrap_or("").to_string();
let service_type = service.service_type();
let resolve = ManagerProxy::new(connection).await?;
while matter.transport().mdns_resolve_in_flight() {
if let Ok((srv_data, txt_data, _cn, _ct, _cd, _fl)) = resolve
.resolve_service(IF_INDEX_ANY, &label, service_type, "local", AF_UNSPEC, 0)
.await
{
Self::deposit_resolve(matter, name_buf.as_str(), &srv_data, &txt_data);
}
Timer::after(Duration::from_millis(QUERY_POLL_INTERVAL_MS)).await;
}
}
}
#[allow(clippy::type_complexity)]
fn deposit_browse(
matter: &Matter<'_>,
instance_name: &str,
srv_data: &[(u16, u16, u16, String, Vec<(i32, i32, Vec<u8>)>, String)],
txt_data: &[Vec<u8>],
) {
for (_p, _w, port, _host, addresses, _ch) in srv_data {
let ips = all_addresses(addresses);
if !ips.is_empty() {
let txt = txt_iter(txt_data);
matter
.transport()
.try_deposit_mdns_browse(&MdnsRemoteService {
instance_name: DottedName(instance_name),
port: Some(*port),
addrs: ips.iter().copied(),
txt: txt.iter().copied(),
scope_id: link_local_scope_id(addresses),
});
}
}
}
#[allow(clippy::type_complexity)]
fn deposit_resolve(
matter: &Matter<'_>,
instance_name: &str,
srv_data: &[(u16, u16, u16, String, Vec<(i32, i32, Vec<u8>)>, String)],
txt_data: &[Vec<u8>],
) {
for (_p, _w, port, _host, addresses, _ch) in srv_data {
let ips = all_addresses(addresses);
if !ips.is_empty() {
let txt = txt_iter(txt_data);
matter
.transport()
.try_deposit_mdns_resolve(&MdnsRemoteService {
instance_name: DottedName(instance_name),
port: Some(*port),
addrs: ips.iter().copied(),
txt: txt.iter().copied(),
scope_id: link_local_scope_id(addresses),
});
}
}
}
async fn update_services(
&mut self,
matter: &Matter<'_>,
services: &HashSet<MatterLocalService>,
) -> Result<(), Error> {
for service in services {
if !self.services.contains_key(service) {
info!("Registering mDNS service: {:?}", service);
let path = self.register(matter, service).await?;
self.services.insert(service.clone(), path);
}
}
loop {
let removed = self
.services
.iter()
.find(|(service, _)| !services.contains(service));
if let Some((service, path)) = removed {
info!("Deregistering mDNS service: {:?}", service);
self.deregister(path.as_ref()).await?;
self.services.remove(&service.clone());
} else {
break;
}
}
Ok(())
}
async fn register(
&mut self,
matter: &Matter<'_>,
service: &MatterLocalService,
) -> Result<OwnedObjectPath, Error> {
let mut buf = [0u8; 512];
let (service, _) = service.service(matter.dev_det(), matter.port(), &mut buf)?;
let resolve = ManagerProxy::new(&self.connection).await?;
let txt = service
.txt_kvs
.clone()
.map(|(k, v)| (k, v.as_bytes()))
.collect::<HashMap<_, _>>();
let id = format!("rs-matter-{}", service.name);
let path = resolve
.register_service(
&id,
service.name,
service.service_protocol,
service.port,
0,
0,
&[txt],
)
.await?;
Ok(path)
}
async fn deregister(&self, path: ObjectPath<'_>) -> Result<(), Error> {
let resolve = ManagerProxy::new(&self.connection).await?;
resolve.unregister_service(&path).await?;
Ok(())
}
}
fn all_addresses(addresses: &[(i32, i32, Vec<u8>)]) -> Vec<IpAddr> {
addresses
.iter()
.filter_map(|(_ifindex, family, bytes)| parse_ip_address(*family, bytes))
.collect()
}
fn link_local_scope_id(addresses: &[(i32, i32, Vec<u8>)]) -> u32 {
addresses
.iter()
.find_map(
|(ifindex, family, bytes)| match parse_ip_address(*family, bytes) {
Some(IpAddr::V6(v6)) if v6.is_unicast_link_local() && *ifindex > 0 => {
Some(*ifindex as u32)
}
_ => None,
},
)
.unwrap_or(0)
}
fn txt_iter(txt_data: &[Vec<u8>]) -> Vec<(&str, &str)> {
let mut pairs = Vec::new();
for entry in txt_data {
if let Ok(s) = core::str::from_utf8(entry) {
match s.find('=') {
Some(eq) => pairs.push((&s[..eq], &s[eq + 1..])),
None => pairs.push((s, "")),
}
}
}
pairs
}
fn parse_dns_name(data: &[u8]) -> Option<String> {
let labels: Vec<&str> = Name::from_slice(data)
.ok()?
.iter()
.filter(|label| !label.is_empty())
.filter_map(|label| core::str::from_utf8(label.as_slice()).ok())
.collect();
(!labels.is_empty()).then(|| labels.join("."))
}
fn parse_service_instance(instance: &str) -> Option<(String, String, String)> {
let instance = instance.trim_end_matches('.');
let type_start = instance.find("._matterc._udp")?;
let name = &instance[..type_start];
let after_name = &instance[type_start + 1..]; let domain_start = after_name.find(".local")?;
let type_ = &after_name[..domain_start + ".local".len()];
let dot_local_pos = type_.rfind(".local")?;
let service_type = &type_[..dot_local_pos];
let domain = "local";
Some((
name.to_string(),
service_type.to_string(),
domain.to_string(),
))
}
fn parse_ip_address(family: i32, addr_bytes: &[u8]) -> Option<IpAddr> {
match family {
2 => {
if addr_bytes.len() >= 4 {
Some(IpAddr::V4(std::net::Ipv4Addr::new(
addr_bytes[0],
addr_bytes[1],
addr_bytes[2],
addr_bytes[3],
)))
} else {
None
}
}
10 => {
if addr_bytes.len() >= 16 {
let mut octets = [0u8; 16];
octets.copy_from_slice(&addr_bytes[..16]);
Some(IpAddr::V6(std::net::Ipv6Addr::from(octets)))
} else {
None
}
}
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_dns_name_simple() {
let data = [5, b'l', b'o', b'c', b'a', b'l', 0];
assert_eq!(parse_dns_name(&data), Some("local".to_string()));
}
#[test]
fn parse_dns_name_multi_label() {
let data = [
8, b'_', b'm', b'a', b't', b't', b'e', b'r', b'c', 4, b'_', b'u', b'd', b'p', 5, b'l', b'o', b'c', b'a', b'l', 0, ];
assert_eq!(
parse_dns_name(&data),
Some("_matterc._udp.local".to_string())
);
}
#[test]
fn parse_dns_name_service_instance() {
let data = [
13, b'M', b'a', b't', b't', b'e', b'r', b' ', b'D', b'e', b'v', b'i', b'c',
b'e', 8, b'_', b'm', b'a', b't', b't', b'e', b'r', b'c', 4, b'_', b'u', b'd', b'p', 5, b'l', b'o', b'c', b'a', b'l', 0, ];
assert_eq!(
parse_dns_name(&data),
Some("Matter Device._matterc._udp.local".to_string())
);
}
#[test]
fn parse_dns_name_empty() {
let data = [0];
assert_eq!(parse_dns_name(&data), None);
}
#[test]
fn parse_dns_name_truncated() {
let data = [10, b'h', b'e', b'l', b'l', b'o'];
assert_eq!(parse_dns_name(&data), None);
}
#[test]
fn parse_dns_name_with_spaces() {
let data = [
13, b'M', b'a', b't', b't', b'e', b'r', b' ', b'D', b'e', b'v', b'i', b'c', b'e', 0,
];
assert_eq!(parse_dns_name(&data), Some("Matter Device".to_string()));
}
#[test]
fn parse_service_instance_simple() {
let result = parse_service_instance("MyDevice._matterc._udp.local");
assert!(result.is_some());
let (name, type_, domain) = result.unwrap();
assert_eq!(name, "MyDevice");
assert_eq!(type_, "_matterc._udp");
assert_eq!(domain, "local");
}
#[test]
fn parse_service_instance_with_spaces() {
let result = parse_service_instance("Matter Test Device._matterc._udp.local");
assert!(result.is_some());
let (name, type_, domain) = result.unwrap();
assert_eq!(name, "Matter Test Device");
assert_eq!(type_, "_matterc._udp");
assert_eq!(domain, "local");
}
#[test]
fn parse_service_instance_with_trailing_dot() {
let result = parse_service_instance("MyDevice._matterc._udp.local.");
assert!(result.is_some());
let (name, type_, domain) = result.unwrap();
assert_eq!(name, "MyDevice");
assert_eq!(type_, "_matterc._udp");
assert_eq!(domain, "local");
}
#[test]
fn parse_service_instance_invalid_no_matterc() {
let result = parse_service_instance("MyDevice._http._tcp.local");
assert!(result.is_none());
}
#[test]
fn parse_service_instance_invalid_no_local() {
let result = parse_service_instance("MyDevice._matterc._udp.example.com");
assert!(result.is_none());
}
#[test]
fn parse_ip_address_ipv4() {
let addr_bytes = [192, 168, 1, 100];
let result = parse_ip_address(2, &addr_bytes);
assert!(result.is_some());
assert_eq!(
result.unwrap(),
IpAddr::V4(std::net::Ipv4Addr::new(192, 168, 1, 100))
);
}
#[test]
fn parse_ip_address_ipv4_localhost() {
let addr_bytes = [127, 0, 0, 1];
let result = parse_ip_address(2, &addr_bytes);
assert!(result.is_some());
assert_eq!(
result.unwrap(),
IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1))
);
}
#[test]
fn parse_ip_address_ipv6_localhost() {
let addr_bytes = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
let result = parse_ip_address(10, &addr_bytes);
assert!(result.is_some());
assert_eq!(
result.unwrap(),
IpAddr::V6(std::net::Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1))
);
}
#[test]
fn parse_ip_address_ipv6_link_local() {
let addr_bytes = [0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1];
let result = parse_ip_address(10, &addr_bytes);
assert!(result.is_some());
assert_eq!(
result.unwrap(),
IpAddr::V6(std::net::Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 1))
);
}
#[test]
fn parse_ip_address_ipv4_too_short() {
let addr_bytes = [192, 168, 1]; let result = parse_ip_address(2, &addr_bytes);
assert!(result.is_none());
}
#[test]
fn parse_ip_address_ipv6_too_short() {
let addr_bytes = [0, 0, 0, 0, 0, 0, 0, 0]; let result = parse_ip_address(10, &addr_bytes);
assert!(result.is_none());
}
#[test]
fn parse_ip_address_unknown_family() {
let addr_bytes = [192, 168, 1, 100];
let result = parse_ip_address(99, &addr_bytes); assert!(result.is_none());
}
}