Skip to main content

ntex_server/net/
accept.rs

1#![allow(clippy::missing_panics_doc)]
2use std::time::{Duration, Instant};
3use std::{cell::Cell, fmt, io, sync::Arc, sync::mpsc, thread};
4use std::{collections::VecDeque, num::NonZeroUsize};
5
6use ntex_polling::{Event, Events, Poller};
7use ntex_rt::System;
8use ntex_util::{future::Either, time::Millis, time::sleep};
9
10use super::socket::{Connection, Listener, SocketAddr};
11use super::{Server, ServerStatus, Token};
12
13const EXIT_TIMEOUT: Duration = Duration::from_millis(100);
14const ERR_TIMEOUT: Duration = Duration::from_millis(500);
15const ERR_SLEEP_TIMEOUT: Millis = Millis(525);
16
17#[derive(Debug)]
18pub enum AcceptorCommand {
19    Stop(oneshot::Sender<()>),
20    Terminate,
21    Pause,
22    Resume,
23    Timer,
24}
25
26#[derive(Debug)]
27struct ServerSocketInfo {
28    addr: SocketAddr,
29    token: Token,
30    sock: Listener,
31    registered: Cell<bool>,
32    timeout: Cell<Option<Instant>>,
33}
34
35#[derive(Debug, Clone)]
36pub struct AcceptNotify(Arc<Poller>, mpsc::Sender<AcceptorCommand>);
37
38impl AcceptNotify {
39    fn new(waker: Arc<Poller>, tx: mpsc::Sender<AcceptorCommand>) -> Self {
40        AcceptNotify(waker, tx)
41    }
42
43    pub fn send(&self, cmd: AcceptorCommand) {
44        let _ = self.1.send(cmd);
45        let _ = self.0.notify();
46    }
47}
48
49/// Streamin io accept loop
50pub struct AcceptLoop {
51    name: String,
52    testing: bool,
53    notify: AcceptNotify,
54    inner: Option<(mpsc::Receiver<AcceptorCommand>, Arc<Poller>)>,
55    status_handler: Option<Box<dyn FnMut(ServerStatus) + Send>>,
56}
57
58impl Default for AcceptLoop {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64impl AcceptLoop {
65    /// Create accept loop
66    pub fn new() -> AcceptLoop {
67        // Create a poller instance
68        let poll = Arc::new(
69            Poller::new()
70                .map_err(|e| panic!("Cannot create Poller {e}"))
71                .unwrap(),
72        );
73
74        let (tx, rx) = mpsc::channel();
75        let notify = AcceptNotify::new(poll.clone(), tx);
76
77        AcceptLoop {
78            notify,
79            name: "ntex:accept".to_string(),
80            inner: Some((rx, poll)),
81            testing: false,
82            status_handler: None,
83        }
84    }
85
86    /// Set server name.
87    ///
88    /// Name is used for worker thread name
89    pub fn name<T: AsRef<str>>(&mut self, name: T) {
90        self.name = format!("{}:accept", name.as_ref());
91    }
92
93    /// Get notification api for the loop
94    pub fn notify(&self) -> AcceptNotify {
95        self.notify.clone()
96    }
97
98    pub fn set_status_handler<F>(&mut self, f: F)
99    where
100        F: FnMut(ServerStatus) + Send + 'static,
101    {
102        self.status_handler = Some(Box::new(f));
103    }
104
105    pub fn testing(&mut self) {
106        self.testing = true;
107    }
108
109    /// Start accept loop
110    pub fn start(mut self, socks: Vec<(Token, Listener)>, srv: Server) {
111        let (tx, rx_start) = oneshot::channel();
112        let (rx, poll) = self
113            .inner
114            .take()
115            .expect("AcceptLoop cannot be used multiple times");
116
117        Accept::start(
118            tx,
119            rx,
120            poll,
121            socks,
122            srv,
123            self.name.clone(),
124            self.notify.clone(),
125            self.testing,
126            self.status_handler.take(),
127        );
128
129        let _ = rx_start.recv();
130    }
131}
132
133impl fmt::Debug for AcceptLoop {
134    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135        f.debug_struct("AcceptLoop")
136            .field("name", &self.name)
137            .field("notify", &self.notify)
138            .field("inner", &self.inner)
139            .field("status_handler", &self.status_handler.is_some())
140            .finish()
141    }
142}
143
144struct Accept {
145    name: String,
146    poller: Arc<Poller>,
147    rx: mpsc::Receiver<AcceptorCommand>,
148    tx: Option<oneshot::Sender<()>>,
149    sockets: Vec<ServerSocketInfo>,
150    srv: Server,
151    notify: AcceptNotify,
152    testing: bool,
153    backpressure: bool,
154    backlog: VecDeque<Connection>,
155    status_handler: Option<Box<dyn FnMut(ServerStatus) + Send>>,
156}
157
158impl Accept {
159    #[allow(clippy::too_many_arguments)]
160    fn start(
161        tx: oneshot::Sender<()>,
162        rx: mpsc::Receiver<AcceptorCommand>,
163        poller: Arc<Poller>,
164        socks: Vec<(Token, Listener)>,
165        srv: Server,
166        name: String,
167        notify: AcceptNotify,
168        testing: bool,
169        status_handler: Option<Box<dyn FnMut(ServerStatus) + Send>>,
170    ) {
171        log::info!("Starting {name:?} accept loop");
172
173        // start accept thread
174        let sys = System::current();
175        let _ = thread::Builder::new().name(name.clone()).spawn(move || {
176            System::set_current(sys);
177            Accept::new(
178                name,
179                tx,
180                rx,
181                poller,
182                socks,
183                srv,
184                notify,
185                testing,
186                status_handler,
187            )
188            .poll();
189        });
190    }
191
192    #[allow(clippy::too_many_arguments)]
193    fn new(
194        name: String,
195        tx: oneshot::Sender<()>,
196        rx: mpsc::Receiver<AcceptorCommand>,
197        poller: Arc<Poller>,
198        socks: Vec<(Token, Listener)>,
199        srv: Server,
200        notify: AcceptNotify,
201        testing: bool,
202        status_handler: Option<Box<dyn FnMut(ServerStatus) + Send>>,
203    ) -> Accept {
204        let mut sockets = Vec::new();
205        for (hnd_token, lst) in socks {
206            sockets.push(ServerSocketInfo {
207                addr: lst.local_addr(),
208                sock: lst,
209                token: hnd_token,
210                registered: Cell::new(false),
211                timeout: Cell::new(None),
212            });
213        }
214
215        Accept {
216            name,
217            poller,
218            rx,
219            sockets,
220            notify,
221            srv,
222            testing,
223            status_handler,
224            tx: Some(tx),
225            backpressure: true,
226            backlog: VecDeque::new(),
227        }
228    }
229
230    fn update_status(&mut self, st: ServerStatus) {
231        if let Some(ref mut hnd) = self.status_handler {
232            (*hnd)(st);
233        }
234    }
235
236    fn poll(mut self) {
237        // Create storage for events
238        let mut events = Events::with_capacity(NonZeroUsize::new(512).unwrap());
239
240        // notify start
241        for idx in 0..self.sockets.len() {
242            self.add_source(idx);
243        }
244        if let Some(tx) = self.tx.take() {
245            thread::sleep(Duration::from_millis(25));
246            let _ = tx.send(());
247        }
248
249        loop {
250            events.clear();
251
252            if let Err(e) = self.poller.wait(&mut events, None) {
253                assert!(
254                    e.kind() == io::ErrorKind::Interrupted,
255                    "Cannot wait for events in poller: {e}"
256                );
257            }
258
259            for idx in 0..self.sockets.len() {
260                if self.sockets[idx].registered.get() {
261                    let readd = self.accept(idx);
262                    if readd {
263                        self.add_source(idx);
264                    }
265                }
266            }
267
268            match self.process_cmd() {
269                Either::Left(()) => events.clear(),
270                Either::Right(rx) => {
271                    // cleanup
272                    for info in self.sockets.drain(..) {
273                        info.sock.remove_source();
274                    }
275                    log::info!("Accept loop {:?} has been stopped", self.name);
276
277                    if let Some(rx) = rx {
278                        if !self.testing {
279                            thread::sleep(EXIT_TIMEOUT);
280                        }
281                        let _ = rx.send(());
282                    }
283
284                    break;
285                }
286            }
287        }
288    }
289
290    fn add_source(&self, idx: usize) {
291        let info = &self.sockets[idx];
292
293        loop {
294            // try to register poller source
295            let result = if info.registered.get() {
296                self.poller.modify(&info.sock, Event::readable(idx))
297            } else {
298                unsafe { self.poller.add(&info.sock, Event::readable(idx)) }
299            };
300            if let Err(err) = result {
301                if err.kind() == io::ErrorKind::WouldBlock {
302                    continue;
303                }
304                log::error!("Cannot register socket listener: {err}");
305
306                // sleep after error
307                info.timeout.set(Some(Instant::now() + ERR_TIMEOUT));
308
309                let notify = self.notify.clone();
310                System::current().handle().spawn(async move {
311                    sleep(ERR_SLEEP_TIMEOUT).await;
312                    notify.send(AcceptorCommand::Timer);
313                });
314            } else {
315                info.registered.set(true);
316            }
317
318            break;
319        }
320    }
321
322    fn remove_source(&self, key: usize) {
323        let info = &self.sockets[key];
324
325        let result = if info.registered.get() {
326            self.poller.modify(&info.sock, Event::none(key))
327        } else {
328            return;
329        };
330
331        // stop listening for incoming connections
332        if let Err(err) = result {
333            log::error!("Cannot stop socket listener for {} err: {}", info.addr, err);
334        }
335    }
336
337    fn process_timer(&mut self) {
338        let now = Instant::now();
339        for key in 0..self.sockets.len() {
340            let info = &mut self.sockets[key];
341            if let Some(inst) = info.timeout.get()
342                && now > inst
343                && !self.backpressure
344            {
345                log::info!("Resuming socket listener on {} after timeout", info.addr);
346                info.timeout.take();
347                self.add_source(key);
348            }
349        }
350    }
351
352    fn process_cmd(&mut self) -> Either<(), Option<oneshot::Sender<()>>> {
353        loop {
354            match self.rx.try_recv() {
355                Ok(cmd) => match cmd {
356                    AcceptorCommand::Stop(rx) => {
357                        if !self.backpressure {
358                            log::info!("Stopping accept loop {:?}", self.name);
359                            self.backpressure(true);
360                        }
361                        break Either::Right(Some(rx));
362                    }
363                    AcceptorCommand::Terminate => {
364                        log::info!("Stopping accept loop {:?}", self.name);
365                        self.backpressure(true);
366                        break Either::Right(None);
367                    }
368                    AcceptorCommand::Pause => {
369                        if !self.backpressure {
370                            log::info!("Pausing accept loop {:?}", self.name);
371                            self.backpressure(true);
372                        }
373                    }
374                    AcceptorCommand::Resume => {
375                        if self.backpressure {
376                            log::info!("Resuming accept loop {:?}", self.name);
377                            self.backpressure(false);
378                        }
379                    }
380                    AcceptorCommand::Timer => {
381                        self.process_timer();
382                    }
383                },
384                Err(err) => {
385                    break match err {
386                        mpsc::TryRecvError::Empty => Either::Left(()),
387                        mpsc::TryRecvError::Disconnected => {
388                            log::error!("Dropping accept loop {:?}", self.name);
389                            self.backpressure(true);
390                            Either::Right(None)
391                        }
392                    };
393                }
394            }
395        }
396    }
397
398    fn backpressure(&mut self, on: bool) {
399        self.update_status(if on {
400            ServerStatus::NotReady
401        } else {
402            ServerStatus::Ready
403        });
404
405        if self.backpressure && !on {
406            // handle backlog
407            while let Some(msg) = self.backlog.pop_front() {
408                if let Err(msg) = self.srv.process(msg) {
409                    log::trace!("Server is unavailable");
410                    self.backlog.push_front(msg);
411                    return;
412                }
413            }
414
415            // re-enable acceptors
416            self.backpressure = false;
417            for (key, info) in self.sockets.iter().enumerate() {
418                if info.timeout.get().is_none() {
419                    // socket with timeout will re-register itself after timeout
420                    log::info!(
421                        "Resuming socket listener on {} after back-pressure",
422                        info.addr
423                    );
424                    self.add_source(key);
425                }
426            }
427        } else if !self.backpressure && on {
428            self.backpressure = true;
429            for key in 0..self.sockets.len() {
430                // disable err timeout
431                let info = &mut self.sockets[key];
432                if info.timeout.take().is_none() {
433                    log::info!("Stopping socket listener on {}", info.addr);
434                    self.remove_source(key);
435                }
436            }
437        }
438    }
439
440    fn accept(&mut self, token: usize) -> bool {
441        loop {
442            if let Some(info) = self.sockets.get_mut(token) {
443                match info.sock.accept() {
444                    Ok(Some(io)) => {
445                        let msg = Connection {
446                            io,
447                            token: info.token,
448                        };
449                        if let Err(msg) = self.srv.process(msg) {
450                            log::trace!("Server is unavailable");
451                            self.backlog.push_back(msg);
452                            self.backpressure(true);
453                            return false;
454                        }
455                    }
456                    Ok(None) => return true,
457                    Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return true,
458                    Err(ref e) if connection_error(e) => (),
459                    Err(e) => {
460                        log::error!("Error accepting socket: {e}");
461
462                        // sleep after error
463                        info.timeout.set(Some(Instant::now() + ERR_TIMEOUT));
464
465                        let notify = self.notify.clone();
466                        System::current().handle().spawn(async move {
467                            sleep(ERR_SLEEP_TIMEOUT).await;
468                            notify.send(AcceptorCommand::Timer);
469                        });
470                        return false;
471                    }
472                }
473            }
474        }
475    }
476}
477
478/// This function defines errors that are per-connection. Which basically
479/// means that if we get this error from `accept()` system call it means
480/// next connection might be ready to be accepted.
481///
482/// All other errors will incur a timeout before next `accept()` is performed.
483/// The timeout is useful to handle resource exhaustion errors like ENFILE
484/// and EMFILE. Otherwise, could enter into tight loop.
485fn connection_error(e: &io::Error) -> bool {
486    e.kind() == io::ErrorKind::ConnectionRefused
487        || e.kind() == io::ErrorKind::ConnectionAborted
488        || e.kind() == io::ErrorKind::ConnectionReset
489        || e.kind() == io::ErrorKind::InvalidInput
490}