1use kade_proto::{prelude::*, Result};
2use std::future::Future;
3use std::path::PathBuf;
4use std::sync::Arc;
5use tokio::net::{TcpListener, TcpStream};
6use tokio::sync::{broadcast, mpsc, Semaphore};
7use tokio::time::{self, Duration};
8use tracing::{debug, error, info, instrument};
9
10pub use kade_proto::DEFAULT_PORT;
11pub const MAX_CONNECTIONS: usize = 250;
12
13#[derive(Debug)]
14struct Listener {
15 db_holder: DbDropGuard,
16 listener: TcpListener,
17 limit_connections: Arc<Semaphore>,
18 notify_shutdown: broadcast::Sender<()>,
19 shutdown_complete_tx: mpsc::Sender<()>,
20 db_path: Option<PathBuf>,
21}
22
23#[derive(Debug)]
24struct Handler {
25 db: Db,
26 connection: Connection,
27 shutdown: Shutdown,
28 _shutdown_complete: mpsc::Sender<()>,
29}
30
31pub async fn run(listener: TcpListener, shutdown: impl Future, db_path: Option<PathBuf>) {
32 let (notify_shutdown, _) = broadcast::channel(1);
33 let (shutdown_complete_tx, mut shutdown_complete_rx) = mpsc::channel(1);
34 let db_holder = DbDropGuard::new();
35
36 if let Some(path) = &db_path {
37 if path.exists() {
38 info!("Loading database from {:?}", path);
39 db_holder.db().load_from(path).await.expect("Failed to load database");
40 }
41 }
42
43 let mut server = Listener {
44 db_path,
45 listener,
46 db_holder,
47 limit_connections: Arc::new(Semaphore::new(MAX_CONNECTIONS)),
48 notify_shutdown,
49 shutdown_complete_tx,
50 };
51
52 tokio::select! {
53 res = server.run() => {
54 if let Err(err) = res {
55 error!(cause = %err, "failed to accept");
56 }
57 }
58 _ = shutdown => {
59 info!("shutting down");
60 if let Some(path) = &server.db_path {
61 info!("Saving database to {:?}", path);
62 server.db_holder.db().dump_to(path).await.expect("Failed to save database");
63 }
64 }
65 }
66
67 let Listener {
68 shutdown_complete_tx,
69 notify_shutdown,
70 ..
71 } = server;
72
73 drop(notify_shutdown);
74 drop(shutdown_complete_tx);
75
76 let _ = shutdown_complete_rx.recv().await;
77}
78
79impl Listener {
80 async fn run(&mut self) -> Result<()> {
81 info!("accepting inbound connections");
82
83 loop {
84 let permit = self.limit_connections.clone().acquire_owned().await.unwrap();
85 let socket = self.accept().await?;
86
87 let mut handler = Handler {
88 db: self.db_holder.db(),
89 connection: Connection::new(socket),
90 shutdown: Shutdown::new(self.notify_shutdown.subscribe()),
91 _shutdown_complete: self.shutdown_complete_tx.clone(),
92 };
93
94 tokio::spawn(async move {
95 if let Err(err) = handler.run().await {
96 error!(cause = ?err, "connection error");
97 }
98 drop(permit);
99 });
100 }
101 }
102
103 async fn accept(&mut self) -> crate::Result<TcpStream> {
104 let mut backoff = 1;
105
106 loop {
107 match self.listener.accept().await {
108 Ok((socket, _)) => return Ok(socket),
109 Err(err) => {
110 if backoff > 64 {
111 return Err(err.into());
112 }
113 }
114 }
115
116 time::sleep(Duration::from_secs(backoff)).await;
117 backoff *= 2;
118 }
119 }
120}
121
122impl Handler {
123 #[instrument(skip(self))]
124 async fn run(&mut self) -> Result<()> {
125 while !self.shutdown.is_shutdown() {
126 let maybe_frame = tokio::select! {
127 res = self.connection.read_frame() => res?,
128 _ = self.shutdown.recv() => {
129 return Ok(());
130 }
131 };
132
133 let frame = match maybe_frame {
134 Some(frame) => frame,
135 None => return Ok(()),
136 };
137
138 let cmd = Command::from_frame(frame)?;
139 debug!(?cmd);
140
141 cmd.apply(&self.db, &mut self.connection, &mut self.shutdown).await?;
142 }
143
144 Ok(())
145 }
146}