use bytes::Bytes;
use http::StatusCode;
#[cfg(feature = "log")]
use log::{debug, info, warn};
use may::net::{TcpListener, TcpStream};
use num_cpus;
use std::io::{self, Read, Write};
use std::net::{SocketAddr, ToSocketAddrs};
use std::sync::atomic::{AtomicBool, Ordering};
use std::{panic, sync::Arc};
use crate::http::{Request, Response};
use crate::runtime::service::{ArcService, Service, ServiceResult};
#[derive(Clone, Debug)]
pub struct ServerConfig {
pub max_body_size: usize,
pub read_timeout_secs: u64,
pub workers: usize,
pub stack_size: usize,
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
max_body_size: 8192,
read_timeout_secs: 30,
workers: num_cpus::get(),
stack_size: 64 * 1024,
}
}
}
pub struct Server {
service: ArcService,
running: Arc<AtomicBool>,
config: ServerConfig,
}
impl Server {
pub fn new(service: impl Service, max_body_size: usize) -> Self {
let mut config = ServerConfig::default();
config.max_body_size = max_body_size;
Self {
service: Arc::new(service),
running: Arc::new(AtomicBool::new(true)),
config,
}
}
pub fn with_config(service: impl Service, config: ServerConfig) -> Self {
Self {
service: Arc::new(service),
running: Arc::new(AtomicBool::new(true)),
config,
}
}
pub fn shutdown(&self) {
self.running.store(false, Ordering::SeqCst);
}
pub fn run(&self, addr: impl ToSocketAddrs) -> io::Result<()> {
may::config().set_workers(self.config.workers);
may::config().set_stack_size(self.config.stack_size);
#[cfg(feature = "log")]
info!(
"Feather Runtime Started on {}",
addr.to_socket_addrs()?.next().unwrap_or(SocketAddr::from(([0, 0, 0, 0], 80)))
);
let listener = TcpListener::bind(addr)?;
while self.running.load(Ordering::SeqCst) {
match listener.accept() {
Ok((stream, addr)) => {
#[cfg(feature = "log")]
debug!("New connection from {}", addr);
let service = self.service.clone();
let config = self.config.clone();
may::go!(move || {
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| Self::conn_handler(stream, service, config)));
match result {
Ok(Ok(())) => (), Ok(Err(e)) => {
#[cfg(feature = "log")]
log::error!("Connection handler error: {}", e);
}
Err(e) => {
let msg = e.downcast_ref::<String>().map(|s| s.as_str()).unwrap_or("Unknown panic");
#[cfg(feature = "log")]
log::error!("Connection handler panic: {}", msg);
}
}
});
}
Err(e) => {
#[cfg(feature = "log")]
warn!("Failed to accept connection: {}", e);
}
}
}
#[cfg(feature = "log")]
info!("Server shutting down");
Ok(())
}
fn send_error(stream: &mut TcpStream, status: StatusCode, message: &str) -> io::Result<()> {
let mut response = Response::default();
response.set_status(status.as_u16());
response.send_text(message);
response.add_header("X-Content-Type-Options", "nosniff").ok();
response.add_header("X-Frame-Options", "DENY").ok();
response.add_header("Connection", "close").ok();
stream.write_all(&response.to_raw())
}
fn conn_handler(mut stream: TcpStream, service: ArcService, config: ServerConfig) -> io::Result<()> {
let mut keep_alive = true;
let mut pipeline_buffer: Vec<u8> = Vec::new();
let remote_addr = stream.local_addr()?;
while keep_alive {
stream.set_read_timeout(Some(std::time::Duration::from_secs(config.read_timeout_secs)))?;
let body = pipeline_buffer;
pipeline_buffer = Vec::new();
let mut buffer = body;
let mut temp = [0u8; 4096];
loop {
let prev_len = buffer.len();
let n = stream.read(&mut temp)?;
if n == 0 {
return Ok(()); }
buffer.extend_from_slice(&temp[..n]);
let check_from = prev_len.saturating_sub(3);
if buffer[check_from..].windows(4).any(|w| w == b"\r\n\r\n") {
break;
}
if buffer.len() > config.max_body_size {
Self::send_error(&mut stream, StatusCode::PAYLOAD_TOO_LARGE, "Headers too large")?;
return Ok(());
}
}
let header_end = buffer.windows(4).position(|w| w == b"\r\n\r\n").unwrap() + 4;
let headers_raw = &buffer[..header_end];
let mut body = buffer[header_end..].to_vec();
let temp_request = match Request::parse(headers_raw, Bytes::new(), remote_addr) {
Ok(r) => r,
Err(e) => {
Self::send_error(&mut stream, StatusCode::BAD_REQUEST, &format!("Invalid request: {}", e))?;
return Ok(());
}
};
if temp_request.headers.get(http::header::TRANSFER_ENCODING).map(|v| v.as_bytes().eq_ignore_ascii_case(b"chunked")).unwrap_or(false) {
Self::send_error(&mut stream, StatusCode::NOT_IMPLEMENTED, "Chunked transfer encoding not supported")?;
return Ok(());
}
keep_alive = match (temp_request.version, temp_request.headers.get(http::header::CONNECTION)) {
(http::Version::HTTP_11, Some(v)) if v.as_bytes().eq_ignore_ascii_case(b"close") => false,
(http::Version::HTTP_11, _) => true,
_ => false,
};
let content_length = temp_request.headers.get(http::header::CONTENT_LENGTH).and_then(|v| v.to_str().ok()).and_then(|v| v.parse::<usize>().ok()).unwrap_or(0);
if content_length > config.max_body_size {
Self::send_error(&mut stream, StatusCode::PAYLOAD_TOO_LARGE, "Request body too large")?;
return Ok(());
}
if body.len() > content_length {
pipeline_buffer = body.split_off(content_length);
}
while body.len() < content_length {
let n = stream.read(&mut temp)?;
if n == 0 {
Self::send_error(&mut stream, StatusCode::BAD_REQUEST, "Unexpected EOF while reading request body")?;
return Ok(());
}
body.extend_from_slice(&temp[..n]);
}
if body.len() > content_length {
pipeline_buffer = body.split_off(content_length);
}
let request = match Request::parse(headers_raw, Bytes::from(body), remote_addr) {
Ok(r) => r,
Err(e) => {
Self::send_error(&mut stream, StatusCode::BAD_REQUEST, &format!("Invalid request: {}", e))?;
return Ok(());
}
};
let result = service.handle(request, None);
match result {
Ok(ServiceResult::Response(response)) => {
let raw = response.to_raw();
stream.write_all(&raw)?;
stream.flush()?;
if !keep_alive {
return Ok(());
}
if let Some(conn) = response.headers.get(http::header::CONNECTION) {
if conn.as_bytes().eq_ignore_ascii_case(b"close") {
return Ok(());
}
}
}
Ok(ServiceResult::Consumed) => return Ok(()),
Err(e) => {
Self::send_error(&mut stream, StatusCode::INTERNAL_SERVER_ERROR, &format!("Internal error: {}", e))?;
return Ok(());
}
}
}
Ok(())
}
}