caxton 0.1.4

A secure WebAssembly runtime for multi-agent systems
Documentation
use crate::domain_types::{MaxExports, MaxImportFunctions};
use nutype::nutype;
#[allow(unused_imports)]
use serde::{Deserialize, Serialize};
use std::collections::HashSet;

#[nutype(
    validate(len_char_min = 1, len_char_max = 255),
    derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Display)
)]
pub struct SafeFunctionName(String);

impl SafeFunctionName {
    pub fn is_messaging_function(&self) -> bool {
        self.to_string().starts_with("agent_message_")
    }

    pub fn is_standard_function(&self) -> bool {
        const STANDARD_FUNCTIONS: &[&str] = &[
            "agent_get_id",
            "agent_get_timestamp",
            "agent_log",
            "agent_message_send",
            "agent_message_receive",
        ];
        STANDARD_FUNCTIONS.contains(&self.to_string().as_str())
    }
}

#[nutype(
    validate(len_char_min = 1, len_char_max = 255),
    derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Display)
)]
pub struct UnsafeFunctionName(String);

impl UnsafeFunctionName {
    pub fn is_memory_function(&self) -> bool {
        self.to_string().starts_with("memory_")
    }

    pub fn is_system_function(&self) -> bool {
        self.to_string().starts_with("system_") || self.to_string().starts_with("process_")
    }
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum FunctionName {
    Safe(SafeFunctionName),
    Unsafe(UnsafeFunctionName),
}

impl FunctionName {
    /// Create a function name from string, categorizing as safe or unsafe
    ///
    /// # Panics
    ///
    /// Panics if the function name cannot be created (should never happen for valid strings).
    pub fn categorize_function(name: &str) -> Self {
        const UNSAFE_FUNCTIONS: &[&str] = &[
            "memory_grow",
            "memory_copy",
            "table_grow",
            "table_copy",
            "process_exit",
            "system_call",
            "fd_write",
            "fd_read",
            "environ_get",
            "environ_sizes_get",
        ];

        if UNSAFE_FUNCTIONS.contains(&name) {
            FunctionName::Unsafe(UnsafeFunctionName::try_new(name.to_string()).unwrap())
        } else {
            SafeFunctionName::try_new(name.to_string()).map_or_else(
                |_| FunctionName::Unsafe(UnsafeFunctionName::try_new(name.to_string()).unwrap()),
                FunctionName::Safe,
            )
        }
    }

    pub fn is_safe(&self) -> bool {
        matches!(self, FunctionName::Safe(_))
    }

    pub fn as_str(&self) -> String {
        match self {
            FunctionName::Safe(name) => name.to_string(),
            FunctionName::Unsafe(name) => name.to_string(),
        }
    }
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StrictSecurityPolicy {
    pub max_import_functions: MaxImportFunctions,
    pub max_exports: MaxExports,
    pub allowed_functions: Vec<SafeFunctionName>,
}

impl StrictSecurityPolicy {
    /// Create a new strict security policy
    ///
    /// # Panics
    ///
    /// Panics if the hardcoded function names are invalid (should never happen).
    pub fn new(max_import_functions: MaxImportFunctions, max_exports: MaxExports) -> Self {
        let allowed_functions = vec![
            SafeFunctionName::try_new("agent_get_id".to_string()).unwrap(),
            SafeFunctionName::try_new("agent_get_timestamp".to_string()).unwrap(),
            SafeFunctionName::try_new("agent_log".to_string()).unwrap(),
        ];

        Self {
            max_import_functions,
            max_exports,
            allowed_functions,
        }
    }

    pub fn enable_networking(&self) -> bool {
        false
    }

    pub fn enable_threads(&self) -> bool {
        false
    }

    pub fn enable_fuel_metering(&self) -> bool {
        true
    }

    pub fn is_function_allowed(&self, function: &FunctionName) -> bool {
        match function {
            FunctionName::Safe(name) => self.allowed_functions.contains(name),
            FunctionName::Unsafe(_) => false,
        }
    }
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RelaxedSecurityPolicy {
    pub max_import_functions: MaxImportFunctions,
    pub max_exports: MaxExports,
    pub allowed_functions: HashSet<FunctionName>,
    pub enable_threads: bool,
    pub enable_networking: bool,
}

impl RelaxedSecurityPolicy {
    /// Create a new relaxed security policy
    ///
    /// # Panics
    ///
    /// Panics if the hardcoded function names are invalid (should never happen).
    pub fn new(
        max_import_functions: MaxImportFunctions,
        max_exports: MaxExports,
        enable_threads: bool,
        enable_networking: bool,
    ) -> Self {
        let mut allowed_functions = HashSet::new();

        allowed_functions.insert(FunctionName::Safe(
            SafeFunctionName::try_new("agent_get_id".to_string()).unwrap(),
        ));
        allowed_functions.insert(FunctionName::Safe(
            SafeFunctionName::try_new("agent_get_timestamp".to_string()).unwrap(),
        ));
        allowed_functions.insert(FunctionName::Safe(
            SafeFunctionName::try_new("agent_log".to_string()).unwrap(),
        ));
        allowed_functions.insert(FunctionName::Safe(
            SafeFunctionName::try_new("agent_message_send".to_string()).unwrap(),
        ));
        allowed_functions.insert(FunctionName::Safe(
            SafeFunctionName::try_new("agent_message_receive".to_string()).unwrap(),
        ));

        if enable_networking {
            allowed_functions.insert(FunctionName::Safe(
                SafeFunctionName::try_new("network_connect".to_string()).unwrap(),
            ));
            allowed_functions.insert(FunctionName::Safe(
                SafeFunctionName::try_new("network_send".to_string()).unwrap(),
            ));
            allowed_functions.insert(FunctionName::Safe(
                SafeFunctionName::try_new("network_receive".to_string()).unwrap(),
            ));
        }

        Self {
            max_import_functions,
            max_exports,
            allowed_functions,
            enable_threads,
            enable_networking,
        }
    }

    pub fn enable_fuel_metering(&self) -> bool {
        !self.enable_threads
    }

    pub fn is_function_allowed(&self, function: &FunctionName) -> bool {
        self.allowed_functions.contains(function)
    }
}

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SecurityLevel {
    Strict(StrictSecurityPolicy),
    Relaxed(RelaxedSecurityPolicy),
}

impl SecurityLevel {
    pub fn validate(&self) -> bool {
        match self {
            SecurityLevel::Strict(policy) => {
                policy.max_import_functions.into_inner() > 0 && policy.max_exports.into_inner() > 0
            }
            SecurityLevel::Relaxed(policy) => {
                (policy.enable_fuel_metering() || !policy.enable_threads)
                    && policy.max_exports.into_inner() > 0
                    && policy.max_import_functions.into_inner() > 0
            }
        }
    }

    pub fn is_function_allowed(&self, function: &FunctionName) -> bool {
        match self {
            SecurityLevel::Strict(policy) => policy.is_function_allowed(function),
            SecurityLevel::Relaxed(policy) => policy.is_function_allowed(function),
        }
    }
}

pub struct ValidatedSecurityPolicy {
    level: SecurityLevel,
}

impl ValidatedSecurityPolicy {
    pub fn new(level: SecurityLevel) -> Option<Self> {
        if level.validate() {
            Some(Self { level })
        } else {
            None
        }
    }

    pub fn level(&self) -> &SecurityLevel {
        &self.level
    }
}