mdns_proto/
server.rs

1use core::marker::PhantomData;
2
3use super::{
4  ConnectionHandle, Pool,
5  error::ProtoError,
6  proto::{Flags, Message, Opcode, Question, ResponseCode},
7};
8
9const FORCE_UNICAST_RESPONSES: bool = false;
10
11/// An endpoint for handling mDNS queries and responses.
12///
13/// This `Endpoint` is using a slab for managing connections and queries.
14#[cfg(feature = "slab")]
15#[cfg_attr(docsrs, doc(cfg(feature = "slab")))]
16pub type SlabEndpoint = Endpoint<slab::Slab<slab::Slab<u16>>, slab::Slab<u16>>;
17
18/// The error type for the server.
19#[derive(Debug, thiserror::Error)]
20pub enum Error<S, Q> {
21  /// The server is full and cannot hold any more connections.
22  #[error(transparent)]
23  Connection(S),
24  /// The connection is full and cannot hold any more queries.
25  #[error(transparent)]
26  Query(Q),
27  /// The connection is not found.
28  #[error("connection not found: {0}")]
29  ConnectionNotFound(ConnectionHandle),
30  /// The query is not found.
31  #[error("query {qid} not found on connection {cid}", qid = _0.qid, cid = _0.cid)]
32  QueryNotFound(QueryHandle),
33  /// Returned when the a query has an invalid opcode.
34  #[error("invalid opcode: {0:?}")]
35  InvalidOpcode(Opcode),
36  /// Returned when the a query has an invalid response code.
37  #[error("invalid response code: {0:?}")]
38  InvalidResponseCode(ResponseCode),
39  /// Returned when a query with a high truncated bit is received.
40  #[error("support for DNS requests with high truncated bit not implemented")]
41  TrancatedQuery,
42  /// Protocol error
43  #[error(transparent)]
44  Proto(#[from] ProtoError),
45}
46
47/// Internal identifier for a `Connection` currently associated with an endpoint
48#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
49pub struct QueryHandle {
50  qid: usize, // the query id
51  mid: u16,   // the message id
52  cid: usize, // the connection id
53}
54
55impl QueryHandle {
56  #[inline]
57  const fn new(cid: usize, qid: usize, mid: u16) -> Self {
58    Self { cid, qid, mid }
59  }
60
61  /// Returns the message id associated with the query handle.
62  #[inline]
63  pub const fn message_id(&self) -> u16 {
64    self.mid
65  }
66
67  /// Returns the query id associated with the query handle.
68  #[inline]
69  pub const fn query_id(&self) -> usize {
70    self.qid
71  }
72
73  /// Returns the connection id associated with the query handle.
74  #[inline]
75  pub const fn connection_id(&self) -> usize {
76    self.cid
77  }
78}
79
80/// A query event
81#[derive(Debug, Eq, PartialEq)]
82pub struct Query<'container, 'innards> {
83  msg: Message<'container, 'innards>,
84  query_handle: QueryHandle,
85}
86
87impl<'container, 'innards> Query<'container, 'innards> {
88  #[inline]
89  const fn new(msg: Message<'container, 'innards>, query_handle: QueryHandle) -> Self {
90    Self { msg, query_handle }
91  }
92
93  /// Returns the question associated with the query event.
94  #[inline]
95  pub fn questions(&self) -> &[Question<'innards>] {
96    self.msg.questions()
97  }
98
99  /// Returns the query handle associated with the query event.
100  #[inline]
101  pub const fn query_handle(&self) -> QueryHandle {
102    self.query_handle
103  }
104}
105
106/// A response event
107#[derive(Debug, Eq, PartialEq)]
108pub struct Response<'innards> {
109  query_handle: QueryHandle,
110  question: Question<'innards>,
111}
112
113impl<'innards> Response<'innards> {
114  /// Creates a new response event.
115  #[inline]
116  pub const fn new(query_handle: QueryHandle, question: Question<'innards>) -> Self {
117    Self {
118      query_handle,
119      question,
120    }
121  }
122
123  /// Returns the query handle associated with the response event.
124  #[inline]
125  pub const fn query_handle(&self) -> QueryHandle {
126    self.query_handle
127  }
128
129  /// Returns the question associated with the response event.
130  #[inline]
131  pub const fn question(&self) -> &Question<'innards> {
132    &self.question
133  }
134}
135
136/// An outgoing event
137pub struct Outgoing {
138  flags: Flags,
139  unicast: bool,
140  id: u16,
141}
142
143impl Outgoing {
144  /// Creates a new outgoing event.
145  #[inline]
146  const fn new(flags: Flags, unicast: bool, id: u16) -> Self {
147    Self { flags, unicast, id }
148  }
149
150  /// Returns the message flags should be used for the outgoing [`Message`].
151  #[inline]
152  pub const fn flags(&self) -> Flags {
153    self.flags
154  }
155
156  /// Returns `true` if the outgoing event is unicast.
157  #[inline]
158  pub const fn is_unicast(&self) -> bool {
159    self.unicast
160  }
161
162  /// Returns the message id should be used for the outgoing [`Message`].
163  ///
164  /// - `0` for multicast response
165  /// - other values for unicast response
166  #[inline]
167  pub const fn id(&self) -> u16 {
168    self.id
169  }
170}
171
172/// The result of a connection is closed
173#[derive(Debug, Eq, PartialEq, Clone, Copy)]
174pub struct Closed<Q> {
175  /// The remaining queries associated with the connection, if any.
176  pub remainings: Q,
177  /// The closed connection handle.
178  pub connection_handle: ConnectionHandle,
179}
180
181/// The main entry point to the library
182///
183/// This object performs no I/O whatsoever. Instead, it consumes incoming packets and
184/// connection-generated events via `handle` and `handle_event`.
185pub struct Endpoint<S, Q> {
186  connections: S,
187  _q: PhantomData<Q>,
188}
189
190impl<S, Q> Default for Endpoint<S, Q>
191where
192  S: Pool<Q>,
193  Q: Pool<u16>,
194{
195  fn default() -> Self {
196    Self::new()
197  }
198}
199
200impl<S, Q> Endpoint<S, Q>
201where
202  S: Pool<Q>,
203  Q: Pool<u16>,
204{
205  /// Create a new server endpoint
206  pub fn new() -> Self {
207    Self {
208      connections: S::new(),
209      _q: PhantomData,
210    }
211  }
212
213  /// Create a new server endpoint with a specific capacity
214  pub fn with_capacity(capacity: usize) -> Result<Self, S::Error> {
215    Ok(Self {
216      connections: S::with_capacity(capacity)?,
217      _q: PhantomData,
218    })
219  }
220
221  /// Close the endpoint
222  pub fn close(&mut self) {
223    self.connections.iter().for_each(|(_idx, conn)| {
224      if !conn.is_empty() {
225        #[cfg(feature = "tracing")]
226        tracing::warn!(
227          "mdns endpoint: connection {} closed with {} remaining queries",
228          _idx,
229          conn.len()
230        );
231      }
232    });
233  }
234
235  /// Accept a new connection
236  pub fn accept(&mut self) -> Result<ConnectionHandle, Error<S::Error, Q::Error>> {
237    let key = self
238      .connections
239      .insert(Q::new())
240      .map_err(Error::Connection)?;
241    Ok(ConnectionHandle(key))
242  }
243
244  /// Handle an incoming query message
245  pub fn recv<'container, 'innards>(
246    &mut self,
247    ch: ConnectionHandle,
248    msg: Message<'container, 'innards>,
249  ) -> Result<Query<'container, 'innards>, Error<S::Error, Q::Error>> {
250    let id = msg.id();
251    let flags = msg.flags();
252    let opcode = flags.opcode();
253
254    if opcode != Opcode::Query {
255      // "In both multicast query and multicast response messages, the OPCODE MUST
256      // be zero on transmission (only standard queries are currently supported
257      // over multicast).  Multicast DNS messages received with an OPCODE other
258      // than zero MUST be silently ignored."  Note: OpcodeQuery == 0
259      #[cfg(feature = "tracing")]
260      tracing::error!(opcode = ?opcode, "mdns endpoint: received query with non-zero OpCode");
261      return Err(Error::InvalidOpcode(opcode));
262    }
263
264    let resp_code = flags.response_code();
265    if resp_code != ResponseCode::NoError {
266      // "In both multicast query and multicast response messages, the Response
267      // Code MUST be zero on transmission.  Multicast DNS messages received with
268      // non-zero Response Codes MUST be silently ignored."
269      #[cfg(feature = "tracing")]
270      tracing::error!(rcode = ?resp_code, "mdns endpoint: received query with non-zero response code");
271      return Err(Error::InvalidResponseCode(resp_code));
272    }
273
274    // TODO(reddaly): Handle "TC (Truncated) Bit":
275    //    In query messages, if the TC bit is set, it means that additional
276    //    Known-Answer records may be following shortly.  A responder SHOULD
277    //    record this fact, and wait for those additional Known-Answer records,
278    //    before deciding whether to respond.  If the TC bit is clear, it means
279    //    that the querying host has no additional Known Answers.
280    if flags.truncated() {
281      #[cfg(feature = "tracing")]
282      tracing::error!(
283        "mdns endpoint: support for mDNS requests with high truncated bit not implemented"
284      );
285      return Err(Error::TrancatedQuery);
286    }
287
288    if let Some(conn) = self.connections.get_mut(ch.0) {
289      let qid = conn.insert(id).map_err(Error::Query)?;
290      return Ok(Query::new(msg, QueryHandle::new(ch.into(), qid, id)));
291    }
292
293    Err(Error::ConnectionNotFound(ch))
294  }
295
296  /// Generate a response for a question
297  pub fn response(
298    &mut self,
299    qh: QueryHandle,
300    question: Question<'_>,
301  ) -> Result<Outgoing, Error<S::Error, Q::Error>> {
302    let mut flags = Flags::new();
303    flags
304      .set_response_code(ResponseCode::NoError)
305      .set_authoritative(true);
306
307    // Handle unicast and multicast responses.
308    // TODO(reddaly): The decision about sending over unicast vs. multicast is not
309    // yet fully compliant with RFC 6762.  For example, the unicast bit should be
310    // ignored if the records in question are close to TTL expiration.  For now,
311    // we just use the unicast bit to make the decision, as per the spec:
312    //     RFC 6762, section 18.12.  Repurposing of Top Bit of qclass in Query
313    //     Section
314    //
315    //     In the Query Section of a Multicast DNS query, the top bit of the
316    //     qclass field is used to indicate that unicast responses are preferred
317    //     for this particular question.  (See Section 5.4.)
318    let qc = question.class();
319    let unicast = (qc & (1 << 15)) != 0 || FORCE_UNICAST_RESPONSES;
320
321    // 18.1: ID (Query Identifier)
322    // 0 for multicast response, query.Id for unicast response
323    let mut id = 0;
324    if unicast {
325      id = qh.message_id();
326    }
327
328    Ok(Outgoing::new(flags, unicast, id))
329  }
330
331  /// Handle a query drain event
332  pub fn drain_query(&mut self, qh: QueryHandle) -> Result<(), Error<S::Error, Q::Error>> {
333    match self.connections.get_mut(qh.cid) {
334      Some(q) => match q.try_remove(qh.qid) {
335        Some(_) => Ok(()),
336        None => Err(Error::QueryNotFound(qh)),
337      },
338      None => Err(Error::ConnectionNotFound(ConnectionHandle(qh.cid))),
339    }
340  }
341
342  /// Handle a connection drain event
343  pub fn drain_connection(
344    &mut self,
345    ch: ConnectionHandle,
346  ) -> Result<Closed<Q>, Error<S::Error, Q::Error>> {
347    match self.connections.try_remove(ch.into()) {
348      Some(queries) => {
349        #[cfg(feature = "tracing")]
350        if !queries.is_empty() {
351          tracing::warn!(
352            "mdns endpoint: connection {} closed with {} remaining queries",
353            ch,
354            queries.len()
355          );
356        }
357        Ok(Closed {
358          remainings: queries,
359          connection_handle: ch,
360        })
361      }
362      None => Err(Error::ConnectionNotFound(ch)),
363    }
364  }
365}