1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
//! Defines components for running a DNS server.
//!
//! The `server` module currently provides a barebones server implementation:
//!
//! * Support for UDP only.
//! * No support for AXFR.
//! * No support for recursion.

use {WireDecoder, WireEncoder, WireMessage, std, wire};
use std::net::UdpSocket;
use std::sync::Arc;

// TODO: Replace println statements with rigorous logging.

/// Specifies an error that occurred while receiving a request and sending its
/// response.
#[derive(Debug)]
pub enum ServerError {
    Io { inner: std::io::Error, what: String },
}

impl std::error::Error for ServerError {
    fn description(&self) -> &str {
        match self {
            &ServerError::Io { ref what, .. } => what,
        }
    }
}

impl std::fmt::Display for ServerError {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
        let d = (self as &std::error::Error).description();
        match self {
            &ServerError::Io { ref inner, .. } => write!(f, "{}: {}", d, inner),
        }
    }
}

/// Handles server events—e.g., by doing a DNS lookup to respond to a DNS
/// request.
///
/// Applications implement the `Handler` trait.
///
pub trait Handler {
    type Error: std::error::Error;
    fn handle_query<'a>(&self,
                        query: &WireMessage,
                        encoder: WireEncoder<'a, wire::marker::Response, wire::marker::AnswerSection>)
                        -> WireEncoder<'a, wire::marker::Response, wire::marker::Done>;
}

/// Manages all networking of a DNS server.
#[derive(Debug)]
pub struct Server<'a, H: 'a + Handler> {
    socket: Arc<UdpSocket>,
    handler: &'a H,
}

impl<'a, H: Handler> Server<'a, H> {
    pub fn new(h: &'a H) -> Result<Self, ServerError> {
        // TODO: Support options for binding to other ports.
        let addr = "0.0.0.0:53";
        let socket = UdpSocket::bind(addr).map_err(|e| {
                ServerError::Io {
                    inner: e,
                    what: format!("Failed to bind UDP socket to {}", addr),
                }
            })?;
        Ok(Server {
            socket: Arc::new(socket),
            handler: h,
        })
    }

    pub fn serve(self) -> Result<(), ServerError> {
        const MAX_UDP_MESSAGE_LEN: usize = 512;
        let mut ibuffer: [u8; MAX_UDP_MESSAGE_LEN] = [0; MAX_UDP_MESSAGE_LEN];
        // println!("Server is listening");
        loop {
            let (recv_len, peer_addr) = self.socket
                .recv_from(&mut ibuffer)
                .map_err(|e| {
                    ServerError::Io {
                        inner: e,
                        what: String::from("Failed to receive from UDP socket"),
                    }
                })?;
            let ipayload = &ibuffer[..recv_len];

            let mut decoder = WireDecoder::new(ipayload);
            let request = match decoder.decode_message() {
                Ok(x) => x,
                Err(e) => {
                    println!("Received invalid message: {}", e);
                    continue;
                }
            };

            // println!("Received message: {:?}", request);

            let mut obuffer: [u8; MAX_UDP_MESSAGE_LEN] = [0; MAX_UDP_MESSAGE_LEN];
            let encoder = match WireEncoder::new_response(&mut obuffer[..], &request) {
                Ok(x) => x,
                Err(_) => continue, // TODO: Should send a SERVFAIL or FORMERR here, probably.
            };

            let encoder = self.handler.handle_query(&request, encoder);
            let opayload = encoder.as_bytes();

            match self.socket.send_to(opayload, peer_addr) {
                Ok(send_len) => {
                    if send_len != opayload.len() {
                        println!("Sent unexpected number of bytes on UDP socket: Expected to send {}, actually sent \
                                  {}",
                                 opayload.len(),
                                 send_len);
                    }
                }
                Err(e) => {
                    println!("Failed to send on UDP socket: {}", e);
                }
            }
        }
    }
}