1use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
2
3use dns_protocol::{Cursor, Deserialize, Label, Message, Question, ResourceType};
4
5use super::{
6 Srv, Txt,
7 error::{ProtoError, proto_error_parse},
8};
9
10#[derive(Debug, Clone, Copy)]
12pub enum Response<'a> {
13 A {
15 name: Label<'a>,
17 addr: Ipv4Addr,
19 },
20 AAAA {
22 name: Label<'a>,
24 addr: Ipv6Addr,
26 zone: Option<u32>,
28 },
29 Ptr(Label<'a>),
31 Txt {
33 name: Label<'a>,
35 txt: Txt<'a, 'a>,
37 },
38 Srv {
40 name: Label<'a>,
42 srv: Srv<'a>,
44 },
45}
46
47trait Ipv6AddrExt {
48 fn is_unicast_link_local(&self) -> bool;
49 fn is_multicast_link_local(&self) -> bool;
50}
51
52impl Ipv6AddrExt for Ipv6Addr {
53 #[inline]
54 fn is_unicast_link_local(&self) -> bool {
55 let octets = self.octets();
56 octets[0] == 0xfe && (octets[1] & 0xc0) == 0x80
57 }
58
59 #[inline]
60 fn is_multicast_link_local(&self) -> bool {
61 let octets = self.octets();
62 octets[0] == 0xff && (octets[1] & 0x0f) == 0x02
63 }
64}
65
66pub struct Endpoint;
68
69impl Endpoint {
70 pub fn prepare_question(name: Label<'_>, unicast_response: bool) -> Question<'_> {
72 let qclass = if unicast_response {
79 let base: u16 = 1;
80 base | (1 << 15)
81 } else {
82 1
83 };
84
85 Question::new(name, ResourceType::Ptr, qclass)
86 }
87
88 pub fn recv<'innards>(
90 from: SocketAddr,
91 msg: &Message<'_, 'innards>,
92 ) -> impl Iterator<Item = Result<Response<'innards>, ProtoError>> {
93 msg
95 .answers()
96 .iter()
97 .chain(msg.additional().iter())
98 .filter_map(move |record| {
99 let record_name = record.name();
100 match record.ty() {
101 ResourceType::A => {
102 let src = record.data();
103 let res: Result<[u8; 4], _> = src.try_into();
104
105 match res {
106 Ok(ip) => Some(Ok(Response::A {
107 name: record_name,
108 addr: Ipv4Addr::from(ip),
109 })),
110 Err(_) => {
111 #[cfg(feature = "tracing")]
112 tracing::error!("mdns endpoint: invalid A record data");
113 Some(Err(proto_error_parse("A")))
114 }
115 }
116 }
117 ResourceType::AAAA => {
118 let src = record.data();
119 let res: Result<[u8; 16], _> = src.try_into();
120
121 match res {
122 Ok(ip) => {
123 let ip = Ipv6Addr::from(ip);
124 let mut zone = None;
125 if Ipv6AddrExt::is_unicast_link_local(&ip) || ip.is_multicast_link_local() {
130 if let SocketAddr::V6(addr) = from {
131 zone = Some(addr.scope_id());
132 }
133 }
134
135 Some(Ok(Response::AAAA {
136 name: record_name,
137 addr: ip,
138 zone,
139 }))
140 }
141 Err(_) => {
142 #[cfg(feature = "tracing")]
143 tracing::error!("mdns endpoint: invalid AAAA record data");
144 Some(Err(proto_error_parse("AAAA")))
145 }
146 }
147 }
148 ResourceType::Ptr => {
149 let mut label = Label::default();
150 let cursor = Cursor::new(record.data());
151 Some(label.deserialize(cursor).map(|_| Response::Ptr(label)))
152 }
153 ResourceType::Srv => {
154 let data = record.data();
155
156 Some(Srv::from_bytes(data).map(|srv| Response::Srv {
157 name: record_name,
158 srv,
159 }))
160 }
161 ResourceType::Txt => {
162 let data = record.data();
163 Some(Ok(Response::Txt {
164 name: record_name,
165 txt: Txt::from_bytes(data),
166 }))
167 }
168 _ => None,
169 }
170 })
171 }
172}