use crate::{
name,
server::{respond, Handler, HttpRequest, HttpSettings, ResponseData, Stream},
version,
};
use kern::Fail;
use rustls::{ServerConfig, ServerSession, Stream as RustlsStream};
use std::io::prelude::{Read, Write};
use std::net::{TcpListener, TcpStream};
use std::sync::{Arc, RwLock};
use std::thread;
pub fn accept_connections<T: Send + Sync + 'static>(
listener: Arc<RwLock<TcpListener>>,
http_settings: Arc<HttpSettings>,
tls_config: Arc<ServerConfig>,
handler: Handler<T>,
shared: Arc<RwLock<T>>,
) {
loop {
if let Ok((stream, _)) = listener.read().unwrap().accept() {
let http_settings = http_settings.clone();
let tls_config = tls_config.clone();
let shared = shared.clone();
thread::spawn(move || {
handle_connection(stream, &http_settings, tls_config, handler, shared).ok();
});
}
}
}
pub fn handle_connection<T: Send + Sync + 'static>(
mut stream: TcpStream,
http_settings: &HttpSettings,
tls_config: Arc<ServerConfig>,
handler: Handler<T>,
shared: Arc<RwLock<T>>,
) -> Result<(), Fail> {
stream
.set_read_timeout(http_settings.read_timeout)
.or_else(Fail::from)?;
stream
.set_write_timeout(http_settings.write_timeout)
.or_else(Fail::from)?;
let mut session = ServerSession::new(&tls_config);
let mut stream = RustlsStream::new(&mut session, &mut stream);
let response = match read_header(&mut stream, http_settings) {
Ok((header, rest)) => {
let http_request = HttpRequest::from(&header, rest, &mut stream, http_settings);
match handler(http_request, shared) {
Ok(response) => response,
Err(err) => respond(
format!("<!DOCTYPE html><html><head><title>{0}</title></head><body><h3>HTTP server error</h3><p>{0}</p><hr><address>{1} v{2}</address></body></html>", err, name(), version()),
"text/html",
Some(ResponseData::new().set_status("400 Bad Request"))),
}
}
Err(err) => {
if err.err_msg() == "received corrupt message" {
return Fail::from("Not a TLS connection");
}
respond(
format!("<!DOCTYPE html><html><head><title>{0}</title></head><body><h3>HTTP server error</h3><p>{0}</p><hr><address>{1} v{2}</address></body></html>", err, name(), version()),
"text/html",
Some(ResponseData::new().set_status("400 Bad Request")),
)
}
};
stream.write_all(&response).or_else(Fail::from)?;
stream.flush().or_else(Fail::from)?;
Ok(())
}
fn read_header(
stream: &mut Stream,
http_settings: &HttpSettings,
) -> Result<(String, Vec<u8>), Fail> {
let mut header = Vec::new();
let mut rest = Vec::new();
let mut buf = vec![0u8; http_settings.header_buffer];
let mut read_fails = 0;
'l: loop {
let length = stream.read(&mut buf).or_else(Fail::from)?;
if header.len() + length > http_settings.max_header_size {
return Fail::from("Max header size exceeded");
}
let buf = &buf[0..length];
'f: for (i, &b) in buf.iter().enumerate() {
if b == b'\r' {
if buf.len() < i + 4 {
let mut buf_temp = vec![0u8; i + 4 - buf.len()];
stream.read(&mut buf_temp).or_else(Fail::from)?;
let mut buf2 = [&buf[..], &buf_temp[..]].concat();
let header_end =
buf2[i + 1] == b'\n' && buf2[i + 2] == b'\r' && buf2[i + 3] == b'\n';
header.append(&mut buf2);
if header_end {
break 'l;
} else {
break 'f;
}
} else if buf[i + 1] == b'\n' && buf[i + 2] == b'\r' && buf[i + 3] == b'\n' {
let (split1, split2) = buf.split_at(i + 4);
header.extend_from_slice(split1);
rest.extend_from_slice(split2);
break 'l;
}
}
if buf.len() == i + 1 {
header.extend_from_slice(&buf);
}
}
if length < http_settings.header_buffer {
read_fails += 1;
if read_fails > http_settings.header_read_attempts {
return Fail::from("Read header failed too often");
}
}
}
Ok((
match String::from_utf8(header) {
Ok(header) => header,
Err(err) => return Fail::from(err),
},
rest,
))
}