jupiter_rs/
server.rs

1use std::sync::{Arc, Mutex};
2use std::time::Duration;
3
4use bytes::BytesMut;
5use tokio::io::{AsyncReadExt, AsyncWriteExt};
6use tokio::net::{TcpListener, TcpStream};
7use tokio::stream::StreamExt;
8
9use crate::commands::CommandDictionary;
10use crate::flag::Flag;
11use crate::platform::Platform;
12use crate::request::Request;
13use crate::watch::{Average, Watch};
14
15const READ_WAIT_TIMEOUT: Duration = Duration::from_millis(500);
16const DEFAULT_BUFFER_SIZE: usize = 8192;
17
18struct Connection {
19    peer_address: String,
20    commands: Average,
21    active: Flag,
22}
23
24impl PartialEq for Connection {
25    fn eq(&self, other: &Self) -> bool {
26        self.peer_address == other.peer_address
27    }
28}
29
30impl Connection {
31    fn is_active(&self) -> bool {
32        self.active.read()
33    }
34}
35
36pub struct ConnectionInfo {
37    pub peer_address: String,
38    pub commands: Average,
39}
40
41pub struct Server {
42    running: Flag,
43    current_address: Mutex<Option<String>>,
44    platform: Arc<Platform>,
45    connections: Mutex<Vec<Arc<Connection>>>,
46}
47
48impl Server {
49    pub fn install(platform: &Arc<Platform>) -> Arc<Self> {
50        let server = Arc::new(Server {
51            running: Flag::new(false),
52            current_address: Mutex::new(None),
53            platform: platform.clone(),
54            connections: Mutex::new(Vec::new()),
55        });
56
57        platform.register::<Server>(server.clone());
58
59        server
60    }
61
62    pub fn connections(&self) -> Vec<ConnectionInfo> {
63        let mut result = Vec::new();
64        for connection in self.connections.lock().unwrap().iter() {
65            result.push(ConnectionInfo {
66                peer_address: connection.peer_address.clone(),
67                commands: connection.commands.clone(),
68            });
69        }
70
71        result
72    }
73
74    pub fn kill(&self, peer_address: &str) -> bool {
75        self.connections
76            .lock()
77            .unwrap()
78            .iter()
79            .find(|c| &c.peer_address == peer_address)
80            .map(|c| c.active.change(false))
81            .is_some()
82    }
83
84    fn add_connection(&self, connection: Arc<Connection>) {
85        self.connections.lock().unwrap().push(connection);
86    }
87
88    fn remove_connection(&self, connection: Arc<Connection>) {
89        let mut mut_connections = self.connections.lock().unwrap();
90        if let Some(index) = mut_connections
91            .iter()
92            .position(|other| *other == connection)
93        {
94            mut_connections.remove(index);
95        }
96    }
97
98    fn is_running(&self) -> bool {
99        self.running.read()
100    }
101
102    fn address(&self) -> String {
103        "0.0.0.0:2410".to_owned()
104        // format!("{}:{}",
105        //         self.config.string("server.host", "0.0.0.0"),
106        //         self.config.checked_number("server.port",
107        //                                    |port| port >= 0 && port <= std::u16::MAX as i64,
108        //                                    2410))
109    }
110
111    pub fn fork(server: &Arc<Server>) {
112        let cloned_server = server.clone();
113        tokio::spawn(async move {
114            cloned_server.event_loop().await;
115        });
116    }
117
118    pub async fn fork_and_await(server: &Arc<Server>) {
119        Server::fork(server);
120
121        while !server.running.read() {
122            tokio::time::delay_for(Duration::from_secs(1)).await;
123        }
124    }
125
126    pub async fn event_loop(&self) {
127        let mut address = String::new();
128
129        while self.platform.is_running.read() {
130            if !self.is_running() {
131                address = self.address();
132                self.running.change(true);
133            }
134
135            if let Ok(mut listener) = TcpListener::bind(&address).await {
136                log::info!("Opened server socket on {}...", &address);
137                *self.current_address.lock().unwrap() = Some(address.clone());
138                self.server_loop(&mut listener).await;
139                log::info!("Closing server socket on {}.", &address);
140            } else {
141                log::error!(
142                    "Cannot open server address: {}... Retrying in 5s.",
143                    &address
144                );
145                tokio::time::delay_for(Duration::from_secs(5)).await;
146            }
147        }
148    }
149
150    async fn server_loop(&self, listener: &mut TcpListener) {
151        let mut incoming = listener.incoming();
152        let mut kill_flag = self.platform.is_running.listener();
153        // let mut config_changed_flag = platform.config().on_change();
154
155        while self.platform.is_running.read() && self.is_running() {
156            tokio::select! {
157                stream = incoming.next() => {
158                    if let Some(Ok(stream)) = stream {
159                        self.handle_new_connection(stream);
160                    } else {
161                        return;
162                    }
163                }
164                // _ = config_changed_flag.recv() => {
165                //      let new_address = server.address();
166                //      if let Some(current_address) = &*server.current_address.lock().unwrap() {
167                //         if current_address != &new_address {
168                //             info!("Server address has changed. Restarting server socket...");
169                //             server.running.store(false, Ordering::SeqCst);
170                //             return;
171                //         }
172                //      }
173                // }
174                _ = kill_flag.expect() => { return; }
175            }
176        }
177    }
178
179    fn handle_new_connection(&self, stream: TcpStream) {
180        let connection = Arc::new(Connection {
181            peer_address: stream
182                .peer_addr()
183                .map(|addr| addr.to_string())
184                .unwrap_or("<unknown>".to_string()),
185            commands: Average::new(),
186            active: Flag::new(true),
187        });
188
189        let platform = self.platform.clone();
190        tokio::spawn(async move {
191            let server = platform.require::<Server>();
192            log::info!("Opened connection from {}...", connection.peer_address);
193            server.add_connection(connection.clone());
194            if let Err(error) = Server::client_loop(platform, connection.clone(), stream).await {
195                log::warn!(
196                    "An IO error occurred in connection {}: {}",
197                    connection.peer_address,
198                    error
199                );
200            }
201            log::info!("Closing connection to {}...", connection.peer_address);
202            server.remove_connection(connection);
203        });
204    }
205
206    async fn client_loop(
207        platform: Arc<Platform>,
208        connection: Arc<Connection>,
209        mut stream: TcpStream,
210    ) -> anyhow::Result<()> {
211        let mut dispatcher = platform.require::<CommandDictionary>().dispatcher();
212        let mut input_buffer = BytesMut::with_capacity(DEFAULT_BUFFER_SIZE);
213        let (mut reader, mut writer) = stream.split();
214
215        while platform.is_running.read() && connection.is_active() {
216            match tokio::time::timeout(READ_WAIT_TIMEOUT, reader.read_buf(&mut input_buffer)).await
217            {
218                Ok(Ok(bytes_read)) if bytes_read > 0 => match Request::parse(&mut input_buffer) {
219                    Ok(Some(request)) => {
220                        let watch = Watch::start();
221                        match dispatcher.invoke(request).await {
222                            Ok(response_data) => {
223                                connection.commands.add(watch.micros());
224                                writer.write_all(response_data.as_ref()).await?;
225                                writer.flush().await?;
226                            }
227                            Err(error) => {
228                                let error_message = error
229                                    .to_string()
230                                    .replace("\r", " ")
231                                    .to_string()
232                                    .replace("\n", " ");
233                                writer
234                                    .write_all(format!("-SERVER: {}\r\n", error_message).as_bytes())
235                                    .await?;
236                                writer.flush().await?;
237                                return Ok(());
238                            }
239                        }
240
241                        if input_buffer.len() == 0 && input_buffer.capacity() > DEFAULT_BUFFER_SIZE
242                        {
243                            input_buffer = BytesMut::with_capacity(DEFAULT_BUFFER_SIZE);
244                        }
245                    }
246                    Err(error) => return Err(error),
247                    _ => (),
248                },
249                Ok(Ok(0)) => return Ok(()),
250                Ok(Err(error)) => {
251                    return Err(anyhow::anyhow!(
252                        "An error occurred while reading from the client: {}",
253                        error
254                    ));
255                }
256                _ => (),
257            }
258        }
259
260        Ok(())
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use tokio::time::Duration;
267
268    use crate::commands::CommandDictionary;
269    use crate::ping;
270    use crate::platform::Platform;
271    use crate::server::Server;
272    use crate::watch::Watch;
273
274    #[test]
275    fn commands_can_be_issued() {
276        tokio_test::block_on(async {
277            let platform = Platform::new();
278            CommandDictionary::install(&platform);
279            ping::install(&platform);
280            Server::fork_and_await(&Server::install(&platform)).await;
281
282            tokio::task::spawn_blocking(move || {
283                let client = redis::Client::open("redis://127.0.0.1:2410/").unwrap();
284                let mut con = client
285                    .get_connection_with_timeout(Duration::from_secs(10))
286                    .unwrap();
287                let watch = Watch::start();
288                let result: String = redis::cmd("PING").query(&mut con).unwrap();
289                assert_eq!(result, "PONG");
290                println!("Request took: {}", watch);
291
292                platform.is_running.change(false);
293            })
294            .await
295            .unwrap();
296        });
297    }
298}