1use std::pin::Pin;
2use std::task::{Context, Poll};
3use std::time::Duration;
4use std::{io, mem, net};
5
6use scrappy_rt::net::TcpStream;
7use scrappy_rt::time::{delay_until, Instant};
8use scrappy_rt::{spawn, System};
9use futures::channel::mpsc::{unbounded, UnboundedReceiver};
10use futures::channel::oneshot;
11use futures::future::ready;
12use futures::stream::FuturesUnordered;
13use futures::{ready, Future, FutureExt, Stream, StreamExt};
14use log::{error, info};
15use net2::TcpBuilder;
16use num_cpus;
17
18use crate::accept::{AcceptLoop, AcceptNotify, Command};
19use crate::config::{ConfiguredService, ServiceConfig};
20use crate::server::{Server, ServerCommand};
21use crate::service::{InternalServiceFactory, ServiceFactory, StreamNewService};
22use crate::signals::{Signal, Signals};
23use crate::socket::StdListener;
24use crate::worker::{self, Worker, WorkerAvailability, WorkerClient};
25use crate::Token;
26
27pub struct ServerBuilder {
29 threads: usize,
30 token: Token,
31 backlog: i32,
32 workers: Vec<(usize, WorkerClient)>,
33 services: Vec<Box<dyn InternalServiceFactory>>,
34 sockets: Vec<(Token, String, StdListener)>,
35 accept: AcceptLoop,
36 exit: bool,
37 shutdown_timeout: Duration,
38 no_signals: bool,
39 cmd: UnboundedReceiver<ServerCommand>,
40 server: Server,
41 notify: Vec<oneshot::Sender<()>>,
42}
43
44impl Default for ServerBuilder {
45 fn default() -> Self {
46 Self::new()
47 }
48}
49
50impl ServerBuilder {
51 pub fn new() -> ServerBuilder {
53 let (tx, rx) = unbounded();
54 let server = Server::new(tx);
55
56 ServerBuilder {
57 threads: num_cpus::get(),
58 token: Token(0),
59 workers: Vec::new(),
60 services: Vec::new(),
61 sockets: Vec::new(),
62 accept: AcceptLoop::new(server.clone()),
63 backlog: 2048,
64 exit: false,
65 shutdown_timeout: Duration::from_secs(30),
66 no_signals: false,
67 cmd: rx,
68 notify: Vec::new(),
69 server,
70 }
71 }
72
73 pub fn workers(mut self, num: usize) -> Self {
78 self.threads = num;
79 self
80 }
81
82 pub fn backlog(mut self, num: i32) -> Self {
93 self.backlog = num;
94 self
95 }
96
97 pub fn maxconn(self, num: usize) -> Self {
104 worker::max_concurrent_connections(num);
105 self
106 }
107
108 pub fn system_exit(mut self) -> Self {
110 self.exit = true;
111 self
112 }
113
114 pub fn disable_signals(mut self) -> Self {
116 self.no_signals = true;
117 self
118 }
119
120 pub fn shutdown_timeout(mut self, sec: u64) -> Self {
128 self.shutdown_timeout = Duration::from_secs(sec);
129 self
130 }
131
132 pub fn configure<F>(mut self, f: F) -> io::Result<ServerBuilder>
138 where
139 F: Fn(&mut ServiceConfig) -> io::Result<()>,
140 {
141 let mut cfg = ServiceConfig::new(self.threads, self.backlog);
142
143 f(&mut cfg)?;
144
145 if let Some(apply) = cfg.apply {
146 let mut srv = ConfiguredService::new(apply);
147 for (name, lst) in cfg.services {
148 let token = self.token.next();
149 srv.stream(token, name.clone(), lst.local_addr()?);
150 self.sockets.push((token, name, StdListener::Tcp(lst)));
151 }
152 self.services.push(Box::new(srv));
153 }
154 self.threads = cfg.threads;
155
156 Ok(self)
157 }
158
159 pub fn bind<F, U, N: AsRef<str>>(mut self, name: N, addr: U, factory: F) -> io::Result<Self>
161 where
162 F: ServiceFactory<TcpStream>,
163 U: net::ToSocketAddrs,
164 {
165 let sockets = bind_addr(addr, self.backlog)?;
166
167 for lst in sockets {
168 let token = self.token.next();
169 self.services.push(StreamNewService::create(
170 name.as_ref().to_string(),
171 token,
172 factory.clone(),
173 lst.local_addr()?,
174 ));
175 self.sockets
176 .push((token, name.as_ref().to_string(), StdListener::Tcp(lst)));
177 }
178 Ok(self)
179 }
180
181 #[cfg(all(unix))]
182 pub fn bind_uds<F, U, N>(self, name: N, addr: U, factory: F) -> io::Result<Self>
184 where
185 F: ServiceFactory<scrappy_rt::net::UnixStream>,
186 N: AsRef<str>,
187 U: AsRef<std::path::Path>,
188 {
189 use std::os::unix::net::UnixListener;
190
191 if let Err(e) = std::fs::remove_file(addr.as_ref()) {
194 if e.kind() != std::io::ErrorKind::NotFound {
196 return Err(e);
197 }
198 }
199
200 let lst = UnixListener::bind(addr)?;
201 self.listen_uds(name, lst, factory)
202 }
203
204 #[cfg(all(unix))]
205 pub fn listen_uds<F, N: AsRef<str>>(
209 mut self,
210 name: N,
211 lst: std::os::unix::net::UnixListener,
212 factory: F,
213 ) -> io::Result<Self>
214 where
215 F: ServiceFactory<scrappy_rt::net::UnixStream>,
216 {
217 use std::net::{IpAddr, Ipv4Addr, SocketAddr};
218 let token = self.token.next();
219 let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
220 self.services.push(StreamNewService::create(
221 name.as_ref().to_string(),
222 token,
223 factory.clone(),
224 addr,
225 ));
226 self.sockets
227 .push((token, name.as_ref().to_string(), StdListener::Uds(lst)));
228 Ok(self)
229 }
230
231 pub fn listen<F, N: AsRef<str>>(
233 mut self,
234 name: N,
235 lst: net::TcpListener,
236 factory: F,
237 ) -> io::Result<Self>
238 where
239 F: ServiceFactory<TcpStream>,
240 {
241 let token = self.token.next();
242 self.services.push(StreamNewService::create(
243 name.as_ref().to_string(),
244 token,
245 factory,
246 lst.local_addr()?,
247 ));
248 self.sockets
249 .push((token, name.as_ref().to_string(), StdListener::Tcp(lst)));
250 Ok(self)
251 }
252
253 #[doc(hidden)]
254 pub fn start(self) -> Server {
255 self.run()
256 }
257
258 pub fn run(mut self) -> Server {
260 if self.sockets.is_empty() {
261 panic!("Server should have at least one bound socket");
262 } else {
263 info!("Starting {} workers", self.threads);
264
265 let mut workers = Vec::new();
267 for idx in 0..self.threads {
268 let worker = self.start_worker(idx, self.accept.get_notify());
269 workers.push(worker.clone());
270 self.workers.push((idx, worker));
271 }
272
273 for sock in &self.sockets {
275 info!("Starting \"{}\" service on {}", sock.1, sock.2);
276 }
277 self.accept.start(
278 mem::replace(&mut self.sockets, Vec::new())
279 .into_iter()
280 .map(|t| (t.0, t.2))
281 .collect(),
282 workers,
283 );
284
285 if !self.no_signals {
287 Signals::start(self.server.clone()).unwrap();
288 }
289
290 let server = self.server.clone();
292 spawn(self);
293 server
294 }
295 }
296
297 fn start_worker(&self, idx: usize, notify: AcceptNotify) -> WorkerClient {
298 let avail = WorkerAvailability::new(notify);
299 let services: Vec<Box<dyn InternalServiceFactory>> =
300 self.services.iter().map(|v| v.clone_factory()).collect();
301
302 Worker::start(idx, services, avail, self.shutdown_timeout)
303 }
304
305 fn handle_cmd(&mut self, item: ServerCommand) {
306 match item {
307 ServerCommand::Pause(tx) => {
308 self.accept.send(Command::Pause);
309 let _ = tx.send(());
310 }
311 ServerCommand::Resume(tx) => {
312 self.accept.send(Command::Resume);
313 let _ = tx.send(());
314 }
315 ServerCommand::Signal(sig) => {
316 match sig {
319 Signal::Int => {
320 info!("SIGINT received, exiting");
321 self.exit = true;
322 self.handle_cmd(ServerCommand::Stop {
323 graceful: false,
324 completion: None,
325 })
326 }
327 Signal::Term => {
328 info!("SIGTERM received, stopping");
329 self.exit = true;
330 self.handle_cmd(ServerCommand::Stop {
331 graceful: true,
332 completion: None,
333 })
334 }
335 Signal::Quit => {
336 info!("SIGQUIT received, exiting");
337 self.exit = true;
338 self.handle_cmd(ServerCommand::Stop {
339 graceful: false,
340 completion: None,
341 })
342 }
343 _ => (),
344 }
345 }
346 ServerCommand::Notify(tx) => {
347 self.notify.push(tx);
348 }
349 ServerCommand::Stop {
350 graceful,
351 completion,
352 } => {
353 let exit = self.exit;
354
355 self.accept.send(Command::Stop);
357 let notify = std::mem::replace(&mut self.notify, Vec::new());
358
359 if !self.workers.is_empty() && graceful {
361 spawn(
362 self.workers
363 .iter()
364 .map(move |worker| worker.1.stop(graceful))
365 .collect::<FuturesUnordered<_>>()
366 .collect::<Vec<_>>()
367 .then(move |_| {
368 if let Some(tx) = completion {
369 let _ = tx.send(());
370 }
371 for tx in notify {
372 let _ = tx.send(());
373 }
374 if exit {
375 spawn(
376 async {
377 delay_until(
378 Instant::now() + Duration::from_millis(300),
379 )
380 .await;
381 System::current().stop();
382 }
383 .boxed(),
384 );
385 }
386 ready(())
387 }),
388 )
389 } else {
390 if self.exit {
392 spawn(
393 delay_until(Instant::now() + Duration::from_millis(300)).then(
394 |_| {
395 System::current().stop();
396 ready(())
397 },
398 ),
399 );
400 }
401 if let Some(tx) = completion {
402 let _ = tx.send(());
403 }
404 for tx in notify {
405 let _ = tx.send(());
406 }
407 }
408 }
409 ServerCommand::WorkerFaulted(idx) => {
410 let mut found = false;
411 for i in 0..self.workers.len() {
412 if self.workers[i].0 == idx {
413 self.workers.swap_remove(i);
414 found = true;
415 break;
416 }
417 }
418
419 if found {
420 error!("Worker has died {:?}, restarting", idx);
421
422 let mut new_idx = self.workers.len();
423 'found: loop {
424 for i in 0..self.workers.len() {
425 if self.workers[i].0 == new_idx {
426 new_idx += 1;
427 continue 'found;
428 }
429 }
430 break;
431 }
432
433 let worker = self.start_worker(new_idx, self.accept.get_notify());
434 self.workers.push((new_idx, worker.clone()));
435 self.accept.send(Command::Worker(worker));
436 }
437 }
438 }
439 }
440}
441
442impl Future for ServerBuilder {
443 type Output = ();
444
445 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
446 loop {
447 match ready!(Pin::new(&mut self.cmd).poll_next(cx)) {
448 Some(it) => self.as_mut().get_mut().handle_cmd(it),
449 None => {
450 return Poll::Pending;
451 }
452 }
453 }
454 }
455}
456
457pub(super) fn bind_addr<S: net::ToSocketAddrs>(
458 addr: S,
459 backlog: i32,
460) -> io::Result<Vec<net::TcpListener>> {
461 let mut err = None;
462 let mut succ = false;
463 let mut sockets = Vec::new();
464 for addr in addr.to_socket_addrs()? {
465 match create_tcp_listener(addr, backlog) {
466 Ok(lst) => {
467 succ = true;
468 sockets.push(lst);
469 }
470 Err(e) => err = Some(e),
471 }
472 }
473
474 if !succ {
475 if let Some(e) = err.take() {
476 Err(e)
477 } else {
478 Err(io::Error::new(
479 io::ErrorKind::Other,
480 "Can not bind to address.",
481 ))
482 }
483 } else {
484 Ok(sockets)
485 }
486}
487
488fn create_tcp_listener(addr: net::SocketAddr, backlog: i32) -> io::Result<net::TcpListener> {
489 let builder = match addr {
490 net::SocketAddr::V4(_) => TcpBuilder::new_v4()?,
491 net::SocketAddr::V6(_) => TcpBuilder::new_v6()?,
492 };
493 builder.reuse_address(true)?;
494 builder.bind(addr)?;
495 Ok(builder.listen(backlog)?)
496}