1use std::sync::{Arc, Mutex};
2use std::time::Duration;
3
4use bytes::BytesMut;
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tokio::net::{TcpListener, TcpStream};
7use tokio::stream::StreamExt;
8
9use crate::commands::CommandDictionary;
10use crate::flag::Flag;
11use crate::platform::Platform;
12use crate::request::Request;
13use crate::watch::{Average, Watch};
14
15const READ_WAIT_TIMEOUT: Duration = Duration::from_millis(500);
16const DEFAULT_BUFFER_SIZE: usize = 8192;
17
18struct Connection {
19 peer_address: String,
20 commands: Average,
21 active: Flag,
22}
23
24impl PartialEq for Connection {
25 fn eq(&self, other: &Self) -> bool {
26 self.peer_address == other.peer_address
27 }
28}
29
30impl Connection {
31 fn is_active(&self) -> bool {
32 self.active.read()
33 }
34}
35
36pub struct ConnectionInfo {
37 pub peer_address: String,
38 pub commands: Average,
39}
40
41pub struct Server {
42 running: Flag,
43 current_address: Mutex<Option<String>>,
44 platform: Arc<Platform>,
45 connections: Mutex<Vec<Arc<Connection>>>,
46}
47
48impl Server {
49 pub fn install(platform: &Arc<Platform>) -> Arc<Self> {
50 let server = Arc::new(Server {
51 running: Flag::new(false),
52 current_address: Mutex::new(None),
53 platform: platform.clone(),
54 connections: Mutex::new(Vec::new()),
55 });
56
57 platform.register::<Server>(server.clone());
58
59 server
60 }
61
62 pub fn connections(&self) -> Vec<ConnectionInfo> {
63 let mut result = Vec::new();
64 for connection in self.connections.lock().unwrap().iter() {
65 result.push(ConnectionInfo {
66 peer_address: connection.peer_address.clone(),
67 commands: connection.commands.clone(),
68 });
69 }
70
71 result
72 }
73
74 pub fn kill(&self, peer_address: &str) -> bool {
75 self.connections
76 .lock()
77 .unwrap()
78 .iter()
79 .find(|c| &c.peer_address == peer_address)
80 .map(|c| c.active.change(false))
81 .is_some()
82 }
83
84 fn add_connection(&self, connection: Arc<Connection>) {
85 self.connections.lock().unwrap().push(connection);
86 }
87
88 fn remove_connection(&self, connection: Arc<Connection>) {
89 let mut mut_connections = self.connections.lock().unwrap();
90 if let Some(index) = mut_connections
91 .iter()
92 .position(|other| *other == connection)
93 {
94 mut_connections.remove(index);
95 }
96 }
97
98 fn is_running(&self) -> bool {
99 self.running.read()
100 }
101
102 fn address(&self) -> String {
103 "0.0.0.0:2410".to_owned()
104 }
110
111 pub fn fork(server: &Arc<Server>) {
112 let cloned_server = server.clone();
113 tokio::spawn(async move {
114 cloned_server.event_loop().await;
115 });
116 }
117
118 pub async fn fork_and_await(server: &Arc<Server>) {
119 Server::fork(server);
120
121 while !server.running.read() {
122 tokio::time::delay_for(Duration::from_secs(1)).await;
123 }
124 }
125
126 pub async fn event_loop(&self) {
127 let mut address = String::new();
128
129 while self.platform.is_running.read() {
130 if !self.is_running() {
131 address = self.address();
132 self.running.change(true);
133 }
134
135 if let Ok(mut listener) = TcpListener::bind(&address).await {
136 log::info!("Opened server socket on {}...", &address);
137 *self.current_address.lock().unwrap() = Some(address.clone());
138 self.server_loop(&mut listener).await;
139 log::info!("Closing server socket on {}.", &address);
140 } else {
141 log::error!(
142 "Cannot open server address: {}... Retrying in 5s.",
143 &address
144 );
145 tokio::time::delay_for(Duration::from_secs(5)).await;
146 }
147 }
148 }
149
150 async fn server_loop(&self, listener: &mut TcpListener) {
151 let mut incoming = listener.incoming();
152 let mut kill_flag = self.platform.is_running.listener();
153 while self.platform.is_running.read() && self.is_running() {
156 tokio::select! {
157 stream = incoming.next() => {
158 if let Some(Ok(stream)) = stream {
159 self.handle_new_connection(stream);
160 } else {
161 return;
162 }
163 }
164 _ = kill_flag.expect() => { return; }
175 }
176 }
177 }
178
179 fn handle_new_connection(&self, stream: TcpStream) {
180 let connection = Arc::new(Connection {
181 peer_address: stream
182 .peer_addr()
183 .map(|addr| addr.to_string())
184 .unwrap_or("<unknown>".to_string()),
185 commands: Average::new(),
186 active: Flag::new(true),
187 });
188
189 let platform = self.platform.clone();
190 tokio::spawn(async move {
191 let server = platform.require::<Server>();
192 log::info!("Opened connection from {}...", connection.peer_address);
193 server.add_connection(connection.clone());
194 if let Err(error) = Server::client_loop(platform, connection.clone(), stream).await {
195 log::warn!(
196 "An IO error occurred in connection {}: {}",
197 connection.peer_address,
198 error
199 );
200 }
201 log::info!("Closing connection to {}...", connection.peer_address);
202 server.remove_connection(connection);
203 });
204 }
205
206 async fn client_loop(
207 platform: Arc<Platform>,
208 connection: Arc<Connection>,
209 mut stream: TcpStream,
210 ) -> anyhow::Result<()> {
211 let mut dispatcher = platform.require::<CommandDictionary>().dispatcher();
212 let mut input_buffer = BytesMut::with_capacity(DEFAULT_BUFFER_SIZE);
213 let (mut reader, mut writer) = stream.split();
214
215 while platform.is_running.read() && connection.is_active() {
216 match tokio::time::timeout(READ_WAIT_TIMEOUT, reader.read_buf(&mut input_buffer)).await
217 {
218 Ok(Ok(bytes_read)) if bytes_read > 0 => match Request::parse(&mut input_buffer) {
219 Ok(Some(request)) => {
220 let watch = Watch::start();
221 match dispatcher.invoke(request).await {
222 Ok(response_data) => {
223 connection.commands.add(watch.micros());
224 writer.write_all(response_data.as_ref()).await?;
225 writer.flush().await?;
226 }
227 Err(error) => {
228 let error_message = error
229 .to_string()
230 .replace("\r", " ")
231 .to_string()
232 .replace("\n", " ");
233 writer
234 .write_all(format!("-SERVER: {}\r\n", error_message).as_bytes())
235 .await?;
236 writer.flush().await?;
237 return Ok(());
238 }
239 }
240
241 if input_buffer.len() == 0 && input_buffer.capacity() > DEFAULT_BUFFER_SIZE
242 {
243 input_buffer = BytesMut::with_capacity(DEFAULT_BUFFER_SIZE);
244 }
245 }
246 Err(error) => return Err(error),
247 _ => (),
248 },
249 Ok(Ok(0)) => return Ok(()),
250 Ok(Err(error)) => {
251 return Err(anyhow::anyhow!(
252 "An error occurred while reading from the client: {}",
253 error
254 ));
255 }
256 _ => (),
257 }
258 }
259
260 Ok(())
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use tokio::time::Duration;
267
268 use crate::commands::CommandDictionary;
269 use crate::ping;
270 use crate::platform::Platform;
271 use crate::server::Server;
272 use crate::watch::Watch;
273
274 #[test]
275 fn commands_can_be_issued() {
276 tokio_test::block_on(async {
277 let platform = Platform::new();
278 CommandDictionary::install(&platform);
279 ping::install(&platform);
280 Server::fork_and_await(&Server::install(&platform)).await;
281
282 tokio::task::spawn_blocking(move || {
283 let client = redis::Client::open("redis://127.0.0.1:2410/").unwrap();
284 let mut con = client
285 .get_connection_with_timeout(Duration::from_secs(10))
286 .unwrap();
287 let watch = Watch::start();
288 let result: String = redis::cmd("PING").query(&mut con).unwrap();
289 assert_eq!(result, "PONG");
290 println!("Request took: {}", watch);
291
292 platform.is_running.change(false);
293 })
294 .await
295 .unwrap();
296 });
297 }
298}