cc-lb-runtime-protocol 0.1.0

cc-lb plugin protocol runtime — handshake, self-check, dispatch, identity, host functions for Extism plugins targeting the cc-lb host.
Documentation
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{SystemTime, UNIX_EPOCH};

use extism::{Function, PTR, UserData, ValType, host_fn};
use ring::rand::{SecureRandom, SystemRandom};

const RANDOM_BYTES_MAX: u64 = 65_536;

#[non_exhaustive]
#[derive(Debug, Default)]
pub struct HostState {
    storage: Mutex<HashMap<String, Vec<u8>>>,
}

impl HostState {
    pub fn new() -> Self {
        Self::default()
    }

    fn get(&self, plugin_name: &str, key: &str) -> Result<Vec<u8>, extism::Error> {
        let scoped = scoped_key(plugin_name, key);
        let storage = self
            .storage
            .lock()
            .map_err(|_| extism::Error::msg("storage lock poisoned"))?;
        Ok(storage.get(&scoped).cloned().unwrap_or_default())
    }

    fn put(
        &self,
        plugin_name: &str,
        key: &str,
        value: Vec<u8>,
        quota_bytes: usize,
    ) -> Result<(), extism::Error> {
        let scoped = scoped_key(plugin_name, key);
        let mut storage = self
            .storage
            .lock()
            .map_err(|_| extism::Error::msg("storage lock poisoned"))?;
        let existing_entry = storage
            .get(&scoped)
            .map_or(0, |stored_value| scoped.len() + stored_value.len());
        let used = storage
            .iter()
            .filter(|(stored_key, _)| {
                stored_key.starts_with(plugin_name)
                    && stored_key.as_bytes().get(plugin_name.len()) == Some(&b':')
            })
            .map(|(stored_key, stored_value)| stored_key.len() + stored_value.len())
            .sum::<usize>();
        let next = used
            .saturating_sub(existing_entry)
            .saturating_add(scoped.len().saturating_add(value.len()));
        if next > quota_bytes {
            return Err(extism::Error::msg("scoped storage quota exceeded"));
        }
        storage.insert(scoped, value);
        Ok(())
    }
}

#[non_exhaustive]
#[derive(Clone)]
pub struct HostFunctionContext {
    plugin_name: String,
    storage_quota_bytes: usize,
    state: Arc<HostState>,
}

impl HostFunctionContext {
    pub fn new(plugin_name: String, storage_quota_bytes: usize, state: Arc<HostState>) -> Self {
        Self {
            plugin_name,
            storage_quota_bytes,
            state,
        }
    }
}

pub fn functions(context: HostFunctionContext) -> Vec<Function> {
    let user_data = UserData::new(context);
    vec![
        Function::new("cc_lb_log", [PTR, PTR], [], user_data.clone(), cc_lb_log),
        Function::new(
            "cc_lb_metric_incr",
            [PTR, PTR],
            [],
            user_data.clone(),
            cc_lb_metric_incr,
        ),
        Function::new(
            "cc_lb_storage_get",
            [PTR],
            [PTR],
            user_data.clone(),
            cc_lb_storage_get,
        ),
        Function::new(
            "cc_lb_storage_put",
            [PTR, PTR],
            [],
            user_data.clone(),
            cc_lb_storage_put,
        ),
        Function::new(
            "cc_lb_now_unix_ms",
            [],
            [ValType::I64],
            user_data.clone(),
            cc_lb_now_unix_ms,
        ),
        Function::new(
            "cc_lb_random_bytes",
            [PTR],
            [PTR],
            user_data,
            cc_lb_random_bytes,
        ),
    ]
}

host_fn!(pub cc_lb_log(data: HostFunctionContext; level: String, msg: String) {
    let context = context_from_user_data(&data)?;
    match level.as_str() {
        "trace" => tracing::trace!(plugin = %context.plugin_name, message = %msg),
        "debug" => tracing::debug!(plugin = %context.plugin_name, message = %msg),
        "warn" => tracing::warn!(plugin = %context.plugin_name, message = %msg),
        "error" => tracing::error!(plugin = %context.plugin_name, message = %msg),
        _ => tracing::info!(plugin = %context.plugin_name, message = %msg),
    }
    Ok(())
});

host_fn!(pub cc_lb_metric_incr(data: HostFunctionContext; name: String, value: u64) {
    let context = context_from_user_data(&data)?;
    metrics::counter!(name.clone(), "plugin" => context.plugin_name.clone()).increment(value);
    Ok(())
});

host_fn!(pub cc_lb_storage_get(data: HostFunctionContext; key: String) -> Vec<u8> {
    let context = context_from_user_data(&data)?;
    context.state.get(&context.plugin_name, &key)
});

host_fn!(pub cc_lb_storage_put(data: HostFunctionContext; key: String, value: Vec<u8>) {
    let context = context_from_user_data(&data)?;
    context
        .state
        .put(&context.plugin_name, &key, value, context.storage_quota_bytes)?;
    Ok(())
});

host_fn!(pub cc_lb_now_unix_ms(_data: HostFunctionContext; ) -> u64 {
    let now = SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .unwrap_or_default()
        .as_millis()
        .try_into()
        .unwrap_or(u64::MAX);
    Ok(now)
});

host_fn!(pub cc_lb_random_bytes(_data: HostFunctionContext; n: u64) -> Vec<u8> {
    if n > RANDOM_BYTES_MAX {
        return Err(extism::Error::msg("random byte request exceeds host cap"));
    }
    let mut bytes = vec![0_u8; n as usize];
    SystemRandom::new()
        .fill(&mut bytes)
        .map_err(|_| extism::Error::msg("secure random generation failed"))?;
    Ok(bytes)
});

fn context_from_user_data(
    data: &UserData<HostFunctionContext>,
) -> Result<HostFunctionContext, extism::Error> {
    let data = data.get()?;
    data.lock()
        .map(|context| context.clone())
        .map_err(|_| extism::Error::msg("host function context lock poisoned"))
}

fn scoped_key(plugin_name: &str, key: &str) -> String {
    format!("{plugin_name}:{key}")
}