Skip to main content

cc_lb_runtime_protocol/
host_functions.rs

1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3use std::time::{SystemTime, UNIX_EPOCH};
4
5use extism::{Function, PTR, UserData, ValType, host_fn};
6use ring::rand::{SecureRandom, SystemRandom};
7
8const RANDOM_BYTES_MAX: u64 = 65_536;
9
10#[non_exhaustive]
11#[derive(Debug, Default)]
12pub struct HostState {
13    storage: Mutex<HashMap<String, Vec<u8>>>,
14}
15
16impl HostState {
17    pub fn new() -> Self {
18        Self::default()
19    }
20
21    fn get(&self, plugin_name: &str, key: &str) -> Result<Vec<u8>, extism::Error> {
22        let scoped = scoped_key(plugin_name, key);
23        let storage = self
24            .storage
25            .lock()
26            .map_err(|_| extism::Error::msg("storage lock poisoned"))?;
27        Ok(storage.get(&scoped).cloned().unwrap_or_default())
28    }
29
30    fn put(
31        &self,
32        plugin_name: &str,
33        key: &str,
34        value: Vec<u8>,
35        quota_bytes: usize,
36    ) -> Result<(), extism::Error> {
37        let scoped = scoped_key(plugin_name, key);
38        let mut storage = self
39            .storage
40            .lock()
41            .map_err(|_| extism::Error::msg("storage lock poisoned"))?;
42        let existing_entry = storage
43            .get(&scoped)
44            .map_or(0, |stored_value| scoped.len() + stored_value.len());
45        let used = storage
46            .iter()
47            .filter(|(stored_key, _)| {
48                stored_key.starts_with(plugin_name)
49                    && stored_key.as_bytes().get(plugin_name.len()) == Some(&b':')
50            })
51            .map(|(stored_key, stored_value)| stored_key.len() + stored_value.len())
52            .sum::<usize>();
53        let next = used
54            .saturating_sub(existing_entry)
55            .saturating_add(scoped.len().saturating_add(value.len()));
56        if next > quota_bytes {
57            return Err(extism::Error::msg("scoped storage quota exceeded"));
58        }
59        storage.insert(scoped, value);
60        Ok(())
61    }
62}
63
64#[non_exhaustive]
65#[derive(Clone)]
66pub struct HostFunctionContext {
67    plugin_name: String,
68    storage_quota_bytes: usize,
69    state: Arc<HostState>,
70}
71
72impl HostFunctionContext {
73    pub fn new(plugin_name: String, storage_quota_bytes: usize, state: Arc<HostState>) -> Self {
74        Self {
75            plugin_name,
76            storage_quota_bytes,
77            state,
78        }
79    }
80}
81
82pub fn functions(context: HostFunctionContext) -> Vec<Function> {
83    let user_data = UserData::new(context);
84    vec![
85        Function::new("cc_lb_log", [PTR, PTR], [], user_data.clone(), cc_lb_log),
86        Function::new(
87            "cc_lb_metric_incr",
88            [PTR, PTR],
89            [],
90            user_data.clone(),
91            cc_lb_metric_incr,
92        ),
93        Function::new(
94            "cc_lb_storage_get",
95            [PTR],
96            [PTR],
97            user_data.clone(),
98            cc_lb_storage_get,
99        ),
100        Function::new(
101            "cc_lb_storage_put",
102            [PTR, PTR],
103            [],
104            user_data.clone(),
105            cc_lb_storage_put,
106        ),
107        Function::new(
108            "cc_lb_now_unix_ms",
109            [],
110            [ValType::I64],
111            user_data.clone(),
112            cc_lb_now_unix_ms,
113        ),
114        Function::new(
115            "cc_lb_random_bytes",
116            [PTR],
117            [PTR],
118            user_data,
119            cc_lb_random_bytes,
120        ),
121    ]
122}
123
124host_fn!(pub cc_lb_log(data: HostFunctionContext; level: String, msg: String) {
125    let context = context_from_user_data(&data)?;
126    match level.as_str() {
127        "trace" => tracing::trace!(plugin = %context.plugin_name, message = %msg),
128        "debug" => tracing::debug!(plugin = %context.plugin_name, message = %msg),
129        "warn" => tracing::warn!(plugin = %context.plugin_name, message = %msg),
130        "error" => tracing::error!(plugin = %context.plugin_name, message = %msg),
131        _ => tracing::info!(plugin = %context.plugin_name, message = %msg),
132    }
133    Ok(())
134});
135
136host_fn!(pub cc_lb_metric_incr(data: HostFunctionContext; name: String, value: u64) {
137    let context = context_from_user_data(&data)?;
138    metrics::counter!(name.clone(), "plugin" => context.plugin_name.clone()).increment(value);
139    Ok(())
140});
141
142host_fn!(pub cc_lb_storage_get(data: HostFunctionContext; key: String) -> Vec<u8> {
143    let context = context_from_user_data(&data)?;
144    context.state.get(&context.plugin_name, &key)
145});
146
147host_fn!(pub cc_lb_storage_put(data: HostFunctionContext; key: String, value: Vec<u8>) {
148    let context = context_from_user_data(&data)?;
149    context
150        .state
151        .put(&context.plugin_name, &key, value, context.storage_quota_bytes)?;
152    Ok(())
153});
154
155host_fn!(pub cc_lb_now_unix_ms(_data: HostFunctionContext; ) -> u64 {
156    let now = SystemTime::now()
157        .duration_since(UNIX_EPOCH)
158        .unwrap_or_default()
159        .as_millis()
160        .try_into()
161        .unwrap_or(u64::MAX);
162    Ok(now)
163});
164
165host_fn!(pub cc_lb_random_bytes(_data: HostFunctionContext; n: u64) -> Vec<u8> {
166    if n > RANDOM_BYTES_MAX {
167        return Err(extism::Error::msg("random byte request exceeds host cap"));
168    }
169    let mut bytes = vec![0_u8; n as usize];
170    SystemRandom::new()
171        .fill(&mut bytes)
172        .map_err(|_| extism::Error::msg("secure random generation failed"))?;
173    Ok(bytes)
174});
175
176fn context_from_user_data(
177    data: &UserData<HostFunctionContext>,
178) -> Result<HostFunctionContext, extism::Error> {
179    let data = data.get()?;
180    data.lock()
181        .map(|context| context.clone())
182        .map_err(|_| extism::Error::msg("host function context lock poisoned"))
183}
184
185fn scoped_key(plugin_name: &str, key: &str) -> String {
186    format!("{plugin_name}:{key}")
187}