patina-ai 0.23.0

Context orchestration for AI development - captures and evolves patterns over time
Documentation
//! Shared host trait logic — single implementation for all worlds.
//!
//! Each WASM world generates separate Rust types via bindgen, but the
//! logic behind host function implementations is identical. This module
//! centralizes that logic so security-sensitive changes happen in one
//! place, not 3-4x across worlds.
//!
//! F1 fix: eliminates ~700 lines of duplicated host trait logic.
//! F2 fix: path traversal protection in count_layer_files.

use std::path::PathBuf;

use super::command::QueryDispatchFn;
use super::{GrantedCapabilities, QueryScope};

// =========================================================================
// Log host support
// =========================================================================

/// Shared log implementation — formats and emits plugin log messages.
pub(super) fn log(plugin_name: &str, level_str: &str, message: &str) {
    eprintln!("[plugin:{}] {}: {}", plugin_name, level_str, message);
}

// =========================================================================
// Layer host support
// =========================================================================

pub(super) fn find_project_root(project_root: &Option<PathBuf>) -> Option<String> {
    project_root
        .as_ref()
        .map(|p| p.to_string_lossy().to_string())
}

pub(super) fn read_config(project_root: &Option<PathBuf>) -> Result<String, String> {
    let root = project_root
        .as_ref()
        .ok_or_else(|| "no project root".to_string())?;
    let config =
        crate::project::load_with_migration(root).map_err(|e| format!("load config: {}", e))?;
    serde_json::to_string(&config).map_err(|e| format!("serialize config: {}", e))
}

pub(super) fn detect_environment() -> Result<String, String> {
    let env =
        crate::environment::Environment::detect().map_err(|e| format!("detect env: {}", e))?;
    serde_json::to_string(&env).map_err(|e| format!("serialize env: {}", e))
}

pub(super) fn get_stored_tools(project_root: &Option<PathBuf>) -> Vec<String> {
    let root = match project_root.as_ref() {
        Some(r) => r,
        None => return vec![],
    };
    let config = match crate::project::load_with_migration(root) {
        Ok(c) => c,
        Err(_) => return vec![],
    };
    config
        .environment
        .map(|e| e.detected_tools)
        .unwrap_or_default()
}

/// Count `.md` files in a layer subdirectory.
///
/// F2 fix: rejects path traversal attempts (`../`, absolute paths).
/// Returns 0 on invalid input — no information leak, no error message
/// that confirms the traversal was attempted.
pub(super) fn count_layer_files(project_root: &Option<PathBuf>, subdir: &str) -> u32 {
    let root = match project_root.as_ref() {
        Some(r) => r,
        None => return 0,
    };
    // F2 FIX: sanitize subdir — reject path traversal
    let sub = std::path::Path::new(subdir);
    if sub.components().any(|c| {
        matches!(
            c,
            std::path::Component::ParentDir | std::path::Component::RootDir
        )
    }) {
        return 0; // silent reject — no information leak
    }
    let path = root.join("layer").join(sub);
    if let Ok(entries) = std::fs::read_dir(path) {
        entries
            .filter_map(Result::ok)
            .filter(|e| e.path().extension().is_some_and(|ext| ext == "md"))
            .count() as u32
    } else {
        0
    }
}

pub(super) fn get_project_uid(project_root: &Option<PathBuf>) -> Option<String> {
    let root = project_root.as_ref()?;
    crate::project::get_uid(root)
}

pub(super) fn check_adapter_version(
    project_root: &Option<PathBuf>,
    adapter_name: &str,
) -> Result<Option<String>, String> {
    let root = project_root
        .as_ref()
        .ok_or_else(|| "no project root".to_string())?;
    let adapter = crate::adapters::get_adapter(adapter_name);
    adapter
        .check_for_updates(root)
        .map(|opt| opt.map(|(current, _)| current))
        .map_err(|e| format!("adapter check: {}", e))
}

// =========================================================================
// Query host support
// =========================================================================

/// Keys in query params that are host-controlled and must not be
/// set by plugins. The lib strips these before dispatch when the
/// plugin's scope doesn't grant them.
const SCOPE_RESERVED_KEYS: &[&str] = &["all_repos", "repo", "project_root", "db_path"];

/// Sanitize query params by stripping scope-reserved keys.
///
/// Called before dispatching to the binary callback.
/// Testable independently of wasmtime infrastructure.
pub(super) fn sanitize_query_params(params: &str, scope: &QueryScope) -> String {
    let mut args: serde_json::Value = match serde_json::from_str(params) {
        Ok(v) => v,
        Err(_) => return params.to_string(),
    };

    if matches!(scope, QueryScope::AllRepos) {
        // AllRepos scope: params pass through unmodified
        return params.to_string();
    }

    // CurrentProject: strip all scope-reserved keys
    if let Some(obj) = args.as_object_mut() {
        for key in SCOPE_RESERVED_KEYS {
            obj.remove(*key);
        }
    }

    serde_json::to_string(&args).unwrap_or_else(|_| params.to_string())
}

/// Capability-gated query dispatch.
///
/// Defense in depth: kinds are validated at load time (check_capabilities)
/// AND at call time (grants.query_kinds check below). Query scope is
/// enforced at call time — all_repos requires AllRepos scope.
pub(super) fn query(
    plugin_name: &str,
    grants: &GrantedCapabilities,
    query_fn: &mut Option<QueryDispatchFn>,
    kind: &str,
    params: &str,
) -> Result<String, String> {
    // Call-time gating: kind must be in granted set
    if !grants.query_kinds.contains(kind) {
        return Err(format!(
            "query kind '{}' not granted for plugin '{}'",
            kind, plugin_name
        ));
    }

    // Scope enforcement: deny all_repos explicitly, then sanitize.
    if let Ok(args) = serde_json::from_str::<serde_json::Value>(params) {
        let all_repos = args
            .get("all_repos")
            .and_then(|v| v.as_bool())
            .unwrap_or(false);
        if all_repos && !matches!(grants.query_scope, QueryScope::AllRepos) {
            return Err("all_repos not allowed: plugin query_scope is current_project".to_string());
        }
        if all_repos {
            eprintln!("[plugin:{}] query: all_repos=true (audit)", plugin_name);
        }
    }

    // Sanitize: strip scope-reserved keys so callback can't bypass policy
    let sanitized_params = sanitize_query_params(params, &grants.query_scope);

    // Delegate to binary-provided dispatch function
    let query_fn = query_fn
        .as_mut()
        .ok_or_else(|| "query dispatch not configured".to_string())?;
    query_fn(kind, &sanitized_params)
}

// =========================================================================
// HTTP host support
// =========================================================================

/// Build an HTTP client with cross-domain redirect rejection.
///
/// Shared by mother-child (instantiate_child) and task (run_task) engines.
/// If a response redirects to a different host, the request is stopped
/// (prevents allowlist bypass via open redirectors).
pub(super) fn build_http_client() -> anyhow::Result<reqwest::blocking::Client> {
    reqwest::blocking::Client::builder()
        .redirect(reqwest::redirect::Policy::custom(|attempt| {
            if attempt.url().host_str() != attempt.previous().last().and_then(|u| u.host_str()) {
                attempt.stop()
            } else {
                attempt.follow()
            }
        }))
        .build()
        .map_err(|e| anyhow::anyhow!("build HTTP client: {}", e))
}

/// Validate and parse an HTTP URL for domain-allowlisted access.
///
/// Returns the extracted domain on success. Enforces:
/// - HTTPS only (no plaintext HTTP)
/// - No IP addresses (IPv4 or IPv6)
/// - No localhost
///
/// Pure function — testable independently of wasmtime.
pub(super) fn validate_http_url(url: &str) -> Result<String, String> {
    let parsed = reqwest::Url::parse(url).map_err(|e| format!("invalid URL: {}", e))?;

    // HTTPS only
    if parsed.scheme() != "https" {
        return Err(format!("only HTTPS allowed, got '{}'", parsed.scheme()));
    }

    let host = parsed
        .host_str()
        .ok_or_else(|| "no host in URL".to_string())?;

    // No localhost
    if host == "localhost" {
        return Err("localhost not allowed".to_string());
    }

    // No IP addresses (IPv4 or IPv6)
    // host_str() returns brackets for IPv6 (e.g., "[::1]") — strip them
    let bare_host = host
        .strip_prefix('[')
        .and_then(|h| h.strip_suffix(']'))
        .unwrap_or(host);
    if bare_host.parse::<std::net::IpAddr>().is_ok() {
        return Err("IP addresses not allowed".to_string());
    }

    Ok(bare_host.to_string())
}

/// Result of an HTTP operation — plain types for cross-world portability.
pub(super) struct HttpResult {
    pub status: u16,
    pub body: String,
}

/// Domain-allowlisted HTTP POST.
///
/// Defense in depth: domains are validated at load time (check_capabilities)
/// AND at call time (grants.http_domains check). URLs are sanitized by
/// validate_http_url. Cross-domain redirects rejected by client policy.
pub(super) fn http_post(
    http_client: &reqwest::blocking::Client,
    grants: &GrantedCapabilities,
    plugin_name: &str,
    url: &str,
    body: &str,
    content_type: &str,
) -> Result<HttpResult, String> {
    let domain = validate_http_url(url)?;
    if !grants.http_domains.contains(&domain) {
        return Err(format!(
            "domain '{}' not in allowlist for plugin '{}'",
            domain, plugin_name
        ));
    }
    let response = http_client
        .post(url)
        .header("Content-Type", content_type)
        .body(body.to_string())
        .send()
        .map_err(|e| format!("HTTP POST failed: {}", e))?;
    let status = response.status().as_u16();
    let resp_body = response.text().map_err(|e| format!("read body: {}", e))?;
    Ok(HttpResult {
        status,
        body: resp_body,
    })
}

/// Domain-allowlisted HTTP GET.
pub(super) fn http_get(
    http_client: &reqwest::blocking::Client,
    grants: &GrantedCapabilities,
    plugin_name: &str,
    url: &str,
) -> Result<HttpResult, String> {
    let domain = validate_http_url(url)?;
    if !grants.http_domains.contains(&domain) {
        return Err(format!(
            "domain '{}' not in allowlist for plugin '{}'",
            domain, plugin_name
        ));
    }
    let response = http_client
        .get(url)
        .send()
        .map_err(|e| format!("HTTP GET failed: {}", e))?;
    let status = response.status().as_u16();
    let resp_body = response.text().map_err(|e| format!("read body: {}", e))?;
    Ok(HttpResult {
        status,
        body: resp_body,
    })
}