1use std::{fmt, future::Future, io, net, sync::Arc};
2
3use socket2::{Domain, SockAddr, Socket, Type};
4
5use ntex_net::Io;
6use ntex_service::ServiceFactory;
7use ntex_util::time::Millis;
8
9use crate::{Server, WorkerPool};
10
11use super::accept::AcceptLoop;
12use super::config::{Config, ServiceConfig};
13use super::factory::{self, FactoryServiceType};
14use super::factory::{OnAccept, OnAcceptWrapper, OnWorkerStart, OnWorkerStartWrapper};
15use super::{socket::Listener, Connection, ServerStatus, Stream, StreamServer, Token};
16
17pub struct ServerBuilder {
22 token: Token,
23 backlog: i32,
24 services: Vec<FactoryServiceType>,
25 sockets: Vec<(Token, String, Listener)>,
26 on_worker_start: Vec<Box<dyn OnWorkerStart + Send>>,
27 on_accept: Option<Box<dyn OnAccept + Send>>,
28 accept: AcceptLoop,
29 pool: WorkerPool,
30}
31
32impl Default for ServerBuilder {
33 fn default() -> Self {
34 Self::new()
35 }
36}
37
38impl fmt::Debug for ServerBuilder {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 f.debug_struct("ServerBuilder")
41 .field("token", &self.token)
42 .field("backlog", &self.backlog)
43 .field("sockets", &self.sockets)
44 .field("accept", &self.accept)
45 .field("worker-pool", &self.pool)
46 .finish()
47 }
48}
49
50impl ServerBuilder {
51 pub fn new() -> ServerBuilder {
53 ServerBuilder {
54 token: Token(0),
55 services: Vec::new(),
56 sockets: Vec::new(),
57 on_accept: None,
58 on_worker_start: Vec::new(),
59 accept: AcceptLoop::default(),
60 backlog: 2048,
61 pool: WorkerPool::new(),
62 }
63 }
64
65 pub fn workers(mut self, num: usize) -> Self {
70 self.pool = self.pool.workers(num);
71 self
72 }
73
74 pub fn backlog(mut self, num: i32) -> Self {
85 self.backlog = num;
86 self
87 }
88
89 pub fn maxconn(self, num: usize) -> Self {
96 super::max_concurrent_connections(num);
97 self
98 }
99
100 pub fn stop_runtime(mut self) -> Self {
104 self.pool = self.pool.stop_runtime();
105 self
106 }
107
108 pub fn disable_signals(mut self) -> Self {
112 self.pool = self.pool.disable_signals();
113 self
114 }
115
116 pub fn enable_affinity(mut self) -> Self {
120 self.pool = self.pool.enable_affinity();
121 self
122 }
123
124 pub fn shutdown_timeout<T: Into<Millis>>(mut self, timeout: T) -> Self {
132 self.pool = self.pool.shutdown_timeout(timeout);
133 self
134 }
135
136 pub fn status_handler<F>(mut self, handler: F) -> Self
140 where
141 F: FnMut(ServerStatus) + Send + 'static,
142 {
143 self.accept.set_status_handler(handler);
144 self
145 }
146
147 pub fn configure<F>(mut self, f: F) -> io::Result<ServerBuilder>
153 where
154 F: Fn(&mut ServiceConfig) -> io::Result<()>,
155 {
156 let mut cfg = ServiceConfig::new(self.token, self.backlog);
157
158 f(&mut cfg)?;
159
160 let (token, sockets, factory) = cfg.into_factory();
161 self.token = token;
162 self.sockets.extend(sockets);
163 self.services.push(factory);
164
165 Ok(self)
166 }
167
168 pub async fn configure_async<F, R>(mut self, f: F) -> io::Result<ServerBuilder>
174 where
175 F: Fn(ServiceConfig) -> R,
176 R: Future<Output = io::Result<()>>,
177 {
178 let cfg = ServiceConfig::new(self.token, self.backlog);
179
180 f(cfg.clone()).await?;
181
182 let (token, sockets, factory) = cfg.into_factory();
183 self.token = token;
184 self.sockets.extend(sockets);
185 self.services.push(factory);
186
187 Ok(self)
188 }
189
190 pub fn on_worker_start<F, R, E>(mut self, f: F) -> Self
195 where
196 F: Fn() -> R + Send + Clone + 'static,
197 R: Future<Output = Result<(), E>> + 'static,
198 E: fmt::Display + 'static,
199 {
200 self.on_worker_start.push(OnWorkerStartWrapper::create(f));
201 self
202 }
203
204 pub fn on_accept<F, R, E>(mut self, f: F) -> Self
208 where
209 F: Fn(Arc<str>, Stream) -> R + Send + Clone + 'static,
210 R: Future<Output = Result<Stream, E>> + 'static,
211 E: fmt::Display + 'static,
212 {
213 self.on_accept = Some(OnAcceptWrapper::create(f));
214 self
215 }
216
217 pub fn bind<F, U, N, R>(mut self, name: N, addr: U, factory: F) -> io::Result<Self>
219 where
220 U: net::ToSocketAddrs,
221 N: AsRef<str>,
222 F: Fn(Config) -> R + Send + Clone + 'static,
223 R: ServiceFactory<Io> + 'static,
224 {
225 let sockets = bind_addr(addr, self.backlog)?;
226
227 let mut tokens = Vec::new();
228 for lst in sockets {
229 let token = self.token.next();
230 self.sockets
231 .push((token, name.as_ref().to_string(), Listener::from_tcp(lst)));
232 tokens.push((token, ""));
233 }
234
235 self.services.push(factory::create_factory_service(
236 name.as_ref().to_string(),
237 tokens,
238 factory,
239 ));
240
241 Ok(self)
242 }
243
244 #[cfg(unix)]
245 pub fn bind_uds<F, U, N, R>(self, name: N, addr: U, factory: F) -> io::Result<Self>
247 where
248 N: AsRef<str>,
249 U: AsRef<std::path::Path>,
250 F: Fn(Config) -> R + Send + Clone + 'static,
251 R: ServiceFactory<Io> + 'static,
252 {
253 use std::os::unix::net::UnixListener;
254
255 if let Err(e) = std::fs::remove_file(addr.as_ref()) {
258 if e.kind() != std::io::ErrorKind::NotFound {
260 return Err(e);
261 }
262 }
263
264 let lst = UnixListener::bind(addr)?;
265 self.listen_uds(name, lst, factory)
266 }
267
268 #[cfg(unix)]
269 pub fn listen_uds<F, N: AsRef<str>, R>(
273 mut self,
274 name: N,
275 lst: std::os::unix::net::UnixListener,
276 factory: F,
277 ) -> io::Result<Self>
278 where
279 F: Fn(Config) -> R + Send + Clone + 'static,
280 R: ServiceFactory<Io> + 'static,
281 {
282 let token = self.token.next();
283 self.services.push(factory::create_factory_service(
284 name.as_ref().to_string(),
285 vec![(token, "")],
286 factory,
287 ));
288 self.sockets
289 .push((token, name.as_ref().to_string(), Listener::from_uds(lst)));
290 Ok(self)
291 }
292
293 pub fn listen<F, N: AsRef<str>, R>(
295 mut self,
296 name: N,
297 lst: net::TcpListener,
298 factory: F,
299 ) -> io::Result<Self>
300 where
301 F: Fn(Config) -> R + Send + Clone + 'static,
302 R: ServiceFactory<Io> + 'static,
303 {
304 let token = self.token.next();
305 self.services.push(factory::create_factory_service(
306 name.as_ref().to_string(),
307 vec![(token, "")],
308 factory,
309 ));
310 self.sockets
311 .push((token, name.as_ref().to_string(), Listener::from_tcp(lst)));
312 Ok(self)
313 }
314
315 pub fn set_tag<N: AsRef<str>>(mut self, name: N, tag: &'static str) -> Self {
317 let mut token = None;
318 for sock in &self.sockets {
319 if sock.1 == name.as_ref() {
320 token = Some(sock.0);
321 break;
322 }
323 }
324
325 if let Some(token) = token {
326 for svc in &mut self.services {
327 if svc.name(token) == name.as_ref() {
328 svc.set_tag(token, tag);
329 }
330 }
331 } else {
332 panic!("Cannot find service by name {:?}", name.as_ref());
333 }
334
335 self
336 }
337
338 pub fn run(self) -> Server<Connection> {
340 if self.sockets.is_empty() {
341 panic!("Server should have at least one bound socket");
342 } else {
343 let srv = StreamServer::new(
344 self.accept.notify(),
345 self.services,
346 self.on_worker_start,
347 self.on_accept,
348 );
349 let svc = self.pool.run(srv);
350
351 let sockets = self
352 .sockets
353 .into_iter()
354 .map(|sock| {
355 log::info!("Starting \"{}\" service on {}", sock.1, sock.2);
356 (sock.0, sock.2)
357 })
358 .collect();
359 self.accept.start(sockets, svc.clone());
360
361 svc
362 }
363 }
364}
365
366pub fn bind_addr<S: net::ToSocketAddrs>(
367 addr: S,
368 backlog: i32,
369) -> io::Result<Vec<net::TcpListener>> {
370 let mut err = None;
371 let mut succ = false;
372 let mut sockets = Vec::new();
373 for addr in addr.to_socket_addrs()? {
374 match create_tcp_listener(addr, backlog) {
375 Ok(lst) => {
376 succ = true;
377 sockets.push(lst);
378 }
379 Err(e) => err = Some(e),
380 }
381 }
382
383 if !succ {
384 if let Some(e) = err.take() {
385 Err(e)
386 } else {
387 Err(io::Error::new(
388 io::ErrorKind::InvalidInput,
389 "Cannot bind to address.",
390 ))
391 }
392 } else {
393 Ok(sockets)
394 }
395}
396
397pub fn create_tcp_listener(
398 addr: net::SocketAddr,
399 backlog: i32,
400) -> io::Result<net::TcpListener> {
401 let builder = match addr {
402 net::SocketAddr::V4(_) => Socket::new(Domain::IPV4, Type::STREAM, None)?,
403 net::SocketAddr::V6(_) => Socket::new(Domain::IPV6, Type::STREAM, None)?,
404 };
405
406 #[cfg(not(windows))]
410 builder.set_reuse_address(true)?;
411
412 builder.bind(&SockAddr::from(addr))?;
413 builder.listen(backlog)?;
414 Ok(net::TcpListener::from(builder))
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420
421 #[test]
422 fn test_bind_addr() {
423 let addrs: Vec<net::SocketAddr> = Vec::new();
424 assert!(bind_addr(&addrs[..], 10).is_err());
425 }
426
427 #[test]
428 fn test_debug() {
429 let builder = ServerBuilder::default();
430 assert!(format!("{:?}", builder).contains("ServerBuilder"));
431 }
432}