1use {WireDecoder, WireEncoder, WireMessage, std, wire};
10use std::net::UdpSocket;
11use std::sync::Arc;
12
13#[derive(Debug)]
18pub enum ServerError {
19 Io { inner: std::io::Error, what: String },
20}
21
22impl std::error::Error for ServerError {
23 fn description(&self) -> &str {
24 match self {
25 &ServerError::Io { ref what, .. } => what,
26 }
27 }
28}
29
30impl std::fmt::Display for ServerError {
31 fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
32 let d = (self as &std::error::Error).description();
33 match self {
34 &ServerError::Io { ref inner, .. } => write!(f, "{}: {}", d, inner),
35 }
36 }
37}
38
39pub trait Handler {
45 type Error: std::error::Error;
46 fn handle_query<'a>(&self,
47 query: &WireMessage,
48 encoder: WireEncoder<'a, wire::marker::Response, wire::marker::AnswerSection>)
49 -> WireEncoder<'a, wire::marker::Response, wire::marker::Done>;
50}
51
52#[derive(Debug)]
54pub struct Server<'a, H: 'a + Handler> {
55 socket: Arc<UdpSocket>,
56 handler: &'a H,
57}
58
59impl<'a, H: Handler> Server<'a, H> {
60 pub fn new(h: &'a H) -> Result<Self, ServerError> {
61 let addr = "0.0.0.0:53";
63 let socket = UdpSocket::bind(addr).map_err(|e| {
64 ServerError::Io {
65 inner: e,
66 what: format!("Failed to bind UDP socket to {}", addr),
67 }
68 })?;
69 Ok(Server {
70 socket: Arc::new(socket),
71 handler: h,
72 })
73 }
74
75 pub fn serve(self) -> Result<(), ServerError> {
76 const MAX_UDP_MESSAGE_LEN: usize = 512;
77 let mut ibuffer: [u8; MAX_UDP_MESSAGE_LEN] = [0; MAX_UDP_MESSAGE_LEN];
78 loop {
80 let (recv_len, peer_addr) = self.socket
81 .recv_from(&mut ibuffer)
82 .map_err(|e| {
83 ServerError::Io {
84 inner: e,
85 what: String::from("Failed to receive from UDP socket"),
86 }
87 })?;
88 let ipayload = &ibuffer[..recv_len];
89
90 let mut decoder = WireDecoder::new(ipayload);
91 let request = match decoder.decode_message() {
92 Ok(x) => x,
93 Err(e) => {
94 println!("Received invalid message: {}", e);
95 continue;
96 }
97 };
98
99 let mut obuffer: [u8; MAX_UDP_MESSAGE_LEN] = [0; MAX_UDP_MESSAGE_LEN];
102 let encoder = match WireEncoder::new_response(&mut obuffer[..], &request) {
103 Ok(x) => x,
104 Err(_) => continue, };
106
107 let encoder = self.handler.handle_query(&request, encoder);
108 let opayload = encoder.as_bytes();
109
110 match self.socket.send_to(opayload, peer_addr) {
111 Ok(send_len) => {
112 if send_len != opayload.len() {
113 println!("Sent unexpected number of bytes on UDP socket: Expected to send {}, actually sent \
114 {}",
115 opayload.len(),
116 send_len);
117 }
118 }
119 Err(e) => {
120 println!("Failed to send on UDP socket: {}", e);
121 }
122 }
123 }
124 }
125}