stargazer 1.1.1

A fast and easy to use Gemini server
// stargazer - A Gemini Server
// Copyright (C) 2021 Sashanoraa <sasha@noraa.gay>
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <https://www.gnu.org/licenses/>.

use crate::router::{CGIRoute, Request, RoutePath};
use crate::{
    error::{error, Context, ErrorConv, GemError, Result},
    router::SCGIRoute,
    EXEC,
};
use crate::{get_file::parse_path, router::SCGIAddress};
use crate::{LogInfo, CONF};
use base64::Engine as _;
use async_io::Timer;
use async_net::{unix::UnixStream, TcpStream};
use async_process::{Command, Stdio};
use blocking::unblock;
use futures_lite::{
    future::poll_once,
    io::{AsyncReadExt, AsyncWriteExt},
    AsyncRead, AsyncWrite, FutureExt,
};
use futures_rustls::rustls::{
    ProtocolVersion, ServerConnection, SupportedCipherSuite,
};
use futures_rustls::server::TlsStream;
use is_executable::IsExecutable;
use log::{debug, error, warn};
use sha2::{Digest, Sha256};
use std::time::SystemTime;
use std::{collections::HashMap, net::IpAddr};
use std::{
    ffi::{OsStr, OsString},
    time::Duration,
};
use std::{
    io::ErrorKind,
    path::{Path, PathBuf},
};

pub async fn serve_cgi<'a>(
    cgi_route: &'static CGIRoute,
    domain: &'a str,
    req: &'a Request,
    remote_addr: IpAddr,
    server_port: u16,
    stream: &'a mut TlsStream<TcpStream>,
) -> Result<LogInfo> {
    let os_str = parse_path(&req.path)?;
    let url_path = Path::new(&os_str).to_owned();

    let is_exec = unblock(move || {
        cgi_route.root.is_file() && cgi_route.root.is_executable()
    })
    .await;
    let (script_name, path_info) = if is_exec {
        (PathBuf::new(), PathBuf::new())
    } else {
        unblock(move || {
            let mut found = false;
            let mut script_name = PathBuf::new();
            let mut iter = url_path.iter();
            for segment in &mut iter {
                script_name.push(segment);
                let script_path = cgi_route.root.join(&script_name);
                if script_path.is_file() && script_path.is_executable() {
                    found = true;
                    break;
                }
            }
            let mut path_info = PathBuf::new();
            for segment in iter {
                path_info.push(segment);
            }
            if found {
                Ok((script_name, path_info))
            } else {
                Err(GemError::NotFound)
            }
        })
        .await?
    };

    let (_, session) = stream.get_ref();
    let script_path = cgi_route.root.join(&script_name).canonicalize()?;
    let mut cmd = Command::new(&script_path);
    cmd.current_dir(&cgi_route.root.canonicalize()?);
    cmd.envs(
        common_vars(domain, req, remote_addr, server_port, session).await?,
    );
    cmd.env("GATEWAY_INTERFACE", "CGI/1.1");
    fn add_leading_slash(s: &OsStr) -> OsString {
        let mut new = OsString::from("/");
        new.push(s);
        new
    }
    cmd.env("SCRIPT_NAME", add_leading_slash(script_name.as_os_str()));
    cmd.env("PATH_INFO", add_leading_slash(path_info.as_os_str()));
    cmd.stdout(Stdio::piped());
    cmd.stderr(Stdio::piped());

    cfg_if::cfg_if! {
        if #[cfg(unix)] {
            use async_process::unix::CommandExt;
            if let Some(user) = &cgi_route.user {
                cmd.uid(user.uid());
                cmd.gid(user.primary_group_id());
            }
        }
    }
    cfg_if::cfg_if! {
        if #[cfg(linux)] {
            unsafe {
                cmd.pre_exec(|| {
                    libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGHUP);
                    Ok(())
                });
            }
        }
    }

    debug!("Starting CGI process `{}` ...", script_path.display());
    let mut child = cmd.spawn()?;
    let mut buf = [0u8; 4096];
    // When this function returns drop_chan is dropped causing
    // wait_chan.recv() to exit
    let (_drop_chan, wait_chan): (async_channel::Sender<()>, _) =
        async_channel::bounded(1);
    let mut child_stdout =
        child.stdout.take().context("Child stdout not captured")?;
    let mut child_stderr =
        child.stderr.take().context("Child stderr not captured")?;
    EXEC.spawn(async move {
            if let Some(timeout) = cgi_route.timeout {
                Timer::after(Duration::from_secs(timeout)).await;
            } else {
                let _ = wait_chan.recv().await;
                // Wait a second for the CGI process to exit on its own
                Timer::after(Duration::from_secs(1)).await;
                if let Ok(Some(_)) = child.try_status() {
                    // Everything is good, it exited on it's own
                    return
                }
            }
            // If the child hasn't exited after the timeout
            if let Ok(None) | Err(_) = child.try_status() {
                debug!("CGI process {} outlived timeout or needs to be killed \
                    due to an error, killing...", child.id());
                cfg_if::cfg_if! {
                    // On *nix systems SIGTERM the process, wait 3 seconds and if it hasn't exited
                    // SIGKILL it. On non-*niw systems, just kill it.
                    if #[cfg(unix)] {
                        if unsafe { libc::kill(child.id() as i32, libc::SIGTERM) } != 0 {
                            debug!("Sending SIGTERM was unsuccessful, sending SIGKILL");
                            if let Err(e) = child.kill().context("Error killing CGI process") {
                                error!("{:#}", e);
                            }
                        }
                        Timer::after(Duration::from_secs(3)).await;
                        if let Ok(None) | Err(_) = child.try_status() {
                            debug!("Process didn't exit after 3 seconds, sending SIGKILL");
                            if let Err(e) = child.kill().context("Error killing CGI process") {
                                error!("{:#}", e);
                            }
                        }
                    } else {
                        if let Err(e) = child.kill().context("Error killing CGI process") {
                            error!("{:#}", e);
                        }
                    }
                }
            }
        })
        .detach();
    let mut body_size = 0usize;
    // Cache the beginning of the response so that the it can be logged
    let mut response_cache = Vec::new();
    let mut response_cache_size = 0usize;
    loop {
        // Okay so this loop is fun. If the client disconnects I want to know,
        // stop this loop and kill the CGI process. The problem is if child
        // isn't writing anything this loop will just sit here forever.
        // So, first I make the child process read stop after waiting for 5
        // seconds so that doesn't just wait forever it the child never writes
        // anything. Then if the timer expires we know. But, if we just write
        // 0 bytes to thew socket that will succeed even if the client
        // disconnected. So instead, we attempt to read 1 byte. We use poll_once
        // because we don't want to wait on this, if it doesn't returns Pending
        // the socket is probably still open. But if it does return and it read
        // 0 bytes we know the client disconnected and we can exit the loop.
        let bytes_read = async { child_stdout.read(&mut buf).await.map(Some) }
            .or(async {
                Timer::after(Duration::from_secs(5)).await;
                Ok(None)
            })
            .await
            .context("Error reading CGI output")?;
        let bytes_read = match bytes_read {
            Some(0) => break,
            Some(v) => v,
            None => {
                if let Some(ret) = poll_once(stream.read(&mut [0])).await {
                    if ret.into_io_error()? == 0 {
                        return Err(error("Client closed connection"))
                            .into_io_error();
                    }
                }
                continue;
            }
        };
        stream
            .write_all(&buf[..bytes_read])
            .await
            .context("Error writing CGI output")
            .into_io_error()?;
        body_size += bytes_read;
        // Only save log if cgi-response-logging is enable in config
        if CONF.cgi_resp_logging && response_cache_size <= 1027 {
            response_cache_size += bytes_read;
            response_cache.extend_from_slice(&buf[..bytes_read]);
        }
    }

    // Parse response header for logging
    let (status, meta) = if CONF.cgi_resp_logging {
        let mut error_out = String::new();
        child_stderr
            .read_to_string(&mut error_out)
            .await
            .context("Error reading child stderr")?;
        let error_out = error_out.trim();
        if !error_out.is_empty() {
            warn!("CGI script error output:\n{}", error_out);
        }
        parse_header(&response_cache).unwrap_or((0, vec![]))
    } else {
        (0, vec![])
    };

    debug!("CGI Done");
    Ok(LogInfo {
        size: body_size,
        status,
        meta,
    })
}

trait SCGIClient: AsyncRead + AsyncWrite + Unpin + Send {}
impl SCGIClient for TcpStream {}
impl SCGIClient for UnixStream {}

pub async fn serve_scgi<'a>(
    scgi_route: &'static SCGIRoute,
    domain: &'a str,
    req: &'a Request,
    route_path: &'a RoutePath,
    remote_addr: IpAddr,
    server_port: u16,
    stream: &'a mut TlsStream<TcpStream>,
) -> Result<LogInfo> {
    let (_, session) = stream.get_ref();
    let mut vars =
        common_vars(domain, req, remote_addr, server_port, session).await?;
    match route_path {
        RoutePath::Prefix(path) => {
            vars.insert("SCRIPT_NAME", path.to_owned());
            vars.insert(
                "PATH_INFO",
                req.path
                    .strip_prefix(path)
                    .context("Routing Error")?
                    .to_owned(),
            );
        }
        RoutePath::Exact(path) => {
            vars.insert("SCRIPT_NAME", path.to_owned());
            vars.insert("PATH_INFO", "/".to_owned());
        }
        RoutePath::Regex(regex) => {
            vars.insert("SCRIPT_NAME", regex.to_string());
            vars.insert("PATH_INFO", req.path.to_owned());
        }
        RoutePath::All => {
            vars.insert("SCRIPT_NAME", "/".to_owned());
            vars.insert("PATH_INFO", req.path.to_owned());
        }
    }
    let mut req_buf: Vec<u8> = Vec::new();
    #[allow(clippy::octal_escapes)]
    // All `\0` are null bytes, I think using hex is less readable
    req_buf.extend(b"CONTENT_LENGTH\00\0SCGI\01\0");
    for (key, value) in vars.iter() {
        req_buf.extend(key.as_bytes());
        req_buf.push(0);
        req_buf.extend(value.as_bytes());
        req_buf.push(0);
    }
    let mut req = Vec::new();
    req.extend(format!("{}:", req_buf.len()).as_bytes());
    req.extend(req_buf);
    req.push(",".as_bytes()[0]);
    let mut client_stream: Box<dyn SCGIClient> = match &scgi_route.addr {
        SCGIAddress::Tcp(addr) => {
            Box::new(TcpStream::connect(addr).await.with_context(|| {
                format!("Couldn't connect to SCGI server {}", addr)
            })?)
        }
        SCGIAddress::Unix(path) => {
            Box::new(UnixStream::connect(path).await.with_context(|| {
                format!("Couldn't connect to SCGI server {}", path.display())
            })?)
        }
    };
    client_stream.write_all(&req).await?;
    let mut buf = [0u8; 4096];
    let mut body_size = 0usize;
    // Cache the beginning of the response so that the it can be logged
    let mut response_cache = Vec::new();
    let mut response_cache_size = 0usize;
    loop {
        // See above
        let bytes_read = async { client_stream.read(&mut buf).await.map(Some) }
            .or(async {
                Timer::after(Duration::from_secs(5)).await;
                Ok(None)
            })
            .await;
        let bytes_read = match bytes_read {
            Ok(Some(0)) => break,
            Ok(Some(bytes_read)) => bytes_read,
            Ok(None) => {
                if let Some(ret) = poll_once(stream.read(&mut [0])).await {
                    if ret.into_io_error()? == 0 {
                        return Err(error("Client closed connection"))
                            .into_io_error();
                    }
                }
                continue;
            }
            Err(e) => match e.kind() {
                ErrorKind::ConnectionAborted | ErrorKind::ConnectionReset => {
                    break
                }
                _ => return Err(e).context("Error reading SCGI output"),
            },
        };
        stream
            .write_all(&buf[..bytes_read])
            .await
            .context("Error writing SCGI output")
            .into_io_error()?;
        body_size += bytes_read;
        // Only save log if cgi-response-logging is enable in config
        if CONF.cgi_resp_logging && response_cache_size <= 1027 {
            response_cache_size += bytes_read;
            response_cache.extend_from_slice(&buf[..bytes_read]);
        }
    }

    // Parse response header for logging
    let (status, meta) = if CONF.cgi_resp_logging {
        parse_header(&response_cache).unwrap_or((0, vec![]))
    } else {
        (0, vec![])
    };

    Ok(LogInfo {
        size: body_size,
        status,
        meta,
    })
}

async fn common_vars(
    domain: &str,
    req: &Request,
    remote_addr: IpAddr,
    server_port: u16,
    session: &ServerConnection,
) -> Result<HashMap<&'static str, String>> {
    let tls_version = tls_version_str(
        &session
            .protocol_version()
            .context("Error reading TLS version")?,
    );
    let cipher_suite = cipher_str(
        &session
            .negotiated_cipher_suite()
            .context("Error reading cipher suite")?,
    );
    let mut vars = HashMap::new();
    vars.insert("SERVER_PROTOCOL", "GEMINI".to_owned());
    vars.insert(
        "SERVER_SOFTWARE",
        format!("stargazer/{}", env!("CARGO_PKG_VERSION")),
    );
    vars.insert("SERVER_NAME", domain.to_owned());
    vars.insert("HOSTNAME", domain.to_owned());
    vars.insert("REMOTE_HOST", remote_addr.to_string());
    vars.insert("REMOTE_ADDR", remote_addr.to_string());
    vars.insert("SERVER_PORT", server_port.to_string());
    vars.insert("GEMINI_URL", req.to_string());
    vars.insert("QUERY_STRING", req.query.clone());
    vars.insert("TLS_VERSION", tls_version);
    vars.insert("TLS_CIPHER", cipher_suite);

    if let Some(cert_raw) =
        session.peer_certificates().and_then(|list| list.get(0))
    {
        match x509_parser::parse_x509_certificate(cert_raw.as_ref())
            .context("Error parsing client cert")
        {
            Ok((_, cert)) => {
                let not_before =
                    format_datetime(cert.validity().not_before.timestamp());
                let not_after =
                    format_datetime(cert.validity().not_after.timestamp());
                let mut hasher = Sha256::new();
                hasher.update(cert_raw);
                let hash = hasher.finalize();
                let enc = base64::engine::general_purpose::STANDARD.encode(hash);
                let remote_user = cert
                    .subject()
                    .iter_common_name()
                    .next()
                    .and_then(|cn| cn.as_str().ok())
                    .unwrap_or("");

                vars.insert("AUTH_TYPE", "CERTIFICATE".to_owned());
                vars.insert("TLS_CLIENT_HASH", enc);
                vars.insert("TLS_CLIENT_NOT_BEFORE", not_before);
                vars.insert("TLS_CLIENT_NOT_AFTER", not_after);
                vars.insert("TLS_CLIENT_SUBJECT", cert.subject().to_string());
                vars.insert("TLS_CLIENT_ISSUER", cert.issuer().to_string());
                vars.insert("REMOTE_USER", remote_user.to_owned());
            }
            Err(e) => log::warn!("{:#}", e),
        }
    }

    Ok(vars)
}

fn tls_version_str(v: &ProtocolVersion) -> String {
    match v {
        ProtocolVersion::SSLv2 => "SSLv2".to_string(),
        ProtocolVersion::SSLv3 => "SSLv3".to_string(),
        ProtocolVersion::TLSv1_0 => "TLSv1.0".to_string(),
        ProtocolVersion::TLSv1_1 => "TLSv1.1".to_string(),
        ProtocolVersion::TLSv1_2 => "TLSv1.2".to_string(),
        ProtocolVersion::TLSv1_3 => "TLSv1.3".to_string(),
        _ => "UNKNOWN".to_string(),
    }
}

fn cipher_str(cipher: &SupportedCipherSuite) -> String {
    use futures_rustls::rustls::cipher_suite::*;
    if *cipher == TLS13_AES_128_GCM_SHA256 {
        "TLS_AES_128_GCM_SHA256"
    } else if *cipher == TLS13_AES_256_GCM_SHA384 {
        "TLS_AES_256_GCM_SHA384"
    } else if *cipher == TLS13_CHACHA20_POLY1305_SHA256 {
        "TLS_CHACHA20_POLY1305_SHA256"
    } else if *cipher == TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 {
        "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"
    } else if *cipher == TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 {
        "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"
    } else if *cipher == TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 {
        "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256"
    } else if *cipher == TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 {
        "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"
    } else if *cipher == TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 {
        "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"
    } else if *cipher == TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 {
        "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256"
    } else {
        "UNKNOWN"
    }
    .to_string()
}

fn format_datetime(timestamp: i64) -> String {
    humantime::format_rfc3339_seconds(
        SystemTime::UNIX_EPOCH + Duration::from_secs(timestamp as u64),
    )
    .to_string()
}

/// Parse response header for logging
fn parse_header(buf: &[u8]) -> anyhow::Result<(u8, Vec<u8>)> {
    if buf.len() >= 5 {
        let status = std::str::from_utf8(&buf[..2])?.parse()?;
        let mut found_cr = false;

        for (i, byte) in buf[3..].iter().enumerate() {
            if *byte == b'\r' {
                found_cr = true;
            } else if found_cr && *byte == b'\n' {
                return Ok((status, Vec::from(&buf[3..i + 4])));
            }
        }
    }

    Err(anyhow::anyhow!("CGI response doesn't have a header"))
}