1use std::{fmt, future::Future, io, net};
2
3use socket2::{Domain, SockAddr, Socket, Type};
4
5use ntex_net::Io;
6use ntex_service::ServiceFactory;
7use ntex_util::time::Millis;
8
9use crate::{Server, WorkerPool};
10
11use super::accept::AcceptLoop;
12use super::config::{Config, ServiceConfig};
13use super::factory::{self, FactoryServiceType, OnWorkerStart, OnWorkerStartWrapper};
14use super::{socket::Listener, Connection, ServerStatus, StreamServer, Token};
15
16pub struct ServerBuilder {
21 token: Token,
22 backlog: i32,
23 services: Vec<FactoryServiceType>,
24 sockets: Vec<(Token, String, Listener)>,
25 on_worker_start: Vec<Box<dyn OnWorkerStart + Send>>,
26 accept: AcceptLoop,
27 pool: WorkerPool,
28}
29
30impl Default for ServerBuilder {
31 fn default() -> Self {
32 Self::new()
33 }
34}
35
36impl fmt::Debug for ServerBuilder {
37 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38 f.debug_struct("ServerBuilder")
39 .field("token", &self.token)
40 .field("backlog", &self.backlog)
41 .field("sockets", &self.sockets)
42 .field("accept", &self.accept)
43 .field("worker-pool", &self.pool)
44 .finish()
45 }
46}
47
48impl ServerBuilder {
49 pub fn new() -> ServerBuilder {
51 ServerBuilder {
52 token: Token(0),
53 services: Vec::new(),
54 sockets: Vec::new(),
55 on_worker_start: Vec::new(),
56 accept: AcceptLoop::default(),
57 backlog: 2048,
58 pool: WorkerPool::new(),
59 }
60 }
61
62 pub fn workers(mut self, num: usize) -> Self {
67 self.pool = self.pool.workers(num);
68 self
69 }
70
71 pub fn backlog(mut self, num: i32) -> Self {
82 self.backlog = num;
83 self
84 }
85
86 pub fn maxconn(self, num: usize) -> Self {
93 super::max_concurrent_connections(num);
94 self
95 }
96
97 pub fn stop_runtime(mut self) -> Self {
101 self.pool = self.pool.stop_runtime();
102 self
103 }
104
105 pub fn disable_signals(mut self) -> Self {
109 self.pool = self.pool.disable_signals();
110 self
111 }
112
113 pub fn enable_affinity(mut self) -> Self {
117 self.pool = self.pool.enable_affinity();
118 self
119 }
120
121 pub fn shutdown_timeout<T: Into<Millis>>(mut self, timeout: T) -> Self {
129 self.pool = self.pool.shutdown_timeout(timeout);
130 self
131 }
132
133 pub fn status_handler<F>(mut self, handler: F) -> Self
137 where
138 F: FnMut(ServerStatus) + Send + 'static,
139 {
140 self.accept.set_status_handler(handler);
141 self
142 }
143
144 pub fn configure<F>(mut self, f: F) -> io::Result<ServerBuilder>
150 where
151 F: Fn(&mut ServiceConfig) -> io::Result<()>,
152 {
153 let mut cfg = ServiceConfig::new(self.token, self.backlog);
154
155 f(&mut cfg)?;
156
157 let (token, sockets, factory) = cfg.into_factory();
158 self.token = token;
159 self.sockets.extend(sockets);
160 self.services.push(factory);
161
162 Ok(self)
163 }
164
165 pub async fn configure_async<F, R>(mut self, f: F) -> io::Result<ServerBuilder>
171 where
172 F: Fn(ServiceConfig) -> R,
173 R: Future<Output = io::Result<()>>,
174 {
175 let cfg = ServiceConfig::new(self.token, self.backlog);
176
177 f(cfg.clone()).await?;
178
179 let (token, sockets, factory) = cfg.into_factory();
180 self.token = token;
181 self.sockets.extend(sockets);
182 self.services.push(factory);
183
184 Ok(self)
185 }
186
187 pub fn on_worker_start<F, R, E>(mut self, f: F) -> Self
192 where
193 F: Fn() -> R + Send + Clone + 'static,
194 R: Future<Output = Result<(), E>> + 'static,
195 E: fmt::Display + 'static,
196 {
197 self.on_worker_start.push(OnWorkerStartWrapper::create(f));
198 self
199 }
200
201 pub fn bind<F, U, N, R>(mut self, name: N, addr: U, factory: F) -> io::Result<Self>
203 where
204 U: net::ToSocketAddrs,
205 N: AsRef<str>,
206 F: Fn(Config) -> R + Send + Clone + 'static,
207 R: ServiceFactory<Io> + 'static,
208 {
209 let sockets = bind_addr(addr, self.backlog)?;
210
211 let mut tokens = Vec::new();
212 for lst in sockets {
213 let token = self.token.next();
214 self.sockets
215 .push((token, name.as_ref().to_string(), Listener::from_tcp(lst)));
216 tokens.push((token, ""));
217 }
218
219 self.services.push(factory::create_factory_service(
220 name.as_ref().to_string(),
221 tokens,
222 factory,
223 ));
224
225 Ok(self)
226 }
227
228 #[cfg(unix)]
229 pub fn bind_uds<F, U, N, R>(self, name: N, addr: U, factory: F) -> io::Result<Self>
231 where
232 N: AsRef<str>,
233 U: AsRef<std::path::Path>,
234 F: Fn(Config) -> R + Send + Clone + 'static,
235 R: ServiceFactory<Io> + 'static,
236 {
237 use std::os::unix::net::UnixListener;
238
239 if let Err(e) = std::fs::remove_file(addr.as_ref()) {
242 if e.kind() != std::io::ErrorKind::NotFound {
244 return Err(e);
245 }
246 }
247
248 let lst = UnixListener::bind(addr)?;
249 self.listen_uds(name, lst, factory)
250 }
251
252 #[cfg(unix)]
253 pub fn listen_uds<F, N: AsRef<str>, R>(
257 mut self,
258 name: N,
259 lst: std::os::unix::net::UnixListener,
260 factory: F,
261 ) -> io::Result<Self>
262 where
263 F: Fn(Config) -> R + Send + Clone + 'static,
264 R: ServiceFactory<Io> + 'static,
265 {
266 let token = self.token.next();
267 self.services.push(factory::create_factory_service(
268 name.as_ref().to_string(),
269 vec![(token, "")],
270 factory,
271 ));
272 self.sockets
273 .push((token, name.as_ref().to_string(), Listener::from_uds(lst)));
274 Ok(self)
275 }
276
277 pub fn listen<F, N: AsRef<str>, R>(
279 mut self,
280 name: N,
281 lst: net::TcpListener,
282 factory: F,
283 ) -> io::Result<Self>
284 where
285 F: Fn(Config) -> R + Send + Clone + 'static,
286 R: ServiceFactory<Io> + 'static,
287 {
288 let token = self.token.next();
289 self.services.push(factory::create_factory_service(
290 name.as_ref().to_string(),
291 vec![(token, "")],
292 factory,
293 ));
294 self.sockets
295 .push((token, name.as_ref().to_string(), Listener::from_tcp(lst)));
296 Ok(self)
297 }
298
299 pub fn set_tag<N: AsRef<str>>(mut self, name: N, tag: &'static str) -> Self {
301 let mut token = None;
302 for sock in &self.sockets {
303 if sock.1 == name.as_ref() {
304 token = Some(sock.0);
305 break;
306 }
307 }
308
309 if let Some(token) = token {
310 for svc in &mut self.services {
311 if svc.name(token) == name.as_ref() {
312 svc.set_tag(token, tag);
313 }
314 }
315 } else {
316 panic!("Cannot find service by name {:?}", name.as_ref());
317 }
318
319 self
320 }
321
322 pub fn run(self) -> Server<Connection> {
324 if self.sockets.is_empty() {
325 panic!("Server should have at least one bound socket");
326 } else {
327 let srv = StreamServer::new(
328 self.accept.notify(),
329 self.services,
330 self.on_worker_start,
331 );
332 let svc = self.pool.run(srv);
333
334 let sockets = self
335 .sockets
336 .into_iter()
337 .map(|sock| {
338 log::info!("Starting \"{}\" service on {}", sock.1, sock.2);
339 (sock.0, sock.2)
340 })
341 .collect();
342 self.accept.start(sockets, svc.clone());
343
344 svc
345 }
346 }
347}
348
349pub fn bind_addr<S: net::ToSocketAddrs>(
350 addr: S,
351 backlog: i32,
352) -> io::Result<Vec<net::TcpListener>> {
353 let mut err = None;
354 let mut succ = false;
355 let mut sockets = Vec::new();
356 for addr in addr.to_socket_addrs()? {
357 match create_tcp_listener(addr, backlog) {
358 Ok(lst) => {
359 succ = true;
360 sockets.push(lst);
361 }
362 Err(e) => err = Some(e),
363 }
364 }
365
366 if !succ {
367 if let Some(e) = err.take() {
368 Err(e)
369 } else {
370 Err(io::Error::new(
371 io::ErrorKind::InvalidInput,
372 "Cannot bind to address.",
373 ))
374 }
375 } else {
376 Ok(sockets)
377 }
378}
379
380pub fn create_tcp_listener(
381 addr: net::SocketAddr,
382 backlog: i32,
383) -> io::Result<net::TcpListener> {
384 let builder = match addr {
385 net::SocketAddr::V4(_) => Socket::new(Domain::IPV4, Type::STREAM, None)?,
386 net::SocketAddr::V6(_) => Socket::new(Domain::IPV6, Type::STREAM, None)?,
387 };
388
389 #[cfg(not(windows))]
393 builder.set_reuse_address(true)?;
394
395 builder.bind(&SockAddr::from(addr))?;
396 builder.listen(backlog)?;
397 Ok(net::TcpListener::from(builder))
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403
404 #[test]
405 fn test_bind_addr() {
406 let addrs: Vec<net::SocketAddr> = Vec::new();
407 assert!(bind_addr(&addrs[..], 10).is_err());
408 }
409
410 #[test]
411 fn test_debug() {
412 let builder = ServerBuilder::default();
413 assert!(format!("{:?}", builder).contains("ServerBuilder"));
414 }
415}