1use std::sync::Arc;
2use tokio::{io::{AsyncWriteExt, BufReader}, net::{TcpListener, TcpStream}, sync::Mutex};
3
4use crate::{config::Config, http::{HttpVersion, StatusCode}, logger::Logger, request::{HttpRequest, RequestError}, response::HttpResponse, router::Router};
5
6pub struct Server<'a> {
9 addr: &'a str,
10 router: Arc<Mutex<Router>>,
11 logger: Arc<Mutex<Logger>>,
12}
13
14impl Default for Server<'_> {
15 fn default() -> Self {
16 Self { addr: "127.0.0.1:8080", router: Default::default(), logger: Default::default() }
17 }
18}
19
20impl<'a> Server<'a> {
21
22 pub fn with_address(addr: &'a str) -> Result<Self, String> {
27 let config = Config::from_file("config.toml").unwrap_or_default();
28 let logger = Logger::new(&config.log_file).map_err(|e| format!("Error with the log file : {:?}", e.kind()))?;
29 let router = Router::new(config, Default::default());
30 Ok(Self { addr, router: Arc::new(Mutex::new(router)), logger: Arc::new(Mutex::new(logger))})
31 }
32
33 pub fn with_address_and_config(addr: &'a str, config_file: &str) -> Result<Self, String> {
38 let config = Config::from_file(config_file)?;
39 let logger = Logger::new(&config.log_file).map_err(|e| format!("Error with the log file : {:?}", e.kind()))?;
40 let router = Router::new(config, Default::default());
41 Ok(Self { addr, router: Arc::new(Mutex::new(router)), logger: Arc::new(Mutex::new(logger))})
42 }
43
44 pub async fn start(&mut self) -> std::io::Result<()> {
46 let listener = TcpListener::bind(self.addr).await?;
47 println!("Server listening on {}", self.addr);
48
49 loop {
50 let (stream, addr) = listener.accept().await?;
51 let ip = addr.ip().to_string();
52 let logger = Arc::clone(&self.logger);
53 let router = Arc::clone(&self.router);
54
55
56 tokio::spawn(async move {
57 if let Err(e) = handle_client(stream, router, logger, ip).await {
58 eprintln!("Error while handling client: {e}")
59 };
60 });
61 }
62 }
63}
64
65pub async fn handle_client(stream: TcpStream, router: Arc<Mutex<Router>>, logger: Arc<Mutex<Logger>>, ip: String) -> std::io::Result<()> {
67 let mut req_count: u8 = 0;
68 let (reader, mut writer) = stream.into_split();
69 let mut reader = BufReader::new(reader);
70
71 loop {
72 let mut close = true;
73 let response = match HttpRequest::from_async_reader(&mut reader).await {
74 Ok(req) => {
75 let mut res = router.lock().await.handle_request(&req);
76 logger.lock().await.log(&ip, &format!("{:?}", req.method), &req.path, res.status.code())?;
77 let connection = req.header("connection").unwrap_or_default().to_lowercase();
78 if req.version != HttpVersion::Http1_0 && connection == "keep-alive" && req_count < 200 {
79 req_count += 1;
80 res.headers.insert("Connection".into(), "keep-alive".into());
81 res.headers.insert("Keep-Alive".into(), "timeout=5, max=200".into());
82 close = false;
83 } else {
84 res.headers.insert("Connection".into(), "close".into());
85 }
86 res
87 },
88 Err(RequestError::MalformedRequest) => {
89 logger.lock().await.log(&ip, "", "", StatusCode::BadRequest.code())?;
90 HttpResponse::builder().with_status(StatusCode::BadRequest).build().unwrap()
91 },
92 Err(RequestError::UnsupportedVersion) => {
93 logger.lock().await.log(&ip, "", "", StatusCode::HttpVersionNotSupported.code())?;
94 HttpResponse::builder().with_status(StatusCode::HttpVersionNotSupported).build().unwrap()
95 },
96 Err(RequestError::UnsupportedMethod) => {
97 logger.lock().await.log(&ip, "", "", StatusCode::MethodNotAllowed.code())?;
98 HttpResponse::method_not_allowed().build().unwrap()
99 },
100 Err(_) => {
101 logger.lock().await.log(&ip, "", "", StatusCode::InternalServerError.code())?;
102 HttpResponse::internal_server_error().build().unwrap()
103 }
104 };
105
106 writer.write_all(&response.to_bytes()).await?;
107 writer.flush().await?;
108
109
110 if close {
111 break;
112 }
113 }
114
115 Ok(())
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121 use std::fs;
122
123 #[tokio::test]
124 async fn test_server_with_invalid_config() {
125 let res = Server::with_address_and_config("127.0.0.1:0", "no_config_file.toml");
126 assert!(res.is_err());
127 }
128
129 #[test]
130 fn test_default_init() {
131 let config_file = "test_default_config.toml";
133 fs::write(config_file, "root_folder = './public'\nindex_file='index.html'\nnot_found_file='404.html'\nlog_file='test_log.log'\n[mime_types]\nhtml='text/html'").unwrap();
134 let srv = Server::with_address_and_config("127.0.0.1:0", config_file);
135 assert!(srv.is_ok());
136 let _ = fs::remove_file(config_file);
137 let _ = fs::remove_file("test_log.log");
138 }
139}