ironcladserver/
lib.rs

1use async_std::{prelude::*, task};
2use futures::stream::StreamExt;
3use std::{
4    error::Error,
5    fmt, fs,
6    io::prelude::*,
7    net::{TcpListener, TcpStream},
8    process,
9    sync::{mpsc, Arc, Mutex},
10    thread,
11    time::Duration,
12};
13pub mod cli;
14pub mod error;
15use crate::cli::{Config, ServerConfigArguments};
16
17pub enum ServerConcurrency {
18    RunningAsync,
19    RunningThreadPool,
20}
21pub struct Server {
22    ip_port: String,
23    pub concurrency: ServerConcurrency,
24    workers_pool: Option<ThreadPool>,
25}
26
27impl Server {
28    /// Reads a ip address, port and concurrency settings from Config (i.e. user cli input)
29    /// and returns the Server object
30    ///
31    pub fn init(config: Config) -> Result<Server, Box<dyn Error>> {
32        let ip_addr = config
33            .args_opts_map
34            .get(&ServerConfigArguments::IpAddress)
35            .unwrap();
36        let port = config
37            .args_opts_map
38            .get(&ServerConfigArguments::Port)
39            .unwrap();
40        let ip_port = format!("{}:{}", ip_addr, port);
41
42        let (concurrency, workers_pool) =
43            match config.args_opts_map.get(&ServerConfigArguments::ThreadPool) {
44                Some(value) => {
45                    let pool_size: usize = match value.parse() {
46                        Ok(size) => size,
47                        Err(_) => process::exit(0), // TODO: change this to an Error in error.rs
48                    };
49                    (
50                        ServerConcurrency::RunningThreadPool,
51                        Some(ThreadPool::new(pool_size)?),
52                    )
53                }
54                None => (ServerConcurrency::RunningAsync, None),
55            };
56        Ok(Server {
57            ip_port,
58            concurrency,
59            workers_pool,
60        })
61    }
62
63    /// Starts the server with a thread pool
64    pub fn start_tp(&self) -> Result<(), Box<dyn Error>> {
65        let listener = TcpListener::bind(&self.ip_port)?;
66        println!("Started the server with a thread pool.");
67
68        for stream in listener.incoming() {
69            let stream = stream?;
70            match &self.workers_pool {
71                Some(pool) => {
72                    pool.execute(|| {
73                        handle_connection_tp(stream);
74                    });
75                }
76                None => {
77                    process::exit(0); // TODO: change this to an Error in error.rs
78                }
79            }
80        }
81        Ok(())
82    }
83    /// Starts the server using async
84    ///
85    pub async fn start_async(&self) -> Result<(), Box<dyn Error>> {
86        let listener = async_std::net::TcpListener::bind(&self.ip_port).await?;
87        println!("Started the server and serving requests using async.");
88
89        listener
90            .incoming()
91            .for_each_concurrent(/* limit */ None, |tcpstream| async move {
92                let tcpstream = tcpstream.unwrap();
93                handle_connection_async(tcpstream).await;
94            })
95            .await;
96
97        Ok(())
98    }
99
100    /// Chaining POC
101    pub fn start1(self) -> Self {
102        println!("ttttest");
103        self
104    }
105}
106
107async fn handle_connection_async(mut stream: async_std::net::TcpStream) {
108    let mut buffer = [0; 1024];
109    match stream.read(&mut buffer).await {
110        Ok(_bytes_read) => {
111            //println!("Read {} bytes", bytes_read);
112
113            let get = b"GET / HTTP/1.1\r\n";
114            let sleep = b"GET /sleep HTTP/1.1\r\n";
115
116            let (status_line, filename) = if buffer.starts_with(get) {
117                ("HTTP/1.1 200 OK\r\n\r\n", r"resources\html\home.html")
118            } else if buffer.starts_with(sleep) {
119                task::sleep(Duration::from_secs(5)).await;
120                ("HTTP/1.1 200 OK\r\n\r\n", r"resources\html\home.html")
121            } else {
122                ("HTTP/1.1 404 NOT FOUND\r\n\r\n", r"resources\html\404.html")
123            };
124            let contents = fs::read_to_string(filename).unwrap();
125
126            let response = format!("{status_line}{contents}");
127            stream.write(response.as_bytes()).await.unwrap();
128            stream.flush().await.unwrap();
129        }
130        Err(e) => {
131            // Handle the error in some way.
132            eprintln!("Error reading from stream: {}", e);
133        }
134    }
135}
136
137fn handle_connection_tp(mut stream: TcpStream) {
138    let mut buffer = [0; 1024];
139    match stream.read(&mut buffer) {
140        Ok(_bytes_read) => {
141            //println!("Read {} bytes", bytes_read);
142
143            let get = b"GET / HTTP/1.1\r\n";
144            let sleep = b"GET /sleep HTTP/1.1\r\n";
145
146            let (status_line, filename) = if buffer.starts_with(get) {
147                ("HTTP/1.1 200 OK", r"resources\html\home.html")
148            } else if buffer.starts_with(sleep) {
149                thread::sleep(Duration::from_secs(5));
150                ("HTTP/1.1 200 OK", r"resources\html\home.html")
151            } else {
152                ("HTTP/1.1 404 NOT FOUND", r"resources\html\404.html")
153            };
154
155            let contents = fs::read_to_string(filename).unwrap();
156
157            let response = format!(
158                "{}\r\nContent-Length: {}\r\n\r\n{}",
159                status_line,
160                contents.len(),
161                contents
162            );
163
164            stream.write_all(response.as_bytes()).unwrap();
165            stream.flush().unwrap();
166        }
167        Err(e) => {
168            // Handle the error in some way.
169            eprintln!("Error reading from stream: {}", e);
170        }
171    }
172}
173
174pub struct ThreadPool {
175    workers: Vec<Worker>,
176    sender: Option<mpsc::Sender<Job>>,
177}
178
179type Job = Box<dyn FnOnce() + Send + 'static>;
180
181#[derive(Debug)]
182struct PoolCreationError {
183    msg: String,
184}
185
186impl PoolCreationError {
187    fn new(msg: &str) -> Self {
188        PoolCreationError {
189            msg: msg.to_string(),
190        }
191    }
192}
193
194impl fmt::Display for PoolCreationError {
195    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
196        write!(f, "{}", self.msg)
197    }
198}
199
200impl Error for PoolCreationError {} // PoolCreationError is of type Error. No need to override existing Error methods
201
202impl ThreadPool {
203    /// Create a new ThreadPool.
204    ///
205    /// The size is the number of threads in the pool.
206    ///
207    /// # Panics
208    ///
209    /// The `new` function will panic if the size is zero.
210    fn new(size: usize) -> Result<ThreadPool, PoolCreationError> {
211        if size < 1 {
212            return Err(PoolCreationError::new(
213                "Number of workers in pool must be greater than 0.",
214            ));
215        }
216
217        let (sender, receiver) = mpsc::channel();
218
219        let receiver = Arc::new(Mutex::new(receiver));
220
221        let mut workers = Vec::with_capacity(size);
222
223        for id in 0..size {
224            workers.push(Worker::new(id, Arc::clone(&receiver)));
225        }
226
227        Ok(ThreadPool {
228            workers,
229            sender: Some(sender),
230        })
231    }
232
233    fn execute<F>(&self, f: F)
234    where
235        F: FnOnce() + Send + 'static,
236    {
237        let job = Box::new(f);
238
239        self.sender.as_ref().unwrap().send(job).unwrap();
240    }
241}
242
243impl Drop for ThreadPool {
244    fn drop(&mut self) {
245        drop(self.sender.take());
246
247        for worker in &mut self.workers {
248            println!("Shutting down worker {}", worker.id);
249
250            if let Some(thread) = worker.thread.take() {
251                thread.join().unwrap();
252            }
253        }
254    }
255}
256
257struct Worker {
258    id: usize,
259    thread: Option<thread::JoinHandle<()>>,
260}
261
262impl Worker {
263    fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Job>>>) -> Worker {
264        let thread = thread::spawn(move || loop {
265            let message = receiver.lock().unwrap().recv();
266
267            match message {
268                Ok(job) => {
269                    println!("Worker {id} got a job; executing.");
270
271                    job();
272                }
273                Err(_) => {
274                    println!("Worker {id} disconnected; shutting down.");
275                    break;
276                }
277            }
278        });
279
280        Worker {
281            id,
282            thread: Some(thread),
283        }
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290    use std::collections::HashMap;
291
292    #[test]
293    fn invalid_server_ip() {
294        let mut config_args_opts_map: HashMap<ServerConfigArguments, String> = HashMap::new();
295        config_args_opts_map.insert(ServerConfigArguments::IpAddress, String::from("127.0.1"));
296        config_args_opts_map.insert(ServerConfigArguments::Port, String::from("7878"));
297        config_args_opts_map.insert(ServerConfigArguments::ThreadPool, String::from("10"));
298        let test_config: Config = Config {
299            program: "ironcladserver",
300            command: cli::ServerCommand::Start,
301            args_opts_map: config_args_opts_map,
302        };
303
304        let server = Server::init(test_config).unwrap();
305        let result = server.start_tp();
306
307        if let Err(e) = result {
308            assert_eq!(e.to_string(), "No such host is known. (os error 11001)");
309        } else {
310            panic!("Expected Err, but got Ok");
311        }
312    }
313}