1use std::{fmt, io, net, sync::Arc};
2
3use socket2::{Domain, SockAddr, Socket, Type};
4
5use ntex_io::Io;
6use ntex_service::{ServiceFactory, cfg::SharedCfg};
7use ntex_util::time::Millis;
8
9use crate::{Server, WorkerPool};
10
11use super::accept::AcceptLoop;
12use super::config::{Config, ServiceConfig};
13use super::factory::{self, FactoryServiceType};
14use super::factory::{OnAccept, OnAcceptWrapper, OnWorkerStart, OnWorkerStartWrapper};
15use super::{Connection, ServerStatus, Stream, StreamServer, Token, socket::Listener};
16
17pub struct ServerBuilder {
22 token: Token,
23 backlog: i32,
24 services: Vec<FactoryServiceType>,
25 sockets: Vec<(Token, String, Listener)>,
26 on_worker_start: Vec<Box<dyn OnWorkerStart + Send>>,
27 on_accept: Option<Box<dyn OnAccept + Send>>,
28 accept: AcceptLoop,
29 pool: WorkerPool,
30}
31
32impl Default for ServerBuilder {
33 fn default() -> Self {
34 Self::new()
35 }
36}
37
38impl fmt::Debug for ServerBuilder {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 f.debug_struct("ServerBuilder")
41 .field("token", &self.token)
42 .field("backlog", &self.backlog)
43 .field("sockets", &self.sockets)
44 .field("accept", &self.accept)
45 .field("worker-pool", &self.pool)
46 .finish()
47 }
48}
49
50impl ServerBuilder {
51 pub fn new() -> ServerBuilder {
53 ServerBuilder {
54 token: Token(0),
55 services: Vec::new(),
56 sockets: Vec::new(),
57 on_accept: None,
58 on_worker_start: Vec::new(),
59 accept: AcceptLoop::default(),
60 backlog: 2048,
61 pool: WorkerPool::new(),
62 }
63 }
64
65 pub fn workers(mut self, num: usize) -> Self {
70 self.pool = self.pool.workers(num);
71 self
72 }
73
74 pub fn backlog(mut self, num: i32) -> Self {
85 self.backlog = num;
86 self
87 }
88
89 pub fn maxconn(self, num: usize) -> Self {
96 super::max_concurrent_connections(num);
97 self
98 }
99
100 pub fn stop_runtime(mut self) -> Self {
104 self.pool = self.pool.stop_runtime();
105 self
106 }
107
108 pub fn disable_signals(mut self) -> Self {
112 self.pool = self.pool.disable_signals();
113 self
114 }
115
116 pub fn enable_affinity(mut self) -> Self {
120 self.pool = self.pool.enable_affinity();
121 self
122 }
123
124 pub fn shutdown_timeout<T: Into<Millis>>(mut self, timeout: T) -> Self {
132 self.pool = self.pool.shutdown_timeout(timeout);
133 self
134 }
135
136 pub fn status_handler<F>(mut self, handler: F) -> Self
140 where
141 F: FnMut(ServerStatus) + Send + 'static,
142 {
143 self.accept.set_status_handler(handler);
144 self
145 }
146
147 pub async fn configure<F>(mut self, f: F) -> io::Result<ServerBuilder>
153 where
154 F: AsyncFn(ServiceConfig) -> io::Result<()>,
155 {
156 let cfg = ServiceConfig::new(self.token, self.backlog);
157
158 f(cfg.clone()).await?;
159
160 let (token, sockets, factory) = cfg.into_factory();
161 self.token = token;
162 self.sockets.extend(sockets);
163 self.services.push(factory);
164
165 Ok(self)
166 }
167
168 pub fn on_worker_start<F, E>(mut self, f: F) -> Self
173 where
174 F: AsyncFn() -> Result<(), E> + Send + Clone + 'static,
175 E: fmt::Display + 'static,
176 {
177 self.on_worker_start.push(OnWorkerStartWrapper::create(f));
178 self
179 }
180
181 pub fn on_accept<F, E>(mut self, f: F) -> Self
185 where
186 F: AsyncFn(Arc<str>, Stream) -> Result<Stream, E> + Send + Clone + 'static,
187 E: fmt::Display + 'static,
188 {
189 self.on_accept = Some(OnAcceptWrapper::create(f));
190 self
191 }
192
193 pub fn bind<F, U, N, R>(mut self, name: N, addr: U, factory: F) -> io::Result<Self>
195 where
196 U: net::ToSocketAddrs,
197 N: AsRef<str>,
198 F: AsyncFn(Config) -> R + Send + Clone + 'static,
199 R: ServiceFactory<Io, SharedCfg> + 'static,
200 {
201 let sockets = bind_addr(addr, self.backlog)?;
202
203 let mut tokens = Vec::new();
204 for lst in sockets {
205 let token = self.token.next();
206 self.sockets
207 .push((token, name.as_ref().to_string(), Listener::from_tcp(lst)));
208 tokens.push((token, SharedCfg::default()));
209 }
210
211 self.services.push(factory::create_factory_service(
212 name.as_ref().to_string(),
213 tokens,
214 factory,
215 ));
216
217 Ok(self)
218 }
219
220 #[cfg(unix)]
221 pub fn bind_uds<F, U, N, R>(self, name: N, addr: U, factory: F) -> io::Result<Self>
223 where
224 N: AsRef<str>,
225 U: AsRef<std::path::Path>,
226 F: AsyncFn(Config) -> R + Send + Clone + 'static,
227 R: ServiceFactory<Io, SharedCfg> + 'static,
228 {
229 use std::os::unix::net::UnixListener;
230
231 if let Err(e) = std::fs::remove_file(addr.as_ref()) {
234 if e.kind() != std::io::ErrorKind::NotFound {
236 return Err(e);
237 }
238 }
239
240 let lst = UnixListener::bind(addr)?;
241 self.listen_uds(name, lst, factory)
242 }
243
244 #[cfg(unix)]
245 pub fn listen_uds<F, N: AsRef<str>, R>(
249 mut self,
250 name: N,
251 lst: std::os::unix::net::UnixListener,
252 factory: F,
253 ) -> io::Result<Self>
254 where
255 F: AsyncFn(Config) -> R + Send + Clone + 'static,
256 R: ServiceFactory<Io, SharedCfg> + 'static,
257 {
258 let token = self.token.next();
259 self.services.push(factory::create_factory_service(
260 name.as_ref().to_string(),
261 vec![(token, SharedCfg::default())],
262 factory,
263 ));
264 self.sockets
265 .push((token, name.as_ref().to_string(), Listener::from_uds(lst)));
266 Ok(self)
267 }
268
269 pub fn listen<F, N: AsRef<str>, R>(
271 mut self,
272 name: N,
273 lst: net::TcpListener,
274 factory: F,
275 ) -> io::Result<Self>
276 where
277 F: AsyncFn(Config) -> R + Send + Clone + 'static,
278 R: ServiceFactory<Io, SharedCfg> + 'static,
279 {
280 let token = self.token.next();
281 self.services.push(factory::create_factory_service(
282 name.as_ref().to_string(),
283 vec![(token, SharedCfg::default())],
284 factory,
285 ));
286 self.sockets
287 .push((token, name.as_ref().to_string(), Listener::from_tcp(lst)));
288 Ok(self)
289 }
290
291 pub fn config<N, U>(mut self, name: N, cfg: U) -> Self
293 where
294 N: AsRef<str>,
295 U: Into<SharedCfg>,
296 {
297 let cfg = cfg.into();
298 let mut token = None;
299 for sock in &self.sockets {
300 if sock.1 == name.as_ref() {
301 token = Some(sock.0);
302 break;
303 }
304 }
305
306 if let Some(token) = token {
307 for svc in &mut self.services {
308 if svc.name(token) == name.as_ref() {
309 svc.set_config(token, cfg);
310 }
311 }
312 } else {
313 panic!("Cannot find service by name {:?}", name.as_ref());
314 }
315
316 self
317 }
318
319 #[doc(hidden)]
320 pub fn testing(mut self) -> Self {
322 self.accept.testing();
323 self
324 }
325
326 pub fn run(self) -> Server<Connection> {
328 if self.sockets.is_empty() {
329 panic!("Server should have at least one bound socket");
330 } else {
331 let srv = StreamServer::new(
332 self.accept.notify(),
333 self.services,
334 self.on_worker_start,
335 self.on_accept,
336 );
337 let svc = self.pool.run(srv);
338
339 let sockets = self
340 .sockets
341 .into_iter()
342 .map(|sock| {
343 log::info!("Starting \"{}\" service on {}", sock.1, sock.2);
344 (sock.0, sock.2)
345 })
346 .collect();
347 self.accept.start(sockets, svc.clone());
348
349 svc
350 }
351 }
352}
353
354pub fn bind_addr<S: net::ToSocketAddrs>(
355 addr: S,
356 backlog: i32,
357) -> io::Result<Vec<net::TcpListener>> {
358 let mut err = None;
359 let mut succ = false;
360 let mut sockets = Vec::new();
361 for addr in addr.to_socket_addrs()? {
362 match create_tcp_listener(addr, backlog) {
363 Ok(lst) => {
364 succ = true;
365 sockets.push(lst);
366 }
367 Err(e) => err = Some(e),
368 }
369 }
370
371 if !succ {
372 if let Some(e) = err.take() {
373 Err(e)
374 } else {
375 Err(io::Error::new(
376 io::ErrorKind::InvalidInput,
377 "Cannot bind to address.",
378 ))
379 }
380 } else {
381 Ok(sockets)
382 }
383}
384
385pub fn create_tcp_listener(
386 addr: net::SocketAddr,
387 backlog: i32,
388) -> io::Result<net::TcpListener> {
389 let builder = match addr {
390 net::SocketAddr::V4(_) => Socket::new(Domain::IPV4, Type::STREAM, None)?,
391 net::SocketAddr::V6(_) => Socket::new(Domain::IPV6, Type::STREAM, None)?,
392 };
393
394 #[cfg(not(windows))]
398 builder.set_reuse_address(true)?;
399
400 builder.bind(&SockAddr::from(addr))?;
401 builder.listen(backlog)?;
402 Ok(net::TcpListener::from(builder))
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[test]
410 fn test_bind_addr() {
411 let addrs: Vec<net::SocketAddr> = Vec::new();
412 assert!(bind_addr(&addrs[..], 10).is_err());
413 }
414
415 #[test]
416 fn test_debug() {
417 let builder = ServerBuilder::default();
418 assert!(format!("{builder:?}").contains("ServerBuilder"));
419 }
420}