use core::marker::PhantomData;
use super::{
ConnectionHandle, Pool,
error::ProtoError,
proto::{Flags, Message, Opcode, Question, ResponseCode},
};
const FORCE_UNICAST_RESPONSES: bool = false;
#[cfg(feature = "slab")]
#[cfg_attr(docsrs, doc(cfg(feature = "slab")))]
pub type SlabEndpoint = Endpoint<slab::Slab<slab::Slab<u16>>, slab::Slab<u16>>;
#[derive(Debug, thiserror::Error)]
pub enum Error<S, Q> {
#[error(transparent)]
Connection(S),
#[error(transparent)]
Query(Q),
#[error("connection not found: {0}")]
ConnectionNotFound(ConnectionHandle),
#[error("query {qid} not found on connection {cid}", qid = _0.qid, cid = _0.cid)]
QueryNotFound(QueryHandle),
#[error("invalid opcode: {0:?}")]
InvalidOpcode(Opcode),
#[error("invalid response code: {0:?}")]
InvalidResponseCode(ResponseCode),
#[error("support for DNS requests with high truncated bit not implemented")]
TrancatedQuery,
#[error(transparent)]
Proto(#[from] ProtoError),
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub struct QueryHandle {
qid: usize, mid: u16, cid: usize, }
impl QueryHandle {
#[inline]
const fn new(cid: usize, qid: usize, mid: u16) -> Self {
Self { cid, qid, mid }
}
#[inline]
pub const fn message_id(&self) -> u16 {
self.mid
}
#[inline]
pub const fn query_id(&self) -> usize {
self.qid
}
#[inline]
pub const fn connection_id(&self) -> usize {
self.cid
}
}
#[derive(Debug, Eq, PartialEq)]
pub struct Query<'container, 'innards> {
msg: Message<'container, 'innards>,
query_handle: QueryHandle,
}
impl<'container, 'innards> Query<'container, 'innards> {
#[inline]
const fn new(msg: Message<'container, 'innards>, query_handle: QueryHandle) -> Self {
Self { msg, query_handle }
}
#[inline]
pub fn questions(&self) -> &[Question<'innards>] {
self.msg.questions()
}
#[inline]
pub const fn query_handle(&self) -> QueryHandle {
self.query_handle
}
}
#[derive(Debug, Eq, PartialEq)]
pub struct Response<'innards> {
query_handle: QueryHandle,
question: Question<'innards>,
}
impl<'innards> Response<'innards> {
#[inline]
pub const fn new(query_handle: QueryHandle, question: Question<'innards>) -> Self {
Self {
query_handle,
question,
}
}
#[inline]
pub const fn query_handle(&self) -> QueryHandle {
self.query_handle
}
#[inline]
pub const fn question(&self) -> &Question<'innards> {
&self.question
}
}
pub struct Outgoing {
flags: Flags,
unicast: bool,
id: u16,
}
impl Outgoing {
#[inline]
const fn new(flags: Flags, unicast: bool, id: u16) -> Self {
Self { flags, unicast, id }
}
#[inline]
pub const fn flags(&self) -> Flags {
self.flags
}
#[inline]
pub const fn is_unicast(&self) -> bool {
self.unicast
}
#[inline]
pub const fn id(&self) -> u16 {
self.id
}
}
#[derive(Debug, Eq, PartialEq, Clone, Copy)]
pub struct Closed<Q> {
pub remainings: Q,
pub connection_handle: ConnectionHandle,
}
pub struct Endpoint<S, Q> {
connections: S,
_q: PhantomData<Q>,
}
impl<S, Q> Default for Endpoint<S, Q>
where
S: Pool<Q>,
Q: Pool<u16>,
{
fn default() -> Self {
Self::new()
}
}
impl<S, Q> Endpoint<S, Q>
where
S: Pool<Q>,
Q: Pool<u16>,
{
pub fn new() -> Self {
Self {
connections: S::new(),
_q: PhantomData,
}
}
pub fn with_capacity(capacity: usize) -> Result<Self, S::Error> {
Ok(Self {
connections: S::with_capacity(capacity)?,
_q: PhantomData,
})
}
pub fn close(&mut self) {
self.connections.iter().for_each(|(_idx, conn)| {
if !conn.is_empty() {
#[cfg(feature = "tracing")]
tracing::warn!(
"mdns endpoint: connection {} closed with {} remaining queries",
_idx,
conn.len()
);
}
});
}
pub fn accept(&mut self) -> Result<ConnectionHandle, Error<S::Error, Q::Error>> {
let key = self
.connections
.insert(Q::new())
.map_err(Error::Connection)?;
Ok(ConnectionHandle(key))
}
pub fn recv<'container, 'innards>(
&mut self,
ch: ConnectionHandle,
msg: Message<'container, 'innards>,
) -> Result<Query<'container, 'innards>, Error<S::Error, Q::Error>> {
let id = msg.id();
let flags = msg.flags();
let opcode = flags.opcode();
if opcode != Opcode::Query {
#[cfg(feature = "tracing")]
tracing::error!(opcode = ?opcode, "mdns endpoint: received query with non-zero OpCode");
return Err(Error::InvalidOpcode(opcode));
}
let resp_code = flags.response_code();
if resp_code != ResponseCode::NoError {
#[cfg(feature = "tracing")]
tracing::error!(rcode = ?resp_code, "mdns endpoint: received query with non-zero response code");
return Err(Error::InvalidResponseCode(resp_code));
}
if flags.truncated() {
#[cfg(feature = "tracing")]
tracing::error!(
"mdns endpoint: support for mDNS requests with high truncated bit not implemented"
);
return Err(Error::TrancatedQuery);
}
if let Some(conn) = self.connections.get_mut(ch.0) {
let qid = conn.insert(id).map_err(Error::Query)?;
return Ok(Query::new(msg, QueryHandle::new(ch.into(), qid, id)));
}
Err(Error::ConnectionNotFound(ch))
}
pub fn response(
&mut self,
qh: QueryHandle,
question: Question<'_>,
) -> Result<Outgoing, Error<S::Error, Q::Error>> {
let mut flags = Flags::new();
flags
.set_response_code(ResponseCode::NoError)
.set_authoritative(true);
let qc = question.class();
let unicast = (qc & (1 << 15)) != 0 || FORCE_UNICAST_RESPONSES;
let mut id = 0;
if unicast {
id = qh.message_id();
}
Ok(Outgoing::new(flags, unicast, id))
}
pub fn drain_query(&mut self, qh: QueryHandle) -> Result<(), Error<S::Error, Q::Error>> {
match self.connections.get_mut(qh.cid) {
Some(q) => match q.try_remove(qh.qid) {
Some(_) => Ok(()),
None => Err(Error::QueryNotFound(qh)),
},
None => Err(Error::ConnectionNotFound(ConnectionHandle(qh.cid))),
}
}
pub fn drain_connection(
&mut self,
ch: ConnectionHandle,
) -> Result<Closed<Q>, Error<S::Error, Q::Error>> {
match self.connections.try_remove(ch.into()) {
Some(queries) => {
#[cfg(feature = "tracing")]
if !queries.is_empty() {
tracing::warn!(
"mdns endpoint: connection {} closed with {} remaining queries",
ch,
queries.len()
);
}
Ok(Closed {
remainings: queries,
connection_handle: ch,
})
}
None => Err(Error::ConnectionNotFound(ch)),
}
}
}