nucleus_http/
lib.rs

1pub mod cookies;
2pub mod http;
3pub mod methods;
4pub mod request;
5pub mod response;
6pub mod routes;
7pub mod state;
8pub mod thread_pool;
9pub mod utils;
10pub mod virtual_host;
11
12use anyhow::Context;
13use bytes::{BufMut, BytesMut};
14use futures::StreamExt;
15use response::Response;
16use routes::Router;
17use rustls_acme::{caches::DirCache, AcmeConfig};
18use std::{
19    collections::HashMap,
20    fmt::Debug,
21    path::{Path, PathBuf},
22    sync::Arc,
23    time::Duration,
24    vec,
25};
26use tokio::{
27    self,
28    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
29    net::TcpListener,
30    select,
31    signal::unix::{signal, SignalKind},
32    sync::RwLock,
33    task::JoinHandle,
34    time::timeout,
35};
36use tokio_rustls::{
37    rustls::{self, Certificate, PrivateKey},
38    TlsAcceptor,
39};
40use tokio_util::sync::CancellationToken;
41
42pub struct Server<S> {
43    listener: TcpListener,
44    acceptor: Option<TlsAcceptor>,
45    router: Arc<RwLock<Router<S>>>,
46    virtual_hosts: Arc<RwLock<HashMap<String, virtual_host::VirtualHost<S>>>>,
47    cancel: CancellationToken,
48    doc_root: PathBuf,
49    timeout: Duration,
50}
51
52trait ConnectionStream: AsyncWrite + AsyncRead + Unpin + Send + Sync {}
53
54// Auto Implement Stream for all types that implent asyncRead + asyncWrite
55impl<T> ConnectionStream for T where T: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync {}
56
57pub struct Connection {
58    stream: Box<dyn ConnectionStream>,
59    client_ip: std::net::SocketAddr,
60}
61
62impl Connection {
63    #[tracing::instrument(level = "debug", skip(self))]
64    pub async fn write_all(&mut self, src: &[u8]) -> tokio::io::Result<()> {
65        self.stream.write_all(src).await?;
66        Ok(())
67    }
68
69    #[tracing::instrument(level = "debug", skip(self, response))]
70    pub async fn write_response(&mut self, response: Response) -> tokio::io::Result<()> {
71        let response_buffer = response.to_send_buffer();
72        log::trace!("Writing: {}Bytes", response_buffer.len());
73        self.write_all(&response_buffer).await?;
74        Ok(())
75    }
76}
77
78impl<S> Server<S>
79where
80    S: Clone + Send + Sync + 'static,
81{
82    #[tracing::instrument(level = "debug", skip(router))]
83    pub async fn bind(
84        ip: &str,
85        router: Router<S>,
86        doc_root: impl AsRef<Path> + Debug,
87    ) -> Result<Self, tokio::io::Error> {
88        let listener = tokio::net::TcpListener::bind(ip).await?;
89        Ok(Server {
90            listener,
91            router: Arc::new(RwLock::new(router)),
92            virtual_hosts: Arc::new(RwLock::new(HashMap::new())),
93            acceptor: None,
94            cancel: CancellationToken::new(),
95            doc_root: PathBuf::from(doc_root.as_ref()),
96            timeout: Duration::from_secs(30),
97        })
98    }
99
100    #[tracing::instrument(level = "debug", skip(router))]
101    pub async fn bind_tls(
102        ip: &str,
103        cert: &Path,
104        key: &Path,
105        router: Router<S>,
106        doc_root: impl AsRef<Path> + Debug,
107    ) -> Result<Self, anyhow::Error> {
108        let files = vec![cert, key];
109        let context = format!("Opening: {:#?}, {:#?}", cert, key);
110        let (mut keys, certs) = load_keys_and_certs(&files).context(context)?;
111        let config = rustls::ServerConfig::builder()
112            .with_safe_defaults()
113            .with_no_client_auth()
114            .with_single_cert(certs, keys.remove(0))
115            .context("Loading Certs")?;
116        let acceptor = TlsAcceptor::from(Arc::new(config));
117        let listener = tokio::net::TcpListener::bind(ip)
118            .await
119            .context("binding tls")?;
120        Ok(Server {
121            listener,
122            router: Arc::new(RwLock::new(router)),
123            virtual_hosts: Arc::new(RwLock::new(HashMap::new())),
124            acceptor: Some(acceptor),
125            cancel: CancellationToken::new(),
126            doc_root: PathBuf::from(doc_root.as_ref()),
127            timeout: Duration::from_secs(60),
128        })
129    }
130
131    #[tracing::instrument(level = "debug", skip(router, domains))]
132    pub async fn bind_tls_alpn(
133        ip: &str,
134        router: Router<S>,
135        doc_root: impl AsRef<Path> + Debug,
136        domains: impl IntoIterator<Item = impl AsRef<str>>,
137        email: &str,
138    ) -> Result<Self, anyhow::Error> {
139        let contact = format!("mailto:{email}");
140        let acme = AcmeConfig::new(domains)
141            .contact_push(&contact)
142            .cache(DirCache::new("./rustls_acme_cache"));
143        let mut state = acme.state();
144        let resolver = state.resolver();
145        let config = rustls::ServerConfig::builder()
146            .with_safe_defaults()
147            .with_no_client_auth()
148            .with_cert_resolver(resolver);
149        tokio::spawn(async move {
150            loop {
151                match state.next().await.unwrap() {
152                    Ok(ok) => log::info!("event: {:?}", ok),
153                    Err(err) => log::error!("error: {:?}", err),
154                }
155            }
156        });
157        let acceptor = TlsAcceptor::from(Arc::new(config));
158        let listener = tokio::net::TcpListener::bind(ip)
159            .await
160            .context("binding tls")?;
161        Ok(Server {
162            listener,
163            router: Arc::new(RwLock::new(router)),
164            virtual_hosts: Arc::new(RwLock::new(HashMap::new())),
165            acceptor: Some(acceptor),
166            cancel: CancellationToken::new(),
167            doc_root: PathBuf::from(doc_root.as_ref()),
168            timeout: Duration::from_secs(60),
169        })
170    }
171
172    #[tracing::instrument(level = "debug", skip(self))]
173    pub fn virtual_hosts(&self) -> Arc<RwLock<HashMap<String, virtual_host::VirtualHost<S>>>> {
174        self.virtual_hosts.clone()
175    }
176
177    #[tracing::instrument(level = "debug", skip(self, virtual_host))]
178    pub async fn add_virtual_host(&mut self, virtual_host: virtual_host::VirtualHost<S>) {
179        let virtual_hosts = self.virtual_hosts();
180        let mut locked = virtual_hosts.write().await;
181        locked.insert(virtual_host.hostname().to_string(), virtual_host);
182    }
183
184    #[tracing::instrument(level = "debug", skip(self))]
185    pub async fn accept(&self) -> tokio::io::Result<Connection> {
186        let (stream, client_ip) = self.listener.accept().await?;
187        if let Some(acceptor) = &self.acceptor {
188            let acceptor = acceptor.clone();
189            match acceptor.accept(stream).await {
190                Ok(s) => Ok(Connection {
191                    client_ip,
192                    stream: Box::new(tokio_rustls::TlsStream::Server(s)),
193                }),
194                Err(_) => Err(tokio::io::Error::new(
195                    tokio::io::ErrorKind::Other,
196                    "Error Accepting TLS Stream",
197                )),
198            }
199        } else {
200            Ok(Connection {
201                client_ip,
202                stream: Box::new(stream),
203            })
204        }
205    }
206
207    #[tracing::instrument(level = "debug", skip(self, connection))]
208    fn serve_connection(&self, mut connection: Connection) -> JoinHandle<()> {
209        let router = self.router.clone();
210        let token = self.cancel.clone();
211        let doc_root = self.doc_root.clone();
212        let vhosts = self.virtual_hosts();
213        let ip = connection.client_ip;
214        let timeout_duration = self.timeout;
215        let read_loop = async move {
216            let mut request_bytes = BytesMut::with_capacity(1024);
217            let mut buffer = vec![0; 1024]; //Vector to avoid buffer on stack
218            while let Ok(stream_read_result) =
219                timeout(timeout_duration, connection.stream.read(&mut buffer)).await
220            {
221                match stream_read_result {
222                    Ok(0) => {
223                        tracing::debug!("{ip}: Connection Terminated by client");
224                        return;
225                    }
226                    Ok(n) => {
227                        //got some bytes append them and see if we need to do any proccessing
228                        for b in buffer.iter().take(n) {
229                            request_bytes.put_u8(*b);
230                        }
231                        let request_result =
232                            request::Request::from_bytes(request_bytes.clone().into());
233                        match request_result {
234                            Ok(r) => {
235                                let path = r.path();
236                                let host = r.hostname();
237                                tracing::info!(
238                                    "{ip}: {} {} Request for: {}",
239                                    r.method(),
240                                    r.version(),
241                                    path
242                                );
243
244                                let html_path = if let Some(vhost) = vhosts.read().await.get(host) {
245                                    vhost.root_dir().clone()
246                                } else {
247                                    doc_root.clone()
248                                };
249                                let router_locked = router.read().await;
250                                let response = router_locked.route(&r, &html_path).await;
251                                tracing::debug!("{ip}|{path}: Writing Response");
252                                if let Err(error) = connection.write_response(response).await {
253                                    // not clearing string here so we can try
254                                    // again, otherwise might be terminated
255                                    // connection which will be handled
256                                    tracing::error!(
257                                        "{ip}|{path}: Error Writing response: {}",
258                                        error.to_string()
259                                    );
260                                } else {
261                                    //clear buffer
262                                    tracing::trace!(
263                                        "{ip}|{path}: Wrote response, clearing request buffer"
264                                    );
265                                    if r.keep_alive() {
266                                        connection.stream.flush().await.expect("Error flushing");
267                                        request_bytes.clear();
268                                    } else {
269                                        tracing::debug!(
270                                            "{ip}|{path}: Shutting down Stream, no keep alive"
271                                        );
272                                        //returning should drop the connection and shutdown the socket
273                                        return;
274                                    }
275                                }
276                            }
277                            Err(e) => match e {
278                                request::Error::InvalidString
279                                | request::Error::MissingBlankLine => {}
280                                request::Error::WaitingOnBody(pb) => {
281                                    if let Some(bytes_left) = pb {
282                                        let free_bytes =
283                                            request_bytes.capacity() - request_bytes.len();
284                                        if free_bytes < bytes_left {
285                                            // we know body size preallocate for it
286                                            request_bytes.reserve(bytes_left - free_bytes);
287                                        }
288                                    }
289                                }
290                                _ => {
291                                    let error_res = format!("400 bad request: {}", e);
292                                    let req_string = String::from_utf8_lossy(&buffer);
293                                    tracing::warn!("{ip}: {} Request: {}", error_res, req_string);
294                                    let response = Response::error(
295                                        http::StatusCode::BAD_REQUEST,
296                                        error_res.into(),
297                                    );
298                                    if let Err(err) = connection.write_response(response).await {
299                                        tracing::error!(
300                                            "{ip}: Error Writing Data: {}",
301                                            err.to_string()
302                                        );
303                                    }
304                                    //returning should drop the connection and shutdown the socket
305                                    tracing::warn!("{ip}: Shutting down Stream, bad request");
306                                    return;
307                                }
308                            },
309                        }
310                    }
311                    Err(err) => {
312                        tracing::error!("{ip}: Socket read error: {}", err.to_string());
313                        return;
314                    }
315                }
316            }
317            tracing::debug!("{ip} Connection Server Read Timeout");
318            /*
319            let mut response = Response::error(
320                StatusCode::REQUEST_TIMEOUT,
321                "Client failed to send request in time".into(),
322            );
323            response.add_header(("Connection", "close"));
324            if let Err(error) = connection.write_response(response).await {
325                //just log error since we are dropping connection anyhow
326                tracing::debug!("{ip} Error Writing: {}", error);
327            }
328            */
329        };
330
331        tokio::spawn(async move {
332            select! {
333                _ = read_loop => {
334                }
335                _ = token.cancelled() => {
336                    tracing::debug!("shutting down listen thread");
337                }
338            }
339        })
340    }
341
342    #[tracing::instrument(level = "debug", skip(self))]
343    pub async fn serve(&self) -> tokio::io::Result<()> {
344        let accept_loop = async move {
345            loop {
346                let accept_attempt = self.accept().await;
347                match accept_attempt {
348                    Ok(connection) => {
349                        tracing::info!("Accepted Connection From {}", connection.client_ip);
350                        self.serve_connection(connection);
351                    }
352                    Err(e) => {
353                        tracing::error!("Error Accepting Connection: {}", e.to_string());
354                    }
355                }
356            }
357        };
358
359        let mut sigterm = signal(SignalKind::terminate()).unwrap();
360        select! {
361            _ = accept_loop => {
362                tracing::info!("shutting down due to acceptor exit");
363                Ok(())
364            }
365            _ = tokio::signal::ctrl_c() => {
366                tracing::info!("Received CTRL C shutting down");
367                self.cancel.cancel();
368                Ok(())
369            }
370            _ = sigterm.recv() => {
371                tracing::info!("Received SigTerm shutting down");
372                self.cancel.cancel();
373                Ok(())
374            }
375        }
376    }
377}
378fn load_keys_and_certs(paths: &Vec<&Path>) -> std::io::Result<(Vec<PrivateKey>, Vec<Certificate>)> {
379    let mut keys = vec![];
380    let mut certs = vec![];
381    for path in paths {
382        let items =
383            rustls_pemfile::read_all(&mut std::io::BufReader::new(std::fs::File::open(path)?))?;
384        for item in items {
385            match item {
386                rustls_pemfile::Item::RSAKey(key) => {
387                    keys.push(PrivateKey(key));
388                }
389                rustls_pemfile::Item::ECKey(key) => {
390                    keys.push(PrivateKey(key));
391                }
392                rustls_pemfile::Item::PKCS8Key(key) => {
393                    keys.push(PrivateKey(key));
394                }
395                rustls_pemfile::Item::X509Certificate(cert) => {
396                    certs.push(Certificate(cert));
397                }
398                _ => {}
399            }
400        }
401    }
402    Ok((keys, certs))
403}