wcgi-host 0.1.2

Utilities for implementing WCGI (Webassembly Common Gateway Interface) support in hosts.
Documentation
//! Common abstractions for implementing a WCGI host.
//!
//! # Cargo Features
//!
//! - `schemars` - will enable JSON Schema generation for certain types using the
//!   [`schemars`](https://crates.io/crates/schemars) crate

use std::{
    collections::HashMap,
    fmt::{self, Display, Formatter},
    str::FromStr,
};

use tokio::io::AsyncBufRead;
use wasmparser::Payload;

mod cgi;
mod wcgi;

/// The CGI dialect to use when running a CGI workload.
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
pub enum CgiDialect {
    /// The "official" CGI dialect, as defined by
    /// [RFC 3875](https://www.ietf.org/rfc/rfc3875).
    #[default]
    Rfc3875,
    Wcgi,
}

impl CgiDialect {
    pub const CUSTOM_SECTION_NAME: &str = "cgi-dialect";

    /// Try to identify which [`CgiDialect`] should be used based on the
    /// WebAssembly module's binary representation.
    ///
    /// # Implementation Notes
    ///
    /// This currently works by looking through the WebAssembly binary for a
    /// custom section called [`CgiDialect::CUSTOM_SECTION_NAME`] and matching
    /// it against one of the known CGI dialects.
    ///
    /// This whole process is kinda hacky because it means you need to alter
    /// your binary to include the custom section. In the future, the CGI
    /// dialect should be specified using some external mechanism like metadata.
    pub fn from_wasm(wasm: &[u8]) -> Option<CgiDialect> {
        let dialect_sections = wasmparser::Parser::new(0)
            .parse_all(wasm)
            .filter_map(|p| match p {
                Ok(Payload::CustomSection(custom))
                    if custom.name() == CgiDialect::CUSTOM_SECTION_NAME =>
                {
                    Some(custom.data())
                }
                _ => None,
            });

        for data in dialect_sections {
            let dialect = std::str::from_utf8(data).ok().and_then(|s| s.parse().ok());
            if let Some(dialect) = dialect {
                return Some(dialect);
            }
        }

        None
    }

    pub fn prepare_environment_variables(
        self,
        parts: http::request::Parts,
        env: &mut HashMap<String, String>,
    ) {
        match self {
            CgiDialect::Rfc3875 => cgi::prepare_environment_variables(parts, env),
            CgiDialect::Wcgi => wcgi::prepare_environment_variables(parts, env),
        }
    }

    /// Extract the [`http::response::Parts`] from a CGI script's stdout.
    ///
    /// # Note
    ///
    /// This might stall if reading from stdout stalls. Care should be taken to
    /// avoid waiting forever (e.g. by adding a timeout).
    pub async fn extract_response_header(
        self,
        stdout: &mut (impl AsyncBufRead + Unpin),
    ) -> Result<http::response::Parts, CgiError> {
        match self {
            CgiDialect::Rfc3875 => cgi::extract_response_header(stdout).await,
            CgiDialect::Wcgi => wcgi::extract_response_header(stdout).await,
        }
    }

    pub const fn to_str(self) -> &'static str {
        match self {
            CgiDialect::Rfc3875 => "rfc-3875",
            CgiDialect::Wcgi => "wcgi",
        }
    }
}

impl FromStr for CgiDialect {
    type Err = UnknownCgiDialect;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        match s {
            "rfc-3875" => Ok(CgiDialect::Rfc3875),
            "wcgi" => Ok(CgiDialect::Wcgi),
            _ => Err(UnknownCgiDialect),
        }
    }
}

impl Display for CgiDialect {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        write!(f, "{}", self.to_str())
    }
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UnknownCgiDialect;

impl Display for UnknownCgiDialect {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        write!(f, "Unknown CGI dialect")
    }
}

impl std::error::Error for UnknownCgiDialect {}

#[derive(Debug)]
pub enum CgiError {
    StdoutRead(std::io::Error),
    InvalidHeaders {
        error: http::Error,
        header: String,
        value: String,
    },
    MalformedWcgiHeader {
        error: ::wcgi::WcgiError,
        header: String,
    },
}

impl std::error::Error for CgiError {
    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
        match self {
            CgiError::StdoutRead(e) => Some(e),
            CgiError::InvalidHeaders { error, .. } => Some(error),
            CgiError::MalformedWcgiHeader { error, .. } => error.source(),
        }
    }
}

impl fmt::Display for CgiError {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        match self {
            CgiError::StdoutRead(_) => write!(f, "Unable to read the STDOUT pipe"),
            CgiError::InvalidHeaders { header, value, .. } => {
                write!(f, "Unable to parse header ({header}: {value})")
            }
            CgiError::MalformedWcgiHeader { header, .. } => {
                write!(f, "Unable to parse WCGI header ({header})")
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn round_trip_cgi_dialect_to_string() {
        let dialects = [CgiDialect::Rfc3875, CgiDialect::Wcgi];

        for dialect in dialects {
            let repr = dialect.to_string();
            let round_tripped: CgiDialect = repr.parse().unwrap();
            assert_eq!(round_tripped, dialect);
        }
    }
}