Skip to main content

ntex_server/net/
accept.rs

1#![allow(clippy::missing_panics_doc)]
2use std::time::{Duration, Instant};
3use std::{cell::Cell, fmt, io, sync::Arc, sync::mpsc, thread};
4use std::{collections::VecDeque, num::NonZeroUsize};
5
6use ntex_polling::{Event, Events, Poller};
7use ntex_rt::System;
8use ntex_util::{future::Either, time::Millis, time::sleep};
9
10use super::socket::{Connection, Listener, SocketAddr};
11use super::{Server, ServerStatus, Token};
12
13const EXIT_TIMEOUT: Duration = Duration::from_millis(100);
14const ERR_TIMEOUT: Duration = Duration::from_millis(500);
15const ERR_SLEEP_TIMEOUT: Millis = Millis(525);
16
17#[derive(Debug)]
18pub enum AcceptorCommand {
19    Stop(oneshot::Sender<()>),
20    Terminate,
21    Pause,
22    Resume,
23    Timer,
24}
25
26#[derive(Debug)]
27struct ServerSocketInfo {
28    addr: SocketAddr,
29    token: Token,
30    sock: Listener,
31    registered: Cell<bool>,
32    timeout: Cell<Option<Instant>>,
33}
34
35#[derive(Debug, Clone)]
36pub struct AcceptNotify(Arc<Poller>, mpsc::Sender<AcceptorCommand>);
37
38impl AcceptNotify {
39    fn new(waker: Arc<Poller>, tx: mpsc::Sender<AcceptorCommand>) -> Self {
40        AcceptNotify(waker, tx)
41    }
42
43    pub fn send(&self, cmd: AcceptorCommand) {
44        let _ = self.1.send(cmd);
45        let _ = self.0.notify();
46    }
47}
48
49/// Streamin io accept loop
50pub struct AcceptLoop {
51    name: String,
52    testing: bool,
53    notify: AcceptNotify,
54    inner: Option<(mpsc::Receiver<AcceptorCommand>, Arc<Poller>)>,
55    status_handler: Option<Box<dyn FnMut(ServerStatus) + Send>>,
56}
57
58impl Default for AcceptLoop {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64impl AcceptLoop {
65    /// Create accept loop
66    pub fn new() -> AcceptLoop {
67        // Create a poller instance
68        let poll = Arc::new(
69            Poller::new()
70                .map_err(|e| panic!("Cannot create Poller {e}"))
71                .unwrap(),
72        );
73
74        let (tx, rx) = mpsc::channel();
75        let notify = AcceptNotify::new(poll.clone(), tx);
76
77        AcceptLoop {
78            notify,
79            name: "ntex:accept".to_string(),
80            inner: Some((rx, poll)),
81            testing: false,
82            status_handler: None,
83        }
84    }
85
86    /// Set server name.
87    ///
88    /// Name is used for worker thread name
89    pub fn name<T: AsRef<str>>(&mut self, name: T) {
90        self.name = format!("{}:accept", name.as_ref());
91    }
92
93    /// Get notification api for the loop
94    pub fn notify(&self) -> AcceptNotify {
95        self.notify.clone()
96    }
97
98    pub fn set_status_handler<F>(&mut self, f: F)
99    where
100        F: FnMut(ServerStatus) + Send + 'static,
101    {
102        self.status_handler = Some(Box::new(f));
103    }
104
105    pub fn testing(&mut self) {
106        self.testing = true;
107    }
108
109    /// Start accept loop
110    pub fn start(mut self, socks: Vec<(Token, Listener)>, srv: Server) {
111        let (tx, rx_start) = oneshot::channel();
112        let (rx, poll) = self
113            .inner
114            .take()
115            .expect("AcceptLoop cannot be used multiple times");
116
117        Accept::start(
118            tx,
119            rx,
120            poll,
121            socks,
122            srv,
123            self.name.clone(),
124            self.notify.clone(),
125            self.testing,
126            self.status_handler.take(),
127        );
128
129        let _ = rx_start.recv();
130    }
131}
132
133impl fmt::Debug for AcceptLoop {
134    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135        f.debug_struct("AcceptLoop")
136            .field("name", &self.name)
137            .field("notify", &self.notify)
138            .field("inner", &self.inner)
139            .field("status_handler", &self.status_handler.is_some())
140            .finish()
141    }
142}
143
144struct Accept {
145    poller: Arc<Poller>,
146    rx: mpsc::Receiver<AcceptorCommand>,
147    tx: Option<oneshot::Sender<()>>,
148    sockets: Vec<ServerSocketInfo>,
149    srv: Server,
150    notify: AcceptNotify,
151    testing: bool,
152    backpressure: bool,
153    backlog: VecDeque<Connection>,
154    status_handler: Option<Box<dyn FnMut(ServerStatus) + Send>>,
155}
156
157impl Accept {
158    #[allow(clippy::too_many_arguments)]
159    fn start(
160        tx: oneshot::Sender<()>,
161        rx: mpsc::Receiver<AcceptorCommand>,
162        poller: Arc<Poller>,
163        socks: Vec<(Token, Listener)>,
164        srv: Server,
165        name: String,
166        notify: AcceptNotify,
167        testing: bool,
168        status_handler: Option<Box<dyn FnMut(ServerStatus) + Send>>,
169    ) {
170        log::info!("Starting {name:?} accept loop");
171
172        // start accept thread
173        let sys = System::current();
174        let _ = thread::Builder::new().name(name).spawn(move || {
175            System::set_current(sys);
176            Accept::new(tx, rx, poller, socks, srv, notify, testing, status_handler).poll();
177        });
178    }
179
180    #[allow(clippy::too_many_arguments)]
181    fn new(
182        tx: oneshot::Sender<()>,
183        rx: mpsc::Receiver<AcceptorCommand>,
184        poller: Arc<Poller>,
185        socks: Vec<(Token, Listener)>,
186        srv: Server,
187        notify: AcceptNotify,
188        testing: bool,
189        status_handler: Option<Box<dyn FnMut(ServerStatus) + Send>>,
190    ) -> Accept {
191        let mut sockets = Vec::new();
192        for (hnd_token, lst) in socks {
193            sockets.push(ServerSocketInfo {
194                addr: lst.local_addr(),
195                sock: lst,
196                token: hnd_token,
197                registered: Cell::new(false),
198                timeout: Cell::new(None),
199            });
200        }
201
202        Accept {
203            poller,
204            rx,
205            sockets,
206            notify,
207            srv,
208            testing,
209            status_handler,
210            tx: Some(tx),
211            backpressure: true,
212            backlog: VecDeque::new(),
213        }
214    }
215
216    fn update_status(&mut self, st: ServerStatus) {
217        if let Some(ref mut hnd) = self.status_handler {
218            (*hnd)(st);
219        }
220    }
221
222    fn poll(mut self) {
223        // Create storage for events
224        let mut events = Events::with_capacity(NonZeroUsize::new(512).unwrap());
225
226        let mut timeout = Some(Duration::ZERO);
227        loop {
228            events.clear();
229
230            if let Err(e) = self.poller.wait(&mut events, timeout) {
231                assert!(
232                    e.kind() == io::ErrorKind::Interrupted,
233                    "Cannot wait for events in poller: {e}"
234                );
235            } else if timeout.is_some() {
236                timeout = None;
237                let _ = self.tx.take().unwrap().send(());
238            }
239
240            for idx in 0..self.sockets.len() {
241                if self.sockets[idx].registered.get() {
242                    let readd = self.accept(idx);
243                    if readd {
244                        self.add_source(idx);
245                    }
246                }
247            }
248
249            match self.process_cmd() {
250                Either::Left(()) => events.clear(),
251                Either::Right(rx) => {
252                    // cleanup
253                    for info in self.sockets.drain(..) {
254                        info.sock.remove_source();
255                    }
256                    log::info!("Accept loop has been stopped");
257
258                    if let Some(rx) = rx {
259                        if !self.testing {
260                            thread::sleep(EXIT_TIMEOUT);
261                        }
262                        let _ = rx.send(());
263                    }
264
265                    break;
266                }
267            }
268        }
269    }
270
271    fn add_source(&self, idx: usize) {
272        let info = &self.sockets[idx];
273
274        loop {
275            // try to register poller source
276            let result = if info.registered.get() {
277                self.poller.modify(&info.sock, Event::readable(idx))
278            } else {
279                unsafe { self.poller.add(&info.sock, Event::readable(idx)) }
280            };
281            if let Err(err) = result {
282                if err.kind() == io::ErrorKind::WouldBlock {
283                    continue;
284                }
285                log::error!("Cannot register socket listener: {err}");
286
287                // sleep after error
288                info.timeout.set(Some(Instant::now() + ERR_TIMEOUT));
289
290                let notify = self.notify.clone();
291                System::current().arbiter().spawn(Box::pin(async move {
292                    sleep(ERR_SLEEP_TIMEOUT).await;
293                    notify.send(AcceptorCommand::Timer);
294                }));
295            } else {
296                info.registered.set(true);
297            }
298
299            break;
300        }
301    }
302
303    fn remove_source(&self, key: usize) {
304        let info = &self.sockets[key];
305
306        let result = if info.registered.get() {
307            self.poller.modify(&info.sock, Event::none(key))
308        } else {
309            return;
310        };
311
312        // stop listening for incoming connections
313        if let Err(err) = result {
314            log::error!("Cannot stop socket listener for {} err: {}", info.addr, err);
315        }
316    }
317
318    fn process_timer(&mut self) {
319        let now = Instant::now();
320        for key in 0..self.sockets.len() {
321            let info = &mut self.sockets[key];
322            if let Some(inst) = info.timeout.get()
323                && now > inst
324                && !self.backpressure
325            {
326                log::info!("Resuming socket listener on {} after timeout", info.addr);
327                info.timeout.take();
328                self.add_source(key);
329            }
330        }
331    }
332
333    fn process_cmd(&mut self) -> Either<(), Option<oneshot::Sender<()>>> {
334        loop {
335            match self.rx.try_recv() {
336                Ok(cmd) => match cmd {
337                    AcceptorCommand::Stop(rx) => {
338                        if !self.backpressure {
339                            log::info!("Stopping accept loop");
340                            self.backpressure(true);
341                        }
342                        break Either::Right(Some(rx));
343                    }
344                    AcceptorCommand::Terminate => {
345                        log::info!("Stopping accept loop");
346                        self.backpressure(true);
347                        break Either::Right(None);
348                    }
349                    AcceptorCommand::Pause => {
350                        if !self.backpressure {
351                            log::info!("Pausing accept loop");
352                            self.backpressure(true);
353                        }
354                    }
355                    AcceptorCommand::Resume => {
356                        if self.backpressure {
357                            log::info!("Resuming accept loop");
358                            self.backpressure(false);
359                        }
360                    }
361                    AcceptorCommand::Timer => {
362                        self.process_timer();
363                    }
364                },
365                Err(err) => {
366                    break match err {
367                        mpsc::TryRecvError::Empty => Either::Left(()),
368                        mpsc::TryRecvError::Disconnected => {
369                            log::error!("Dropping accept loop");
370                            self.backpressure(true);
371                            Either::Right(None)
372                        }
373                    };
374                }
375            }
376        }
377    }
378
379    fn backpressure(&mut self, on: bool) {
380        self.update_status(if on {
381            ServerStatus::NotReady
382        } else {
383            ServerStatus::Ready
384        });
385
386        if self.backpressure && !on {
387            // handle backlog
388            while let Some(msg) = self.backlog.pop_front() {
389                if let Err(msg) = self.srv.process(msg) {
390                    log::trace!("Server is unavailable");
391                    self.backlog.push_front(msg);
392                    return;
393                }
394            }
395
396            // re-enable acceptors
397            self.backpressure = false;
398            for (key, info) in self.sockets.iter().enumerate() {
399                if info.timeout.get().is_none() {
400                    // socket with timeout will re-register itself after timeout
401                    log::info!(
402                        "Resuming socket listener on {} after back-pressure",
403                        info.addr
404                    );
405                    self.add_source(key);
406                }
407            }
408        } else if !self.backpressure && on {
409            self.backpressure = true;
410            for key in 0..self.sockets.len() {
411                // disable err timeout
412                let info = &mut self.sockets[key];
413                if info.timeout.take().is_none() {
414                    log::info!("Stopping socket listener on {}", info.addr);
415                    self.remove_source(key);
416                }
417            }
418        }
419    }
420
421    fn accept(&mut self, token: usize) -> bool {
422        loop {
423            if let Some(info) = self.sockets.get_mut(token) {
424                match info.sock.accept() {
425                    Ok(Some(io)) => {
426                        let msg = Connection {
427                            io,
428                            token: info.token,
429                        };
430                        if let Err(msg) = self.srv.process(msg) {
431                            log::trace!("Server is unavailable");
432                            self.backlog.push_back(msg);
433                            self.backpressure(true);
434                            return false;
435                        }
436                    }
437                    Ok(None) => return true,
438                    Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return true,
439                    Err(ref e) if connection_error(e) => (),
440                    Err(e) => {
441                        log::error!("Error accepting socket: {e}");
442
443                        // sleep after error
444                        info.timeout.set(Some(Instant::now() + ERR_TIMEOUT));
445
446                        let notify = self.notify.clone();
447                        System::current().arbiter().spawn(Box::pin(async move {
448                            sleep(ERR_SLEEP_TIMEOUT).await;
449                            notify.send(AcceptorCommand::Timer);
450                        }));
451                        return false;
452                    }
453                }
454            }
455        }
456    }
457}
458
459/// This function defines errors that are per-connection. Which basically
460/// means that if we get this error from `accept()` system call it means
461/// next connection might be ready to be accepted.
462///
463/// All other errors will incur a timeout before next `accept()` is performed.
464/// The timeout is useful to handle resource exhaustion errors like ENFILE
465/// and EMFILE. Otherwise, could enter into tight loop.
466fn connection_error(e: &io::Error) -> bool {
467    e.kind() == io::ErrorKind::ConnectionRefused
468        || e.kind() == io::ErrorKind::ConnectionAborted
469        || e.kind() == io::ErrorKind::ConnectionReset
470        || e.kind() == io::ErrorKind::InvalidInput
471}