1use std::{fmt, io, net, sync::Arc};
2
3use socket2::{Domain, SockAddr, Socket, Type};
4
5use ntex_io::Io;
6use ntex_rt::System;
7use ntex_service::{ServiceFactory, cfg::SharedCfg};
8use ntex_util::time::Millis;
9
10use crate::{Server, WorkerPool};
11
12use super::accept::AcceptLoop;
13use super::config::{Config, ServiceConfig};
14use super::factory::{self, FactoryServiceType};
15use super::factory::{OnAccept, OnAcceptWrapper, OnWorkerStart, OnWorkerStartWrapper};
16use super::{Connection, ServerStatus, Stream, StreamServer, Token, socket::Listener};
17
18pub struct ServerBuilder {
23 name: String,
24 token: Token,
25 backlog: i32,
26 services: Vec<FactoryServiceType>,
27 sockets: Vec<(Token, String, Listener)>,
28 on_worker_start: Vec<Box<dyn OnWorkerStart + Send>>,
29 on_accept: Option<Box<dyn OnAccept + Send>>,
30 accept: AcceptLoop,
31 pool: WorkerPool,
32}
33
34impl Default for ServerBuilder {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40impl fmt::Debug for ServerBuilder {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 f.debug_struct("ServerBuilder")
43 .field("name", &self.name)
44 .field("token", &self.token)
45 .field("backlog", &self.backlog)
46 .field("sockets", &self.sockets)
47 .field("accept", &self.accept)
48 .field("worker-pool", &self.pool)
49 .finish()
50 }
51}
52
53impl ServerBuilder {
54 pub fn new() -> ServerBuilder {
56 let sys = System::current();
57 let mut accept = AcceptLoop::default();
58 accept.name(sys.name());
59 if sys.testing() {
60 accept.testing()
61 }
62
63 ServerBuilder {
64 accept,
65 name: sys.name().to_string(),
66 token: Token(0),
67 services: Vec::new(),
68 sockets: Vec::new(),
69 on_accept: None,
70 on_worker_start: Vec::new(),
71 backlog: 2048,
72 pool: WorkerPool::default().name(sys.name()),
73 }
74 }
75
76 pub fn name<T: AsRef<str>>(mut self, name: T) -> Self {
80 self.name = name.as_ref().to_string();
81 self.accept.name(self.name.as_str());
82 self.pool = self.pool.name(self.name.as_str());
83 self
84 }
85
86 pub fn workers(mut self, num: usize) -> Self {
91 self.pool = self.pool.workers(num);
92 self
93 }
94
95 pub fn backlog(mut self, num: i32) -> Self {
106 self.backlog = num;
107 self
108 }
109
110 pub fn maxconn(self, num: usize) -> Self {
117 super::max_concurrent_connections(num);
118 self
119 }
120
121 pub fn stop_runtime(mut self) -> Self {
125 self.pool = self.pool.stop_runtime();
126 self
127 }
128
129 pub fn disable_signals(mut self) -> Self {
133 self.pool = self.pool.disable_signals();
134 self
135 }
136
137 pub fn enable_affinity(mut self) -> Self {
141 self.pool = self.pool.enable_affinity();
142 self
143 }
144
145 pub fn shutdown_timeout<T: Into<Millis>>(mut self, timeout: T) -> Self {
153 self.pool = self.pool.shutdown_timeout(timeout);
154 self
155 }
156
157 pub fn status_handler<F>(mut self, handler: F) -> Self
161 where
162 F: FnMut(ServerStatus) + Send + 'static,
163 {
164 self.accept.set_status_handler(handler);
165 self
166 }
167
168 pub async fn configure<F>(mut self, f: F) -> io::Result<ServerBuilder>
174 where
175 F: AsyncFn(ServiceConfig) -> io::Result<()>,
176 {
177 let cfg = ServiceConfig::new(self.token, self.backlog);
178
179 f(cfg.clone()).await?;
180
181 let (token, sockets, factory) = cfg.into_factory();
182 self.token = token;
183 self.sockets.extend(sockets);
184 self.services.push(factory);
185
186 Ok(self)
187 }
188
189 pub fn on_worker_start<F, E>(mut self, f: F) -> Self
194 where
195 F: AsyncFn() -> Result<(), E> + Send + Clone + 'static,
196 E: fmt::Display + 'static,
197 {
198 self.on_worker_start.push(OnWorkerStartWrapper::create(f));
199 self
200 }
201
202 pub fn on_accept<F, E>(mut self, f: F) -> Self
206 where
207 F: AsyncFn(Arc<str>, Stream) -> Result<Stream, E> + Send + Clone + 'static,
208 E: fmt::Display + 'static,
209 {
210 self.on_accept = Some(OnAcceptWrapper::create(f));
211 self
212 }
213
214 pub fn bind<F, U, N, R>(mut self, name: N, addr: U, factory: F) -> io::Result<Self>
216 where
217 U: net::ToSocketAddrs,
218 N: AsRef<str>,
219 F: AsyncFn(Config) -> R + Send + Clone + 'static,
220 R: ServiceFactory<Io, SharedCfg> + 'static,
221 {
222 let sockets = bind_addr(addr, self.backlog)?;
223
224 let mut tokens = Vec::new();
225 for lst in sockets {
226 let token = self.token.next();
227 self.sockets
228 .push((token, name.as_ref().to_string(), Listener::from_tcp(lst)));
229 tokens.push((token, SharedCfg::default()));
230 }
231
232 self.services.push(factory::create_factory_service(
233 name.as_ref().to_string(),
234 tokens,
235 factory,
236 ));
237
238 Ok(self)
239 }
240
241 #[cfg(unix)]
242 pub fn bind_uds<F, U, N, R>(self, name: N, addr: U, factory: F) -> io::Result<Self>
244 where
245 N: AsRef<str>,
246 U: AsRef<std::path::Path>,
247 F: AsyncFn(Config) -> R + Send + Clone + 'static,
248 R: ServiceFactory<Io, SharedCfg> + 'static,
249 {
250 use std::os::unix::net::UnixListener;
251
252 if let Err(e) = std::fs::remove_file(addr.as_ref()) {
255 if e.kind() != std::io::ErrorKind::NotFound {
257 return Err(e);
258 }
259 }
260
261 let lst = UnixListener::bind(addr)?;
262 self.listen_uds(name, lst, factory)
263 }
264
265 #[cfg(unix)]
266 pub fn listen_uds<F, N: AsRef<str>, R>(
270 mut self,
271 name: N,
272 lst: std::os::unix::net::UnixListener,
273 factory: F,
274 ) -> io::Result<Self>
275 where
276 F: AsyncFn(Config) -> R + Send + Clone + 'static,
277 R: ServiceFactory<Io, SharedCfg> + 'static,
278 {
279 let token = self.token.next();
280 self.services.push(factory::create_factory_service(
281 name.as_ref().to_string(),
282 vec![(token, SharedCfg::default())],
283 factory,
284 ));
285 self.sockets
286 .push((token, name.as_ref().to_string(), Listener::from_uds(lst)));
287 Ok(self)
288 }
289
290 pub fn listen<F, N: AsRef<str>, R>(
292 mut self,
293 name: N,
294 lst: net::TcpListener,
295 factory: F,
296 ) -> io::Result<Self>
297 where
298 F: AsyncFn(Config) -> R + Send + Clone + 'static,
299 R: ServiceFactory<Io, SharedCfg> + 'static,
300 {
301 let token = self.token.next();
302 self.services.push(factory::create_factory_service(
303 name.as_ref().to_string(),
304 vec![(token, SharedCfg::default())],
305 factory,
306 ));
307 self.sockets
308 .push((token, name.as_ref().to_string(), Listener::from_tcp(lst)));
309 Ok(self)
310 }
311
312 pub fn config<N, U>(mut self, name: N, cfg: U) -> Self
314 where
315 N: AsRef<str>,
316 U: Into<SharedCfg>,
317 {
318 let cfg = cfg.into();
319 let mut token = None;
320 for sock in &self.sockets {
321 if sock.1 == name.as_ref() {
322 token = Some(sock.0);
323 break;
324 }
325 }
326
327 if let Some(token) = token {
328 for svc in &mut self.services {
329 if svc.name(token) == name.as_ref() {
330 svc.set_config(token, cfg);
331 }
332 }
333 } else {
334 panic!("Cannot find service by name {:?}", name.as_ref());
335 }
336
337 self
338 }
339
340 pub fn run(self) -> Server<Connection> {
342 if self.sockets.is_empty() {
343 panic!("Server should have at least one bound socket");
344 } else {
345 let srv = StreamServer::new(
346 self.accept.notify(),
347 self.services,
348 self.on_worker_start,
349 self.on_accept,
350 );
351 let svc = self.pool.run(srv);
352
353 let sockets = self
354 .sockets
355 .into_iter()
356 .map(|sock| {
357 log::info!("Starting \"{}\" service on {}", sock.1, sock.2);
358 (sock.0, sock.2)
359 })
360 .collect();
361 self.accept.start(sockets, svc.clone());
362
363 svc
364 }
365 }
366}
367
368pub fn bind_addr<S: net::ToSocketAddrs>(
369 addr: S,
370 backlog: i32,
371) -> io::Result<Vec<net::TcpListener>> {
372 let mut err = None;
373 let mut succ = false;
374 let mut sockets = Vec::new();
375 for addr in addr.to_socket_addrs()? {
376 match create_tcp_listener(addr, backlog) {
377 Ok(lst) => {
378 succ = true;
379 sockets.push(lst);
380 }
381 Err(e) => err = Some(e),
382 }
383 }
384
385 if !succ {
386 if let Some(e) = err.take() {
387 Err(e)
388 } else {
389 Err(io::Error::new(
390 io::ErrorKind::InvalidInput,
391 "Cannot bind to address.",
392 ))
393 }
394 } else {
395 Ok(sockets)
396 }
397}
398
399pub fn create_tcp_listener(
400 addr: net::SocketAddr,
401 backlog: i32,
402) -> io::Result<net::TcpListener> {
403 let builder = match addr {
404 net::SocketAddr::V4(_) => Socket::new(Domain::IPV4, Type::STREAM, None)?,
405 net::SocketAddr::V6(_) => Socket::new(Domain::IPV6, Type::STREAM, None)?,
406 };
407
408 #[cfg(not(windows))]
412 builder.set_reuse_address(true)?;
413
414 builder.bind(&SockAddr::from(addr))?;
415 builder.listen(backlog)?;
416 Ok(net::TcpListener::from(builder))
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422
423 #[test]
424 fn test_bind_addr() {
425 let addrs: Vec<net::SocketAddr> = Vec::new();
426 assert!(bind_addr(&addrs[..], 10).is_err());
427 }
428
429 #[ntex::test]
430 async fn test_debug() {
431 let builder = ServerBuilder::default();
432 assert!(format!("{builder:?}").contains("ServerBuilder"));
433 }
434}