1pub mod cookies;
2pub mod http;
3pub mod methods;
4pub mod request;
5pub mod response;
6pub mod routes;
7pub mod state;
8pub mod thread_pool;
9pub mod utils;
10pub mod virtual_host;
11
12use anyhow::Context;
13use bytes::{BufMut, BytesMut};
14use futures::StreamExt;
15use response::Response;
16use routes::Router;
17use rustls_acme::{caches::DirCache, AcmeConfig};
18use std::{
19 collections::HashMap,
20 fmt::Debug,
21 path::{Path, PathBuf},
22 sync::Arc,
23 time::Duration,
24 vec,
25};
26use tokio::{
27 self,
28 io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
29 net::TcpListener,
30 select,
31 signal::unix::{signal, SignalKind},
32 sync::RwLock,
33 task::JoinHandle,
34 time::timeout,
35};
36use tokio_rustls::{
37 rustls::{self, Certificate, PrivateKey},
38 TlsAcceptor,
39};
40use tokio_util::sync::CancellationToken;
41
42pub struct Server<S> {
43 listener: TcpListener,
44 acceptor: Option<TlsAcceptor>,
45 router: Arc<RwLock<Router<S>>>,
46 virtual_hosts: Arc<RwLock<HashMap<String, virtual_host::VirtualHost<S>>>>,
47 cancel: CancellationToken,
48 doc_root: PathBuf,
49 timeout: Duration,
50}
51
52trait ConnectionStream: AsyncWrite + AsyncRead + Unpin + Send + Sync {}
53
54impl<T> ConnectionStream for T where T: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync {}
56
57pub struct Connection {
58 stream: Box<dyn ConnectionStream>,
59 client_ip: std::net::SocketAddr,
60}
61
62impl Connection {
63 #[tracing::instrument(level = "debug", skip(self))]
64 pub async fn write_all(&mut self, src: &[u8]) -> tokio::io::Result<()> {
65 self.stream.write_all(src).await?;
66 Ok(())
67 }
68
69 #[tracing::instrument(level = "debug", skip(self, response))]
70 pub async fn write_response(&mut self, response: Response) -> tokio::io::Result<()> {
71 let response_buffer = response.to_send_buffer();
72 log::trace!("Writing: {}Bytes", response_buffer.len());
73 self.write_all(&response_buffer).await?;
74 Ok(())
75 }
76}
77
78impl<S> Server<S>
79where
80 S: Clone + Send + Sync + 'static,
81{
82 #[tracing::instrument(level = "debug", skip(router))]
83 pub async fn bind(
84 ip: &str,
85 router: Router<S>,
86 doc_root: impl AsRef<Path> + Debug,
87 ) -> Result<Self, tokio::io::Error> {
88 let listener = tokio::net::TcpListener::bind(ip).await?;
89 Ok(Server {
90 listener,
91 router: Arc::new(RwLock::new(router)),
92 virtual_hosts: Arc::new(RwLock::new(HashMap::new())),
93 acceptor: None,
94 cancel: CancellationToken::new(),
95 doc_root: PathBuf::from(doc_root.as_ref()),
96 timeout: Duration::from_secs(30),
97 })
98 }
99
100 #[tracing::instrument(level = "debug", skip(router))]
101 pub async fn bind_tls(
102 ip: &str,
103 cert: &Path,
104 key: &Path,
105 router: Router<S>,
106 doc_root: impl AsRef<Path> + Debug,
107 ) -> Result<Self, anyhow::Error> {
108 let files = vec![cert, key];
109 let context = format!("Opening: {:#?}, {:#?}", cert, key);
110 let (mut keys, certs) = load_keys_and_certs(&files).context(context)?;
111 let config = rustls::ServerConfig::builder()
112 .with_safe_defaults()
113 .with_no_client_auth()
114 .with_single_cert(certs, keys.remove(0))
115 .context("Loading Certs")?;
116 let acceptor = TlsAcceptor::from(Arc::new(config));
117 let listener = tokio::net::TcpListener::bind(ip)
118 .await
119 .context("binding tls")?;
120 Ok(Server {
121 listener,
122 router: Arc::new(RwLock::new(router)),
123 virtual_hosts: Arc::new(RwLock::new(HashMap::new())),
124 acceptor: Some(acceptor),
125 cancel: CancellationToken::new(),
126 doc_root: PathBuf::from(doc_root.as_ref()),
127 timeout: Duration::from_secs(60),
128 })
129 }
130
131 #[tracing::instrument(level = "debug", skip(router, domains))]
132 pub async fn bind_tls_alpn(
133 ip: &str,
134 router: Router<S>,
135 doc_root: impl AsRef<Path> + Debug,
136 domains: impl IntoIterator<Item = impl AsRef<str>>,
137 email: &str,
138 ) -> Result<Self, anyhow::Error> {
139 let contact = format!("mailto:{email}");
140 let acme = AcmeConfig::new(domains)
141 .contact_push(&contact)
142 .cache(DirCache::new("./rustls_acme_cache"));
143 let mut state = acme.state();
144 let resolver = state.resolver();
145 let config = rustls::ServerConfig::builder()
146 .with_safe_defaults()
147 .with_no_client_auth()
148 .with_cert_resolver(resolver);
149 tokio::spawn(async move {
150 loop {
151 match state.next().await.unwrap() {
152 Ok(ok) => log::info!("event: {:?}", ok),
153 Err(err) => log::error!("error: {:?}", err),
154 }
155 }
156 });
157 let acceptor = TlsAcceptor::from(Arc::new(config));
158 let listener = tokio::net::TcpListener::bind(ip)
159 .await
160 .context("binding tls")?;
161 Ok(Server {
162 listener,
163 router: Arc::new(RwLock::new(router)),
164 virtual_hosts: Arc::new(RwLock::new(HashMap::new())),
165 acceptor: Some(acceptor),
166 cancel: CancellationToken::new(),
167 doc_root: PathBuf::from(doc_root.as_ref()),
168 timeout: Duration::from_secs(60),
169 })
170 }
171
172 #[tracing::instrument(level = "debug", skip(self))]
173 pub fn virtual_hosts(&self) -> Arc<RwLock<HashMap<String, virtual_host::VirtualHost<S>>>> {
174 self.virtual_hosts.clone()
175 }
176
177 #[tracing::instrument(level = "debug", skip(self, virtual_host))]
178 pub async fn add_virtual_host(&mut self, virtual_host: virtual_host::VirtualHost<S>) {
179 let virtual_hosts = self.virtual_hosts();
180 let mut locked = virtual_hosts.write().await;
181 locked.insert(virtual_host.hostname().to_string(), virtual_host);
182 }
183
184 #[tracing::instrument(level = "debug", skip(self))]
185 pub async fn accept(&self) -> tokio::io::Result<Connection> {
186 let (stream, client_ip) = self.listener.accept().await?;
187 if let Some(acceptor) = &self.acceptor {
188 let acceptor = acceptor.clone();
189 match acceptor.accept(stream).await {
190 Ok(s) => Ok(Connection {
191 client_ip,
192 stream: Box::new(tokio_rustls::TlsStream::Server(s)),
193 }),
194 Err(_) => Err(tokio::io::Error::new(
195 tokio::io::ErrorKind::Other,
196 "Error Accepting TLS Stream",
197 )),
198 }
199 } else {
200 Ok(Connection {
201 client_ip,
202 stream: Box::new(stream),
203 })
204 }
205 }
206
207 #[tracing::instrument(level = "debug", skip(self, connection))]
208 fn serve_connection(&self, mut connection: Connection) -> JoinHandle<()> {
209 let router = self.router.clone();
210 let token = self.cancel.clone();
211 let doc_root = self.doc_root.clone();
212 let vhosts = self.virtual_hosts();
213 let ip = connection.client_ip;
214 let timeout_duration = self.timeout;
215 let read_loop = async move {
216 let mut request_bytes = BytesMut::with_capacity(1024);
217 let mut buffer = vec![0; 1024]; while let Ok(stream_read_result) =
219 timeout(timeout_duration, connection.stream.read(&mut buffer)).await
220 {
221 match stream_read_result {
222 Ok(0) => {
223 tracing::debug!("{ip}: Connection Terminated by client");
224 return;
225 }
226 Ok(n) => {
227 for b in buffer.iter().take(n) {
229 request_bytes.put_u8(*b);
230 }
231 let request_result =
232 request::Request::from_bytes(request_bytes.clone().into());
233 match request_result {
234 Ok(r) => {
235 let path = r.path();
236 let host = r.hostname();
237 tracing::info!(
238 "{ip}: {} {} Request for: {}",
239 r.method(),
240 r.version(),
241 path
242 );
243
244 let html_path = if let Some(vhost) = vhosts.read().await.get(host) {
245 vhost.root_dir().clone()
246 } else {
247 doc_root.clone()
248 };
249 let router_locked = router.read().await;
250 let response = router_locked.route(&r, &html_path).await;
251 tracing::debug!("{ip}|{path}: Writing Response");
252 if let Err(error) = connection.write_response(response).await {
253 tracing::error!(
257 "{ip}|{path}: Error Writing response: {}",
258 error.to_string()
259 );
260 } else {
261 tracing::trace!(
263 "{ip}|{path}: Wrote response, clearing request buffer"
264 );
265 if r.keep_alive() {
266 connection.stream.flush().await.expect("Error flushing");
267 request_bytes.clear();
268 } else {
269 tracing::debug!(
270 "{ip}|{path}: Shutting down Stream, no keep alive"
271 );
272 return;
274 }
275 }
276 }
277 Err(e) => match e {
278 request::Error::InvalidString
279 | request::Error::MissingBlankLine => {}
280 request::Error::WaitingOnBody(pb) => {
281 if let Some(bytes_left) = pb {
282 let free_bytes =
283 request_bytes.capacity() - request_bytes.len();
284 if free_bytes < bytes_left {
285 request_bytes.reserve(bytes_left - free_bytes);
287 }
288 }
289 }
290 _ => {
291 let error_res = format!("400 bad request: {}", e);
292 let req_string = String::from_utf8_lossy(&buffer);
293 tracing::warn!("{ip}: {} Request: {}", error_res, req_string);
294 let response = Response::error(
295 http::StatusCode::BAD_REQUEST,
296 error_res.into(),
297 );
298 if let Err(err) = connection.write_response(response).await {
299 tracing::error!(
300 "{ip}: Error Writing Data: {}",
301 err.to_string()
302 );
303 }
304 tracing::warn!("{ip}: Shutting down Stream, bad request");
306 return;
307 }
308 },
309 }
310 }
311 Err(err) => {
312 tracing::error!("{ip}: Socket read error: {}", err.to_string());
313 return;
314 }
315 }
316 }
317 tracing::debug!("{ip} Connection Server Read Timeout");
318 };
330
331 tokio::spawn(async move {
332 select! {
333 _ = read_loop => {
334 }
335 _ = token.cancelled() => {
336 tracing::debug!("shutting down listen thread");
337 }
338 }
339 })
340 }
341
342 #[tracing::instrument(level = "debug", skip(self))]
343 pub async fn serve(&self) -> tokio::io::Result<()> {
344 let accept_loop = async move {
345 loop {
346 let accept_attempt = self.accept().await;
347 match accept_attempt {
348 Ok(connection) => {
349 tracing::info!("Accepted Connection From {}", connection.client_ip);
350 self.serve_connection(connection);
351 }
352 Err(e) => {
353 tracing::error!("Error Accepting Connection: {}", e.to_string());
354 }
355 }
356 }
357 };
358
359 let mut sigterm = signal(SignalKind::terminate()).unwrap();
360 select! {
361 _ = accept_loop => {
362 tracing::info!("shutting down due to acceptor exit");
363 Ok(())
364 }
365 _ = tokio::signal::ctrl_c() => {
366 tracing::info!("Received CTRL C shutting down");
367 self.cancel.cancel();
368 Ok(())
369 }
370 _ = sigterm.recv() => {
371 tracing::info!("Received SigTerm shutting down");
372 self.cancel.cancel();
373 Ok(())
374 }
375 }
376 }
377}
378fn load_keys_and_certs(paths: &Vec<&Path>) -> std::io::Result<(Vec<PrivateKey>, Vec<Certificate>)> {
379 let mut keys = vec![];
380 let mut certs = vec![];
381 for path in paths {
382 let items =
383 rustls_pemfile::read_all(&mut std::io::BufReader::new(std::fs::File::open(path)?))?;
384 for item in items {
385 match item {
386 rustls_pemfile::Item::RSAKey(key) => {
387 keys.push(PrivateKey(key));
388 }
389 rustls_pemfile::Item::ECKey(key) => {
390 keys.push(PrivateKey(key));
391 }
392 rustls_pemfile::Item::PKCS8Key(key) => {
393 keys.push(PrivateKey(key));
394 }
395 rustls_pemfile::Item::X509Certificate(cert) => {
396 certs.push(Certificate(cert));
397 }
398 _ => {}
399 }
400 }
401 }
402 Ok((keys, certs))
403}