tinyhttp_internal/
config.rs

1use std::{
2    collections::HashMap,
3    net::TcpStream,
4    ops::Deref,
5    sync::OnceLock,
6};
7
8use crate::{middleware::MiddlewareResponse, request::Request};
9pub use dyn_clone::DynClone;
10use std::fmt::Debug;
11
12use crate::response::Response;
13
14use rusty_pool::{Builder, ThreadPool};
15
16#[cfg(not(feature = "async"))]
17use std::net::{Incoming, TcpListener};
18
19#[cfg(not(feature = "async"))]
20use crate::http::start_http;
21
22
23#[cfg(test)]
24use std::any::Any;
25
26type RouteVec = Vec<Box<dyn Route>>;
27
28type MiddlewareFn = fn(&mut Request) -> MiddlewareResponse;
29
30pub static PRE_MIDDLEWARE_CONST: OnceLock<Box<dyn FnMut(&mut Request) + Send + Sync>> =
31    OnceLock::new();
32
33pub static POST_MIDDLEWARE_CONST: OnceLock<Box<dyn FnMut(&mut Request) + Send + Sync>> =
34    OnceLock::new();
35
36#[derive(Clone, Copy, Debug)]
37pub enum Method {
38    GET,
39    POST,
40}
41
42pub trait ToResponse: DynClone + Sync + Send {
43    fn to_res(&self, res: Request, sock: &mut TcpStream) -> Response;
44}
45
46pub trait Route: DynClone + Sync + Send + ToResponse {
47    fn get_path(&self) -> &str;
48    fn get_method(&self) -> Method;
49    fn wildcard(&self) -> Option<String>;
50    fn clone_dyn(&self) -> Box<dyn Route>;
51
52    #[cfg(test)]
53    fn any(&self) -> &dyn Any;
54}
55
56impl Clone for Box<dyn Route> {
57    fn clone(&self) -> Self {
58        self.clone_dyn()
59    }
60}
61
62pub struct HttpListener {
63    pub(crate) socket: TcpListener,
64    pub config: Config,
65    pub pool: ThreadPool,
66    pub use_pool: bool,
67}
68
69impl HttpListener {
70    pub fn new<P: Into<TcpListener>>(socket: P, config: Config) -> HttpListener {
71        #[cfg(feature = "log")]
72        log::debug!("Using {} threads", num_cpus::get());
73
74        HttpListener {
75            socket: socket.into(),
76            config,
77            pool: ThreadPool::default(),
78            use_pool: true,
79        }
80    }
81
82    pub fn threads(mut self, threads: usize) -> Self {
83        let pool = Builder::new().core_size(threads).build();
84
85        self.pool = pool;
86        self
87    }
88
89    pub fn use_tp(mut self, r: bool) -> Self {
90        self.use_pool = r;
91        self
92    }
93
94    #[cfg(not(feature = "async"))]
95    pub fn start(self) {
96        let conf_clone = self.config.clone();
97        start_http(self, conf_clone);
98    }
99
100    #[cfg(not(feature = "async"))]
101    pub fn get_stream(&self) -> Incoming<'_> {
102        self.socket.incoming()
103    }
104}
105
106#[derive(Clone)]
107pub struct Routes {
108    routes: RouteVec,
109}
110
111impl Routes {
112    pub fn new<R: Into<RouteVec>>(routes: R) -> Routes {
113        let routes = routes.into();
114        Routes { routes }
115    }
116
117    pub fn get_stream(self) -> RouteVec {
118        self.routes
119    }
120}
121
122#[derive(Clone)]
123pub struct Config {
124    mount_point: Option<String>,
125    get_routes: Option<HashMap<String, Box<dyn Route>>>,
126    post_routes: Option<HashMap<String, Box<dyn Route>>>,
127    debug: bool,
128    pub ssl: bool,
129    ssl_chain: Option<String>,
130    ssl_priv: Option<String>,
131    headers: Option<HashMap<String, String>>,
132    gzip: bool,
133    spa: bool,
134    http2: bool,
135    middleware: Option<Vec<MiddlewareFn>>,
136}
137
138impl Default for Config {
139    fn default() -> Self {
140        Config::new()
141    }
142}
143
144impl Config {
145    /// Generates default settings (which don't work by itself)
146    ///
147    /// Chain with mount_point or routes
148    ///
149    /// ### Example:
150    /// ```ignore
151    /// use tinyhttp::prelude::*;
152    ///
153    /// #[get("/test")]
154    /// fn get_test() -> String {
155    ///   String::from("Hello, there!\n")
156    /// }
157    ///
158    /// let routes = Routes::new(vec![get_test()]);
159    /// let routes_config = Config::new().routes(routes);
160    /// /// or
161    /// let mount_config = Config::new().mount_point(".");
162    /// ```
163
164    pub fn new() -> Config {
165        //assert!(routes.len() > 0);
166
167        #[cfg(feature = "log")]
168        log::info!("tinyhttp version: {}", env!("CARGO_PKG_VERSION"));
169
170        Config {
171            mount_point: None,
172            get_routes: None,
173            post_routes: None,
174            debug: false,
175            ssl: false,
176            ssl_chain: None,
177            ssl_priv: None,
178            headers: None,
179            gzip: false,
180            spa: false,
181            http2: false,
182            middleware: None,
183        }
184    }
185
186    /// A mount point that will be searched when a request isn't defined with a get or post route
187    ///
188    /// ### Example:
189    /// ```ignore
190    /// let config = Config::new().mount_point(".")
191    /// /// if index.html exists in current directory, it will be returned if "/" or "/index.html" is requested.
192    /// ```
193
194    pub fn mount_point<P: Into<String>>(mut self, path: P) -> Self {
195        self.mount_point = Some(path.into());
196        self
197    }
198
199    /// Add routes with a Route member
200    ///
201    /// ### Example:
202    /// ```ignore
203    /// use tinyhttp::prelude::*;
204    ///
205    ///
206    /// #[get("/test")]
207    /// fn get_test() -> &'static str {
208    ///   "Hello, World!"
209    /// }
210    ///
211    /// #[post("/test")]
212    /// fn post_test() -> Vec<u8> {
213    ///   b"Hello, Post!".to_vec()
214    /// }
215    ///
216    /// fn main() {
217    ///   let socket = TcpListener::new(":::80").unwrap();
218    ///   let routes = Routes::new(vec![get_test(), post_test()]);
219    ///   let config = Config::new().routes(routes);
220    ///   let http = HttpListener::new(socket, config);
221    ///
222    ///   http.start();
223    /// }
224    /// ```
225
226    pub fn routes(mut self, routes: Routes) -> Self {
227        let mut get_routes = HashMap::new();
228        let mut post_routes = HashMap::new();
229        let routes = routes.get_stream();
230
231        for route in routes {
232            match route.get_method() {
233                Method::GET => {
234                    #[cfg(feature = "log")]
235                    log::info!("GET Route init!: {}", &route.get_path());
236
237                    get_routes.insert(route.get_path().to_string(), route);
238                }
239                Method::POST => {
240                    #[cfg(feature = "log")]
241                    log::info!("POST Route init!: {}", &route.get_path());
242                    post_routes.insert(route.get_path().to_string(), route);
243                }
244            }
245        }
246        if !get_routes.is_empty() {
247            self.get_routes = Some(get_routes);
248        } else {
249            self.get_routes = None;
250        }
251
252        if !post_routes.is_empty() {
253            self.post_routes = Some(post_routes);
254        } else {
255            self.post_routes = None;
256        }
257
258        self
259    }
260
261    /// Enables SSL
262    ///
263    /// ### Example:
264    /// ```ignore
265    /// let config = Config::new().ssl("./fullchain.pem", "./privkey.pem");
266    /// ```
267    /// This will only accept HTTPS connections
268
269    pub fn ssl(mut self, ssl_chain: String, ssl_priv: String) -> Self {
270        self.ssl_chain = Some(ssl_chain);
271        self.ssl_priv = Some(ssl_priv);
272        self.ssl = true;
273        self
274    }
275    pub fn debug(mut self) -> Self {
276        self.debug = true;
277        self
278    }
279
280    /// Define custom headers
281    ///
282    /// ```ignore
283    /// let config = Config::new().headers(vec!["Access-Control-Allow-Origin: *".into()]);
284    /// ```
285    pub fn headers(mut self, headers: Vec<String>) -> Self {
286        let mut hash_map: HashMap<String, String> = HashMap::new();
287        for i in headers {
288            let mut split = i.split_inclusive(": ");
289            hash_map.insert(
290                split.next().unwrap().to_string(),
291                split.next().unwrap().to_string() + "\r\n",
292            );
293        }
294
295        self.headers = Some(hash_map);
296        self
297    }
298
299    pub fn spa(mut self, res: bool) -> Self {
300        self.spa = res;
301        self
302    }
303
304    /// Enables gzip compression
305    pub fn gzip(mut self, res: bool) -> Self {
306        self.gzip = res;
307        self
308    }
309
310    pub fn http2(mut self, res: bool) -> Self {
311        self.http2 = res;
312        self
313    }
314
315    pub fn middleware(mut self, middleware: Vec<MiddlewareFn>) -> Self {
316        self.middleware = Some(middleware);
317        self
318    }
319
320    pub fn get_middleware(&self) -> Option<&[MiddlewareFn]> {
321        self.middleware.as_deref()
322    }
323
324    pub fn get_headers(&self) -> Option<&HashMap<String, String>> {
325        self.headers.as_ref()
326    }
327    pub fn get_gzip(&self) -> bool {
328        self.gzip
329    }
330    pub fn get_debug(&self) -> bool {
331        self.debug
332    }
333    pub fn get_mount(&self) -> Option<&String> {
334        self.mount_point.as_ref()
335    }
336    pub fn get_routes(&self, req_path: &str) -> Option<&dyn Route> {
337        let req_path = if req_path.ends_with('/') && req_path.matches('/').count() > 1 {
338            let mut chars = req_path.chars();
339            chars.next_back();
340            chars.as_str()
341        } else {
342            req_path
343        };
344
345        #[cfg(feature = "log")]
346        log::trace!("get_routes -> new_path: {}", &req_path);
347
348        let routes = self.get_routes.as_ref()?;
349
350        if let Some(route) = routes.get(req_path) {
351            return Some(route.deref());
352        }
353
354        if let Some((_, wildcard_route)) = routes
355            .iter()
356            .find(|(path, route)| req_path.starts_with(*path) && route.wildcard().is_some())
357        {
358            return Some(wildcard_route.deref());
359        }
360
361        None
362    }
363
364    pub fn post_routes(&self, req_path: &str) -> Option<&dyn Route> {
365        #[cfg(feature = "log")]
366        log::trace!("post_routes -> path: {}", req_path);
367
368        let req_path = if req_path.ends_with('/') && req_path.matches('/').count() > 1 {
369            let mut chars = req_path.chars();
370            chars.next_back();
371            chars.as_str()
372        } else {
373            req_path
374        };
375
376        #[cfg(feature = "log")]
377        log::trace!("get_routes -> new_path: {}", &req_path);
378
379        let routes = self.post_routes.as_ref()?;
380
381        if let Some(route) = routes.get(req_path) {
382            return Some(route.deref());
383        }
384
385        if let Some((_, wildcard_route)) = routes
386            .iter()
387            .find(|(path, route)| req_path.starts_with(*path) && route.wildcard().is_some())
388        {
389            return Some(wildcard_route.deref());
390        }
391
392        None
393    }
394
395    pub fn get_spa(&self) -> bool {
396        self.spa
397    }
398}