1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
use std::io;
use std::net::{SocketAddr, TcpListener, TcpStream};
use std::sync::{
    mpsc::{self, Receiver},
    Arc, Mutex,
};
use std::time::Duration;

use failure::Fail;
use log::*;

use crate::byte_stream::ByteStream;
use crate::model;
use crate::model::{Error, ErrorKind};
use crate::tcp_listener_ext::*;

pub struct TcpAcceptor {
    listener: TcpListener,
    rw_timeout: Option<Duration>,
    /// receive termination message
    rx: Arc<Mutex<Receiver<()>>>,
    /// timeout for accept
    accept_timeout: Option<Duration>,
}

impl TcpAcceptor {
    fn new(
        listener: TcpListener,
        rw_timeout: Option<Duration>,
        rx: Arc<Mutex<Receiver<()>>>,
        accept_timeout: Option<Duration>,
    ) -> Self {
        Self {
            listener,
            rw_timeout,
            rx,
            accept_timeout,
        }
    }

    fn accept_timeout(&self) -> io::Result<(TcpStream, SocketAddr)> {
        self.listener
            .accept_timeout(self.accept_timeout.clone())
            .and_then(|(tcp, addr)| {
                tcp.set_read_timeout(self.rw_timeout.clone())?;
                tcp.set_write_timeout(self.rw_timeout.clone())?;
                Ok((tcp, addr))
            })
    }
}

fn check_message(rx: &Arc<Mutex<Receiver<()>>>) -> Result<bool, Error> {
    use mpsc::TryRecvError;
    match rx.lock()?.try_recv() {
        Ok(()) => Ok(true),
        Err(TryRecvError::Empty) => Ok(false),
        Err(TryRecvError::Disconnected) => Err(ErrorKind::disconnected("acceptor").into()),
    }
}

macro_rules! check_done {
    ($rx:expr) => {
        match check_message($rx) {
            Ok(true) => return None,
            Ok(false) => {}
            Err(_) => return None,
        }
    };
}

impl Iterator for TcpAcceptor {
    type Item = (TcpStream, SocketAddr);
    fn next(&mut self) -> Option<Self::Item> {
        loop {
            check_done!(&self.rx);
            match self.accept_timeout() {
                Ok(x) => return Some(x),
                Err(err) if err.kind() == io::ErrorKind::TimedOut => {
                    // trace!("accept timeout: {}", err);
                }
                Err(err) => {
                    error!("accept error: {}", err);
                    trace!("accept error: {:?}", err);
                    return None;
                }
            }
        }
    }
}

pub trait Binder {
    type Stream: ByteStream + 'static;
    type Iter: Iterator<Item = (Self::Stream, SocketAddr)> + Send + 'static;
    fn bind(&self, addr: SocketAddr) -> Result<Self::Iter, Error>;
}

pub struct TcpBinder {
    rw_timeout: Option<Duration>,
    /// receiver for Acceptor termination message
    rx: Arc<Mutex<Receiver<()>>>,
    accept_timeout: Option<Duration>,
}

impl TcpBinder {
    pub fn new(
        rw_timeout: Option<Duration>,
        rx: Arc<Mutex<Receiver<()>>>,
        accept_timeout: Option<Duration>,
    ) -> Self {
        Self {
            rw_timeout,
            rx,
            accept_timeout,
        }
    }
}

impl Binder for TcpBinder {
    type Stream = TcpStream;
    type Iter = TcpAcceptor;
    fn bind(&self, addr: SocketAddr) -> Result<Self::Iter, Error> {
        let tcp = net2::TcpBuilder::new_v4()?;
        let tcp = tcp
            .reuse_address(true)?
            .bind(&addr)
            .map_err(|err| addr_error(err, addr))?;

        // `backlog` parameter to `TcpBuilder::listen() is directly passed to `listen(2)` system call.
        // If it is too small, clients may not `connect(2)` to the server.
        // Here, `backlog` is intended to be as large as `net.core.somaxconn` kernel parameter,
        let listener = tcp.listen(256)?;

        Ok(TcpAcceptor::new(
            listener,
            self.rw_timeout,
            self.rx.clone(),
            self.accept_timeout,
        ))
    }
}

fn addr_error(io_err: io::Error, addr: SocketAddr) -> model::Error {
    match io_err.kind() {
        io::ErrorKind::AddrInUse => ErrorKind::AddressAlreadInUse { addr }.into(),
        io::ErrorKind::AddrNotAvailable => ErrorKind::AddressNotAvailable { addr }.into(),
        _ => io_err.context(ErrorKind::Io),
    }
    .into()
}