1use std::collections::HashMap;
2use std::sync::{Arc, Mutex};
3use std::time::{SystemTime, UNIX_EPOCH};
4
5use extism::{Function, PTR, UserData, ValType, host_fn};
6use ring::rand::{SecureRandom, SystemRandom};
7
8const RANDOM_BYTES_MAX: u64 = 65_536;
9
10#[non_exhaustive]
11#[derive(Debug, Default)]
12pub struct HostState {
13 storage: Mutex<HashMap<String, Vec<u8>>>,
14}
15
16impl HostState {
17 pub fn new() -> Self {
18 Self::default()
19 }
20
21 fn get(&self, plugin_name: &str, key: &str) -> Result<Vec<u8>, extism::Error> {
22 let scoped = scoped_key(plugin_name, key);
23 let storage = self
24 .storage
25 .lock()
26 .map_err(|_| extism::Error::msg("storage lock poisoned"))?;
27 Ok(storage.get(&scoped).cloned().unwrap_or_default())
28 }
29
30 fn put(
31 &self,
32 plugin_name: &str,
33 key: &str,
34 value: Vec<u8>,
35 quota_bytes: usize,
36 ) -> Result<(), extism::Error> {
37 let scoped = scoped_key(plugin_name, key);
38 let mut storage = self
39 .storage
40 .lock()
41 .map_err(|_| extism::Error::msg("storage lock poisoned"))?;
42 let existing_entry = storage
43 .get(&scoped)
44 .map_or(0, |stored_value| scoped.len() + stored_value.len());
45 let used = storage
46 .iter()
47 .filter(|(stored_key, _)| {
48 stored_key.starts_with(plugin_name)
49 && stored_key.as_bytes().get(plugin_name.len()) == Some(&b':')
50 })
51 .map(|(stored_key, stored_value)| stored_key.len() + stored_value.len())
52 .sum::<usize>();
53 let next = used
54 .saturating_sub(existing_entry)
55 .saturating_add(scoped.len().saturating_add(value.len()));
56 if next > quota_bytes {
57 return Err(extism::Error::msg("scoped storage quota exceeded"));
58 }
59 storage.insert(scoped, value);
60 Ok(())
61 }
62}
63
64#[non_exhaustive]
65#[derive(Clone)]
66pub struct HostFunctionContext {
67 plugin_name: String,
68 storage_quota_bytes: usize,
69 state: Arc<HostState>,
70}
71
72impl HostFunctionContext {
73 pub fn new(plugin_name: String, storage_quota_bytes: usize, state: Arc<HostState>) -> Self {
74 Self {
75 plugin_name,
76 storage_quota_bytes,
77 state,
78 }
79 }
80}
81
82pub fn functions(context: HostFunctionContext) -> Vec<Function> {
83 let user_data = UserData::new(context);
84 vec![
85 Function::new("cc_lb_log", [PTR, PTR], [], user_data.clone(), cc_lb_log),
86 Function::new(
87 "cc_lb_metric_incr",
88 [PTR, PTR],
89 [],
90 user_data.clone(),
91 cc_lb_metric_incr,
92 ),
93 Function::new(
94 "cc_lb_storage_get",
95 [PTR],
96 [PTR],
97 user_data.clone(),
98 cc_lb_storage_get,
99 ),
100 Function::new(
101 "cc_lb_storage_put",
102 [PTR, PTR],
103 [],
104 user_data.clone(),
105 cc_lb_storage_put,
106 ),
107 Function::new(
108 "cc_lb_now_unix_ms",
109 [],
110 [ValType::I64],
111 user_data.clone(),
112 cc_lb_now_unix_ms,
113 ),
114 Function::new(
115 "cc_lb_random_bytes",
116 [PTR],
117 [PTR],
118 user_data,
119 cc_lb_random_bytes,
120 ),
121 ]
122}
123
124host_fn!(pub cc_lb_log(data: HostFunctionContext; level: String, msg: String) {
125 let context = context_from_user_data(&data)?;
126 match level.as_str() {
127 "trace" => tracing::trace!(plugin = %context.plugin_name, message = %msg),
128 "debug" => tracing::debug!(plugin = %context.plugin_name, message = %msg),
129 "warn" => tracing::warn!(plugin = %context.plugin_name, message = %msg),
130 "error" => tracing::error!(plugin = %context.plugin_name, message = %msg),
131 _ => tracing::info!(plugin = %context.plugin_name, message = %msg),
132 }
133 Ok(())
134});
135
136host_fn!(pub cc_lb_metric_incr(data: HostFunctionContext; name: String, value: u64) {
137 let context = context_from_user_data(&data)?;
138 metrics::counter!(name.clone(), "plugin" => context.plugin_name.clone()).increment(value);
139 Ok(())
140});
141
142host_fn!(pub cc_lb_storage_get(data: HostFunctionContext; key: String) -> Vec<u8> {
143 let context = context_from_user_data(&data)?;
144 context.state.get(&context.plugin_name, &key)
145});
146
147host_fn!(pub cc_lb_storage_put(data: HostFunctionContext; key: String, value: Vec<u8>) {
148 let context = context_from_user_data(&data)?;
149 context
150 .state
151 .put(&context.plugin_name, &key, value, context.storage_quota_bytes)?;
152 Ok(())
153});
154
155host_fn!(pub cc_lb_now_unix_ms(_data: HostFunctionContext; ) -> u64 {
156 let now = SystemTime::now()
157 .duration_since(UNIX_EPOCH)
158 .unwrap_or_default()
159 .as_millis()
160 .try_into()
161 .unwrap_or(u64::MAX);
162 Ok(now)
163});
164
165host_fn!(pub cc_lb_random_bytes(_data: HostFunctionContext; n: u64) -> Vec<u8> {
166 if n > RANDOM_BYTES_MAX {
167 return Err(extism::Error::msg("random byte request exceeds host cap"));
168 }
169 let mut bytes = vec![0_u8; n as usize];
170 SystemRandom::new()
171 .fill(&mut bytes)
172 .map_err(|_| extism::Error::msg("secure random generation failed"))?;
173 Ok(bytes)
174});
175
176fn context_from_user_data(
177 data: &UserData<HostFunctionContext>,
178) -> Result<HostFunctionContext, extism::Error> {
179 let data = data.get()?;
180 data.lock()
181 .map(|context| context.clone())
182 .map_err(|_| extism::Error::msg("host function context lock poisoned"))
183}
184
185fn scoped_key(plugin_name: &str, key: &str) -> String {
186 format!("{plugin_name}:{key}")
187}