aver-rt 0.4.4

Shared Rust runtime pieces for Aver-generated programs
Documentation
use crate::{AverList, AverStr, HttpHeaders, HttpRequest, HttpResponse};
use std::io::{BufRead, BufReader, Read, Write};
use std::net::{TcpListener, TcpStream};
use std::time::Duration;

pub fn listen<F>(port: i64, mut handler: F) -> Result<(), String>
where
    F: FnMut(HttpRequest) -> HttpResponse,
{
    let listener = bind_listener(port)?;
    for incoming in listener.incoming() {
        let mut stream = incoming
            .map_err(|e| format!("HttpServer.listen: failed to accept connection: {}", e))?;
        serve_one(&mut stream, &mut handler);
    }
    Ok(())
}

pub fn listen_with<C, F>(port: i64, context: C, mut handler: F) -> Result<(), String>
where
    C: Clone,
    F: FnMut(C, HttpRequest) -> HttpResponse,
{
    let listener = bind_listener(port)?;
    for incoming in listener.incoming() {
        let mut stream = incoming
            .map_err(|e| format!("HttpServer.listen: failed to accept connection: {}", e))?;
        let ctx = context.clone();
        serve_one(&mut stream, |request| handler(ctx.clone(), request));
    }
    Ok(())
}

fn bind_listener(port: i64) -> Result<TcpListener, String> {
    TcpListener::bind(format!("0.0.0.0:{}", port))
        .map_err(|e| format!("HttpServer.listen: failed to bind on {}: {}", port, e))
}

fn serve_one<F>(stream: &mut TcpStream, mut handler: F)
where
    F: FnMut(HttpRequest) -> HttpResponse,
{
    if let Err(e) = stream.set_read_timeout(Some(Duration::from_secs(30))) {
        let _ = write_http_response(
            stream,
            &HttpResponse {
                status: 500,
                body: AverStr::from(format!("HttpServer: failed to set read timeout: {}", e)),
                headers: HttpHeaders::default(),
            },
        );
        return;
    }
    if let Err(e) = stream.set_write_timeout(Some(Duration::from_secs(30))) {
        let _ = write_http_response(
            stream,
            &HttpResponse {
                status: 500,
                body: AverStr::from(format!("HttpServer: failed to set write timeout: {}", e)),
                headers: HttpHeaders::default(),
            },
        );
        return;
    }

    let request = match parse_http_request(stream) {
        Ok(req) => req,
        Err(msg) => {
            let _ = write_http_response(
                stream,
                &HttpResponse {
                    status: 400,
                    body: AverStr::from(format!("Bad Request: {}", msg)),
                    headers: HttpHeaders::default(),
                },
            );
            return;
        }
    };

    let response = handler(request);
    let _ = write_http_response(stream, &response);
}

fn parse_http_request(stream: &mut TcpStream) -> Result<HttpRequest, String> {
    const BODY_LIMIT: usize = 10 * 1024 * 1024;

    let reader_stream = stream
        .try_clone()
        .map_err(|e| format!("cannot clone TCP stream: {}", e))?;
    let mut reader = BufReader::new(reader_stream);

    let mut request_line = String::new();
    let line_len = reader
        .read_line(&mut request_line)
        .map_err(|e| format!("cannot read request line: {}", e))?;
    if line_len == 0 {
        return Err("empty request".to_string());
    }

    let request_line = request_line.trim_end_matches(&['\r', '\n'][..]);
    let mut request_parts = request_line.split_whitespace();
    let method = request_parts
        .next()
        .ok_or_else(|| "missing HTTP method".to_string())?
        .to_string();
    let path = request_parts
        .next()
        .ok_or_else(|| "missing request path".to_string())?
        .to_string();
    request_parts
        .next()
        .ok_or_else(|| "missing HTTP version".to_string())?;

    // Build Map<lowercase-name, List<value>>. Multiple lines with the
    // same name accumulate as separate values (RFC 9110 §5.3 — same
    // semantics as comma-joining for non-Set-Cookie, kept structural
    // so consumers can choose).
    let mut headers: HttpHeaders = HttpHeaders::default();
    let mut content_length = 0usize;

    loop {
        let mut line = String::new();
        let bytes = reader
            .read_line(&mut line)
            .map_err(|e| format!("cannot read header line: {}", e))?;
        if bytes == 0 {
            break;
        }

        let trimmed = line.trim_end_matches(&['\r', '\n'][..]);
        if trimmed.is_empty() {
            break;
        }

        let (name, value) = trimmed
            .split_once(':')
            .ok_or_else(|| format!("malformed header: '{}'", trimmed))?;
        let name = name.trim().to_ascii_lowercase();
        let value = value.trim().to_string();

        if name.eq_ignore_ascii_case("Content-Length") {
            content_length = value
                .parse::<usize>()
                .map_err(|_| format!("invalid Content-Length value: '{}'", value))?;
            if content_length > BODY_LIMIT {
                return Err(format!("request body exceeds {} bytes limit", BODY_LIMIT));
            }
        }

        let key = AverStr::from(name);
        let value = AverStr::from(value);
        let entry = match headers.get(&key) {
            Some(existing) => {
                let mut buf: Vec<AverStr> = existing.iter().cloned().collect();
                buf.push(value);
                AverList::from_vec(buf)
            }
            None => AverList::from_vec(vec![value]),
        };
        headers = headers.insert(key, entry);
    }

    let mut body_bytes = vec![0_u8; content_length];
    if content_length > 0 {
        reader
            .read_exact(&mut body_bytes)
            .map_err(|e| format!("cannot read request body: {}", e))?;
    }
    let body = String::from_utf8_lossy(&body_bytes).into_owned();

    Ok(HttpRequest {
        method: AverStr::from(method),
        path: AverStr::from(path),
        body: AverStr::from(body),
        headers,
    })
}

fn status_reason(status: i64) -> &'static str {
    match status {
        200 => "OK",
        201 => "Created",
        204 => "No Content",
        301 => "Moved Permanently",
        302 => "Found",
        304 => "Not Modified",
        400 => "Bad Request",
        401 => "Unauthorized",
        403 => "Forbidden",
        404 => "Not Found",
        405 => "Method Not Allowed",
        409 => "Conflict",
        422 => "Unprocessable Entity",
        429 => "Too Many Requests",
        500 => "Internal Server Error",
        501 => "Not Implemented",
        502 => "Bad Gateway",
        503 => "Service Unavailable",
        _ => "OK",
    }
}

fn write_http_response(stream: &mut TcpStream, response: &HttpResponse) -> std::io::Result<()> {
    let mut headers: Vec<(String, String)> = Vec::new();
    for (name, values) in response.headers.iter() {
        if name.eq_ignore_ascii_case("Content-Length") || name.eq_ignore_ascii_case("Connection") {
            continue;
        }
        for value in values.iter() {
            headers.push((name.to_string(), value.to_string()));
        }
    }

    if !headers
        .iter()
        .any(|(name, _)| name.eq_ignore_ascii_case("Content-Type"))
    {
        headers.push((
            "Content-Type".to_string(),
            "text/plain; charset=utf-8".to_string(),
        ));
    }

    headers.push((
        "Content-Length".to_string(),
        response.body.len().to_string(),
    ));
    headers.push(("Connection".to_string(), "close".to_string()));

    let mut head = format!(
        "HTTP/1.1 {} {}\r\n",
        response.status,
        status_reason(response.status)
    );
    for (name, value) in headers {
        head.push_str(&format!("{}: {}\r\n", name, value));
    }
    head.push_str("\r\n");

    stream.write_all(head.as_bytes())?;
    stream.write_all(response.body.as_bytes())?;
    stream.flush()?;
    Ok(())
}