1use core::marker::PhantomData;
2
3use super::{
4 ConnectionHandle, Pool,
5 error::ProtoError,
6 proto::{Flags, Message, Opcode, Question, ResponseCode},
7};
8
9const FORCE_UNICAST_RESPONSES: bool = false;
10
11#[cfg(feature = "slab")]
15#[cfg_attr(docsrs, doc(cfg(feature = "slab")))]
16pub type SlabEndpoint = Endpoint<slab::Slab<slab::Slab<u16>>, slab::Slab<u16>>;
17
18#[derive(Debug, thiserror::Error)]
20pub enum Error<S, Q> {
21 #[error(transparent)]
23 Connection(S),
24 #[error(transparent)]
26 Query(Q),
27 #[error("connection not found: {0}")]
29 ConnectionNotFound(ConnectionHandle),
30 #[error("query {qid} not found on connection {cid}", qid = _0.qid, cid = _0.cid)]
32 QueryNotFound(QueryHandle),
33 #[error("invalid opcode: {0:?}")]
35 InvalidOpcode(Opcode),
36 #[error("invalid response code: {0:?}")]
38 InvalidResponseCode(ResponseCode),
39 #[error("support for DNS requests with high truncated bit not implemented")]
41 TrancatedQuery,
42 #[error(transparent)]
44 Proto(#[from] ProtoError),
45}
46
47#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
49pub struct QueryHandle {
50 qid: usize, mid: u16, cid: usize, }
54
55impl QueryHandle {
56 #[inline]
57 const fn new(cid: usize, qid: usize, mid: u16) -> Self {
58 Self { cid, qid, mid }
59 }
60
61 #[inline]
63 pub const fn message_id(&self) -> u16 {
64 self.mid
65 }
66
67 #[inline]
69 pub const fn query_id(&self) -> usize {
70 self.qid
71 }
72
73 #[inline]
75 pub const fn connection_id(&self) -> usize {
76 self.cid
77 }
78}
79
80#[derive(Debug, Eq, PartialEq)]
82pub struct Query<'container, 'innards> {
83 msg: Message<'container, 'innards>,
84 query_handle: QueryHandle,
85}
86
87impl<'container, 'innards> Query<'container, 'innards> {
88 #[inline]
89 const fn new(msg: Message<'container, 'innards>, query_handle: QueryHandle) -> Self {
90 Self { msg, query_handle }
91 }
92
93 #[inline]
95 pub fn questions(&self) -> &[Question<'innards>] {
96 self.msg.questions()
97 }
98
99 #[inline]
101 pub const fn query_handle(&self) -> QueryHandle {
102 self.query_handle
103 }
104}
105
106#[derive(Debug, Eq, PartialEq)]
108pub struct Response<'innards> {
109 query_handle: QueryHandle,
110 question: Question<'innards>,
111}
112
113impl<'innards> Response<'innards> {
114 #[inline]
116 pub const fn new(query_handle: QueryHandle, question: Question<'innards>) -> Self {
117 Self {
118 query_handle,
119 question,
120 }
121 }
122
123 #[inline]
125 pub const fn query_handle(&self) -> QueryHandle {
126 self.query_handle
127 }
128
129 #[inline]
131 pub const fn question(&self) -> &Question<'innards> {
132 &self.question
133 }
134}
135
136pub struct Outgoing {
138 flags: Flags,
139 unicast: bool,
140 id: u16,
141}
142
143impl Outgoing {
144 #[inline]
146 const fn new(flags: Flags, unicast: bool, id: u16) -> Self {
147 Self { flags, unicast, id }
148 }
149
150 #[inline]
152 pub const fn flags(&self) -> Flags {
153 self.flags
154 }
155
156 #[inline]
158 pub const fn is_unicast(&self) -> bool {
159 self.unicast
160 }
161
162 #[inline]
167 pub const fn id(&self) -> u16 {
168 self.id
169 }
170}
171
172#[derive(Debug, Eq, PartialEq, Clone, Copy)]
174pub struct Closed<Q> {
175 pub remainings: Q,
177 pub connection_handle: ConnectionHandle,
179}
180
181pub struct Endpoint<S, Q> {
186 connections: S,
187 _q: PhantomData<Q>,
188}
189
190impl<S, Q> Default for Endpoint<S, Q>
191where
192 S: Pool<Q>,
193 Q: Pool<u16>,
194{
195 fn default() -> Self {
196 Self::new()
197 }
198}
199
200impl<S, Q> Endpoint<S, Q>
201where
202 S: Pool<Q>,
203 Q: Pool<u16>,
204{
205 pub fn new() -> Self {
207 Self {
208 connections: S::new(),
209 _q: PhantomData,
210 }
211 }
212
213 pub fn with_capacity(capacity: usize) -> Result<Self, S::Error> {
215 Ok(Self {
216 connections: S::with_capacity(capacity)?,
217 _q: PhantomData,
218 })
219 }
220
221 pub fn close(&mut self) {
223 self.connections.iter().for_each(|(_idx, conn)| {
224 if !conn.is_empty() {
225 #[cfg(feature = "tracing")]
226 tracing::warn!(
227 "mdns endpoint: connection {} closed with {} remaining queries",
228 _idx,
229 conn.len()
230 );
231 }
232 });
233 }
234
235 pub fn accept(&mut self) -> Result<ConnectionHandle, Error<S::Error, Q::Error>> {
237 let key = self
238 .connections
239 .insert(Q::new())
240 .map_err(Error::Connection)?;
241 Ok(ConnectionHandle(key))
242 }
243
244 pub fn recv<'container, 'innards>(
246 &mut self,
247 ch: ConnectionHandle,
248 msg: Message<'container, 'innards>,
249 ) -> Result<Query<'container, 'innards>, Error<S::Error, Q::Error>> {
250 let id = msg.id();
251 let flags = msg.flags();
252 let opcode = flags.opcode();
253
254 if opcode != Opcode::Query {
255 #[cfg(feature = "tracing")]
260 tracing::error!(opcode = ?opcode, "mdns endpoint: received query with non-zero OpCode");
261 return Err(Error::InvalidOpcode(opcode));
262 }
263
264 let resp_code = flags.response_code();
265 if resp_code != ResponseCode::NoError {
266 #[cfg(feature = "tracing")]
270 tracing::error!(rcode = ?resp_code, "mdns endpoint: received query with non-zero response code");
271 return Err(Error::InvalidResponseCode(resp_code));
272 }
273
274 if flags.truncated() {
281 #[cfg(feature = "tracing")]
282 tracing::error!(
283 "mdns endpoint: support for mDNS requests with high truncated bit not implemented"
284 );
285 return Err(Error::TrancatedQuery);
286 }
287
288 if let Some(conn) = self.connections.get_mut(ch.0) {
289 let qid = conn.insert(id).map_err(Error::Query)?;
290 return Ok(Query::new(msg, QueryHandle::new(ch.into(), qid, id)));
291 }
292
293 Err(Error::ConnectionNotFound(ch))
294 }
295
296 pub fn response(
298 &mut self,
299 qh: QueryHandle,
300 question: Question<'_>,
301 ) -> Result<Outgoing, Error<S::Error, Q::Error>> {
302 let mut flags = Flags::new();
303 flags
304 .set_response_code(ResponseCode::NoError)
305 .set_authoritative(true);
306
307 let qc = question.class();
319 let unicast = (qc & (1 << 15)) != 0 || FORCE_UNICAST_RESPONSES;
320
321 let mut id = 0;
324 if unicast {
325 id = qh.message_id();
326 }
327
328 Ok(Outgoing::new(flags, unicast, id))
329 }
330
331 pub fn drain_query(&mut self, qh: QueryHandle) -> Result<(), Error<S::Error, Q::Error>> {
333 match self.connections.get_mut(qh.cid) {
334 Some(q) => match q.try_remove(qh.qid) {
335 Some(_) => Ok(()),
336 None => Err(Error::QueryNotFound(qh)),
337 },
338 None => Err(Error::ConnectionNotFound(ConnectionHandle(qh.cid))),
339 }
340 }
341
342 pub fn drain_connection(
344 &mut self,
345 ch: ConnectionHandle,
346 ) -> Result<Closed<Q>, Error<S::Error, Q::Error>> {
347 match self.connections.try_remove(ch.into()) {
348 Some(queries) => {
349 #[cfg(feature = "tracing")]
350 if !queries.is_empty() {
351 tracing::warn!(
352 "mdns endpoint: connection {} closed with {} remaining queries",
353 ch,
354 queries.len()
355 );
356 }
357 Ok(Closed {
358 remainings: queries,
359 connection_handle: ch,
360 })
361 }
362 None => Err(Error::ConnectionNotFound(ch)),
363 }
364 }
365}