coa_website/
server.rs

1use std::sync::Arc;
2use tokio::{io::{AsyncWriteExt, BufReader}, net::{TcpListener, TcpStream}, sync::Mutex};
3
4use crate::{config::Config, http::{HttpVersion, StatusCode}, logger::Logger, request::{HttpRequest, RequestError}, response::HttpResponse, router::Router};
5
6/// Asynchronous HTTP server that listens on a TCP socket and handles requests.
7/// It will serves static and dynamic routes via a `Router` and logs each request with a `Logger`
8pub struct Server<'a> {
9    addr: &'a str,
10    router: Arc<Mutex<Router>>,
11    logger: Arc<Mutex<Logger>>,
12}
13
14impl Default for Server<'_> {
15    fn default() -> Self {
16        Self { addr: "127.0.0.1:8080", router: Default::default(), logger: Default::default() }
17    }
18}
19
20impl<'a> Server<'a> {
21
22    /// Creates a new server bound to `addr`, reading the configuration from a file named **config.toml** or used the default configuration
23    /// 
24    /// # Errors
25    /// Returns an error if the configuration or logger initialization fails
26    pub fn with_address(addr: &'a str) -> Result<Self, String> {
27        let config = Config::from_file("config.toml").unwrap_or_default();
28        let logger = Logger::new(&config.log_file).map_err(|e| format!("Error with the log file : {:?}", e.kind()))?;
29        let router = Router::new(config, Default::default());
30        Ok(Self { addr, router: Arc::new(Mutex::new(router)), logger: Arc::new(Mutex::new(logger))})
31    }
32
33    /// Creates a new server bound to `addr`, reading the configuration `config_file`
34    /// 
35    /// # Errors
36    /// Returns an error if the configuration or logger initialization fails
37    pub fn with_address_and_config(addr: &'a str, config_file: &str) -> Result<Self, String> {
38        let config = Config::from_file(config_file)?;
39        let logger = Logger::new(&config.log_file).map_err(|e| format!("Error with the log file : {:?}", e.kind()))?;
40        let router = Router::new(config, Default::default());
41        Ok(Self { addr, router: Arc::new(Mutex::new(router)), logger: Arc::new(Mutex::new(logger))})
42    }
43
44    /// Starts the async server loop
45    pub async fn start(&mut self) -> std::io::Result<()> {
46        let listener = TcpListener::bind(self.addr).await?;
47        println!("Server listening on {}", self.addr);
48
49        loop {
50            let (stream, addr) = listener.accept().await?;
51            let ip = addr.ip().to_string();
52            let logger = Arc::clone(&self.logger);
53            let router = Arc::clone(&self.router);
54            
55
56            tokio::spawn(async move {
57                if let Err(e) = handle_client(stream, router, logger, ip).await {
58                    eprintln!("Error while handling client: {e}")
59                };
60            });
61        }
62    }
63}
64
65/// Handles a single TCP connection, it can process multiple HTTP requests if keep-alive is used
66pub async fn handle_client(stream: TcpStream, router: Arc<Mutex<Router>>, logger: Arc<Mutex<Logger>>, ip: String) -> std::io::Result<()> {
67    let mut req_count: u8 = 0;
68    let (reader, mut writer) = stream.into_split();
69    let mut reader = BufReader::new(reader);
70
71    loop {
72        let mut close = true;
73        let response = match HttpRequest::from_async_reader(&mut reader).await {
74            Ok(req) => {
75                let mut res = router.lock().await.handle_request(&req);
76                logger.lock().await.log(&ip, &format!("{:?}", req.method), &req.path, res.status.code())?;
77                let connection = req.header("connection").unwrap_or_default().to_lowercase();
78                if req.version != HttpVersion::Http1_0 && connection == "keep-alive" && req_count < 200 {
79                    req_count += 1;
80                    res.headers.insert("Connection".into(), "keep-alive".into());
81                    res.headers.insert("Keep-Alive".into(), "timeout=5, max=200".into());
82                    close = false;
83                } else {
84                    res.headers.insert("Connection".into(), "close".into());
85                }
86                res
87            },
88            Err(RequestError::MalformedRequest) => {
89                logger.lock().await.log(&ip, "", "", StatusCode::BadRequest.code())?;
90                HttpResponse::builder().with_status(StatusCode::BadRequest).build().unwrap()
91            },
92            Err(RequestError::UnsupportedVersion) => {
93                logger.lock().await.log(&ip, "", "", StatusCode::HttpVersionNotSupported.code())?;
94                HttpResponse::builder().with_status(StatusCode::HttpVersionNotSupported).build().unwrap()
95            },
96            Err(RequestError::UnsupportedMethod) => {
97                logger.lock().await.log(&ip, "", "", StatusCode::MethodNotAllowed.code())?;
98                HttpResponse::method_not_allowed().build().unwrap()
99            },
100            Err(_) => {
101                logger.lock().await.log(&ip, "", "", StatusCode::InternalServerError.code())?;
102                HttpResponse::internal_server_error().build().unwrap()
103            }
104        };
105    
106        writer.write_all(&response.to_bytes()).await?;
107        writer.flush().await?;
108
109        
110        if close {
111            break;
112        }
113    }
114
115    Ok(())
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121    use std::fs;
122
123    #[tokio::test]
124    async fn test_server_with_invalid_config() {
125        let res = Server::with_address_and_config("127.0.0.1:0", "no_config_file.toml");
126        assert!(res.is_err());
127    }
128
129    #[test]
130    fn test_default_init() {
131        // Default config uses default values
132        let config_file = "test_default_config.toml";
133        fs::write(config_file, "root_folder = './public'\nindex_file='index.html'\nnot_found_file='404.html'\nlog_file='test_log.log'\n[mime_types]\nhtml='text/html'").unwrap();
134        let srv = Server::with_address_and_config("127.0.0.1:0", config_file);
135        assert!(srv.is_ok());
136        let _ = fs::remove_file(config_file);
137        let _ = fs::remove_file("test_log.log");
138    }
139}