Skip to main content

ntex_server/net/
accept.rs

1use std::time::{Duration, Instant};
2use std::{cell::Cell, fmt, io, sync::Arc, sync::mpsc, thread};
3use std::{collections::VecDeque, num::NonZeroUsize};
4
5use ntex_polling::{Event, Events, Poller};
6use ntex_rt::System;
7use ntex_util::{future::Either, time::Millis, time::sleep};
8
9use super::socket::{Connection, Listener, SocketAddr};
10use super::{Server, ServerStatus, Token};
11
12const EXIT_TIMEOUT: Duration = Duration::from_millis(100);
13const ERR_TIMEOUT: Duration = Duration::from_millis(500);
14const ERR_SLEEP_TIMEOUT: Millis = Millis(525);
15
16#[derive(Debug)]
17pub enum AcceptorCommand {
18    Stop(oneshot::Sender<()>),
19    Terminate,
20    Pause,
21    Resume,
22    Timer,
23}
24
25#[derive(Debug)]
26struct ServerSocketInfo {
27    addr: SocketAddr,
28    token: Token,
29    sock: Listener,
30    registered: Cell<bool>,
31    timeout: Cell<Option<Instant>>,
32}
33
34#[derive(Debug, Clone)]
35pub struct AcceptNotify(Arc<Poller>, mpsc::Sender<AcceptorCommand>);
36
37impl AcceptNotify {
38    fn new(waker: Arc<Poller>, tx: mpsc::Sender<AcceptorCommand>) -> Self {
39        AcceptNotify(waker, tx)
40    }
41
42    pub fn send(&self, cmd: AcceptorCommand) {
43        let _ = self.1.send(cmd);
44        let _ = self.0.notify();
45    }
46}
47
48/// Streamin io accept loop
49pub struct AcceptLoop {
50    name: String,
51    testing: bool,
52    notify: AcceptNotify,
53    inner: Option<(mpsc::Receiver<AcceptorCommand>, Arc<Poller>)>,
54    status_handler: Option<Box<dyn FnMut(ServerStatus) + Send>>,
55}
56
57impl Default for AcceptLoop {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63impl AcceptLoop {
64    /// Create accept loop
65    pub fn new() -> AcceptLoop {
66        // Create a poller instance
67        let poll = Arc::new(
68            Poller::new()
69                .map_err(|e| panic!("Cannot create Poller {e}"))
70                .unwrap(),
71        );
72
73        let (tx, rx) = mpsc::channel();
74        let notify = AcceptNotify::new(poll.clone(), tx);
75
76        AcceptLoop {
77            notify,
78            name: "ntex:accept".to_string(),
79            inner: Some((rx, poll)),
80            testing: false,
81            status_handler: None,
82        }
83    }
84
85    /// Set server name.
86    ///
87    /// Name is used for worker thread name
88    pub fn name<T: AsRef<str>>(&mut self, name: T) {
89        self.name = format!("{}:accept", name.as_ref());
90    }
91
92    /// Get notification api for the loop
93    pub fn notify(&self) -> AcceptNotify {
94        self.notify.clone()
95    }
96
97    pub fn set_status_handler<F>(&mut self, f: F)
98    where
99        F: FnMut(ServerStatus) + Send + 'static,
100    {
101        self.status_handler = Some(Box::new(f));
102    }
103
104    pub fn testing(&mut self) {
105        self.testing = true;
106    }
107
108    /// Start accept loop
109    pub fn start(mut self, socks: Vec<(Token, Listener)>, srv: Server) {
110        let (tx, rx_start) = oneshot::channel();
111        let (rx, poll) = self
112            .inner
113            .take()
114            .expect("AcceptLoop cannot be used multiple times");
115
116        Accept::start(
117            tx,
118            rx,
119            poll,
120            socks,
121            srv,
122            self.name.clone(),
123            self.notify.clone(),
124            self.testing,
125            self.status_handler.take(),
126        );
127
128        let _ = rx_start.recv();
129    }
130}
131
132impl fmt::Debug for AcceptLoop {
133    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134        f.debug_struct("AcceptLoop")
135            .field("name", &self.name)
136            .field("notify", &self.notify)
137            .field("inner", &self.inner)
138            .field("status_handler", &self.status_handler.is_some())
139            .finish()
140    }
141}
142
143struct Accept {
144    poller: Arc<Poller>,
145    rx: mpsc::Receiver<AcceptorCommand>,
146    tx: Option<oneshot::Sender<()>>,
147    sockets: Vec<ServerSocketInfo>,
148    srv: Server,
149    notify: AcceptNotify,
150    testing: bool,
151    backpressure: bool,
152    backlog: VecDeque<Connection>,
153    status_handler: Option<Box<dyn FnMut(ServerStatus) + Send>>,
154}
155
156impl Accept {
157    #[allow(clippy::too_many_arguments)]
158    fn start(
159        tx: oneshot::Sender<()>,
160        rx: mpsc::Receiver<AcceptorCommand>,
161        poller: Arc<Poller>,
162        socks: Vec<(Token, Listener)>,
163        srv: Server,
164        name: String,
165        notify: AcceptNotify,
166        testing: bool,
167        status_handler: Option<Box<dyn FnMut(ServerStatus) + Send>>,
168    ) {
169        log::info!("Starting {name:?} accept loop");
170
171        // start accept thread
172        let sys = System::current();
173        let _ = thread::Builder::new().name(name).spawn(move || {
174            System::set_current(sys);
175            Accept::new(tx, rx, poller, socks, srv, notify, testing, status_handler).poll()
176        });
177    }
178
179    #[allow(clippy::too_many_arguments)]
180    fn new(
181        tx: oneshot::Sender<()>,
182        rx: mpsc::Receiver<AcceptorCommand>,
183        poller: Arc<Poller>,
184        socks: Vec<(Token, Listener)>,
185        srv: Server,
186        notify: AcceptNotify,
187        testing: bool,
188        status_handler: Option<Box<dyn FnMut(ServerStatus) + Send>>,
189    ) -> Accept {
190        let mut sockets = Vec::new();
191        for (hnd_token, lst) in socks.into_iter() {
192            sockets.push(ServerSocketInfo {
193                addr: lst.local_addr(),
194                sock: lst,
195                token: hnd_token,
196                registered: Cell::new(false),
197                timeout: Cell::new(None),
198            });
199        }
200
201        Accept {
202            poller,
203            rx,
204            sockets,
205            notify,
206            srv,
207            testing,
208            status_handler,
209            tx: Some(tx),
210            backpressure: true,
211            backlog: VecDeque::new(),
212        }
213    }
214
215    fn update_status(&mut self, st: ServerStatus) {
216        if let Some(ref mut hnd) = self.status_handler {
217            (*hnd)(st)
218        }
219    }
220
221    fn poll(mut self) {
222        // Create storage for events
223        let mut events = Events::with_capacity(NonZeroUsize::new(512).unwrap());
224
225        let mut timeout = Some(Duration::ZERO);
226        loop {
227            events.clear();
228
229            if let Err(e) = self.poller.wait(&mut events, timeout) {
230                if e.kind() != io::ErrorKind::Interrupted {
231                    panic!("Cannot wait for events in poller: {e}")
232                }
233            } else if timeout.is_some() {
234                timeout = None;
235                let _ = self.tx.take().unwrap().send(());
236            }
237
238            for idx in 0..self.sockets.len() {
239                if self.sockets[idx].registered.get() {
240                    let readd = self.accept(idx);
241                    if readd {
242                        self.add_source(idx);
243                    }
244                }
245            }
246
247            match self.process_cmd() {
248                Either::Left(_) => events.clear(),
249                Either::Right(rx) => {
250                    // cleanup
251                    for info in self.sockets.drain(..) {
252                        info.sock.remove_source()
253                    }
254                    log::info!("Accept loop has been stopped");
255
256                    if let Some(rx) = rx {
257                        if !self.testing {
258                            thread::sleep(EXIT_TIMEOUT);
259                        }
260                        let _ = rx.send(());
261                    }
262
263                    break;
264                }
265            }
266        }
267    }
268
269    fn add_source(&self, idx: usize) {
270        let info = &self.sockets[idx];
271
272        loop {
273            // try to register poller source
274            let result = if info.registered.get() {
275                self.poller.modify(&info.sock, Event::readable(idx))
276            } else {
277                unsafe { self.poller.add(&info.sock, Event::readable(idx)) }
278            };
279            if let Err(err) = result {
280                if err.kind() == io::ErrorKind::WouldBlock {
281                    continue;
282                }
283                log::error!("Cannot register socket listener: {err}");
284
285                // sleep after error
286                info.timeout.set(Some(Instant::now() + ERR_TIMEOUT));
287
288                let notify = self.notify.clone();
289                System::current().arbiter().spawn(Box::pin(async move {
290                    sleep(ERR_SLEEP_TIMEOUT).await;
291                    notify.send(AcceptorCommand::Timer);
292                }));
293            } else {
294                info.registered.set(true);
295            }
296
297            break;
298        }
299    }
300
301    fn remove_source(&self, key: usize) {
302        let info = &self.sockets[key];
303
304        let result = if info.registered.get() {
305            self.poller.modify(&info.sock, Event::none(key))
306        } else {
307            return;
308        };
309
310        // stop listening for incoming connections
311        if let Err(err) = result {
312            log::error!("Cannot stop socket listener for {} err: {}", info.addr, err);
313        }
314    }
315
316    fn process_timer(&mut self) {
317        let now = Instant::now();
318        for key in 0..self.sockets.len() {
319            let info = &mut self.sockets[key];
320            if let Some(inst) = info.timeout.get()
321                && now > inst
322                && !self.backpressure
323            {
324                log::info!("Resuming socket listener on {} after timeout", info.addr);
325                info.timeout.take();
326                self.add_source(key);
327            }
328        }
329    }
330
331    fn process_cmd(&mut self) -> Either<(), Option<oneshot::Sender<()>>> {
332        loop {
333            match self.rx.try_recv() {
334                Ok(cmd) => match cmd {
335                    AcceptorCommand::Stop(rx) => {
336                        if !self.backpressure {
337                            log::info!("Stopping accept loop");
338                            self.backpressure(true);
339                        }
340                        break Either::Right(Some(rx));
341                    }
342                    AcceptorCommand::Terminate => {
343                        log::info!("Stopping accept loop");
344                        self.backpressure(true);
345                        break Either::Right(None);
346                    }
347                    AcceptorCommand::Pause => {
348                        if !self.backpressure {
349                            log::info!("Pausing accept loop");
350                            self.backpressure(true);
351                        }
352                    }
353                    AcceptorCommand::Resume => {
354                        if self.backpressure {
355                            log::info!("Resuming accept loop");
356                            self.backpressure(false);
357                        }
358                    }
359                    AcceptorCommand::Timer => {
360                        self.process_timer();
361                    }
362                },
363                Err(err) => {
364                    break match err {
365                        mpsc::TryRecvError::Empty => Either::Left(()),
366                        mpsc::TryRecvError::Disconnected => {
367                            log::error!("Dropping accept loop");
368                            self.backpressure(true);
369                            Either::Right(None)
370                        }
371                    };
372                }
373            }
374        }
375    }
376
377    fn backpressure(&mut self, on: bool) {
378        self.update_status(if on {
379            ServerStatus::NotReady
380        } else {
381            ServerStatus::Ready
382        });
383
384        if self.backpressure && !on {
385            // handle backlog
386            while let Some(msg) = self.backlog.pop_front() {
387                if let Err(msg) = self.srv.process(msg) {
388                    log::trace!("Server is unavailable");
389                    self.backlog.push_front(msg);
390                    return;
391                }
392            }
393
394            // re-enable acceptors
395            self.backpressure = false;
396            for (key, info) in self.sockets.iter().enumerate() {
397                if info.timeout.get().is_none() {
398                    // socket with timeout will re-register itself after timeout
399                    log::info!(
400                        "Resuming socket listener on {} after back-pressure",
401                        info.addr
402                    );
403                    self.add_source(key);
404                }
405            }
406        } else if !self.backpressure && on {
407            self.backpressure = true;
408            for key in 0..self.sockets.len() {
409                // disable err timeout
410                let info = &mut self.sockets[key];
411                if info.timeout.take().is_none() {
412                    log::info!("Stopping socket listener on {}", info.addr);
413                    self.remove_source(key);
414                }
415            }
416        }
417    }
418
419    fn accept(&mut self, token: usize) -> bool {
420        loop {
421            if let Some(info) = self.sockets.get_mut(token) {
422                match info.sock.accept() {
423                    Ok(Some(io)) => {
424                        let msg = Connection {
425                            io,
426                            token: info.token,
427                        };
428                        if let Err(msg) = self.srv.process(msg) {
429                            log::trace!("Server is unavailable");
430                            self.backlog.push_back(msg);
431                            self.backpressure(true);
432                            return false;
433                        }
434                    }
435                    Ok(None) => return true,
436                    Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => return true,
437                    Err(ref e) if connection_error(e) => continue,
438                    Err(e) => {
439                        log::error!("Error accepting socket: {e}");
440
441                        // sleep after error
442                        info.timeout.set(Some(Instant::now() + ERR_TIMEOUT));
443
444                        let notify = self.notify.clone();
445                        System::current().arbiter().spawn(Box::pin(async move {
446                            sleep(ERR_SLEEP_TIMEOUT).await;
447                            notify.send(AcceptorCommand::Timer);
448                        }));
449                        return false;
450                    }
451                }
452            }
453        }
454    }
455}
456
457/// This function defines errors that are per-connection. Which basically
458/// means that if we get this error from `accept()` system call it means
459/// next connection might be ready to be accepted.
460///
461/// All other errors will incur a timeout before next `accept()` is performed.
462/// The timeout is useful to handle resource exhaustion errors like ENFILE
463/// and EMFILE. Otherwise, could enter into tight loop.
464fn connection_error(e: &io::Error) -> bool {
465    e.kind() == io::ErrorKind::ConnectionRefused
466        || e.kind() == io::ErrorKind::ConnectionAborted
467        || e.kind() == io::ErrorKind::ConnectionReset
468        || e.kind() == io::ErrorKind::InvalidInput
469}