grweb 0.1.3

A high-performance Rust Web framework based on gorust coroutine runtime
Documentation
use crate::{
    AppConfig, ConnectionPool, Method, Response, Router, ServerConfig, SharedPool,
    WebSocket, LoggingConfig,
};
use gorust::{go, runtime};
use grorm::ConnectionPool as dbConnectionPool;
use grlog::{LoggerBuilder, Target};
use log::{error, info};
use std::collections::HashMap;
use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;

pub struct Server {
    config: AppConfig,
    router: Arc<Router>,
    pool: SharedPool,
}

impl Server {
    pub fn new(app_config: AppConfig, mut router: Router) -> Self {
        let log_config = &app_config.logging;
        init_logger_level(log_config);

        let config = app_config;
        let config_server = &config.server;
        let pool = Arc::new(ConnectionPool::new(config_server.max_connections));
        router.set_pool(pool.clone());
        Self {
            config,
            router: Arc::new(router),
            pool,
        }
    }

    pub fn with_db_pool(mut self, db_pool: Arc<dbConnectionPool>) -> Self {
        Arc::get_mut(&mut self.router).unwrap().set_db_pool(db_pool);
        self
    }

    pub fn pool(&self) -> &SharedPool {
        &self.pool
    }

    #[runtime]
    pub fn run(self) -> std::io::Result<()> {
        let addr = self.config.server.addr();
        let listener = TcpListener::bind(&addr)?;

        let addr_for_log = addr.clone();
        go(move || {println!("Server listening on {}", addr_for_log)});

        let shutdown = Arc::new(AtomicBool::new(false));
        let shutdown_flag = shutdown.clone();
        let shutdown_addr = self.config.server.addr();

        std::thread::spawn(move || {
            while gorust::scheduler::Scheduler::is_running() {
                std::thread::sleep(Duration::from_millis(100));
            }
            shutdown_flag.store(true, Ordering::SeqCst);
            let addr: std::net::SocketAddr = shutdown_addr.parse().unwrap();
            for _ in 0..10 {
                if TcpStream::connect_timeout(&addr, Duration::from_millis(100)).is_ok() {
                    break;
                }
                std::thread::sleep(Duration::from_millis(100));
            }
        });

        let router = self.router.clone();
        let config = Arc::new(self.config.server);
        let pool = self.pool.clone();

        for stream in listener.incoming() {
            if shutdown.load(Ordering::SeqCst) {
                info!("Server stopped");
                break;
            }
            match stream {
                Ok(stream) => {
                    if shutdown.load(Ordering::SeqCst) {
                        break;
                    }
                    if !pool.try_acquire() {
                        let _ = send_503_and_close(stream);
                        continue;
                    }
                    let router = router.clone();
                    let config = config.clone();
                    let pool = pool.clone();
                    go(move || {
                        handle_connection(stream, &router, &config);
                        pool.release();
                    });
                }
                Err(e) => {
                    if shutdown.load(Ordering::SeqCst) {
                        break;
                    }
                    error!("Connection failed: {}", e);
                }
            }
        }

        Ok(())
    }
}



fn send_503_and_close(mut stream: TcpStream) -> std::io::Result<()> {
    let response =
        b"HTTP/1.1 503 Service Unavailable\r\nContent-Length: 0\r\nConnection: close\r\n\r\n";
    stream.write_all(response)?;
    stream.flush()
}

fn handle_connection(mut stream: TcpStream, router: &Router, config: &ServerConfig) {
    if config.tcp_nodelay {
        let _ = stream.set_nodelay(true);
    }

    let keep_alive_timeout = Duration::from_secs(config.keep_alive_timeout);
    let mut buffer = vec![0u8; config.read_buffer_size];
    let mut keep_alive = true;

    while keep_alive {
        let _ = stream.set_read_timeout(Some(keep_alive_timeout));

        match stream.read(&mut buffer) {
            Ok(0) => return,
            Ok(n) => {
                if let Some((method, path, body, _header_end, req_keep_alive, headers)) =
                    parse_http_request(&buffer[..n])
                {
                    if is_websocket_upgrade(&method, &headers) {
                        if let Some(ws_handler) = router.find_ws(&path) {
                            if let Some(key) = headers.get("Sec-WebSocket-Key") {
                                if let Some(ws) = WebSocket::accept(stream, key) {
                                    ws_handler(ws);
                                }
                                return;
                            }
                        }
                    }

                    keep_alive = req_keep_alive;

                    let req_data = if body.is_empty() {
                        Vec::new()
                    } else {
                        body.to_vec()
                    };
                    let response = router.handle_request(method, path, req_data, headers);
                    let response_bytes = format_response_fast(&response, keep_alive);
                    if stream.write_all(&response_bytes).is_err() {
                        return;
                    }
                    let _ = stream.flush();
                } else {
                    return;
                }
            }
            Err(_) => return,
        }
    }
}

fn is_websocket_upgrade(method: &Method, headers: &HashMap<String, String>) -> bool {
    if *method != Method::GET {
        return false;
    }
    let upgrade = headers
        .get("Upgrade")
        .map(|v| v.to_lowercase() == "websocket")
        .unwrap_or(false);
    let connection = headers
        .get("Connection")
        .map(|v| v.to_lowercase().contains("upgrade"))
        .unwrap_or(false);
    let has_key = headers.contains_key("Sec-WebSocket-Key");
    let version = headers
        .get("Sec-WebSocket-Version")
        .map(|v| v == "13")
        .unwrap_or(false);

    upgrade && connection && has_key && version
}

fn parse_http_request(
    buffer: &[u8],
) -> Option<(Method, String, &[u8], usize, bool, HashMap<String, String>)> {
    let header_end = find_headers_end(buffer)?;

    let request_line_end = memchr::memchr(b'\n', buffer)?;
    let request_line = &buffer[..request_line_end];

    let first_space = memchr::memchr(b' ', request_line)?;
    let second_space =
        memchr::memchr(b' ', &request_line[first_space + 1..]).map(|p| first_space + 1 + p)?;

    let method_bytes = &request_line[..first_space];
    let path_bytes = &request_line[first_space + 1..second_space];
    let version_bytes = &request_line[second_space + 1..];

    let method = match method_bytes {
        b"GET" => Method::GET,
        b"POST" => Method::POST,
        b"PUT" => Method::PUT,
        b"DELETE" => Method::DELETE,
        b"PATCH" => Method::PATCH,
        b"HEAD" => Method::HEAD,
        b"OPTIONS" => Method::OPTIONS,
        _ => Method::GET,
    };

    let path = String::from_utf8_lossy(path_bytes).to_string();
    let body = if header_end < buffer.len() {
        &buffer[header_end..]
    } else {
        &[]
    };

    let is_http11 = version_bytes == b"HTTP/1.1";
    let headers_slice = &buffer[request_line_end + 1..header_end];
    let has_connection_close = has_header_value(headers_slice, b"Connection", b"close");
    let has_keep_alive = has_header_value(headers_slice, b"Connection", b"keep-alive");

    let keep_alive = if is_http11 {
        !has_connection_close
    } else {
        has_keep_alive
    };

    let headers = parse_headers(headers_slice);

    Some((method, path, body, header_end, keep_alive, headers))
}

fn find_headers_end(buffer: &[u8]) -> Option<usize> {
    buffer
        .windows(4)
        .position(|w| w == b"\r\n\r\n")
        .map(|p| p + 4)
}

fn parse_headers(headers_slice: &[u8]) -> HashMap<String, String> {
    let mut headers = HashMap::new();
    let mut pos = 0;
    while pos < headers_slice.len() {
        let line_end = match memchr::memchr(b'\n', &headers_slice[pos..]) {
            Some(p) => pos + p,
            None => headers_slice.len(),
        };
        let line = &headers_slice[pos..line_end];
        let line = if line.ends_with(b"\r") {
            &line[..line.len() - 1]
        } else {
            line
        };

        if let Some(colon) = memchr::memchr(b':', line) {
            let name = String::from_utf8_lossy(&line[..colon]).to_string();
            let value = String::from_utf8_lossy(line[colon + 1..].trim_ascii()).to_string();
            headers.insert(name, value);
        }

        pos = line_end + 1;
    }
    headers
}

fn has_header_value(headers: &[u8], name: &[u8], value: &[u8]) -> bool {
    let mut pos = 0;
    let name_lower = name.to_ascii_lowercase();
    while pos < headers.len() {
        let line_end = match memchr::memchr(b'\n', &headers[pos..]) {
            Some(p) => pos + p,
            None => headers.len(),
        };
        let line = &headers[pos..line_end];
        let line = if line.ends_with(b"\r") {
            &line[..line.len() - 1]
        } else {
            line
        };

        if let Some(colon) = memchr::memchr(b':', line) {
            let header_name = &line[..colon];
            if header_name.len() == name.len() && header_name.to_ascii_lowercase() == name_lower {
                let header_value = &line[colon + 1..];
                let header_value = header_value.trim_ascii();
                if header_value.to_ascii_lowercase() == value.to_ascii_lowercase() {
                    return true;
                }
            }
        }

        pos = line_end + 1;
    }
    false
}

fn format_response_fast(response: &Response, keep_alive: bool) -> Vec<u8> {
    let status_line: &[u8] = match response.status {
        200 => b"HTTP/1.1 200 OK\r\n",
        201 => b"HTTP/1.1 201 Created\r\n",
        204 => b"HTTP/1.1 204 No Content\r\n",
        400 => b"HTTP/1.1 400 Bad Request\r\n",
        404 => b"HTTP/1.1 404 Not Found\r\n",
        500 => b"HTTP/1.1 500 Internal Server Error\r\n",
        _ => b"HTTP/1.1 200 OK\r\n",
    };

    let body_len = response.body.len();

    let mut cl_buf = itoa::Buffer::new();
    let content_length_str = cl_buf.format(body_len);

    let cl_header = b"Content-Length: ";
    let cl_suffix = b"\r\n";
    let connection: &[u8] = if keep_alive {
        b"Connection: keep-alive\r\n"
    } else {
        b"Connection: close\r\n"
    };

    let mut total_len = status_line.len()
        + cl_header.len()
        + content_length_str.len()
        + cl_suffix.len()
        + connection.len()
        + body_len
        + 2;

    for (k, v) in &response.headers {
        total_len += k.len() + v.len() + 4;
    }

    let mut result = Vec::with_capacity(total_len);

    result.extend_from_slice(status_line);
    result.extend_from_slice(cl_header);
    result.extend_from_slice(content_length_str.as_bytes());
    result.extend_from_slice(cl_suffix);
    result.extend_from_slice(connection);

    for (k, v) in &response.headers {
        result.extend_from_slice(k.as_bytes());
        result.extend_from_slice(b": ");
        result.extend_from_slice(v.as_bytes());
        result.extend_from_slice(b"\r\n");
    }

    result.extend_from_slice(b"\r\n");

    if body_len > 0 {
        result.extend_from_slice(&response.body);
    }

    result
}


fn init_logger_level(log_config: &LoggingConfig) {
    let level = match log_config.level.to_lowercase().as_str() {
            "trace" => log::LevelFilter::Trace,
            "debug" => log::LevelFilter::Debug,
            "info" => log::LevelFilter::Info,
            "warn" => log::LevelFilter::Warn,
            "error" => log::LevelFilter::Error,
            "off" => log::LevelFilter::Off,
            _ => log::LevelFilter::Info,
        };

        let target = match log_config.output.to_lowercase().as_str() {
            "file" => {
                if let Some(ref file_path) = log_config.file {
                    Target::File(PathBuf::from(file_path))
                } else {
                    Target::Stderr
                }
            }
            "stdout" => Target::Stdout,
            _ => Target::Stderr,
        };

        let mut builder = LoggerBuilder::new();
        builder.filter_level(level).target(target);
        builder.try_init();
}