Skip to main content

ntex_server/net/
builder.rs

1#![allow(clippy::missing_panics_doc)]
2use std::{fmt, io, net, sync::Arc};
3
4use ntex_io::Io;
5use ntex_rt::System;
6use ntex_service::{ServiceFactory, cfg::SharedCfg};
7use ntex_util::time::Millis;
8use socket2::{Domain, SockAddr, Socket, Type};
9
10use crate::{Server, WorkerPool};
11
12use super::accept::AcceptLoop;
13use super::config::{Config, ServiceConfig};
14use super::factory::{self, FactoryServiceType};
15use super::factory::{OnAccept, OnAcceptWrapper, OnWorkerStart, OnWorkerStartWrapper};
16use super::{Connection, ServerStatus, Stream, StreamServer, Token, socket::Listener};
17
18/// Streaming service builder
19///
20/// This type can be used to construct an instance of `net streaming server` through a
21/// builder-like pattern.
22pub struct ServerBuilder {
23    name: String,
24    token: Token,
25    backlog: i32,
26    services: Vec<FactoryServiceType>,
27    sockets: Vec<(Token, String, Listener)>,
28    on_worker_start: Vec<Box<dyn OnWorkerStart + Send>>,
29    on_accept: Option<Box<dyn OnAccept + Send>>,
30    accept: AcceptLoop,
31    pool: WorkerPool,
32}
33
34impl Default for ServerBuilder {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40impl fmt::Debug for ServerBuilder {
41    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42        f.debug_struct("ServerBuilder")
43            .field("name", &self.name)
44            .field("token", &self.token)
45            .field("backlog", &self.backlog)
46            .field("sockets", &self.sockets)
47            .field("accept", &self.accept)
48            .field("worker-pool", &self.pool)
49            .finish()
50    }
51}
52
53impl ServerBuilder {
54    #[must_use]
55    /// Create new Server builder instance
56    pub fn new() -> ServerBuilder {
57        let sys = System::current();
58        let mut accept = AcceptLoop::default();
59        accept.name(sys.name());
60        if sys.testing() {
61            accept.testing();
62        }
63
64        ServerBuilder {
65            accept,
66            name: sys.name().to_string(),
67            token: Token(0),
68            services: Vec::new(),
69            sockets: Vec::new(),
70            on_accept: None,
71            on_worker_start: Vec::new(),
72            backlog: 2048,
73            pool: WorkerPool::default().name(sys.name()),
74        }
75    }
76
77    #[must_use]
78    /// Set server name.
79    ///
80    /// Name is used for worker thread name
81    pub fn name<T: AsRef<str>>(mut self, name: T) -> Self {
82        self.name = name.as_ref().to_string();
83        self.accept.name(self.name.as_str());
84        self.pool = self.pool.name(self.name.as_str());
85        self
86    }
87
88    #[must_use]
89    /// Set number of workers to start.
90    ///
91    /// By default server uses number of available logical cpu as workers
92    /// count.
93    pub fn workers(mut self, num: usize) -> Self {
94        self.pool = self.pool.workers(num);
95        self
96    }
97
98    #[must_use]
99    /// Set the maximum number of pending connections.
100    ///
101    /// This refers to the number of clients that can be waiting to be served.
102    /// Exceeding this number results in the client getting an error when
103    /// attempting to connect. It should only affect servers under significant
104    /// load.
105    ///
106    /// Generally set in the 64-2048 range. Default value is 2048.
107    ///
108    /// This method should be called before `bind()` method call.
109    pub fn backlog(mut self, num: i32) -> Self {
110        self.backlog = num;
111        self
112    }
113
114    #[must_use]
115    /// Sets the maximum per-worker number of concurrent connections.
116    ///
117    /// All socket listeners will stop accepting connections when this limit is
118    /// reached for each worker.
119    ///
120    /// By default max connections is set to a 25k per worker.
121    pub fn maxconn(self, num: usize) -> Self {
122        super::max_concurrent_connections(num);
123        self
124    }
125
126    #[must_use]
127    /// Stop ntex runtime when server get dropped.
128    ///
129    /// By default "stop runtime" is disabled.
130    pub fn stop_runtime(mut self) -> Self {
131        self.pool = self.pool.stop_runtime();
132        self
133    }
134
135    #[must_use]
136    /// Disable signal handling.
137    ///
138    /// By default signal handling is enabled.
139    pub fn disable_signals(mut self) -> Self {
140        self.pool = self.pool.disable_signals();
141        self
142    }
143
144    #[must_use]
145    /// Enable cpu affinity
146    ///
147    /// By default affinity is disabled.
148    pub fn enable_affinity(mut self) -> Self {
149        self.pool = self.pool.enable_affinity();
150        self
151    }
152
153    #[must_use]
154    /// Timeout for graceful workers shutdown.
155    ///
156    /// After receiving a stop signal, workers have this much time to finish
157    /// serving requests. Workers still alive after the timeout are force
158    /// dropped.
159    ///
160    /// By default shutdown timeout sets to 30 seconds.
161    pub fn shutdown_timeout<T: Into<Millis>>(mut self, timeout: T) -> Self {
162        self.pool = self.pool.shutdown_timeout(timeout);
163        self
164    }
165
166    #[must_use]
167    /// Set server status handler.
168    ///
169    /// Server calls this handler on every inner status update.
170    pub fn status_handler<F>(mut self, handler: F) -> Self
171    where
172        F: FnMut(ServerStatus) + Send + 'static,
173    {
174        self.accept.set_status_handler(handler);
175        self
176    }
177
178    /// Execute external async configuration as part of the server building
179    /// process.
180    ///
181    /// This function is useful for moving parts of configuration to a
182    /// different module or even library.
183    pub async fn configure<F>(mut self, f: F) -> io::Result<ServerBuilder>
184    where
185        F: AsyncFn(ServiceConfig) -> io::Result<()>,
186    {
187        let cfg = ServiceConfig::new(self.token, self.backlog);
188
189        f(cfg.clone()).await?;
190
191        let (token, sockets, factory) = cfg.into_factory();
192        self.token = token;
193        self.sockets.extend(sockets);
194        self.services.push(factory);
195
196        Ok(self)
197    }
198
199    #[must_use]
200    /// Register async service configuration function.
201    ///
202    /// This function get called during worker runtime configuration stage.
203    /// It get executed in the worker thread.
204    pub fn on_worker_start<F, E>(mut self, f: F) -> Self
205    where
206        F: AsyncFn() -> Result<(), E> + Send + Clone + 'static,
207        E: fmt::Display + 'static,
208    {
209        self.on_worker_start.push(OnWorkerStartWrapper::create(f));
210        self
211    }
212
213    #[must_use]
214    /// Register on-accept callback function.
215    ///
216    /// This function get called with accepted stream.
217    pub fn on_accept<F, E>(mut self, f: F) -> Self
218    where
219        F: AsyncFn(Arc<str>, Stream) -> Result<Stream, E> + Send + Clone + 'static,
220        E: fmt::Display + 'static,
221    {
222        self.on_accept = Some(OnAcceptWrapper::create(f));
223        self
224    }
225
226    /// Add new service to the server.
227    pub fn bind<F, U, N, R>(mut self, name: N, addr: U, factory: F) -> io::Result<Self>
228    where
229        U: net::ToSocketAddrs,
230        N: AsRef<str>,
231        F: AsyncFn(Config) -> R + Send + Clone + 'static,
232        R: ServiceFactory<Io, SharedCfg> + 'static,
233    {
234        let sockets = bind_addr(addr, self.backlog)?;
235
236        let mut tokens = Vec::new();
237        for lst in sockets {
238            let token = self.token.next();
239            self.sockets
240                .push((token, name.as_ref().to_string(), Listener::from_tcp(lst)));
241            tokens.push((token, SharedCfg::default()));
242        }
243
244        self.services.push(factory::create_factory_service(
245            name.as_ref().to_string(),
246            tokens,
247            factory,
248        ));
249
250        Ok(self)
251    }
252
253    #[cfg(unix)]
254    /// Add new unix domain service to the server.
255    pub fn bind_uds<F, U, N, R>(self, name: N, addr: U, factory: F) -> io::Result<Self>
256    where
257        N: AsRef<str>,
258        U: AsRef<std::path::Path>,
259        F: AsyncFn(Config) -> R + Send + Clone + 'static,
260        R: ServiceFactory<Io, SharedCfg> + 'static,
261    {
262        use std::os::unix::net::UnixListener;
263
264        // The path must not exist when we try to bind.
265        // Try to remove it to avoid bind error.
266        if let Err(e) = std::fs::remove_file(addr.as_ref()) {
267            // NotFound is expected and not an issue. Anything else is.
268            if e.kind() != std::io::ErrorKind::NotFound {
269                return Err(e);
270            }
271        }
272
273        let lst = UnixListener::bind(addr)?;
274        self.listen_uds(name, lst, factory)
275    }
276
277    #[cfg(unix)]
278    /// Add new unix domain service to the server.
279    /// Useful when running as a systemd service and
280    /// a socket FD can be acquired using the systemd crate.
281    pub fn listen_uds<F, N: AsRef<str>, R>(
282        mut self,
283        name: N,
284        lst: std::os::unix::net::UnixListener,
285        factory: F,
286    ) -> io::Result<Self>
287    where
288        F: AsyncFn(Config) -> R + Send + Clone + 'static,
289        R: ServiceFactory<Io, SharedCfg> + 'static,
290    {
291        let token = self.token.next();
292        self.services.push(factory::create_factory_service(
293            name.as_ref().to_string(),
294            vec![(token, SharedCfg::default())],
295            factory,
296        ));
297        self.sockets
298            .push((token, name.as_ref().to_string(), Listener::from_uds(lst)));
299        Ok(self)
300    }
301
302    /// Add new service to the server.
303    pub fn listen<F, N: AsRef<str>, R>(
304        mut self,
305        name: N,
306        lst: net::TcpListener,
307        factory: F,
308    ) -> io::Result<Self>
309    where
310        F: AsyncFn(Config) -> R + Send + Clone + 'static,
311        R: ServiceFactory<Io, SharedCfg> + 'static,
312    {
313        let token = self.token.next();
314        self.services.push(factory::create_factory_service(
315            name.as_ref().to_string(),
316            vec![(token, SharedCfg::default())],
317            factory,
318        ));
319        self.sockets
320            .push((token, name.as_ref().to_string(), Listener::from_tcp(lst)));
321        Ok(self)
322    }
323
324    #[must_use]
325    /// Set shared config for named service
326    ///
327    /// # Panics
328    ///
329    /// Panics if named service is not registered
330    pub fn config<N, U>(mut self, name: N, cfg: U) -> Self
331    where
332        N: AsRef<str>,
333        U: Into<SharedCfg>,
334    {
335        let cfg = cfg.into();
336        let mut token = None;
337        for sock in &self.sockets {
338            if sock.1 == name.as_ref() {
339                token = Some(sock.0);
340                break;
341            }
342        }
343
344        if let Some(token) = token {
345            for svc in &mut self.services {
346                if svc.name(token) == name.as_ref() {
347                    svc.set_config(token, cfg);
348                }
349            }
350        } else {
351            panic!("Cannot find service by name {:?}", name.as_ref());
352        }
353
354        self
355    }
356
357    /// Starts processing incoming connections and return server controller.
358    pub fn run(self) -> Server<Connection> {
359        if self.sockets.is_empty() {
360            panic!("Server should have at least one bound socket");
361        } else {
362            let srv = StreamServer::new(
363                self.accept.notify(),
364                self.services,
365                self.on_worker_start,
366                self.on_accept,
367            );
368            let svc = self.pool.run(srv);
369
370            let sockets = self
371                .sockets
372                .into_iter()
373                .map(|sock| {
374                    log::info!("Starting \"{}\" service on {}", sock.1, sock.2);
375                    (sock.0, sock.2)
376                })
377                .collect();
378            self.accept.start(sockets, svc.clone());
379
380            svc
381        }
382    }
383}
384
385pub fn bind_addr<S: net::ToSocketAddrs>(
386    addr: S,
387    backlog: i32,
388) -> io::Result<Vec<net::TcpListener>> {
389    let mut err = None;
390    let mut succ = false;
391    let mut sockets = Vec::new();
392    for addr in addr.to_socket_addrs()? {
393        match create_tcp_listener(addr, backlog) {
394            Ok(lst) => {
395                succ = true;
396                sockets.push(lst);
397            }
398            Err(e) => err = Some(e),
399        }
400    }
401
402    if succ {
403        Ok(sockets)
404    } else if let Some(e) = err.take() {
405        Err(e)
406    } else {
407        Err(io::Error::new(
408            io::ErrorKind::InvalidInput,
409            "Cannot bind to address.",
410        ))
411    }
412}
413
414pub fn create_tcp_listener(
415    addr: net::SocketAddr,
416    backlog: i32,
417) -> io::Result<net::TcpListener> {
418    let builder = match addr {
419        net::SocketAddr::V4(_) => Socket::new(Domain::IPV4, Type::STREAM, None)?,
420        net::SocketAddr::V6(_) => Socket::new(Domain::IPV6, Type::STREAM, None)?,
421    };
422
423    // On Windows, this allows rebinding sockets which are actively in use,
424    // which allows “socket hijacking”, so we explicitly don't set it here.
425    // https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse
426    #[cfg(not(windows))]
427    builder.set_reuse_address(true)?;
428
429    builder.bind(&SockAddr::from(addr))?;
430    builder.listen(backlog)?;
431    Ok(net::TcpListener::from(builder))
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    #[test]
439    fn test_bind_addr() {
440        let addrs: Vec<net::SocketAddr> = Vec::new();
441        assert!(bind_addr(&addrs[..], 10).is_err());
442    }
443
444    #[ntex::test]
445    async fn test_debug() {
446        let builder = ServerBuilder::default();
447        assert!(format!("{builder:?}").contains("ServerBuilder"));
448    }
449}