cc-lb-runtime-wasmtime 0.1.1

Wasmtime-based plugin runtime for cc-lb. Host-side wasm plugin admission + dispatch.
use std::sync::Arc;

use cc_lb_plugin_wire::metadata::PluginMetadata;
use cc_lb_plugin_wire::schema::{HookKind, WireVersion};
use cc_lb_plugin_wire::v1::{
    ArchivedFilterResponse, ArchivedShapeResponse, FilterRequest, Header, ObserveEvent, Principal,
    ShapeRequest, Upstream,
};
use rkyv::rancor::Error as RkyvError;
use rkyv::util::AlignedVec;
use wasmtime::InstancePre;

use crate::budget::StoreBudget;
use crate::cache::{call_filter_hook, call_observe_hook, call_shape_hook};
use crate::cell::PluginCell;
use crate::engine::HostState;
use crate::error::WasmtimeRuntimeError;

pub(crate) fn probe_hook_dispatch(
    instance_pre: Arc<InstancePre<HostState>>,
    hook: HookKind,
    wire_version: WireVersion,
    metadata: &PluginMetadata,
    memory_max_pages: u32,
) -> Result<(), WasmtimeRuntimeError> {
    let cell = Arc::new(PluginCell {
        version_id: 0,
        instance_pre,
        metadata: metadata.clone(),
        memory_max_pages,
        store_budget: Arc::new(StoreBudget::new(1)),
        plugin_name: Arc::from(metadata.name.as_str()),
        content_hash: [0; 32],
    });
    match (hook, wire_version) {
        (HookKind::Filter, WireVersion::V1) => probe_filter_v1(&cell),
        (HookKind::Shape, WireVersion::V1) => probe_shape_v1(&cell),
        (HookKind::Observe, WireVersion::V1) => probe_observe_v1(&cell),
    }
}

fn probe_filter_v1(cell: &Arc<PluginCell>) -> Result<(), WasmtimeRuntimeError> {
    let input = rkyv::to_bytes::<RkyvError>(&sample_filter_request()).map_err(|error| {
        probe_failed(HookKind::Filter, format!("encode FilterRequest: {error}"))
    })?;
    let output = call_filter_hook(cell, input.as_slice())
        .map_err(|error| probe_failed(HookKind::Filter, error.to_string()))?;
    let mut aligned = AlignedVec::<16>::with_capacity(output.len());
    aligned.extend_from_slice(&output);
    rkyv::access::<ArchivedFilterResponse, RkyvError>(&aligned).map_err(|error| {
        probe_failed(HookKind::Filter, format!("decode FilterResponse: {error}"))
    })?;
    Ok(())
}

fn probe_shape_v1(cell: &Arc<PluginCell>) -> Result<(), WasmtimeRuntimeError> {
    let input = rkyv::to_bytes::<RkyvError>(&sample_shape_request())
        .map_err(|error| probe_failed(HookKind::Shape, format!("encode ShapeRequest: {error}")))?;
    let output = call_shape_hook(cell, input.as_slice())
        .map_err(|error| probe_failed(HookKind::Shape, error.to_string()))?;
    let mut aligned = AlignedVec::<16>::with_capacity(output.len());
    aligned.extend_from_slice(&output);
    rkyv::access::<ArchivedShapeResponse, RkyvError>(&aligned)
        .map_err(|error| probe_failed(HookKind::Shape, format!("decode ShapeResponse: {error}")))?;
    Ok(())
}

fn probe_observe_v1(cell: &Arc<PluginCell>) -> Result<(), WasmtimeRuntimeError> {
    let input = rkyv::to_bytes::<RkyvError>(&sample_observe_event()).map_err(|error| {
        probe_failed(HookKind::Observe, format!("encode ObserveEvent: {error}"))
    })?;
    call_observe_hook(cell, input.as_slice())
        .map_err(|error| probe_failed(HookKind::Observe, error.to_string()))?;
    Ok(())
}

fn sample_filter_request() -> FilterRequest {
    FilterRequest {
        request_id: Box::from("probe-req-1"),
        method: Box::from("POST"),
        path: Box::from("/v1/messages"),
        query: None,
        headers: Box::new([hdr("content-type", "application/json")]),
        body: Box::from(&br#"{"model":"claude-3-haiku-20240307","messages":[]}"#[..]),
        principal: synth_principal(),
        candidates: Box::new([]),
    }
}

fn sample_shape_request() -> ShapeRequest {
    ShapeRequest {
        request_id: Box::from("probe-req-1"),
        method: Box::from("POST"),
        path: Box::from("/v1/messages"),
        query: None,
        headers: Box::new([hdr("content-type", "application/json")]),
        body: Box::from(&br#"{"model":"claude-3-haiku-20240307","messages":[]}"#[..]),
        principal: synth_principal(),
        upstream: Upstream::AnthropicDirect {
            base_url: Some(Box::from("https://example.test")),
        },
    }
}

fn sample_observe_event() -> ObserveEvent {
    ObserveEvent::RequestStarted {
        request_id: Box::from("probe-req-1"),
        downstream_user_agent: Some(Box::from("cc-lb-probe/1.0")),
    }
}

fn synth_principal() -> Principal {
    Principal {
        id: Box::from("probe-principal"),
        kind: Box::from("api_key"),
        claims: Box::new([]),
    }
}

fn hdr(name: impl Into<String>, value: impl AsRef<[u8]>) -> Header {
    Header {
        name: name.into().into_boxed_str(),
        value: value.as_ref().to_vec().into_boxed_slice(),
    }
}

fn probe_failed(hook: HookKind, reason: String) -> WasmtimeRuntimeError {
    WasmtimeRuntimeError::ProbeFailed {
        hook: hook.as_str(),
        reason,
    }
}