choki/
lib.rs

1use bumpalo::Bump;
2use src::utils::logger;
3
4use std::cmp::Ordering;
5use std::collections::HashMap;
6use std::fmt::write;
7use std::fs::File;
8use std::hash::Hash;
9use std::path::{ self, Path };
10use std::time::{ Duration, Instant };
11use std::{ fs, io, thread, vec };
12use std::{ io::Write, net::* };
13
14use std::io::{ BufRead, BufReader, Error, Read };
15use structs::*;
16use threadpool::ThreadPool;
17
18extern crate num_cpus;
19
20pub mod src;
21
22use src::request::Request;
23use src::response::Response;
24use src::*;
25
26pub struct Server<T: Clone + std::marker::Send + 'static> {
27    active: bool,
28    pub max_content_length: usize,
29    pub endpoints: Vec<EndPoint<T>>,
30    pub static_endpoints: HashMap<String, String>,
31
32    pub public_var: Option<T>,
33
34    middleware: Option<
35        fn(url: &Url, req: &Request, res: &mut Response, public_var: &Option<T>) -> bool
36    >,
37    logger: Option<fn(input: &HttpServerError)>,
38}
39
40impl<T: Clone + std::marker::Send + 'static> Server<T> {
41    ///max_content_length is the max length of the request in bytes.
42    ///
43    ///For example if the max is set to 1024 but the request is 1 000 000 it will close it straight away.
44    pub fn new(max_content_length: Option<usize>, public_var: Option<T>) -> Server<T> {
45        return Server {
46            active: false,
47            max_content_length: max_content_length.unwrap_or_default(),
48            endpoints: Vec::new(),
49            static_endpoints: HashMap::new(),
50            public_var: public_var,
51            middleware: None,
52            logger: None,
53        };
54    }
55    ///Add function as middleware (just before sending response). The response is a bool. If it's true, the request will continue; if it's false, it will stop.
56    pub fn use_middleware(
57        &mut self,
58        handle: fn(url: &Url, req: &Request, res: &mut Response, public_var: &Option<T>) -> bool
59    ) {
60        self.middleware = Some(handle);
61    }
62    ///Add your own custom error logger function
63    pub fn use_logger(&mut self, handle: fn(input: &HttpServerError)) {
64        self.logger = Some(handle);
65    }
66    ///Creates a new static url
67    /// For example a folder named "images" on path /images every image in that folder will be exposed like "/images/example.png"
68    pub fn new_static(&mut self, path: &str, folder: &str) -> Result<(), HttpServerError> {
69        if self.active == true {
70            return Err(HttpServerError::new("Server is already running!"));
71        }
72        let path_: &Path = Path::new(&folder);
73
74        if path_.is_dir() == false || path_.exists() == false {
75            return Err(
76                HttpServerError::new("Folder does not exist or the path provided is a file!")
77            );
78        }
79        let mut path = path.to_owned();
80        if
81            (self.endpoints.len() > 0 &&
82                self.endpoints.iter().any(|x| x.path == path && x.req_type == RequestType::Get)) ||
83            (self.static_endpoints.len() > 0 && self.static_endpoints.iter().any(|x| x.0 == &path))
84        {
85            return Err(HttpServerError::new("Endpoint already exists!"));
86        }
87        if path.len() > 1 && path.ends_with("/") {
88            path.remove(path.len() - 1);
89        }
90        self.static_endpoints.insert(path, folder.to_owned());
91        Ok(())
92    }
93    fn new_endpoint(
94        &mut self,
95        path: &str,
96        req_type: RequestType,
97        handle: fn(req: Request, res: Response, public_var: Option<T>)
98    ) -> Result<(), HttpServerError> {
99        if self.active == true {
100            return Err(HttpServerError::new("Server is already running!"));
101        }
102        let mut path = path.to_owned();
103        if
104            (self.endpoints.len() > 0 &&
105                self.endpoints.iter().any(|x| x.path == path && x.req_type == req_type)) ||
106            (self.static_endpoints.len() > 0 && self.static_endpoints.iter().any(|x| x.0 == &path))
107        {
108            return Err(HttpServerError::new("Endpoint already exists!"));
109        }
110        if path.len() > 1 && path.ends_with("/") {
111            path.remove(path.len() - 1);
112        }
113        self.endpoints.push(EndPoint::new(path, req_type, handle));
114        Ok(())
115    }
116
117    ///Creates a new GET endpoint
118    pub fn get(
119        &mut self,
120        path: &str,
121        handle: fn(req: Request, res: Response, public_var: Option<T>)
122    ) -> Result<(), HttpServerError> {
123        self.new_endpoint(path, RequestType::Get, handle)
124    }
125
126    ///Creates a new POST endpoint
127    pub fn post(
128        &mut self,
129        path: &str,
130        handle: fn(req: Request, res: Response, public_var: Option<T>)
131    ) -> Result<(), HttpServerError> {
132        self.new_endpoint(path, RequestType::Post, handle)
133    }
134    ///Creates a new PUT endpoint
135    pub fn put(
136        &mut self,
137        path: &str,
138        handle: fn(req: Request, res: Response, public_var: Option<T>)
139    ) -> Result<(), HttpServerError> {
140        self.new_endpoint(path, RequestType::Put, handle)
141    }
142    ///Creates a new DELETE endpoint
143    pub fn delete(
144        &mut self,
145        path: &str,
146        handle: fn(req: Request, res: Response, public_var: Option<T>)
147    ) -> Result<(), HttpServerError> {
148        self.new_endpoint(path, RequestType::Delete, handle)
149    }
150
151    ///Universal endpoint creator
152    pub fn on(
153        &mut self,
154        req_type: RequestType,
155        path: &str,
156        handle: fn(req: Request, res: Response, public_var: Option<T>)
157    ) -> Result<(), HttpServerError> {
158        self.new_endpoint(path, req_type, handle)
159    }
160    ///Starts listening on the given port.
161    /// If no provided threads will use cpu threads as value. The higher the value the higher the cpu usage.
162    /// The on_complete function is executed after the listener has started.
163    pub fn listen(
164        &mut self,
165        port: u32,
166        address: Option<&str>,
167        threads: Option<usize>,
168        on_complete: fn()
169    ) -> Result<(), HttpServerError> {
170        if port > 65_535 {
171            return Err(HttpServerError::new("Invalid port: port must be 0-65,535"));
172        }
173        if self.active == true {
174            return Err(HttpServerError::new("The server is already running!"));
175        }
176        self.active = true;
177
178        let pool: ThreadPool = ThreadPool::new(threads.unwrap_or(num_cpus::get()));
179        let mut routes = self.endpoints.clone();
180        let static_routes = self.static_endpoints.clone();
181
182        let middleware = self.middleware.clone();
183        let logger = self.logger.unwrap_or(logger::eprint);
184
185        let max_content_length = self.max_content_length.clone();
186        let public_var = self.public_var.clone();
187
188        let address = address.unwrap_or("0.0.0.0").to_owned();
189
190        order_routes(&mut routes); // Order them so the first one are without params
191
192        thread::spawn(move || {
193            let tcp: TcpListener = TcpListener::bind(format!("{}:{}", address, port)).unwrap();
194
195            for stream in tcp.incoming() {
196                let routes_clone = routes.clone();
197                let static_routes_clone = static_routes.clone();
198                let max_content_length_clone = max_content_length.clone();
199                let public_var_clone = public_var.clone();
200
201                pool.execute(move || {
202                    let stream = stream.unwrap();
203                    let res = Self::handle_request(
204                        stream,
205                        max_content_length_clone,
206                        routes_clone,
207                        static_routes_clone,
208                        middleware,
209
210                        public_var_clone
211                    );
212                    if res.is_err() {
213                        logger(&res.unwrap_err());
214                    }
215                });
216            }
217        });
218
219        on_complete();
220
221        Ok(())
222    }
223    fn handle_request(
224        stream: TcpStream,
225        max_content_length: usize,
226        routes: Vec<EndPoint<T>>,
227        static_routes: HashMap<String, String>,
228        middleware: Option<
229            fn(url: &Url, req: &Request, res: &mut Response, public_var: &Option<T>) -> bool
230        >,
231        public_var: Option<T>
232    ) -> Result<(), HttpServerError> {
233        let bump = Bump::new(); // Allocator
234
235        let mut bfreader: BufReader<TcpStream> = BufReader::new(
236            stream.try_clone().expect("Failed to create buffer reader")
237        );
238
239        let mut headers_string: String = "".to_string();
240
241        let mut line = "".to_owned();
242        loop {
243            match bfreader.read_line(&mut line) {
244                Ok(size) => {
245                    if size <= 2 {
246                        break;
247                    }
248                }
249                Err(e) => {
250                    return Err(HttpServerError::new("Error reading request headers!"));
251                }
252            }
253
254            headers_string.push_str(&line);
255
256            line = "".to_string();
257        }
258
259        let lines: Vec<&str> = headers_string.lines().collect();
260
261        if lines.len() == 0 {
262            return Err(HttpServerError::new("No headers!"));
263        }
264        let req_url = Url::parse(lines[0]).unwrap();
265
266        let mut req = Request::parse(&lines, Some(req_url.query), None)?;
267
268        if let Some(socket) = stream.peer_addr().ok() {
269            req.ip = Some(socket.ip().to_string());
270        }
271        let content_encoding = req.content_encoding.clone();
272        let mut res = Response::new(stream.try_clone().unwrap(), content_encoding.clone());
273        // Check if supported req type
274        let content_type = req.content_type.clone().unwrap_or(ContentType::None);
275
276        let has_body = content_type != ContentType::None && req.content_length > 0;
277
278        if req_url.req_type == RequestType::Unknown {
279            if has_body {
280                req.read_only_body(&mut bfreader);
281            }
282
283            res.send_code(ResponseCode::MethodNotAllowed);
284            return Err(HttpServerError::new("Method not allowed!"));
285        }
286        // Check if body in GET or HEAD
287        if
288            has_body &&
289            (req_url.req_type == RequestType::Get || req_url.req_type == RequestType::Head)
290        {
291            req.read_only_body(&mut bfreader);
292            res.send_code(ResponseCode::BadRequest);
293            return Err(HttpServerError::new("Bad request!"));
294        }
295        //Check if over content length
296        if max_content_length > 0 && req.content_length > max_content_length && has_body {
297            req.read_only_body(&mut bfreader);
298            res.send_code(ResponseCode::ContentTooLarge);
299            return Err(HttpServerError::new("Content too large!"));
300        }
301        // Middleware
302        if middleware.is_some() {
303            let result = middleware.unwrap()(
304                &(Url {
305                    path: req_url.path.clone(),
306                    req_type: req_url.req_type.clone(),
307                    query: HashMap::new(),
308                }),
309                &req,
310                &mut res,
311                &public_var
312            );
313            if result == false {
314                return Ok(());
315            }
316        }
317        //
318        let mut matching_routes: Vec<EndPoint<T>> = Vec::new();
319        let mut params: HashMap<String, String> = HashMap::new();
320        // Check for matching pattern
321        for route in routes {
322            let match_pattern = Url::match_patern(&req_url.path.clone(), &route.path.clone());
323            if match_pattern.0 == true {
324                matching_routes.push(route);
325                if params.is_empty() {
326                    params = match_pattern.1;
327                }
328            }
329        }
330
331        if matching_routes.len() > 0 {
332            let routes: Vec<EndPoint<T>> = matching_routes
333                .into_iter()
334                .filter(|route| route.req_type == req_url.req_type)
335                .collect();
336
337            if routes.len() == 0 {
338                if has_body {
339                    req.read_only_body(&mut bfreader);
340                }
341
342                res.send_code(ResponseCode::MethodNotAllowed);
343                return Err(HttpServerError::new("Method not allowed!"));
344            }
345            let route = &routes[0];
346
347            req.params = params;
348
349            if has_body {
350                req.extract_body(&mut bfreader, bump);
351            }
352
353            (route.handle)(req, res, public_var);
354
355            return Ok(());
356        }
357
358        let mut sent = false;
359        for route in static_routes {
360            if req_url.path.starts_with(&route.0) {
361                let parts: Vec<&str> = req_url.path.split(&route.0).collect();
362                if parts.len() == 0 {
363                    continue;
364                }
365                let path_str = route.1 + parts[1];
366                let path = Path::new(&path_str);
367
368                if path.exists() && path.is_file() {
369                    match File::open(path) {
370                        Ok(file) => {
371                            let metadata = file.metadata();
372
373                            let bfreader = BufReader::new(file);
374
375                            let mut size: Option<u64> = None;
376                            if metadata.is_ok() {
377                                size = Some(metadata.unwrap().len());
378                            }
379                            res.pipe_stream(bfreader, None, size.as_ref());
380                        }
381                        Err(_err) => {
382                            res.send_code(ResponseCode::NotFound);
383                        }
384                    }
385                } else {
386                    res.send_code(ResponseCode::NotFound);
387                }
388
389                sent = true;
390                break;
391            }
392        }
393        if sent == false {
394            res.send_code(ResponseCode::NotFound);
395            return Err(HttpServerError::new("Not found!"));
396        }
397        return Ok(());
398    }
399    ///Locks the thread from stoping (put it in the end of the main file to keep the server running);
400    pub fn lock() {
401        let dur = Duration::from_secs(5);
402        loop {
403            std::thread::sleep(dur);
404        }
405    }
406}
407
408fn order_routes<T: Clone + Send + 'static>(routes: &mut Vec<EndPoint<T>>) {
409    routes.sort_by(|a, b| {
410        let path_a = &a.path;
411        let path_b = &b.path;
412        if path_a.contains("[") && path_a.contains("]") && !path_b.contains("[") {
413            return Ordering::Greater;
414        }
415        if path_b.contains("[") && path_b.contains("]") && !path_a.contains("[") {
416            return Ordering::Less;
417        }
418        if path_a.len() > path_b.len() {
419            return Ordering::Greater;
420        }
421        if path_a.len() < path_b.len() {
422            return Ordering::Less;
423        }
424        return Ordering::Equal;
425    });
426}