use crate::error::{NetError, NetResult};
use std::net::{Ipv4Addr, Ipv6Addr};
pub const MDNS_IPV4_MULTICAST: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251);
pub const MDNS_IPV6_MULTICAST: Ipv6Addr = Ipv6Addr::new(0xFF02, 0, 0, 0, 0, 0, 0, 0x00FB);
pub const MDNS_PORT: u16 = 5353;
const DNS_TYPE_A: u16 = 1;
const DNS_TYPE_AAAA: u16 = 28;
const DNS_TYPE_PTR: u16 = 12;
const DNS_TYPE_SRV: u16 = 33;
const DNS_TYPE_TXT: u16 = 16;
const DNS_CLASS_IN: u16 = 1;
const DNS_QU_BIT: u16 = 0x8000;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MdnsType {
A,
Aaaa,
Ptr,
Srv,
Txt,
Unknown(u16),
}
impl MdnsType {
fn from_u16(v: u16) -> Self {
match v {
DNS_TYPE_A => Self::A,
DNS_TYPE_AAAA => Self::Aaaa,
DNS_TYPE_PTR => Self::Ptr,
DNS_TYPE_SRV => Self::Srv,
DNS_TYPE_TXT => Self::Txt,
other => Self::Unknown(other),
}
}
fn as_u16(self) -> u16 {
match self {
Self::A => DNS_TYPE_A,
Self::Aaaa => DNS_TYPE_AAAA,
Self::Ptr => DNS_TYPE_PTR,
Self::Srv => DNS_TYPE_SRV,
Self::Txt => DNS_TYPE_TXT,
Self::Unknown(v) => v,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MdnsData {
A(Ipv4Addr),
Aaaa(Ipv6Addr),
Ptr(String),
Srv {
priority: u16,
weight: u16,
port: u16,
target: String,
},
Txt(Vec<String>),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct MdnsRecord {
pub name: String,
pub record_type: MdnsType,
pub ttl: u32,
pub data: MdnsData,
}
#[derive(Debug, Clone)]
pub struct ServiceInfo {
pub service_type: String,
pub instance_name: String,
pub port: u16,
pub host: String,
pub txt_records: Vec<String>,
pub ipv4: Option<Ipv4Addr>,
}
pub struct MdnsAnnouncer;
impl MdnsAnnouncer {
pub fn announce(service: &ServiceInfo) -> Vec<u8> {
let mut packet = DnsPacketBuilder::new();
let answer_count: u16 = 3 + u16::from(service.ipv4.is_some());
packet.write_header(0, 0x8400, 0, answer_count, 0, 0);
let ttl_ptr: u32 = 4500; let ptr_name = ensure_fqdn(&service.service_type);
let instance_fqdn = ensure_fqdn(&service.instance_name);
let ptr_rdata = encode_dns_name(&instance_fqdn);
packet.write_rr(&ptr_name, DNS_TYPE_PTR, DNS_CLASS_IN, ttl_ptr, &ptr_rdata);
let ttl_srv: u32 = 120;
let host_fqdn = ensure_fqdn(&service.host);
let mut srv_rdata = Vec::with_capacity(6 + host_fqdn.len() + 2);
srv_rdata.extend_from_slice(&0u16.to_be_bytes()); srv_rdata.extend_from_slice(&0u16.to_be_bytes()); srv_rdata.extend_from_slice(&service.port.to_be_bytes());
srv_rdata.extend_from_slice(&encode_dns_name(&host_fqdn));
packet.write_rr(
&instance_fqdn,
DNS_TYPE_SRV,
DNS_CLASS_IN,
ttl_srv,
&srv_rdata,
);
let ttl_txt: u32 = 4500;
let txt_rdata = encode_txt_records(&service.txt_records);
packet.write_rr(
&instance_fqdn,
DNS_TYPE_TXT,
DNS_CLASS_IN,
ttl_txt,
&txt_rdata,
);
if let Some(ipv4) = service.ipv4 {
let ttl_a: u32 = 120;
packet.write_rr(&host_fqdn, DNS_TYPE_A, DNS_CLASS_IN, ttl_a, &ipv4.octets());
}
packet.into_bytes()
}
}
pub struct MdnsQuery;
impl MdnsQuery {
pub fn build_query(service_type: &str) -> Vec<u8> {
let mut packet = DnsPacketBuilder::new();
packet.write_header(0, 0x0000, 1, 0, 0, 0);
let fqdn = ensure_fqdn(service_type);
let name_bytes = encode_dns_name(&fqdn);
packet.extend(&name_bytes);
packet.write_u16(DNS_TYPE_PTR);
packet.write_u16(DNS_CLASS_IN | DNS_QU_BIT);
packet.into_bytes()
}
}
pub struct MdnsParser;
impl MdnsParser {
pub fn parse(data: &[u8]) -> NetResult<Vec<MdnsRecord>> {
if data.len() < 12 {
return Err(NetError::parse(0, "DNS message too short (need ≥12 bytes)"));
}
let _id = u16::from_be_bytes([data[0], data[1]]);
let flags = u16::from_be_bytes([data[2], data[3]]);
let qd_count = u16::from_be_bytes([data[4], data[5]]) as usize;
let an_count = u16::from_be_bytes([data[6], data[7]]) as usize;
let ns_count = u16::from_be_bytes([data[8], data[9]]) as usize;
let ar_count = u16::from_be_bytes([data[10], data[11]]) as usize;
let _ = flags;
let mut offset = 12usize;
for _ in 0..qd_count {
offset = skip_name(data, offset)?;
if offset + 4 > data.len() {
return Err(NetError::parse(
offset as u64,
"Truncated DNS question section",
));
}
offset += 4; }
let total_rrs = an_count + ns_count + ar_count;
let mut records = Vec::with_capacity(total_rrs);
for _ in 0..total_rrs {
if offset >= data.len() {
break;
}
match parse_rr(data, offset) {
Ok((record, next_offset)) => {
records.push(record);
offset = next_offset;
}
Err(e) => return Err(e),
}
}
Ok(records)
}
}
struct DnsPacketBuilder {
buf: Vec<u8>,
}
impl DnsPacketBuilder {
fn new() -> Self {
Self {
buf: Vec::with_capacity(512),
}
}
fn write_u16(&mut self, v: u16) {
self.buf.extend_from_slice(&v.to_be_bytes());
}
fn write_u32(&mut self, v: u32) {
self.buf.extend_from_slice(&v.to_be_bytes());
}
fn extend(&mut self, bytes: &[u8]) {
self.buf.extend_from_slice(bytes);
}
fn write_header(
&mut self,
id: u16,
flags: u16,
qd_count: u16,
an_count: u16,
ns_count: u16,
ar_count: u16,
) {
self.write_u16(id);
self.write_u16(flags);
self.write_u16(qd_count);
self.write_u16(an_count);
self.write_u16(ns_count);
self.write_u16(ar_count);
}
fn write_rr(&mut self, name: &str, rtype: u16, rclass: u16, ttl: u32, rdata: &[u8]) {
let name_bytes = encode_dns_name(name);
self.extend(&name_bytes);
self.write_u16(rtype);
self.write_u16(rclass);
self.write_u32(ttl);
self.write_u16(rdata.len() as u16);
self.extend(rdata);
}
fn into_bytes(self) -> Vec<u8> {
self.buf
}
}
fn encode_dns_name(name: &str) -> Vec<u8> {
let name = name.trim_end_matches('.');
let mut out = Vec::new();
if name.is_empty() {
out.push(0u8); return out;
}
for label in name.split('.') {
let bytes = label.as_bytes();
out.push(bytes.len() as u8);
out.extend_from_slice(bytes);
}
out.push(0u8); out
}
fn ensure_fqdn(name: &str) -> String {
if name.ends_with('.') {
name.to_string()
} else {
format!("{name}.")
}
}
fn encode_txt_records(records: &[String]) -> Vec<u8> {
if records.is_empty() {
return vec![0x00];
}
let mut out = Vec::new();
for record in records {
let bytes = record.as_bytes();
out.push(bytes.len() as u8);
out.extend_from_slice(bytes);
}
out
}
fn decode_dns_name(data: &[u8], offset: usize) -> NetResult<(String, usize)> {
let mut labels = Vec::new();
let mut pos = offset;
let mut followed_pointer = false;
let mut end_offset = 0usize;
let mut hops = 0u32;
loop {
if pos >= data.len() {
return Err(NetError::parse(pos as u64, "DNS name truncated"));
}
let byte = data[pos];
if byte == 0 {
if !followed_pointer {
end_offset = pos + 1;
}
break;
} else if byte & 0xC0 == 0xC0 {
if pos + 1 >= data.len() {
return Err(NetError::parse(pos as u64, "DNS pointer truncated"));
}
if !followed_pointer {
end_offset = pos + 2;
followed_pointer = true;
}
let ptr = (u16::from(byte & 0x3F) << 8 | u16::from(data[pos + 1])) as usize;
if ptr >= data.len() {
return Err(NetError::parse(
pos as u64,
format!("DNS pointer {ptr} out of bounds (data len {})", data.len()),
));
}
hops += 1;
if hops > 128 {
return Err(NetError::parse(
pos as u64,
"DNS name pointer loop detected",
));
}
pos = ptr;
} else if byte & 0xC0 == 0 {
let label_len = byte as usize;
pos += 1;
if pos + label_len > data.len() {
return Err(NetError::parse(pos as u64, "DNS label truncated"));
}
let label = std::str::from_utf8(&data[pos..pos + label_len])
.map_err(|_| NetError::parse(pos as u64, "DNS label is not valid UTF-8"))?;
labels.push(label.to_string());
pos += label_len;
} else {
return Err(NetError::parse(
pos as u64,
format!("Invalid DNS label length byte: {byte:#04x}"),
));
}
}
if !followed_pointer {
end_offset = pos + 1; }
Ok((labels.join("."), end_offset))
}
fn skip_name(data: &[u8], offset: usize) -> NetResult<usize> {
let (_, end) = decode_dns_name(data, offset)?;
Ok(end)
}
fn parse_rr(data: &[u8], offset: usize) -> NetResult<(MdnsRecord, usize)> {
let (name, mut pos) = decode_dns_name(data, offset)?;
if pos + 10 > data.len() {
return Err(NetError::parse(pos as u64, "RR header truncated"));
}
let rtype_raw = u16::from_be_bytes([data[pos], data[pos + 1]]);
pos += 2;
let _class = u16::from_be_bytes([data[pos], data[pos + 1]]) & 0x7FFF;
pos += 2;
let ttl = u32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
pos += 4;
let rdlength = u16::from_be_bytes([data[pos], data[pos + 1]]) as usize;
pos += 2;
if pos + rdlength > data.len() {
return Err(NetError::parse(
pos as u64,
format!("RDATA truncated: need {rdlength} bytes at offset {pos}"),
));
}
let rdata = &data[pos..pos + rdlength];
let rdata_start = pos;
pos += rdlength;
let record_type = MdnsType::from_u16(rtype_raw);
let record_data = match record_type {
MdnsType::A => {
if rdata.len() < 4 {
return Err(NetError::parse(
rdata_start as u64,
"A record RDATA must be 4 bytes",
));
}
MdnsData::A(Ipv4Addr::new(rdata[0], rdata[1], rdata[2], rdata[3]))
}
MdnsType::Aaaa => {
if rdata.len() < 16 {
return Err(NetError::parse(
rdata_start as u64,
"AAAA record RDATA must be 16 bytes",
));
}
let bytes: [u8; 16] = rdata[..16]
.try_into()
.map_err(|_| NetError::parse(rdata_start as u64, "AAAA RDATA too short"))?;
MdnsData::Aaaa(Ipv6Addr::from(bytes))
}
MdnsType::Ptr => {
let (ptr_name, _) = decode_dns_name(data, rdata_start)?;
MdnsData::Ptr(ptr_name)
}
MdnsType::Srv => {
if rdata.len() < 6 {
return Err(NetError::parse(
rdata_start as u64,
"SRV record RDATA must be ≥6 bytes",
));
}
let priority = u16::from_be_bytes([rdata[0], rdata[1]]);
let weight = u16::from_be_bytes([rdata[2], rdata[3]]);
let port = u16::from_be_bytes([rdata[4], rdata[5]]);
let (target, _) = decode_dns_name(data, rdata_start + 6)?;
MdnsData::Srv {
priority,
weight,
port,
target,
}
}
MdnsType::Txt => {
let strings = parse_txt_rdata(rdata)?;
MdnsData::Txt(strings)
}
MdnsType::Unknown(_) => {
MdnsData::Txt(vec![format!("raw:{}", hex_encode(rdata))])
}
};
Ok((
MdnsRecord {
name,
record_type,
ttl,
data: record_data,
},
pos,
))
}
fn parse_txt_rdata(rdata: &[u8]) -> NetResult<Vec<String>> {
let mut strings = Vec::new();
let mut pos = 0;
while pos < rdata.len() {
let len = rdata[pos] as usize;
pos += 1;
if pos + len > rdata.len() {
return Err(NetError::parse(pos as u64, "TXT string truncated"));
}
let s = std::str::from_utf8(&rdata[pos..pos + len])
.map_err(|_| NetError::parse(pos as u64, "TXT entry is not valid UTF-8"))?;
if !s.is_empty() {
strings.push(s.to_string());
}
pos += len;
}
Ok(strings)
}
fn hex_encode(bytes: &[u8]) -> String {
bytes.iter().map(|b| format!("{b:02x}")).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_dns_name_simple() {
let encoded = encode_dns_name("example.com");
assert_eq!(encoded, b"\x07example\x03com\x00");
}
#[test]
fn test_encode_dns_name_fqdn() {
let a = encode_dns_name("example.com");
let b = encode_dns_name("example.com.");
assert_eq!(a, b, "trailing dot should not affect encoding");
}
#[test]
fn test_encode_dns_name_root() {
let encoded = encode_dns_name(".");
assert_eq!(encoded, b"\x00");
}
#[test]
fn test_decode_dns_name_simple() {
let data = b"\x07example\x03com\x00";
let (name, end) = decode_dns_name(data, 0).expect("ok");
assert_eq!(name, "example.com");
assert_eq!(end, data.len());
}
#[test]
fn test_decode_dns_name_with_pointer() {
let mut data = Vec::new();
data.extend_from_slice(b"\x07example\x03com\x00"); data.extend_from_slice(b"\x03www\xC0\x00"); let (name, end) = decode_dns_name(&data, 13).expect("ok");
assert_eq!(name, "www.example.com");
assert_eq!(end, 13 + 6);
}
#[test]
fn test_decode_dns_name_empty_label() {
let data = b"\x00";
let (name, end) = decode_dns_name(data, 0).expect("ok");
assert_eq!(name, "");
assert_eq!(end, 1);
}
#[test]
fn test_decode_dns_name_pointer_loop_error() {
let data = [0xC0u8, 0x00];
let result = decode_dns_name(&data, 0);
assert!(result.is_err());
}
#[test]
fn test_build_query_parses_back() {
let query = MdnsQuery::build_query("_http._tcp.local");
assert!(query.len() > 12);
let flags = u16::from_be_bytes([query[2], query[3]]);
assert_eq!(flags & 0x8000, 0, "QR must be 0 for query");
let qd_count = u16::from_be_bytes([query[4], query[5]]);
assert_eq!(qd_count, 1);
}
#[test]
fn test_build_query_qu_bit_set() {
let query = MdnsQuery::build_query("_http._tcp.local.");
let (_, name_end) = decode_dns_name(&query, 12).expect("ok");
let qtype = u16::from_be_bytes([query[name_end], query[name_end + 1]]);
let qclass = u16::from_be_bytes([query[name_end + 2], query[name_end + 3]]);
assert_eq!(qtype, DNS_TYPE_PTR, "PTR query type expected");
assert_eq!(qclass & DNS_QU_BIT, DNS_QU_BIT, "QU bit must be set");
}
#[test]
fn test_announce_produces_response_flag() {
let service = ServiceInfo {
service_type: "_http._tcp.local.".to_string(),
instance_name: "Test Server._http._tcp.local.".to_string(),
port: 8080,
host: "testhost.local.".to_string(),
txt_records: vec!["path=/".to_string()],
ipv4: Some(Ipv4Addr::new(192, 168, 1, 1)),
};
let packet = MdnsAnnouncer::announce(&service);
assert!(packet.len() > 12);
let flags = u16::from_be_bytes([packet[2], packet[3]]);
assert_eq!(flags & 0x8000, 0x8000, "QR bit must be set (response)");
assert_eq!(flags & 0x0400, 0x0400, "AA bit must be set");
}
#[test]
fn test_announce_answer_count_with_ipv4() {
let service = ServiceInfo {
service_type: "_http._tcp.local.".to_string(),
instance_name: "My._http._tcp.local.".to_string(),
port: 80,
host: "myhost.local.".to_string(),
txt_records: vec![],
ipv4: Some(Ipv4Addr::new(10, 0, 0, 1)),
};
let packet = MdnsAnnouncer::announce(&service);
let an_count = u16::from_be_bytes([packet[6], packet[7]]);
assert_eq!(an_count, 4, "PTR + SRV + TXT + A = 4");
}
#[test]
fn test_announce_answer_count_without_ipv4() {
let service = ServiceInfo {
service_type: "_http._tcp.local.".to_string(),
instance_name: "My._http._tcp.local.".to_string(),
port: 80,
host: "myhost.local.".to_string(),
txt_records: vec![],
ipv4: None,
};
let packet = MdnsAnnouncer::announce(&service);
let an_count = u16::from_be_bytes([packet[6], packet[7]]);
assert_eq!(an_count, 3, "PTR + SRV + TXT = 3");
}
#[test]
fn test_parse_announced_packet() {
let service = ServiceInfo {
service_type: "_http._tcp.local.".to_string(),
instance_name: "WebApp._http._tcp.local.".to_string(),
port: 3000,
host: "webapp.local.".to_string(),
txt_records: vec!["version=1.0".to_string(), "secure=no".to_string()],
ipv4: Some(Ipv4Addr::new(192, 168, 100, 5)),
};
let packet = MdnsAnnouncer::announce(&service);
let records = MdnsParser::parse(&packet).expect("parse ok");
assert_eq!(records.len(), 4, "expect PTR + SRV + TXT + A");
let ptr = records.iter().find(|r| r.record_type == MdnsType::Ptr);
assert!(ptr.is_some(), "PTR record expected");
if let Some(rec) = ptr {
assert_eq!(rec.name, "_http._tcp.local");
}
let a_rec = records.iter().find(|r| r.record_type == MdnsType::A);
assert!(a_rec.is_some(), "A record expected");
if let Some(rec) = a_rec {
assert_eq!(rec.data, MdnsData::A(Ipv4Addr::new(192, 168, 100, 5)));
}
}
#[test]
fn test_parse_srv_record_fields() {
let service = ServiceInfo {
service_type: "_myapp._tcp.local.".to_string(),
instance_name: "MyApp._myapp._tcp.local.".to_string(),
port: 9000,
host: "serverbox.local.".to_string(),
txt_records: vec![],
ipv4: None,
};
let packet = MdnsAnnouncer::announce(&service);
let records = MdnsParser::parse(&packet).expect("parse ok");
let srv = records.iter().find(|r| r.record_type == MdnsType::Srv);
assert!(srv.is_some());
if let Some(rec) = srv {
if let MdnsData::Srv { port, .. } = rec.data {
assert_eq!(port, 9000);
} else {
panic!("Expected SRV data");
}
}
}
#[test]
fn test_parse_txt_record_values() {
let service = ServiceInfo {
service_type: "_chat._tcp.local.".to_string(),
instance_name: "ChatApp._chat._tcp.local.".to_string(),
port: 5222,
host: "chatserver.local.".to_string(),
txt_records: vec!["user=alice".to_string(), "room=main".to_string()],
ipv4: None,
};
let packet = MdnsAnnouncer::announce(&service);
let records = MdnsParser::parse(&packet).expect("parse ok");
let txt = records.iter().find(|r| r.record_type == MdnsType::Txt);
assert!(txt.is_some());
if let Some(rec) = txt {
if let MdnsData::Txt(ref strings) = rec.data {
assert!(strings.contains(&"user=alice".to_string()));
assert!(strings.contains(&"room=main".to_string()));
} else {
panic!("Expected TXT data");
}
}
}
#[test]
fn test_parse_error_too_short() {
let result = MdnsParser::parse(&[0u8; 5]);
assert!(result.is_err());
}
#[test]
fn test_parse_minimal_response_no_records() {
let header = [0u8, 1, 0x84, 0x00, 0, 0, 0, 0, 0, 0, 0, 0];
let records = MdnsParser::parse(&header).expect("ok");
assert!(records.is_empty());
}
#[test]
fn test_ensure_fqdn() {
assert_eq!(ensure_fqdn("example.com"), "example.com.");
assert_eq!(ensure_fqdn("example.com."), "example.com.");
}
#[test]
fn test_encode_txt_empty() {
let encoded = encode_txt_records(&[]);
assert_eq!(encoded, vec![0x00]);
}
#[test]
fn test_encode_decode_txt_roundtrip() {
let strings = vec!["key=value".to_string(), "flag".to_string()];
let encoded = encode_txt_records(&strings);
let decoded = parse_txt_rdata(&encoded).expect("ok");
assert_eq!(decoded, strings);
}
}