use std::net::Ipv4Addr;
pub const MDNS_MULTICAST_ADDR: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251);
pub const MDNS_PORT: u16 = 5353;
pub const DNS_TYPE_A: u16 = 1;
pub const DNS_TYPE_AAAA: u16 = 28;
pub const DNS_TYPE_ANY: u16 = 255;
pub const DNS_CLASS_IN: u16 = 1;
pub const CACHE_FLUSH_BIT: u16 = 0x8000;
#[derive(Debug, Clone)]
pub enum MdnsProtocolError {
PacketTooShort {
expected: usize,
actual: usize,
},
InvalidDomainName(usize),
NotAQuery,
InvalidUtf8,
}
impl std::fmt::Display for MdnsProtocolError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::PacketTooShort { expected, actual } => {
write!(
f,
"packet too short: expected at least {} bytes, got {}",
expected, actual
)
}
Self::InvalidDomainName(offset) => {
write!(f, "invalid domain name at offset {}", offset)
}
Self::NotAQuery => write!(f, "not a query packet"),
Self::InvalidUtf8 => write!(f, "invalid UTF-8 in domain name"),
}
}
}
impl std::error::Error for MdnsProtocolError {}
#[derive(Debug, Clone)]
pub struct MdnsQuery {
pub id: u16,
pub domain: String,
pub query_type: u16,
pub unicast_response: bool,
}
pub fn parse_query(packet: &[u8]) -> Result<Option<MdnsQuery>, MdnsProtocolError> {
if packet.len() < 12 {
return Err(MdnsProtocolError::PacketTooShort {
expected: 12,
actual: packet.len(),
});
}
let flags = u16::from_be_bytes([packet[2], packet[3]]);
let is_query = (flags & 0x8000) == 0;
if !is_query {
return Ok(None); }
let id = u16::from_be_bytes([packet[0], packet[1]]);
let qdcount = u16::from_be_bytes([packet[4], packet[5]]);
if qdcount == 0 {
return Ok(None); }
let (domain, offset) = decode_domain_name(packet, 12)?;
if packet.len() < offset + 4 {
return Err(MdnsProtocolError::PacketTooShort {
expected: offset + 4,
actual: packet.len(),
});
}
let query_type = u16::from_be_bytes([packet[offset], packet[offset + 1]]);
let query_class = u16::from_be_bytes([packet[offset + 2], packet[offset + 3]]);
let unicast_response = (query_class & 0x8000) != 0;
Ok(Some(MdnsQuery {
id,
domain: domain.to_lowercase(),
query_type,
unicast_response,
}))
}
#[must_use]
pub fn build_response(query: &MdnsQuery, ip: Ipv4Addr, ttl: u32, cache_flush: bool) -> Vec<u8> {
let mut packet = Vec::with_capacity(64);
packet.extend_from_slice(&query.id.to_be_bytes()); packet.extend_from_slice(&[0x84, 0x00]); packet.extend_from_slice(&[0x00, 0x00]); packet.extend_from_slice(&[0x00, 0x01]); packet.extend_from_slice(&[0x00, 0x00]); packet.extend_from_slice(&[0x00, 0x00]);
packet.extend(encode_domain_name(&query.domain));
packet.extend_from_slice(&DNS_TYPE_A.to_be_bytes());
let class: u16 = if cache_flush {
DNS_CLASS_IN | CACHE_FLUSH_BIT
} else {
DNS_CLASS_IN
};
packet.extend_from_slice(&class.to_be_bytes());
packet.extend_from_slice(&ttl.to_be_bytes()); packet.extend_from_slice(&[0x00, 0x04]); packet.extend_from_slice(&ip.octets());
packet
}
#[must_use]
pub fn build_announcement(domain: &str, ip: Ipv4Addr, ttl: u32) -> Vec<u8> {
let mut packet = Vec::with_capacity(64);
packet.extend_from_slice(&[0x00, 0x00]); packet.extend_from_slice(&[0x84, 0x00]); packet.extend_from_slice(&[0x00, 0x00]); packet.extend_from_slice(&[0x00, 0x01]); packet.extend_from_slice(&[0x00, 0x00]); packet.extend_from_slice(&[0x00, 0x00]);
packet.extend(encode_domain_name(domain));
packet.extend_from_slice(&DNS_TYPE_A.to_be_bytes()); packet.extend_from_slice(&(DNS_CLASS_IN | CACHE_FLUSH_BIT).to_be_bytes()); packet.extend_from_slice(&ttl.to_be_bytes());
packet.extend_from_slice(&[0x00, 0x04]); packet.extend_from_slice(&ip.octets());
packet
}
#[must_use]
pub fn build_goodbye(domain: &str) -> Vec<u8> {
build_announcement(domain, Ipv4Addr::UNSPECIFIED, 0)
}
#[must_use]
pub fn encode_domain_name(domain: &str) -> Vec<u8> {
let mut result = Vec::new();
for label in domain.split('.') {
if label.is_empty() {
continue;
}
result.push(label.len() as u8);
result.extend(label.as_bytes());
}
result.push(0);
result
}
pub fn decode_domain_name(
packet: &[u8],
start: usize,
) -> Result<(String, usize), MdnsProtocolError> {
let mut labels = Vec::new();
let mut offset = start;
let mut jumped = false;
let mut next_offset = start;
let mut jumps = 0;
const MAX_JUMPS: usize = 10;
loop {
if offset >= packet.len() {
return Err(MdnsProtocolError::InvalidDomainName(offset));
}
let len = packet[offset] as usize;
if len == 0 {
if !jumped {
next_offset = offset + 1;
}
break;
}
if (len & 0xC0) == 0xC0 {
if offset + 1 >= packet.len() {
return Err(MdnsProtocolError::InvalidDomainName(offset));
}
jumps += 1;
if jumps > MAX_JUMPS {
return Err(MdnsProtocolError::InvalidDomainName(offset));
}
let pointer = ((len & 0x3F) << 8) | (packet[offset + 1] as usize);
if !jumped {
next_offset = offset + 2;
}
jumped = true;
offset = pointer;
continue;
}
offset += 1;
if offset + len > packet.len() {
return Err(MdnsProtocolError::InvalidDomainName(offset));
}
let label = std::str::from_utf8(&packet[offset..offset + len])
.map_err(|_| MdnsProtocolError::InvalidUtf8)?;
labels.push(label.to_string());
offset += len;
}
Ok((labels.join("."), next_offset))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_domain_name() {
let encoded = encode_domain_name("nginx.arcbox.local");
assert_eq!(encoded[0], 5); assert_eq!(&encoded[1..6], b"nginx");
assert_eq!(encoded[6], 6); assert_eq!(&encoded[7..13], b"arcbox");
assert_eq!(encoded[13], 5); assert_eq!(&encoded[14..19], b"local");
assert_eq!(encoded[19], 0); }
#[test]
fn test_encode_domain_name_empty_labels() {
let encoded = encode_domain_name("test..arcbox.local.");
assert!(encoded.ends_with(&[0]));
}
#[test]
fn test_decode_domain_name() {
let packet = encode_domain_name("test.arcbox.local");
let (domain, len) = decode_domain_name(&packet, 0).unwrap();
assert_eq!(domain, "test.arcbox.local");
assert_eq!(len, packet.len());
}
#[test]
fn test_decode_domain_name_with_compression() {
let mut packet = Vec::new();
packet.extend(encode_domain_name("arcbox.local")); let first_name_end = packet.len();
packet.push(4); packet.extend(b"test");
packet.push(0xC0); packet.push(0x00);
let (domain, _) = decode_domain_name(&packet, first_name_end).unwrap();
assert_eq!(domain, "test.arcbox.local");
}
#[test]
fn test_build_announcement() {
let packet = build_announcement("nginx.arcbox.local", Ipv4Addr::new(192, 168, 64, 10), 120);
assert_eq!(packet[0..2], [0x00, 0x00]); assert_eq!(packet[2] & 0x80, 0x80); assert_eq!(packet[7], 1);
let ip_start = packet.len() - 4;
assert_eq!(&packet[ip_start..], &[192, 168, 64, 10]);
}
#[test]
fn test_build_goodbye() {
let packet = build_goodbye("nginx.arcbox.local");
let rdlen_pos = packet.len() - 6; let ttl_pos = rdlen_pos - 4;
assert_eq!(&packet[ttl_pos..ttl_pos + 4], &[0, 0, 0, 0]);
assert_eq!(&packet[packet.len() - 4..], &[0, 0, 0, 0]);
}
#[test]
fn test_build_response() {
let query = MdnsQuery {
id: 0x1234,
domain: "test.arcbox.local".to_string(),
query_type: DNS_TYPE_A,
unicast_response: false,
};
let packet = build_response(&query, Ipv4Addr::new(10, 0, 0, 1), 120, true);
assert_eq!(u16::from_be_bytes([packet[0], packet[1]]), 0x1234);
assert_eq!(packet[2] & 0x80, 0x80);
assert_eq!(u16::from_be_bytes([packet[6], packet[7]]), 1);
}
#[test]
fn test_parse_query() {
let mut packet = Vec::new();
packet.extend_from_slice(&[0x00, 0x01]); packet.extend_from_slice(&[0x00, 0x00]); packet.extend_from_slice(&[0x00, 0x01]); packet.extend_from_slice(&[0x00, 0x00]); packet.extend_from_slice(&[0x00, 0x00]); packet.extend_from_slice(&[0x00, 0x00]); packet.extend(encode_domain_name("test.arcbox.local"));
packet.extend_from_slice(&[0x00, 0x01]); packet.extend_from_slice(&[0x00, 0x01]);
let query = parse_query(&packet).unwrap().unwrap();
assert_eq!(query.domain, "test.arcbox.local");
assert_eq!(query.query_type, 1);
assert!(!query.unicast_response);
}
#[test]
fn test_parse_query_with_qu_bit() {
let mut packet = Vec::new();
packet.extend_from_slice(&[0x00, 0x00]); packet.extend_from_slice(&[0x00, 0x00]); packet.extend_from_slice(&[0x00, 0x01]); packet.extend_from_slice(&[0x00, 0x00]); packet.extend_from_slice(&[0x00, 0x00]); packet.extend_from_slice(&[0x00, 0x00]); packet.extend(encode_domain_name("test.arcbox.local"));
packet.extend_from_slice(&[0x00, 0x01]); packet.extend_from_slice(&[0x80, 0x01]);
let query = parse_query(&packet).unwrap().unwrap();
assert!(query.unicast_response);
}
#[test]
fn test_parse_response_returns_none() {
let packet = [
0x00, 0x00, 0x84, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
];
assert!(parse_query(&packet).unwrap().is_none());
}
#[test]
fn test_parse_query_too_short() {
let packet = [0x00, 0x01, 0x00]; assert!(matches!(
parse_query(&packet),
Err(MdnsProtocolError::PacketTooShort { .. })
));
}
#[test]
fn test_parse_query_no_questions() {
let packet = [
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
];
assert!(parse_query(&packet).unwrap().is_none());
}
#[test]
fn test_constants() {
assert_eq!(MDNS_MULTICAST_ADDR, Ipv4Addr::new(224, 0, 0, 251));
assert_eq!(MDNS_PORT, 5353);
assert_eq!(DNS_TYPE_A, 1);
assert_eq!(DNS_TYPE_AAAA, 28);
assert_eq!(DNS_TYPE_ANY, 255);
}
}