Skip to main content

heliosdb_proxy/plugins/
host_functions.rs

1//! Host Functions
2//!
3//! Host functions that plugins can call to interact with the proxy.
4
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use parking_lot::RwLock;
10
11/// Host function registry
12pub struct HostFunctionRegistry {
13    /// Registered functions by namespace
14    functions: RwLock<HashMap<String, HashMap<String, HostFunction>>>,
15
16    /// Call statistics
17    stats: RwLock<HostFunctionStats>,
18}
19
20impl HostFunctionRegistry {
21    /// Create a new registry with default functions
22    pub fn new() -> Self {
23        let registry = Self {
24            functions: RwLock::new(HashMap::new()),
25            stats: RwLock::new(HostFunctionStats::default()),
26        };
27
28        // Register default functions
29        registry.register_defaults();
30        registry
31    }
32
33    /// Register default host functions
34    fn register_defaults(&self) {
35        // Helios namespace - core functions
36        self.register("helios", "log", HostFunction::Log);
37        self.register("helios", "metric_inc", HostFunction::MetricInc);
38        self.register("helios", "metric_gauge", HostFunction::MetricGauge);
39        self.register("helios", "get_config", HostFunction::GetConfig);
40        self.register("helios", "get_time", HostFunction::GetTime);
41
42        // Query namespace
43        self.register("query", "execute", HostFunction::QueryExecute);
44        self.register("query", "prepare", HostFunction::QueryPrepare);
45        self.register("query", "get_tables", HostFunction::QueryGetTables);
46        self.register("query", "normalize", HostFunction::QueryNormalize);
47
48        // Cache namespace
49        self.register("cache", "get", HostFunction::CacheGet);
50        self.register("cache", "set", HostFunction::CacheSet);
51        self.register("cache", "delete", HostFunction::CacheDelete);
52        self.register("cache", "exists", HostFunction::CacheExists);
53
54        // HTTP namespace (requires http_fetch permission)
55        self.register("http", "fetch", HostFunction::HttpFetch);
56        self.register("http", "post", HostFunction::HttpPost);
57
58        // Crypto namespace
59        self.register("crypto", "hash", HostFunction::CryptoHash);
60        self.register("crypto", "hmac", HostFunction::CryptoHmac);
61        self.register("crypto", "random", HostFunction::CryptoRandom);
62
63        // KV namespace (plugin-local storage)
64        self.register("kv", "get", HostFunction::KvGet);
65        self.register("kv", "set", HostFunction::KvSet);
66        self.register("kv", "delete", HostFunction::KvDelete);
67        self.register("kv", "list", HostFunction::KvList);
68    }
69
70    /// Register a host function
71    pub fn register(&self, namespace: &str, name: &str, function: HostFunction) {
72        let mut functions = self.functions.write();
73        functions
74            .entry(namespace.to_string())
75            .or_insert_with(HashMap::new)
76            .insert(name.to_string(), function);
77    }
78
79    /// Get a host function
80    pub fn get(&self, namespace: &str, name: &str) -> Option<HostFunction> {
81        let functions = self.functions.read();
82        functions
83            .get(namespace)
84            .and_then(|ns| ns.get(name))
85            .cloned()
86    }
87
88    /// Check if a function exists
89    pub fn exists(&self, namespace: &str, name: &str) -> bool {
90        let functions = self.functions.read();
91        functions
92            .get(namespace)
93            .map(|ns| ns.contains_key(name))
94            .unwrap_or(false)
95    }
96
97    /// List all functions in a namespace
98    pub fn list_namespace(&self, namespace: &str) -> Vec<String> {
99        let functions = self.functions.read();
100        functions
101            .get(namespace)
102            .map(|ns| ns.keys().cloned().collect())
103            .unwrap_or_default()
104    }
105
106    /// List all namespaces
107    pub fn list_namespaces(&self) -> Vec<String> {
108        let functions = self.functions.read();
109        functions.keys().cloned().collect()
110    }
111
112    /// Record a function call
113    pub fn record_call(&self, namespace: &str, name: &str, duration: Duration, success: bool) {
114        let mut stats = self.stats.write();
115        let key = format!("{}:{}", namespace, name);
116
117        let entry = stats.calls.entry(key).or_insert_with(FunctionCallStats::default);
118        entry.total_calls += 1;
119        entry.total_duration += duration;
120
121        if success {
122            entry.successful_calls += 1;
123        } else {
124            entry.failed_calls += 1;
125        }
126
127        if duration > entry.max_duration {
128            entry.max_duration = duration;
129        }
130    }
131
132    /// Get call statistics
133    pub fn get_stats(&self) -> HostFunctionStats {
134        self.stats.read().clone()
135    }
136}
137
138impl Default for HostFunctionRegistry {
139    fn default() -> Self {
140        Self::new()
141    }
142}
143
144/// Host function types
145#[derive(Debug, Clone, PartialEq, Eq)]
146pub enum HostFunction {
147    // Helios core
148    Log,
149    MetricInc,
150    MetricGauge,
151    GetConfig,
152    GetTime,
153
154    // Query operations
155    QueryExecute,
156    QueryPrepare,
157    QueryGetTables,
158    QueryNormalize,
159
160    // Cache operations
161    CacheGet,
162    CacheSet,
163    CacheDelete,
164    CacheExists,
165
166    // HTTP operations
167    HttpFetch,
168    HttpPost,
169
170    // Crypto operations
171    CryptoHash,
172    CryptoHmac,
173    CryptoRandom,
174
175    // KV storage
176    KvGet,
177    KvSet,
178    KvDelete,
179    KvList,
180
181    // Custom function
182    Custom(String),
183}
184
185impl HostFunction {
186    /// Get the required permission for this function
187    pub fn required_permission(&self) -> Option<super::sandbox::Permission> {
188        use super::sandbox::Permission;
189
190        match self {
191            // No permission required
192            HostFunction::Log => None,
193            HostFunction::GetTime => None,
194            HostFunction::GetConfig => None,
195
196            // Metrics permission
197            HostFunction::MetricInc | HostFunction::MetricGauge => Some(Permission::Metrics),
198
199            // Query permission
200            HostFunction::QueryExecute
201            | HostFunction::QueryPrepare
202            | HostFunction::QueryGetTables
203            | HostFunction::QueryNormalize => Some(Permission::QueryExecute),
204
205            // Cache permissions
206            HostFunction::CacheGet | HostFunction::CacheExists => Some(Permission::CacheRead),
207            HostFunction::CacheSet | HostFunction::CacheDelete => Some(Permission::CacheWrite),
208
209            // HTTP permission
210            HostFunction::HttpFetch | HostFunction::HttpPost => Some(Permission::HttpFetch),
211
212            // Crypto permission
213            HostFunction::CryptoHash | HostFunction::CryptoHmac | HostFunction::CryptoRandom => {
214                Some(Permission::Crypto)
215            }
216
217            // KV permissions
218            HostFunction::KvGet | HostFunction::KvList => Some(Permission::KvRead),
219            HostFunction::KvSet | HostFunction::KvDelete => Some(Permission::KvWrite),
220
221            // Custom functions require custom permission
222            HostFunction::Custom(_) => Some(Permission::Custom("custom".to_string())),
223        }
224    }
225
226    /// Get the function signature (for documentation)
227    pub fn signature(&self) -> &'static str {
228        match self {
229            HostFunction::Log => "log(level: i32, message_ptr: i32, message_len: i32)",
230            HostFunction::MetricInc => "metric_inc(name_ptr: i32, name_len: i32, value: f64)",
231            HostFunction::MetricGauge => "metric_gauge(name_ptr: i32, name_len: i32, value: f64)",
232            HostFunction::GetConfig => "get_config(key_ptr: i32, key_len: i32) -> i32",
233            HostFunction::GetTime => "get_time() -> i64",
234
235            HostFunction::QueryExecute => {
236                "query_execute(query_ptr: i32, query_len: i32) -> i32"
237            }
238            HostFunction::QueryPrepare => {
239                "query_prepare(query_ptr: i32, query_len: i32) -> i32"
240            }
241            HostFunction::QueryGetTables => {
242                "query_get_tables(query_ptr: i32, query_len: i32) -> i32"
243            }
244            HostFunction::QueryNormalize => {
245                "query_normalize(query_ptr: i32, query_len: i32) -> i32"
246            }
247
248            HostFunction::CacheGet => "cache_get(key_ptr: i32, key_len: i32) -> i32",
249            HostFunction::CacheSet => {
250                "cache_set(key_ptr: i32, key_len: i32, value_ptr: i32, value_len: i32, ttl: i64)"
251            }
252            HostFunction::CacheDelete => "cache_delete(key_ptr: i32, key_len: i32)",
253            HostFunction::CacheExists => "cache_exists(key_ptr: i32, key_len: i32) -> i32",
254
255            HostFunction::HttpFetch => "http_fetch(url_ptr: i32, url_len: i32) -> i32",
256            HostFunction::HttpPost => {
257                "http_post(url_ptr: i32, url_len: i32, body_ptr: i32, body_len: i32) -> i32"
258            }
259
260            HostFunction::CryptoHash => {
261                "crypto_hash(algo_ptr: i32, algo_len: i32, data_ptr: i32, data_len: i32) -> i32"
262            }
263            HostFunction::CryptoHmac => {
264                "crypto_hmac(key_ptr: i32, key_len: i32, data_ptr: i32, data_len: i32) -> i32"
265            }
266            HostFunction::CryptoRandom => "crypto_random(len: i32) -> i32",
267
268            HostFunction::KvGet => "kv_get(key_ptr: i32, key_len: i32) -> i32",
269            HostFunction::KvSet => {
270                "kv_set(key_ptr: i32, key_len: i32, value_ptr: i32, value_len: i32)"
271            }
272            HostFunction::KvDelete => "kv_delete(key_ptr: i32, key_len: i32)",
273            HostFunction::KvList => "kv_list(prefix_ptr: i32, prefix_len: i32) -> i32",
274
275            HostFunction::Custom(_) => "custom(...)",
276        }
277    }
278}
279
280/// Host function call statistics
281#[derive(Debug, Clone, Default)]
282pub struct HostFunctionStats {
283    /// Per-function statistics
284    pub calls: HashMap<String, FunctionCallStats>,
285}
286
287/// Per-function call statistics
288#[derive(Debug, Clone, Default)]
289pub struct FunctionCallStats {
290    /// Total calls
291    pub total_calls: u64,
292
293    /// Successful calls
294    pub successful_calls: u64,
295
296    /// Failed calls
297    pub failed_calls: u64,
298
299    /// Total duration
300    pub total_duration: Duration,
301
302    /// Maximum duration
303    pub max_duration: Duration,
304}
305
306impl FunctionCallStats {
307    /// Get average duration
308    pub fn avg_duration(&self) -> Duration {
309        if self.total_calls == 0 {
310            Duration::ZERO
311        } else {
312            self.total_duration / self.total_calls as u32
313        }
314    }
315
316    /// Get success rate
317    pub fn success_rate(&self) -> f64 {
318        if self.total_calls == 0 {
319            1.0
320        } else {
321            self.successful_calls as f64 / self.total_calls as f64
322        }
323    }
324}
325
326/// Host function context (passed to function implementations)
327pub struct HostFunctionContext {
328    /// Plugin name
329    pub plugin_name: String,
330
331    /// Request ID
332    pub request_id: String,
333
334    /// Plugin memory (for reading/writing)
335    pub memory: Arc<RwLock<Vec<u8>>>,
336
337    /// Plugin configuration
338    pub config: HashMap<String, serde_json::Value>,
339
340    /// Call start time
341    pub start_time: Instant,
342}
343
344impl HostFunctionContext {
345    /// Read a string from plugin memory
346    pub fn read_string(&self, ptr: i32, len: i32) -> Result<String, HostFunctionError> {
347        let memory = self.memory.read();
348        let start = ptr as usize;
349        let end = start + len as usize;
350
351        if end > memory.len() {
352            return Err(HostFunctionError::MemoryAccessError(
353                "Read out of bounds".to_string(),
354            ));
355        }
356
357        String::from_utf8(memory[start..end].to_vec())
358            .map_err(|e| HostFunctionError::InvalidData(e.to_string()))
359    }
360
361    /// Read bytes from plugin memory
362    pub fn read_bytes(&self, ptr: i32, len: i32) -> Result<Vec<u8>, HostFunctionError> {
363        let memory = self.memory.read();
364        let start = ptr as usize;
365        let end = start + len as usize;
366
367        if end > memory.len() {
368            return Err(HostFunctionError::MemoryAccessError(
369                "Read out of bounds".to_string(),
370            ));
371        }
372
373        Ok(memory[start..end].to_vec())
374    }
375
376    /// Write bytes to plugin memory
377    pub fn write_bytes(&self, ptr: i32, data: &[u8]) -> Result<(), HostFunctionError> {
378        let mut memory = self.memory.write();
379        let start = ptr as usize;
380        let end = start + data.len();
381
382        if end > memory.len() {
383            return Err(HostFunctionError::MemoryAccessError(
384                "Write out of bounds".to_string(),
385            ));
386        }
387
388        memory[start..end].copy_from_slice(data);
389        Ok(())
390    }
391
392    /// Allocate memory in plugin
393    pub fn allocate(&self, size: usize) -> Result<i32, HostFunctionError> {
394        let mut memory = self.memory.write();
395        let ptr = memory.len() as i32;
396        let new_size = memory.len() + size;
397        memory.resize(new_size, 0);
398        Ok(ptr)
399    }
400
401    /// Get elapsed time since call start
402    pub fn elapsed(&self) -> Duration {
403        self.start_time.elapsed()
404    }
405}
406
407/// Host function error
408#[derive(Debug, Clone)]
409pub enum HostFunctionError {
410    /// Memory access error
411    MemoryAccessError(String),
412
413    /// Invalid data
414    InvalidData(String),
415
416    /// Permission denied
417    PermissionDenied(String),
418
419    /// Function not found
420    FunctionNotFound(String),
421
422    /// Execution error
423    ExecutionError(String),
424
425    /// Timeout
426    Timeout,
427}
428
429impl std::fmt::Display for HostFunctionError {
430    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
431        match self {
432            HostFunctionError::MemoryAccessError(msg) => write!(f, "Memory access error: {}", msg),
433            HostFunctionError::InvalidData(msg) => write!(f, "Invalid data: {}", msg),
434            HostFunctionError::PermissionDenied(msg) => write!(f, "Permission denied: {}", msg),
435            HostFunctionError::FunctionNotFound(msg) => write!(f, "Function not found: {}", msg),
436            HostFunctionError::ExecutionError(msg) => write!(f, "Execution error: {}", msg),
437            HostFunctionError::Timeout => write!(f, "Timeout"),
438        }
439    }
440}
441
442impl std::error::Error for HostFunctionError {}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447
448    #[test]
449    fn test_host_function_registry_new() {
450        let registry = HostFunctionRegistry::new();
451
452        // Check default functions are registered
453        assert!(registry.exists("helios", "log"));
454        assert!(registry.exists("cache", "get"));
455        assert!(registry.exists("query", "execute"));
456    }
457
458    #[test]
459    fn test_host_function_registry_list() {
460        let registry = HostFunctionRegistry::new();
461
462        let namespaces = registry.list_namespaces();
463        assert!(namespaces.contains(&"helios".to_string()));
464        assert!(namespaces.contains(&"cache".to_string()));
465        assert!(namespaces.contains(&"query".to_string()));
466
467        let helios_funcs = registry.list_namespace("helios");
468        assert!(helios_funcs.contains(&"log".to_string()));
469    }
470
471    #[test]
472    fn test_host_function_required_permission() {
473        assert!(HostFunction::Log.required_permission().is_none());
474        assert!(HostFunction::HttpFetch.required_permission().is_some());
475        assert!(HostFunction::CacheGet.required_permission().is_some());
476    }
477
478    #[test]
479    fn test_host_function_signature() {
480        let sig = HostFunction::Log.signature();
481        assert!(sig.contains("log"));
482        assert!(sig.contains("level"));
483    }
484
485    #[test]
486    fn test_function_call_stats() {
487        let mut stats = FunctionCallStats::default();
488        stats.total_calls = 10;
489        stats.successful_calls = 9;
490        stats.failed_calls = 1;
491        stats.total_duration = Duration::from_millis(100);
492
493        assert_eq!(stats.avg_duration(), Duration::from_millis(10));
494        assert!((stats.success_rate() - 0.9).abs() < 0.001);
495    }
496
497    #[test]
498    fn test_host_function_context_memory() {
499        let ctx = HostFunctionContext {
500            plugin_name: "test".to_string(),
501            request_id: "req-1".to_string(),
502            memory: Arc::new(RwLock::new(vec![0u8; 1024])),
503            config: HashMap::new(),
504            start_time: Instant::now(),
505        };
506
507        // Write and read back
508        ctx.write_bytes(0, b"hello").unwrap();
509        let read = ctx.read_bytes(0, 5).unwrap();
510        assert_eq!(read, b"hello");
511
512        // Read string
513        let s = ctx.read_string(0, 5).unwrap();
514        assert_eq!(s, "hello");
515    }
516
517    #[test]
518    fn test_host_function_context_out_of_bounds() {
519        let ctx = HostFunctionContext {
520            plugin_name: "test".to_string(),
521            request_id: "req-1".to_string(),
522            memory: Arc::new(RwLock::new(vec![0u8; 10])),
523            config: HashMap::new(),
524            start_time: Instant::now(),
525        };
526
527        // Try to read beyond memory
528        let result = ctx.read_bytes(5, 10);
529        assert!(result.is_err());
530    }
531
532    #[test]
533    fn test_record_call() {
534        let registry = HostFunctionRegistry::new();
535
536        registry.record_call("helios", "log", Duration::from_micros(50), true);
537        registry.record_call("helios", "log", Duration::from_micros(100), true);
538        registry.record_call("helios", "log", Duration::from_micros(75), false);
539
540        let stats = registry.get_stats();
541        let log_stats = stats.calls.get("helios:log").unwrap();
542
543        assert_eq!(log_stats.total_calls, 3);
544        assert_eq!(log_stats.successful_calls, 2);
545        assert_eq!(log_stats.failed_calls, 1);
546        assert_eq!(log_stats.max_duration, Duration::from_micros(100));
547    }
548}