Skip to main content

ntex_server/net/
accept.rs

1#![allow(clippy::missing_panics_doc)]
2use std::time::{Duration, Instant};
3use std::{cell::Cell, fmt, io, sync::Arc, sync::mpsc, thread};
4use std::{collections::VecDeque, num::NonZeroUsize};
5
6use ntex_polling::{Event, Events, Poller};
7use ntex_rt::System;
8use ntex_util::{future::Either, time::Millis, time::sleep};
9
10use super::socket::{Connection, Listener, SocketAddr};
11use super::{Server, ServerStatus, Token};
12
13const EXIT_TIMEOUT: Duration = Duration::from_millis(100);
14const ERR_TIMEOUT: Duration = Duration::from_millis(500);
15const ERR_SLEEP_TIMEOUT: Millis = Millis(525);
16
17#[derive(Debug)]
18pub enum AcceptorCommand {
19    Stop(oneshot::Sender<()>),
20    Terminate,
21    Pause,
22    Resume,
23    Timer,
24}
25
26#[derive(Debug)]
27struct ServerSocketInfo {
28    addr: SocketAddr,
29    token: Token,
30    sock: Listener,
31    registered: Cell<bool>,
32    timeout: Cell<Option<Instant>>,
33}
34
35#[derive(Debug, Clone)]
36pub struct AcceptNotify(Arc<Poller>, mpsc::Sender<AcceptorCommand>);
37
38impl AcceptNotify {
39    fn new(waker: Arc<Poller>, tx: mpsc::Sender<AcceptorCommand>) -> Self {
40        AcceptNotify(waker, tx)
41    }
42
43    pub fn send(&self, cmd: AcceptorCommand) {
44        let _ = self.1.send(cmd);
45        let _ = self.0.notify();
46    }
47}
48
49/// Streamin io accept loop
50pub struct AcceptLoop {
51    name: String,
52    testing: bool,
53    notify: AcceptNotify,
54    inner: Option<(mpsc::Receiver<AcceptorCommand>, Arc<Poller>)>,
55    status_handler: Option<Box<dyn FnMut(ServerStatus) + Send>>,
56}
57
58impl Default for AcceptLoop {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64impl AcceptLoop {
65    /// Create accept loop
66    pub fn new() -> AcceptLoop {
67        // Create a poller instance
68        let poll = Arc::new(
69            Poller::new()
70                .map_err(|e| panic!("Cannot create Poller {e}"))
71                .unwrap(),
72        );
73
74        let (tx, rx) = mpsc::channel();
75        let notify = AcceptNotify::new(poll.clone(), tx);
76
77        AcceptLoop {
78            notify,
79            name: "ntex:accept".to_string(),
80            inner: Some((rx, poll)),
81            testing: false,
82            status_handler: None,
83        }
84    }
85
86    /// Set server name.
87    ///
88    /// Name is used for worker thread name
89    pub fn name<T: AsRef<str>>(&mut self, name: T) {
90        self.name = format!("{}:accept", name.as_ref());
91    }
92
93    /// Get notification api for the loop
94    pub fn notify(&self) -> AcceptNotify {
95        self.notify.clone()
96    }
97
98    pub fn set_status_handler<F>(&mut self, f: F)
99    where
100        F: FnMut(ServerStatus) + Send + 'static,
101    {
102        self.status_handler = Some(Box::new(f));
103    }
104
105    pub fn testing(&mut self) {
106        self.testing = true;
107    }
108
109    /// Start accept loop
110    pub fn start(mut self, socks: Vec<(Token, Listener)>, srv: Server) {
111        let (tx, rx_start) = oneshot::channel();
112        let (rx, poll) = self
113            .inner
114            .take()
115            .expect("AcceptLoop cannot be used multiple times");
116
117        Accept::start(
118            tx,
119            rx,
120            poll,
121            socks,
122            srv,
123            self.name.clone(),
124            self.notify.clone(),
125            self.testing,
126            self.status_handler.take(),
127        );
128
129        let _ = rx_start.recv();
130    }
131}
132
133impl fmt::Debug for AcceptLoop {
134    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135        f.debug_struct("AcceptLoop")
136            .field("name", &self.name)
137            .field("notify", &self.notify)
138            .field("inner", &self.inner)
139            .field("status_handler", &self.status_handler.is_some())
140            .finish()
141    }
142}
143
144struct Accept {
145    name: String,
146    poller: Arc<Poller>,
147    rx: mpsc::Receiver<AcceptorCommand>,
148    tx: Option<oneshot::Sender<()>>,
149    sockets: Vec<ServerSocketInfo>,
150    srv: Server,
151    notify: AcceptNotify,
152    testing: bool,
153    backpressure: bool,
154    backlog: VecDeque<Connection>,
155    status_handler: Option<Box<dyn FnMut(ServerStatus) + Send>>,
156}
157
158impl Accept {
159    #[allow(clippy::too_many_arguments)]
160    fn start(
161        tx: oneshot::Sender<()>,
162        rx: mpsc::Receiver<AcceptorCommand>,
163        poller: Arc<Poller>,
164        socks: Vec<(Token, Listener)>,
165        srv: Server,
166        name: String,
167        notify: AcceptNotify,
168        testing: bool,
169        status_handler: Option<Box<dyn FnMut(ServerStatus) + Send>>,
170    ) {
171        log::info!("Starting {name:?} accept loop");
172
173        // start accept thread
174        let sys = System::current();
175        let _ = thread::Builder::new().name(name.clone()).spawn(move || {
176            System::set_current(sys);
177            Accept::new(
178                name,
179                tx,
180                rx,
181                poller,
182                socks,
183                srv,
184                notify,
185                testing,
186                status_handler,
187            )
188            .poll();
189        });
190    }
191
192    #[allow(clippy::too_many_arguments)]
193    fn new(
194        name: String,
195        tx: oneshot::Sender<()>,
196        rx: mpsc::Receiver<AcceptorCommand>,
197        poller: Arc<Poller>,
198        socks: Vec<(Token, Listener)>,
199        srv: Server,
200        notify: AcceptNotify,
201        testing: bool,
202        status_handler: Option<Box<dyn FnMut(ServerStatus) + Send>>,
203    ) -> Accept {
204        let mut sockets = Vec::new();
205        for (hnd_token, lst) in socks {
206            sockets.push(ServerSocketInfo {
207                addr: lst.local_addr(),
208                sock: lst,
209                token: hnd_token,
210                registered: Cell::new(false),
211                timeout: Cell::new(None),
212            });
213        }
214
215        Accept {
216            name,
217            poller,
218            rx,
219            sockets,
220            notify,
221            srv,
222            testing,
223            status_handler,
224            tx: Some(tx),
225            backpressure: true,
226            backlog: VecDeque::new(),
227        }
228    }
229
230    fn update_status(&mut self, st: ServerStatus) {
231        if let Some(ref mut hnd) = self.status_handler {
232            (*hnd)(st);
233        }
234    }
235
236    fn poll(mut self) {
237        // Create storage for events
238        let mut events = Events::with_capacity(NonZeroUsize::new(512).unwrap());
239
240        // notify start
241        for idx in 0..self.sockets.len() {
242            self.add_source(idx);
243        }
244        if let Some(tx) = self.tx.take() {
245            thread::sleep(Duration::from_millis(25));
246            let _ = tx.send(());
247        }
248
249        loop {
250            for idx in 0..self.sockets.len() {
251                if self.sockets[idx].registered.get() {
252                    let readd = self.accept(idx);
253                    if readd {
254                        self.add_source(idx);
255                    }
256                }
257            }
258
259            if let Either::Right(rx) = self.process_cmd() {
260                // cleanup
261                for info in self.sockets.drain(..) {
262                    info.sock.remove_source();
263                }
264                log::info!("Accept loop {:?} has been stopped", self.name);
265
266                if let Some(rx) = rx {
267                    if !self.testing {
268                        thread::sleep(EXIT_TIMEOUT);
269                    }
270                    let _ = rx.send(());
271                }
272
273                break;
274            }
275
276            events.clear();
277            if let Err(e) = self.poller.wait(&mut events, None) {
278                assert!(
279                    e.kind() == io::ErrorKind::Interrupted,
280                    "Cannot wait for events in poller: {e}"
281                );
282            }
283        }
284    }
285
286    fn add_source(&self, idx: usize) {
287        let info = &self.sockets[idx];
288
289        loop {
290            // try to register poller source
291            let result = if info.registered.get() {
292                self.poller.modify(&info.sock, Event::readable(idx))
293            } else {
294                unsafe { self.poller.add(&info.sock, Event::readable(idx)) }
295            };
296            if let Err(err) = result {
297                if err.kind() == io::ErrorKind::WouldBlock {
298                    continue;
299                }
300                log::error!("Cannot register socket listener: {err}");
301
302                // sleep after error
303                info.timeout.set(Some(Instant::now() + ERR_TIMEOUT));
304
305                let notify = self.notify.clone();
306                System::current().handle().spawn(async move {
307                    sleep(ERR_SLEEP_TIMEOUT).await;
308                    notify.send(AcceptorCommand::Timer);
309                });
310            } else {
311                info.registered.set(true);
312            }
313
314            break;
315        }
316    }
317
318    fn remove_source(&self, key: usize) {
319        let info = &self.sockets[key];
320
321        let result = if info.registered.get() {
322            self.poller.modify(&info.sock, Event::none(key))
323        } else {
324            return;
325        };
326
327        // stop listening for incoming connections
328        if let Err(err) = result {
329            log::error!("Cannot stop socket listener for {} err: {}", info.addr, err);
330        }
331    }
332
333    fn process_timer(&mut self) {
334        let now = Instant::now();
335        for key in 0..self.sockets.len() {
336            let info = &mut self.sockets[key];
337            if let Some(inst) = info.timeout.get()
338                && now > inst
339                && !self.backpressure
340            {
341                log::info!("Resuming socket listener on {} after timeout", info.addr);
342                info.timeout.take();
343                self.add_source(key);
344            }
345        }
346    }
347
348    fn process_cmd(&mut self) -> Either<(), Option<oneshot::Sender<()>>> {
349        loop {
350            match self.rx.try_recv() {
351                Ok(cmd) => match cmd {
352                    AcceptorCommand::Stop(rx) => {
353                        if !self.backpressure {
354                            log::info!("Stopping accept loop {:?}", self.name);
355                            self.backpressure(true);
356                        }
357                        break Either::Right(Some(rx));
358                    }
359                    AcceptorCommand::Terminate => {
360                        log::info!("Stopping accept loop {:?}", self.name);
361                        self.backpressure(true);
362                        break Either::Right(None);
363                    }
364                    AcceptorCommand::Pause => {
365                        if !self.backpressure {
366                            log::info!("Pausing accept loop {:?}", self.name);
367                            self.backpressure(true);
368                        }
369                    }
370                    AcceptorCommand::Resume => {
371                        if self.backpressure {
372                            log::info!("Resuming accept loop {:?}", self.name);
373                            self.backpressure(false);
374                        }
375                    }
376                    AcceptorCommand::Timer => {
377                        self.process_timer();
378                    }
379                },
380                Err(err) => {
381                    break match err {
382                        mpsc::TryRecvError::Empty => Either::Left(()),
383                        mpsc::TryRecvError::Disconnected => {
384                            log::error!("Dropping accept loop {:?}", self.name);
385                            self.backpressure(true);
386                            Either::Right(None)
387                        }
388                    };
389                }
390            }
391        }
392    }
393
394    fn backpressure(&mut self, on: bool) {
395        self.update_status(if on {
396            ServerStatus::NotReady
397        } else {
398            ServerStatus::Ready
399        });
400
401        if self.backpressure && !on {
402            // handle backlog
403            while let Some(msg) = self.backlog.pop_front() {
404                if let Err(msg) = self.srv.process(msg) {
405                    log::trace!("Server is unavailable");
406                    self.backlog.push_front(msg);
407                    return;
408                }
409            }
410
411            // re-enable acceptors
412            self.backpressure = false;
413            for (key, info) in self.sockets.iter().enumerate() {
414                if info.timeout.get().is_none() {
415                    // socket with timeout will re-register itself after timeout
416                    log::info!(
417                        "Resuming socket listener on {} after back-pressure",
418                        info.addr
419                    );
420                    self.add_source(key);
421                }
422            }
423        } else if !self.backpressure && on {
424            self.backpressure = true;
425            for key in 0..self.sockets.len() {
426                // disable err timeout
427                let info = &mut self.sockets[key];
428                if info.timeout.take().is_none() {
429                    log::info!("Stopping socket listener on {}", info.addr);
430                    self.remove_source(key);
431                }
432            }
433        }
434    }
435
436    fn accept(&mut self, token: usize) -> bool {
437        loop {
438            if let Some(info) = self.sockets.get_mut(token) {
439                match info.sock.accept() {
440                    Ok(Some(io)) => {
441                        let msg = Connection {
442                            io,
443                            token: info.token,
444                        };
445                        if let Err(msg) = self.srv.process(msg) {
446                            log::trace!("Server is unavailable");
447                            self.backlog.push_back(msg);
448                            self.backpressure(true);
449                            return false;
450                        }
451                    }
452                    Ok(None) => return true,
453                    Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return true,
454                    Err(ref e) if connection_error(e) => (),
455                    Err(e) => {
456                        log::error!("Error accepting socket: {e}");
457
458                        // sleep after error
459                        info.timeout.set(Some(Instant::now() + ERR_TIMEOUT));
460
461                        let notify = self.notify.clone();
462                        System::current().handle().spawn(async move {
463                            sleep(ERR_SLEEP_TIMEOUT).await;
464                            notify.send(AcceptorCommand::Timer);
465                        });
466                        return false;
467                    }
468                }
469            }
470        }
471    }
472}
473
474/// This function defines errors that are per-connection. Which basically
475/// means that if we get this error from `accept()` system call it means
476/// next connection might be ready to be accepted.
477///
478/// All other errors will incur a timeout before next `accept()` is performed.
479/// The timeout is useful to handle resource exhaustion errors like ENFILE
480/// and EMFILE. Otherwise, could enter into tight loop.
481fn connection_error(e: &io::Error) -> bool {
482    matches!(
483        e.kind(),
484        io::ErrorKind::ConnectionRefused
485            | io::ErrorKind::ConnectionAborted
486            | io::ErrorKind::ConnectionReset
487    )
488}