Skip to main content

heliosdb_proxy/plugins/
host_functions.rs

1//! Host Functions
2//!
3//! Host functions that plugins can call to interact with the proxy.
4
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use parking_lot::RwLock;
10
11/// Host function registry
12pub struct HostFunctionRegistry {
13    /// Registered functions by namespace
14    functions: RwLock<HashMap<String, HashMap<String, HostFunction>>>,
15
16    /// Call statistics
17    stats: RwLock<HostFunctionStats>,
18}
19
20impl HostFunctionRegistry {
21    /// Create a new registry with default functions
22    pub fn new() -> Self {
23        let registry = Self {
24            functions: RwLock::new(HashMap::new()),
25            stats: RwLock::new(HostFunctionStats::default()),
26        };
27
28        // Register default functions
29        registry.register_defaults();
30        registry
31    }
32
33    /// Register default host functions
34    fn register_defaults(&self) {
35        // Helios namespace - core functions
36        self.register("helios", "log", HostFunction::Log);
37        self.register("helios", "metric_inc", HostFunction::MetricInc);
38        self.register("helios", "metric_gauge", HostFunction::MetricGauge);
39        self.register("helios", "get_config", HostFunction::GetConfig);
40        self.register("helios", "get_time", HostFunction::GetTime);
41
42        // Query namespace
43        self.register("query", "execute", HostFunction::QueryExecute);
44        self.register("query", "prepare", HostFunction::QueryPrepare);
45        self.register("query", "get_tables", HostFunction::QueryGetTables);
46        self.register("query", "normalize", HostFunction::QueryNormalize);
47
48        // Cache namespace
49        self.register("cache", "get", HostFunction::CacheGet);
50        self.register("cache", "set", HostFunction::CacheSet);
51        self.register("cache", "delete", HostFunction::CacheDelete);
52        self.register("cache", "exists", HostFunction::CacheExists);
53
54        // HTTP namespace (requires http_fetch permission)
55        self.register("http", "fetch", HostFunction::HttpFetch);
56        self.register("http", "post", HostFunction::HttpPost);
57
58        // Crypto namespace
59        self.register("crypto", "hash", HostFunction::CryptoHash);
60        self.register("crypto", "hmac", HostFunction::CryptoHmac);
61        self.register("crypto", "random", HostFunction::CryptoRandom);
62
63        // KV namespace (plugin-local storage)
64        self.register("kv", "get", HostFunction::KvGet);
65        self.register("kv", "set", HostFunction::KvSet);
66        self.register("kv", "delete", HostFunction::KvDelete);
67        self.register("kv", "list", HostFunction::KvList);
68    }
69
70    /// Register a host function
71    pub fn register(&self, namespace: &str, name: &str, function: HostFunction) {
72        let mut functions = self.functions.write();
73        functions
74            .entry(namespace.to_string())
75            .or_default()
76            .insert(name.to_string(), function);
77    }
78
79    /// Get a host function
80    pub fn get(&self, namespace: &str, name: &str) -> Option<HostFunction> {
81        let functions = self.functions.read();
82        functions
83            .get(namespace)
84            .and_then(|ns| ns.get(name))
85            .cloned()
86    }
87
88    /// Check if a function exists
89    pub fn exists(&self, namespace: &str, name: &str) -> bool {
90        let functions = self.functions.read();
91        functions
92            .get(namespace)
93            .map(|ns| ns.contains_key(name))
94            .unwrap_or(false)
95    }
96
97    /// List all functions in a namespace
98    pub fn list_namespace(&self, namespace: &str) -> Vec<String> {
99        let functions = self.functions.read();
100        functions
101            .get(namespace)
102            .map(|ns| ns.keys().cloned().collect())
103            .unwrap_or_default()
104    }
105
106    /// List all namespaces
107    pub fn list_namespaces(&self) -> Vec<String> {
108        let functions = self.functions.read();
109        functions.keys().cloned().collect()
110    }
111
112    /// Record a function call
113    pub fn record_call(&self, namespace: &str, name: &str, duration: Duration, success: bool) {
114        let mut stats = self.stats.write();
115        let key = format!("{}:{}", namespace, name);
116
117        let entry = stats.calls.entry(key).or_default();
118        entry.total_calls += 1;
119        entry.total_duration += duration;
120
121        if success {
122            entry.successful_calls += 1;
123        } else {
124            entry.failed_calls += 1;
125        }
126
127        if duration > entry.max_duration {
128            entry.max_duration = duration;
129        }
130    }
131
132    /// Get call statistics
133    pub fn get_stats(&self) -> HostFunctionStats {
134        self.stats.read().clone()
135    }
136}
137
138impl Default for HostFunctionRegistry {
139    fn default() -> Self {
140        Self::new()
141    }
142}
143
144/// Host function types
145#[derive(Debug, Clone, PartialEq, Eq)]
146pub enum HostFunction {
147    // Helios core
148    Log,
149    MetricInc,
150    MetricGauge,
151    GetConfig,
152    GetTime,
153
154    // Query operations
155    QueryExecute,
156    QueryPrepare,
157    QueryGetTables,
158    QueryNormalize,
159
160    // Cache operations
161    CacheGet,
162    CacheSet,
163    CacheDelete,
164    CacheExists,
165
166    // HTTP operations
167    HttpFetch,
168    HttpPost,
169
170    // Crypto operations
171    CryptoHash,
172    CryptoHmac,
173    CryptoRandom,
174
175    // KV storage
176    KvGet,
177    KvSet,
178    KvDelete,
179    KvList,
180
181    // Custom function
182    Custom(String),
183}
184
185impl HostFunction {
186    /// Get the required permission for this function
187    pub fn required_permission(&self) -> Option<super::sandbox::Permission> {
188        use super::sandbox::Permission;
189
190        match self {
191            // No permission required
192            HostFunction::Log => None,
193            HostFunction::GetTime => None,
194            HostFunction::GetConfig => None,
195
196            // Metrics permission
197            HostFunction::MetricInc | HostFunction::MetricGauge => Some(Permission::Metrics),
198
199            // Query permission
200            HostFunction::QueryExecute
201            | HostFunction::QueryPrepare
202            | HostFunction::QueryGetTables
203            | HostFunction::QueryNormalize => Some(Permission::QueryExecute),
204
205            // Cache permissions
206            HostFunction::CacheGet | HostFunction::CacheExists => Some(Permission::CacheRead),
207            HostFunction::CacheSet | HostFunction::CacheDelete => Some(Permission::CacheWrite),
208
209            // HTTP permission
210            HostFunction::HttpFetch | HostFunction::HttpPost => Some(Permission::HttpFetch),
211
212            // Crypto permission
213            HostFunction::CryptoHash | HostFunction::CryptoHmac | HostFunction::CryptoRandom => {
214                Some(Permission::Crypto)
215            }
216
217            // KV permissions
218            HostFunction::KvGet | HostFunction::KvList => Some(Permission::KvRead),
219            HostFunction::KvSet | HostFunction::KvDelete => Some(Permission::KvWrite),
220
221            // Custom functions require custom permission
222            HostFunction::Custom(_) => Some(Permission::Custom("custom".to_string())),
223        }
224    }
225
226    /// Get the function signature (for documentation)
227    pub fn signature(&self) -> &'static str {
228        match self {
229            HostFunction::Log => "log(level: i32, message_ptr: i32, message_len: i32)",
230            HostFunction::MetricInc => "metric_inc(name_ptr: i32, name_len: i32, value: f64)",
231            HostFunction::MetricGauge => "metric_gauge(name_ptr: i32, name_len: i32, value: f64)",
232            HostFunction::GetConfig => "get_config(key_ptr: i32, key_len: i32) -> i32",
233            HostFunction::GetTime => "get_time() -> i64",
234
235            HostFunction::QueryExecute => "query_execute(query_ptr: i32, query_len: i32) -> i32",
236            HostFunction::QueryPrepare => "query_prepare(query_ptr: i32, query_len: i32) -> i32",
237            HostFunction::QueryGetTables => {
238                "query_get_tables(query_ptr: i32, query_len: i32) -> i32"
239            }
240            HostFunction::QueryNormalize => {
241                "query_normalize(query_ptr: i32, query_len: i32) -> i32"
242            }
243
244            HostFunction::CacheGet => "cache_get(key_ptr: i32, key_len: i32) -> i32",
245            HostFunction::CacheSet => {
246                "cache_set(key_ptr: i32, key_len: i32, value_ptr: i32, value_len: i32, ttl: i64)"
247            }
248            HostFunction::CacheDelete => "cache_delete(key_ptr: i32, key_len: i32)",
249            HostFunction::CacheExists => "cache_exists(key_ptr: i32, key_len: i32) -> i32",
250
251            HostFunction::HttpFetch => "http_fetch(url_ptr: i32, url_len: i32) -> i32",
252            HostFunction::HttpPost => {
253                "http_post(url_ptr: i32, url_len: i32, body_ptr: i32, body_len: i32) -> i32"
254            }
255
256            HostFunction::CryptoHash => {
257                "crypto_hash(algo_ptr: i32, algo_len: i32, data_ptr: i32, data_len: i32) -> i32"
258            }
259            HostFunction::CryptoHmac => {
260                "crypto_hmac(key_ptr: i32, key_len: i32, data_ptr: i32, data_len: i32) -> i32"
261            }
262            HostFunction::CryptoRandom => "crypto_random(len: i32) -> i32",
263
264            HostFunction::KvGet => "kv_get(key_ptr: i32, key_len: i32) -> i32",
265            HostFunction::KvSet => {
266                "kv_set(key_ptr: i32, key_len: i32, value_ptr: i32, value_len: i32)"
267            }
268            HostFunction::KvDelete => "kv_delete(key_ptr: i32, key_len: i32)",
269            HostFunction::KvList => "kv_list(prefix_ptr: i32, prefix_len: i32) -> i32",
270
271            HostFunction::Custom(_) => "custom(...)",
272        }
273    }
274}
275
276/// Host function call statistics
277#[derive(Debug, Clone, Default)]
278pub struct HostFunctionStats {
279    /// Per-function statistics
280    pub calls: HashMap<String, FunctionCallStats>,
281}
282
283/// Per-function call statistics
284#[derive(Debug, Clone, Default)]
285pub struct FunctionCallStats {
286    /// Total calls
287    pub total_calls: u64,
288
289    /// Successful calls
290    pub successful_calls: u64,
291
292    /// Failed calls
293    pub failed_calls: u64,
294
295    /// Total duration
296    pub total_duration: Duration,
297
298    /// Maximum duration
299    pub max_duration: Duration,
300}
301
302impl FunctionCallStats {
303    /// Get average duration
304    pub fn avg_duration(&self) -> Duration {
305        if self.total_calls == 0 {
306            Duration::ZERO
307        } else {
308            self.total_duration / self.total_calls as u32
309        }
310    }
311
312    /// Get success rate
313    pub fn success_rate(&self) -> f64 {
314        if self.total_calls == 0 {
315            1.0
316        } else {
317            self.successful_calls as f64 / self.total_calls as f64
318        }
319    }
320}
321
322/// Host function context (passed to function implementations)
323pub struct HostFunctionContext {
324    /// Plugin name
325    pub plugin_name: String,
326
327    /// Request ID
328    pub request_id: String,
329
330    /// Plugin memory (for reading/writing)
331    pub memory: Arc<RwLock<Vec<u8>>>,
332
333    /// Plugin configuration
334    pub config: HashMap<String, serde_json::Value>,
335
336    /// Call start time
337    pub start_time: Instant,
338}
339
340impl HostFunctionContext {
341    /// Read a string from plugin memory
342    pub fn read_string(&self, ptr: i32, len: i32) -> Result<String, HostFunctionError> {
343        let memory = self.memory.read();
344        let start = ptr as usize;
345        let end = start + len as usize;
346
347        if end > memory.len() {
348            return Err(HostFunctionError::MemoryAccessError(
349                "Read out of bounds".to_string(),
350            ));
351        }
352
353        String::from_utf8(memory[start..end].to_vec())
354            .map_err(|e| HostFunctionError::InvalidData(e.to_string()))
355    }
356
357    /// Read bytes from plugin memory
358    pub fn read_bytes(&self, ptr: i32, len: i32) -> Result<Vec<u8>, HostFunctionError> {
359        let memory = self.memory.read();
360        let start = ptr as usize;
361        let end = start + len as usize;
362
363        if end > memory.len() {
364            return Err(HostFunctionError::MemoryAccessError(
365                "Read out of bounds".to_string(),
366            ));
367        }
368
369        Ok(memory[start..end].to_vec())
370    }
371
372    /// Write bytes to plugin memory
373    pub fn write_bytes(&self, ptr: i32, data: &[u8]) -> Result<(), HostFunctionError> {
374        let mut memory = self.memory.write();
375        let start = ptr as usize;
376        let end = start + data.len();
377
378        if end > memory.len() {
379            return Err(HostFunctionError::MemoryAccessError(
380                "Write out of bounds".to_string(),
381            ));
382        }
383
384        memory[start..end].copy_from_slice(data);
385        Ok(())
386    }
387
388    /// Allocate memory in plugin
389    pub fn allocate(&self, size: usize) -> Result<i32, HostFunctionError> {
390        let mut memory = self.memory.write();
391        let ptr = memory.len() as i32;
392        let new_size = memory.len() + size;
393        memory.resize(new_size, 0);
394        Ok(ptr)
395    }
396
397    /// Get elapsed time since call start
398    pub fn elapsed(&self) -> Duration {
399        self.start_time.elapsed()
400    }
401}
402
403/// Host function error
404#[derive(Debug, Clone)]
405pub enum HostFunctionError {
406    /// Memory access error
407    MemoryAccessError(String),
408
409    /// Invalid data
410    InvalidData(String),
411
412    /// Permission denied
413    PermissionDenied(String),
414
415    /// Function not found
416    FunctionNotFound(String),
417
418    /// Execution error
419    ExecutionError(String),
420
421    /// Timeout
422    Timeout,
423}
424
425impl std::fmt::Display for HostFunctionError {
426    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
427        match self {
428            HostFunctionError::MemoryAccessError(msg) => write!(f, "Memory access error: {}", msg),
429            HostFunctionError::InvalidData(msg) => write!(f, "Invalid data: {}", msg),
430            HostFunctionError::PermissionDenied(msg) => write!(f, "Permission denied: {}", msg),
431            HostFunctionError::FunctionNotFound(msg) => write!(f, "Function not found: {}", msg),
432            HostFunctionError::ExecutionError(msg) => write!(f, "Execution error: {}", msg),
433            HostFunctionError::Timeout => write!(f, "Timeout"),
434        }
435    }
436}
437
438impl std::error::Error for HostFunctionError {}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443
444    #[test]
445    fn test_host_function_registry_new() {
446        let registry = HostFunctionRegistry::new();
447
448        // Check default functions are registered
449        assert!(registry.exists("helios", "log"));
450        assert!(registry.exists("cache", "get"));
451        assert!(registry.exists("query", "execute"));
452    }
453
454    #[test]
455    fn test_host_function_registry_list() {
456        let registry = HostFunctionRegistry::new();
457
458        let namespaces = registry.list_namespaces();
459        assert!(namespaces.contains(&"helios".to_string()));
460        assert!(namespaces.contains(&"cache".to_string()));
461        assert!(namespaces.contains(&"query".to_string()));
462
463        let helios_funcs = registry.list_namespace("helios");
464        assert!(helios_funcs.contains(&"log".to_string()));
465    }
466
467    #[test]
468    fn test_host_function_required_permission() {
469        assert!(HostFunction::Log.required_permission().is_none());
470        assert!(HostFunction::HttpFetch.required_permission().is_some());
471        assert!(HostFunction::CacheGet.required_permission().is_some());
472    }
473
474    #[test]
475    fn test_host_function_signature() {
476        let sig = HostFunction::Log.signature();
477        assert!(sig.contains("log"));
478        assert!(sig.contains("level"));
479    }
480
481    #[test]
482    fn test_function_call_stats() {
483        let mut stats = FunctionCallStats::default();
484        stats.total_calls = 10;
485        stats.successful_calls = 9;
486        stats.failed_calls = 1;
487        stats.total_duration = Duration::from_millis(100);
488
489        assert_eq!(stats.avg_duration(), Duration::from_millis(10));
490        assert!((stats.success_rate() - 0.9).abs() < 0.001);
491    }
492
493    #[test]
494    fn test_host_function_context_memory() {
495        let ctx = HostFunctionContext {
496            plugin_name: "test".to_string(),
497            request_id: "req-1".to_string(),
498            memory: Arc::new(RwLock::new(vec![0u8; 1024])),
499            config: HashMap::new(),
500            start_time: Instant::now(),
501        };
502
503        // Write and read back
504        ctx.write_bytes(0, b"hello").unwrap();
505        let read = ctx.read_bytes(0, 5).unwrap();
506        assert_eq!(read, b"hello");
507
508        // Read string
509        let s = ctx.read_string(0, 5).unwrap();
510        assert_eq!(s, "hello");
511    }
512
513    #[test]
514    fn test_host_function_context_out_of_bounds() {
515        let ctx = HostFunctionContext {
516            plugin_name: "test".to_string(),
517            request_id: "req-1".to_string(),
518            memory: Arc::new(RwLock::new(vec![0u8; 10])),
519            config: HashMap::new(),
520            start_time: Instant::now(),
521        };
522
523        // Try to read beyond memory
524        let result = ctx.read_bytes(5, 10);
525        assert!(result.is_err());
526    }
527
528    #[test]
529    fn test_record_call() {
530        let registry = HostFunctionRegistry::new();
531
532        registry.record_call("helios", "log", Duration::from_micros(50), true);
533        registry.record_call("helios", "log", Duration::from_micros(100), true);
534        registry.record_call("helios", "log", Duration::from_micros(75), false);
535
536        let stats = registry.get_stats();
537        let log_stats = stats.calls.get("helios:log").unwrap();
538
539        assert_eq!(log_stats.total_calls, 3);
540        assert_eq!(log_stats.successful_calls, 2);
541        assert_eq!(log_stats.failed_calls, 1);
542        assert_eq!(log_stats.max_duration, Duration::from_micros(100));
543    }
544}