use crate::error::{KernelError, KernelResult};
use crate::kernel_api::RowId;
use crate::wasm_runtime::WasmPluginCapabilities;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Debug, Clone, PartialEq)]
pub enum HostCallResult {
Success(Vec<u8>),
Ok,
PermissionDenied(String),
NotFound(String),
InvalidArgs(String),
Error(String),
}
impl HostCallResult {
pub fn status_code(&self) -> i32 {
match self {
HostCallResult::Success(_) => 0,
HostCallResult::Ok => 0,
HostCallResult::PermissionDenied(_) => -1,
HostCallResult::NotFound(_) => -2,
HostCallResult::InvalidArgs(_) => -3,
HostCallResult::Error(_) => -4,
}
}
pub fn data(&self) -> Option<&[u8]> {
match self {
HostCallResult::Success(data) => Some(data),
_ => None,
}
}
}
pub struct HostFunctionContext {
pub plugin_name: String,
pub capabilities: WasmPluginCapabilities,
pub audit_log: Vec<AuditEntry>,
pub transaction_id: Option<u64>,
pub session_vars: HashMap<String, Vec<u8>>,
}
#[derive(Debug, Clone)]
pub struct AuditEntry {
pub timestamp_us: u64,
pub function: String,
pub table: Option<String>,
pub status: i32,
pub rows_affected: u64,
}
impl HostFunctionContext {
pub fn new(plugin_name: &str, capabilities: WasmPluginCapabilities) -> Self {
Self {
plugin_name: plugin_name.to_string(),
capabilities,
audit_log: Vec::new(),
transaction_id: None,
session_vars: HashMap::new(),
}
}
pub fn check_read(&self, table: &str) -> KernelResult<()> {
if !self.capabilities.can_read(table) {
return Err(KernelError::Plugin {
message: format!(
"plugin '{}' not authorized to read table '{}'",
self.plugin_name, table
),
});
}
Ok(())
}
pub fn check_write(&self, table: &str) -> KernelResult<()> {
if !self.capabilities.can_write(table) {
return Err(KernelError::Plugin {
message: format!(
"plugin '{}' not authorized to write table '{}'",
self.plugin_name, table
),
});
}
Ok(())
}
pub fn check_vector_search(&self) -> KernelResult<()> {
if !self.capabilities.can_vector_search {
return Err(KernelError::Plugin {
message: format!(
"plugin '{}' not authorized for vector search",
self.plugin_name
),
});
}
Ok(())
}
pub fn check_index_search(&self) -> KernelResult<()> {
if !self.capabilities.can_index_search {
return Err(KernelError::Plugin {
message: format!(
"plugin '{}' not authorized for index search",
self.plugin_name
),
});
}
Ok(())
}
pub fn audit(&mut self, function: &str, table: Option<&str>, status: i32, rows: u64) {
self.audit_log.push(AuditEntry {
timestamp_us: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_micros() as u64,
function: function.to_string(),
table: table.map(|s| s.to_string()),
status,
rows_affected: rows,
});
}
}
pub trait HostFunction: Send + Sync {
fn name(&self) -> &str;
fn execute(&self, ctx: &mut HostFunctionContext, args: &[u8]) -> HostCallResult;
fn description(&self) -> &str;
}
pub struct SochRead {
_marker: std::marker::PhantomData<()>,
}
impl SochRead {
pub fn new() -> Self {
Self {
_marker: std::marker::PhantomData,
}
}
}
impl Default for SochRead {
fn default() -> Self {
Self::new()
}
}
impl HostFunction for SochRead {
fn name(&self) -> &str {
"soch_read"
}
fn description(&self) -> &str {
"Read rows from a table with optional key filter"
}
fn execute(&self, ctx: &mut HostFunctionContext, args: &[u8]) -> HostCallResult {
let args_str = match std::str::from_utf8(args) {
Ok(s) => s,
Err(_) => {
ctx.audit("soch_read", None, -3, 0);
return HostCallResult::InvalidArgs("invalid UTF-8 in arguments".to_string());
}
};
let table = args_str.lines().next().unwrap_or("");
if let Err(e) = ctx.check_read(table) {
ctx.audit("soch_read", Some(table), -1, 0);
return HostCallResult::PermissionDenied(e.to_string());
}
let mock_data = "table[1]{id,name}:\n(1,\"mock_row\")"
.to_string()
.into_bytes();
ctx.audit("soch_read", Some(table), 0, 1);
HostCallResult::Success(mock_data)
}
}
pub struct SochWrite {
_marker: std::marker::PhantomData<()>,
}
impl SochWrite {
pub fn new() -> Self {
Self {
_marker: std::marker::PhantomData,
}
}
}
impl Default for SochWrite {
fn default() -> Self {
Self::new()
}
}
impl HostFunction for SochWrite {
fn name(&self) -> &str {
"soch_write"
}
fn description(&self) -> &str {
"Write rows to a table"
}
fn execute(&self, ctx: &mut HostFunctionContext, args: &[u8]) -> HostCallResult {
let args_str = match std::str::from_utf8(args) {
Ok(s) => s,
Err(_) => {
ctx.audit("soch_write", None, -3, 0);
return HostCallResult::InvalidArgs("invalid UTF-8 in arguments".to_string());
}
};
let table = args_str.lines().next().unwrap_or("");
if let Err(e) = ctx.check_write(table) {
ctx.audit("soch_write", Some(table), -1, 0);
return HostCallResult::PermissionDenied(e.to_string());
}
let row_count = args_str.lines().skip(1).count() as u64;
ctx.audit("soch_write", Some(table), 0, row_count);
HostCallResult::Success(row_count.to_le_bytes().to_vec())
}
}
pub struct VectorSearch {
_marker: std::marker::PhantomData<()>,
}
impl VectorSearch {
pub fn new() -> Self {
Self {
_marker: std::marker::PhantomData,
}
}
}
impl Default for VectorSearch {
fn default() -> Self {
Self::new()
}
}
impl HostFunction for VectorSearch {
fn name(&self) -> &str {
"vector_search"
}
fn description(&self) -> &str {
"Perform vector similarity search"
}
fn execute(&self, ctx: &mut HostFunctionContext, args: &[u8]) -> HostCallResult {
if let Err(e) = ctx.check_vector_search() {
ctx.audit("vector_search", None, -1, 0);
return HostCallResult::PermissionDenied(e.to_string());
}
let args_str = std::str::from_utf8(args).unwrap_or("");
let collection = args_str.lines().next().unwrap_or("default");
let mock_results: Vec<(RowId, f32)> = vec![(1, 0.1), (2, 0.2), (3, 0.3)];
let mut result = Vec::new();
for (row_id, distance) in mock_results {
result.extend_from_slice(&row_id.to_le_bytes());
result.extend_from_slice(&distance.to_le_bytes());
}
ctx.audit("vector_search", Some(collection), 0, 3);
HostCallResult::Success(result)
}
}
pub struct EmitMetric {
metrics_emitted: AtomicU64,
}
impl EmitMetric {
pub fn new() -> Self {
Self {
metrics_emitted: AtomicU64::new(0),
}
}
pub fn total_emitted(&self) -> u64 {
self.metrics_emitted.load(Ordering::Relaxed)
}
}
impl Default for EmitMetric {
fn default() -> Self {
Self::new()
}
}
impl HostFunction for EmitMetric {
fn name(&self) -> &str {
"emit_metric"
}
fn description(&self) -> &str {
"Emit an observability metric (counter, gauge, or histogram)"
}
fn execute(&self, ctx: &mut HostFunctionContext, args: &[u8]) -> HostCallResult {
if args.is_empty() {
ctx.audit("emit_metric", None, -3, 0);
return HostCallResult::InvalidArgs("empty metric data".to_string());
}
self.metrics_emitted.fetch_add(1, Ordering::Relaxed);
ctx.audit("emit_metric", None, 0, 1);
HostCallResult::Ok
}
}
pub struct LogMessage {
logs: parking_lot::RwLock<Vec<(u8, String)>>,
}
impl LogMessage {
pub fn new() -> Self {
Self {
logs: parking_lot::RwLock::new(Vec::new()),
}
}
pub fn captured_logs(&self) -> Vec<(u8, String)> {
self.logs.read().clone()
}
pub fn clear_logs(&self) {
self.logs.write().clear();
}
}
impl Default for LogMessage {
fn default() -> Self {
Self::new()
}
}
impl HostFunction for LogMessage {
fn name(&self) -> &str {
"log_message"
}
fn description(&self) -> &str {
"Log a message at specified level"
}
fn execute(&self, ctx: &mut HostFunctionContext, args: &[u8]) -> HostCallResult {
if args.is_empty() {
return HostCallResult::InvalidArgs("empty log data".to_string());
}
let level = args[0];
let message = std::str::from_utf8(&args[1..]).unwrap_or("<invalid UTF-8>");
self.logs.write().push((level, message.to_string()));
ctx.audit("log_message", None, 0, 0);
HostCallResult::Ok
}
}
pub struct HostFunctionRegistry {
functions: HashMap<String, Arc<dyn HostFunction>>,
}
impl Default for HostFunctionRegistry {
fn default() -> Self {
Self::new()
}
}
impl HostFunctionRegistry {
pub fn new() -> Self {
let mut registry = Self {
functions: HashMap::new(),
};
registry.register(Arc::new(SochRead::new()));
registry.register(Arc::new(SochWrite::new()));
registry.register(Arc::new(VectorSearch::new()));
registry.register(Arc::new(EmitMetric::new()));
registry.register(Arc::new(LogMessage::new()));
registry
}
pub fn register(&mut self, func: Arc<dyn HostFunction>) {
self.functions.insert(func.name().to_string(), func);
}
pub fn get(&self, name: &str) -> Option<Arc<dyn HostFunction>> {
self.functions.get(name).cloned()
}
pub fn list(&self) -> Vec<(&str, &str)> {
self.functions
.values()
.map(|f| (f.name(), f.description()))
.collect()
}
pub fn execute(
&self,
name: &str,
ctx: &mut HostFunctionContext,
args: &[u8],
) -> HostCallResult {
match self.functions.get(name) {
Some(func) => func.execute(ctx, args),
None => HostCallResult::NotFound(format!("host function '{}' not found", name)),
}
}
}
pub mod wire {
pub fn encode_string(s: &str) -> Vec<u8> {
let mut buf = Vec::with_capacity(4 + s.len());
buf.extend_from_slice(&(s.len() as u32).to_le_bytes());
buf.extend_from_slice(s.as_bytes());
buf
}
pub fn decode_string(data: &[u8]) -> Option<(&str, &[u8])> {
if data.len() < 4 {
return None;
}
let len = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
if data.len() < 4 + len {
return None;
}
let s = std::str::from_utf8(&data[4..4 + len]).ok()?;
Some((s, &data[4 + len..]))
}
pub fn encode_row_id(id: u64) -> [u8; 8] {
id.to_le_bytes()
}
pub fn decode_row_id(data: &[u8]) -> Option<(u64, &[u8])> {
if data.len() < 8 {
return None;
}
let id = u64::from_le_bytes([
data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
]);
Some((id, &data[8..]))
}
pub fn encode_f32_vec(v: &[f32]) -> Vec<u8> {
let mut buf = Vec::with_capacity(4 + v.len() * 4);
buf.extend_from_slice(&(v.len() as u32).to_le_bytes());
for f in v {
buf.extend_from_slice(&f.to_le_bytes());
}
buf
}
pub fn decode_f32_vec(data: &[u8]) -> Option<(Vec<f32>, &[u8])> {
if data.len() < 4 {
return None;
}
let len = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
if data.len() < 4 + len * 4 {
return None;
}
let mut vec = Vec::with_capacity(len);
for i in 0..len {
let offset = 4 + i * 4;
let f = f32::from_le_bytes([
data[offset],
data[offset + 1],
data[offset + 2],
data[offset + 3],
]);
vec.push(f);
}
Some((vec, &data[4 + len * 4..]))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_host_call_result_status() {
assert_eq!(HostCallResult::Ok.status_code(), 0);
assert_eq!(HostCallResult::Success(vec![]).status_code(), 0);
assert_eq!(
HostCallResult::PermissionDenied("".to_string()).status_code(),
-1
);
assert_eq!(HostCallResult::NotFound("".to_string()).status_code(), -2);
assert_eq!(
HostCallResult::InvalidArgs("".to_string()).status_code(),
-3
);
assert_eq!(HostCallResult::Error("".to_string()).status_code(), -4);
}
#[test]
fn test_host_context_permission_checks() {
let caps = WasmPluginCapabilities {
can_read_table: vec!["users".to_string()],
can_write_table: vec!["logs".to_string()],
can_vector_search: true,
can_index_search: false,
..Default::default()
};
let ctx = HostFunctionContext::new("test_plugin", caps);
assert!(ctx.check_read("users").is_ok());
assert!(ctx.check_read("other").is_err());
assert!(ctx.check_write("logs").is_ok());
assert!(ctx.check_write("users").is_err());
assert!(ctx.check_vector_search().is_ok());
assert!(ctx.check_index_search().is_err());
}
#[test]
fn test_soch_read_permission() {
let caps = WasmPluginCapabilities {
can_read_table: vec!["allowed_table".to_string()],
..Default::default()
};
let mut ctx = HostFunctionContext::new("test", caps);
let read_fn = SochRead::new();
let result = read_fn.execute(&mut ctx, b"allowed_table\n");
assert_eq!(result.status_code(), 0);
let result = read_fn.execute(&mut ctx, b"denied_table\n");
assert_eq!(result.status_code(), -1);
}
#[test]
fn test_soch_write_permission() {
let caps = WasmPluginCapabilities {
can_write_table: vec!["writable".to_string()],
..Default::default()
};
let mut ctx = HostFunctionContext::new("test", caps);
let write_fn = SochWrite::new();
let result = write_fn.execute(&mut ctx, b"writable\nrow1\nrow2\n");
assert_eq!(result.status_code(), 0);
assert_eq!(result.data().unwrap(), &2u64.to_le_bytes());
let result = write_fn.execute(&mut ctx, b"readonly\nrow1\n");
assert_eq!(result.status_code(), -1);
}
#[test]
fn test_vector_search() {
let caps = WasmPluginCapabilities {
can_vector_search: true,
..Default::default()
};
let mut ctx = HostFunctionContext::new("test", caps);
let search_fn = VectorSearch::new();
let result = search_fn.execute(&mut ctx, b"collection\n");
assert_eq!(result.status_code(), 0);
let data = result.data().unwrap();
assert_eq!(data.len(), 3 * (8 + 4)); }
#[test]
fn test_emit_metric() {
let caps = WasmPluginCapabilities::default();
let mut ctx = HostFunctionContext::new("test", caps);
let metric_fn = EmitMetric::new();
let result = metric_fn.execute(&mut ctx, b"\x01metric_name\x00\x00\x00\x00");
assert_eq!(result.status_code(), 0);
assert_eq!(metric_fn.total_emitted(), 1);
}
#[test]
fn test_log_message() {
let caps = WasmPluginCapabilities::default();
let mut ctx = HostFunctionContext::new("test", caps);
let log_fn = LogMessage::new();
let result = log_fn.execute(&mut ctx, b"\x01hello world");
assert_eq!(result.status_code(), 0);
let logs = log_fn.captured_logs();
assert_eq!(logs.len(), 1);
assert_eq!(logs[0].0, 1); assert_eq!(logs[0].1, "hello world");
}
#[test]
fn test_host_function_registry() {
let registry = HostFunctionRegistry::new();
assert!(registry.get("soch_read").is_some());
assert!(registry.get("soch_write").is_some());
assert!(registry.get("vector_search").is_some());
assert!(registry.get("emit_metric").is_some());
assert!(registry.get("log_message").is_some());
assert!(registry.get("unknown").is_none());
let list = registry.list();
assert!(list.len() >= 5);
}
#[test]
fn test_registry_execute() {
let registry = HostFunctionRegistry::new();
let caps = WasmPluginCapabilities {
can_read_table: vec!["test".to_string()],
..Default::default()
};
let mut ctx = HostFunctionContext::new("plugin", caps);
let result = registry.execute("soch_read", &mut ctx, b"test\n");
assert_eq!(result.status_code(), 0);
let result = registry.execute("nonexistent", &mut ctx, b"");
assert_eq!(result.status_code(), -2);
}
#[test]
fn test_audit_log() {
let caps = WasmPluginCapabilities {
can_read_table: vec!["audit_test".to_string()],
..Default::default()
};
let mut ctx = HostFunctionContext::new("test", caps);
let read_fn = SochRead::new();
let _ = read_fn.execute(&mut ctx, b"audit_test\n");
assert_eq!(ctx.audit_log.len(), 1);
assert_eq!(ctx.audit_log[0].function, "soch_read");
assert_eq!(ctx.audit_log[0].table, Some("audit_test".to_string()));
assert_eq!(ctx.audit_log[0].status, 0);
}
mod wire_tests {
use super::super::wire::*;
#[test]
fn test_encode_decode_string() {
let s = "hello world";
let encoded = encode_string(s);
let (decoded, rest) = decode_string(&encoded).unwrap();
assert_eq!(decoded, s);
assert!(rest.is_empty());
}
#[test]
fn test_encode_decode_row_id() {
let id = 0x123456789ABCDEF0u64;
let encoded = encode_row_id(id);
let (decoded, rest) = decode_row_id(&encoded).unwrap();
assert_eq!(decoded, id);
assert!(rest.is_empty());
}
#[test]
fn test_encode_decode_f32_vec() {
let v = vec![1.0, 2.0, 3.0, 4.0];
let encoded = encode_f32_vec(&v);
let (decoded, rest) = decode_f32_vec(&encoded).unwrap();
assert_eq!(decoded, v);
assert!(rest.is_empty());
}
#[test]
fn test_decode_empty() {
assert!(decode_string(&[]).is_none());
assert!(decode_row_id(&[]).is_none());
assert!(decode_f32_vec(&[]).is_none());
}
}
}