use std::path::Path;
use crate::error::OxideError;
use crate::types::{FunctionDefinition, ToolDefinition};
#[cfg(feature = "wasm-tools")]
pub struct WasmTool {
definition: ToolDefinition,
engine: wasmtime::Engine,
module: wasmtime::Module,
}
#[cfg(feature = "wasm-tools")]
impl WasmTool {
pub fn from_file(
path: &Path,
definition: ToolDefinition,
) -> Result<Self, OxideError> {
let engine = wasmtime::Engine::default();
let module = wasmtime::Module::from_file(&engine, path)
.map_err(|e| OxideError::Other(format!("wasm load: {e}")))?;
Ok(Self { definition, engine, module })
}
pub fn call(&self, args: serde_json::Value) -> Result<serde_json::Value, OxideError> {
use wasmtime::{Instance, Store};
let mut store = Store::new(&self.engine, ());
let instance = Instance::new(&mut store, &self.module, &[])
.map_err(|e| OxideError::Other(format!("wasm instantiate: {e}")))?;
let memory = instance
.get_memory(&mut store, "memory")
.ok_or_else(|| OxideError::Other("wasm: no `memory` export".into()))?;
let alloc = instance
.get_typed_func::<i32, i32>(&mut store, "alloc")
.map_err(|e| OxideError::Other(format!("wasm: `alloc` not found: {e}")))?;
let tool_call_fn = instance
.get_typed_func::<(i32, i32), i32>(&mut store, "tool_call")
.map_err(|e| OxideError::Other(format!("wasm: `tool_call` not found: {e}")))?;
let input_bytes = serde_json::to_vec(&args).map_err(OxideError::Serde)?;
let input_len = input_bytes.len() as i32;
let input_ptr = alloc
.call(&mut store, input_len)
.map_err(|e| OxideError::Other(format!("wasm alloc failed: {e}")))?;
memory
.write(&mut store, input_ptr as usize, &input_bytes)
.map_err(|e| OxideError::Other(format!("wasm memory write: {e}")))?;
let result_ptr = tool_call_fn
.call(&mut store, (input_ptr, input_len))
.map_err(|e| OxideError::Other(format!("wasm tool_call failed: {e}")))?;
let mem_data = memory.data(&store);
let start = result_ptr as usize;
let end = mem_data[start..]
.iter()
.position(|&b| b == 0)
.map(|i| start + i)
.unwrap_or(start);
let result_bytes = &mem_data[start..end];
let result: serde_json::Value =
serde_json::from_slice(result_bytes).map_err(OxideError::Serde)?;
Ok(result)
}
pub fn definition(&self) -> &ToolDefinition {
&self.definition
}
}
#[cfg(not(feature = "wasm-tools"))]
#[derive(Debug)]
pub struct WasmTool {
_private: (),
}
#[cfg(not(feature = "wasm-tools"))]
impl WasmTool {
#[allow(unused_variables)]
pub fn from_file(path: &Path, definition: ToolDefinition) -> Result<Self, OxideError> {
Err(OxideError::Other(
"WasmTool requires the `wasm-tools` feature: \
oxide-agent = { features = [\"wasm-tools\"] }"
.into(),
))
}
pub fn definition(&self) -> &ToolDefinition {
unimplemented!("wasm-tools feature not enabled")
}
}
pub fn wasm_tool_definition(
name: impl Into<String>,
description: impl Into<String>,
parameters: serde_json::Value,
) -> ToolDefinition {
ToolDefinition {
kind: "function".into(),
function: FunctionDefinition {
name: name.into(),
description: description.into(),
parameters,
},
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn stub_returns_meaningful_error_without_feature() {
#[cfg(not(feature = "wasm-tools"))]
{
let err = WasmTool::from_file(
Path::new("nonexistent.wasm"),
wasm_tool_definition("t", "d", serde_json::json!({})),
)
.unwrap_err();
let msg = err.to_string();
assert!(msg.contains("wasm-tools"), "error should mention the feature flag");
}
#[cfg(feature = "wasm-tools")]
{}
}
}