use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use cortexai_core::errors::ToolError;
use cortexai_core::tool::{ExecutionContext, Tool, ToolSchema};
use wasmtime::{Engine, Linker, Module, Store, Trap, Val};
use crate::sandbox_error::SandboxError;
const DEFAULT_MAX_MEMORY_BYTES: usize = 64 * 1024 * 1024;
const DEFAULT_MAX_EXECUTION_TIME: Duration = Duration::from_secs(30);
#[derive(Debug, Clone)]
pub struct SandboxConfig {
max_memory_bytes: usize,
max_execution_time: Duration,
allowed_paths: Vec<(String, String)>,
allow_network: bool,
max_fuel: Option<u64>,
}
impl SandboxConfig {
pub fn max_memory_bytes(&self) -> usize {
self.max_memory_bytes
}
pub fn max_execution_time(&self) -> Duration {
self.max_execution_time
}
pub fn allowed_paths(&self) -> &[(String, String)] {
&self.allowed_paths
}
pub fn allow_network(&self) -> bool {
self.allow_network
}
pub fn max_fuel(&self) -> Option<u64> {
self.max_fuel
}
pub fn builder() -> SandboxConfigBuilder {
SandboxConfigBuilder::default()
}
}
#[derive(Debug, Clone)]
pub struct SandboxConfigBuilder {
config: SandboxConfig,
}
impl Default for SandboxConfigBuilder {
fn default() -> Self {
Self {
config: SandboxConfig::default(),
}
}
}
impl SandboxConfigBuilder {
pub fn max_memory_bytes(mut self, bytes: usize) -> Self {
self.config.max_memory_bytes = bytes;
self
}
pub fn max_execution_time(mut self, duration: Duration) -> Self {
self.config.max_execution_time = duration;
self
}
pub fn allowed_path(mut self, guest: String, host: String) -> Self {
self.config.allowed_paths.push((guest, host));
self
}
pub fn allow_network(mut self, allow: bool) -> Self {
self.config.allow_network = allow;
self
}
pub fn max_fuel(mut self, fuel: u64) -> Self {
self.config.max_fuel = Some(fuel);
self
}
pub fn build(self) -> SandboxConfig {
self.config
}
}
impl Default for SandboxConfig {
fn default() -> Self {
Self {
max_memory_bytes: DEFAULT_MAX_MEMORY_BYTES,
max_execution_time: DEFAULT_MAX_EXECUTION_TIME,
allowed_paths: Vec::new(),
allow_network: false,
max_fuel: None,
}
}
}
#[derive(Clone)]
pub struct SandboxedModule {
module: Module,
engine: Engine,
config: Arc<SandboxConfig>,
}
impl std::fmt::Debug for SandboxedModule {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SandboxedModule")
.field("config", &self.config)
.finish_non_exhaustive()
}
}
impl SandboxedModule {
pub(crate) fn engine(&self) -> &Engine {
&self.engine
}
pub(crate) fn module(&self) -> &Module {
&self.module
}
pub fn config(&self) -> &SandboxConfig {
&self.config
}
}
pub struct ToolSandbox {
engine: Engine,
config: Arc<SandboxConfig>,
}
impl ToolSandbox {
pub fn new(config: SandboxConfig) -> Self {
let mut engine_config = wasmtime::Config::new();
engine_config.consume_fuel(config.max_fuel().is_some());
let engine = Engine::new(&engine_config)
.expect("Failed to create wasmtime engine");
Self {
engine,
config: Arc::new(config),
}
}
pub fn config(&self) -> &SandboxConfig {
&self.config
}
pub fn load_module(&self, wasm_bytes: &[u8]) -> Result<SandboxedModule, SandboxError> {
let module = Module::new(&self.engine, wasm_bytes)
.map_err(|e| SandboxError::CompilationFailed(e.to_string()))?;
Ok(SandboxedModule {
module,
engine: self.engine.clone(),
config: Arc::clone(&self.config),
})
}
pub fn execute(
&self,
sandboxed: &SandboxedModule,
fn_name: &str,
input: &[u8],
) -> Result<Vec<u8>, SandboxError> {
let mut store = Store::new(sandboxed.engine(), ());
self.configure_store(&mut store);
let linker = Linker::new(sandboxed.engine());
let instance = linker
.instantiate(&mut store, sandboxed.module())
.map_err(|e| SandboxError::ExecutionFailed(e.to_string()))?;
let func = instance
.get_func(&mut store, fn_name)
.ok_or_else(|| {
SandboxError::ExecutionFailed(format!("export '{}' not found", fn_name))
})?;
let params = parse_i32_params(input);
let param_vals: Vec<Val> = params.iter().map(|&v| Val::I32(v)).collect();
let func_type = func.ty(&store);
let mut results = vec![Val::I32(0); func_type.results().len()];
func.call(&mut store, ¶m_vals, &mut results)
.map_err(|e| map_execution_error(e, &self.config))?;
Ok(encode_results(&results))
}
fn configure_store(&self, store: &mut Store<()>) {
if let Some(fuel) = self.config.max_fuel() {
store.set_fuel(fuel).ok();
}
}
}
fn parse_i32_params(input: &[u8]) -> Vec<i32> {
input
.chunks_exact(4)
.map(|chunk| {
let bytes: [u8; 4] = chunk.try_into().expect("chunk is 4 bytes");
i32::from_le_bytes(bytes)
})
.collect()
}
fn map_execution_error(error: wasmtime::Error, config: &SandboxConfig) -> SandboxError {
if let Some(trap) = error.downcast_ref::<Trap>() {
if *trap == Trap::OutOfFuel {
return SandboxError::FuelExhausted {
fuel_limit: config.max_fuel().unwrap_or(0),
};
}
}
SandboxError::ExecutionFailed(error.to_string())
}
fn encode_results(results: &[Val]) -> Vec<u8> {
let mut bytes = Vec::new();
for val in results {
match val {
Val::I32(v) => bytes.extend_from_slice(&v.to_le_bytes()),
Val::I64(v) => bytes.extend_from_slice(&v.to_le_bytes()),
Val::F32(v) => bytes.extend_from_slice(&v.to_le_bytes()),
Val::F64(v) => bytes.extend_from_slice(&v.to_le_bytes()),
_ => {}
}
}
bytes
}
pub struct SandboxedTool {
sandbox: ToolSandbox,
module: SandboxedModule,
schema: ToolSchema,
}
impl std::fmt::Debug for SandboxedTool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SandboxedTool")
.field("schema", &self.schema)
.finish_non_exhaustive()
}
}
impl SandboxedTool {
pub fn new(
sandbox: ToolSandbox,
module: SandboxedModule,
) -> Result<Self, SandboxError> {
let schema = Self::extract_schema(&sandbox, &module)?;
Ok(Self {
sandbox,
module,
schema,
})
}
fn extract_schema(
_sandbox: &ToolSandbox,
module: &SandboxedModule,
) -> Result<ToolSchema, SandboxError> {
let mut store = Store::new(module.engine(), ());
let linker = Linker::new(module.engine());
let instance = linker
.instantiate(&mut store, module.module())
.map_err(|e| SandboxError::ExecutionFailed(e.to_string()))?;
let schema_fn = instance
.get_func(&mut store, "__tool_schema")
.ok_or_else(|| {
SandboxError::ExecutionFailed(
"module missing '__tool_schema' export".to_string(),
)
})?;
let mut results = [Val::I32(0), Val::I32(0)];
schema_fn
.call(&mut store, &[], &mut results)
.map_err(|e| SandboxError::ExecutionFailed(e.to_string()))?;
let ptr = results[0].unwrap_i32() as u32 as usize;
let len = results[1].unwrap_i32() as u32 as usize;
let memory = instance
.get_memory(&mut store, "memory")
.ok_or_else(|| {
SandboxError::ExecutionFailed(
"module missing 'memory' export".to_string(),
)
})?;
let data = memory.data(&store);
if ptr + len > data.len() {
return Err(SandboxError::ExecutionFailed(
"schema pointer out of bounds".to_string(),
));
}
let json_bytes = &data[ptr..ptr + len];
let schema: ToolSchema = serde_json::from_slice(json_bytes).map_err(|e| {
SandboxError::ExecutionFailed(format!("invalid schema JSON: {}", e))
})?;
Ok(schema)
}
fn call_execute(
&self,
args_json: &[u8],
) -> Result<serde_json::Value, SandboxError> {
let mut store = Store::new(self.module.engine(), ());
self.sandbox.configure_store(&mut store);
let linker = Linker::new(self.module.engine());
let instance = linker
.instantiate(&mut store, self.module.module())
.map_err(|e| SandboxError::ExecutionFailed(e.to_string()))?;
let memory = instance
.get_memory(&mut store, "memory")
.ok_or_else(|| {
SandboxError::ExecutionFailed(
"module missing 'memory' export".to_string(),
)
})?;
let args_offset: usize = 1024;
let mem_data = memory.data_mut(&mut store);
if args_offset + args_json.len() > mem_data.len() {
return Err(SandboxError::MemoryLimitExceeded {
limit_bytes: mem_data.len(),
});
}
mem_data[args_offset..args_offset + args_json.len()]
.copy_from_slice(args_json);
let exec_fn = instance
.get_func(&mut store, "__tool_execute")
.ok_or_else(|| {
SandboxError::ExecutionFailed(
"module missing '__tool_execute' export".to_string(),
)
})?;
let params = [
Val::I32(args_offset as i32),
Val::I32(args_json.len() as i32),
];
let mut results = [Val::I32(0), Val::I32(0)];
exec_fn
.call(&mut store, ¶ms, &mut results)
.map_err(|e| map_execution_error(e, self.module.config()))?;
let out_ptr = results[0].unwrap_i32() as u32 as usize;
let out_len = results[1].unwrap_i32() as u32 as usize;
let data = memory.data(&store);
if out_ptr + out_len > data.len() {
return Err(SandboxError::ExecutionFailed(
"result pointer out of bounds".to_string(),
));
}
let result_bytes = &data[out_ptr..out_ptr + out_len];
serde_json::from_slice(result_bytes).map_err(|e| {
SandboxError::ExecutionFailed(format!("invalid result JSON: {}", e))
})
}
}
#[async_trait]
impl Tool for SandboxedTool {
fn schema(&self) -> ToolSchema {
self.schema.clone()
}
async fn execute(
&self,
_context: &ExecutionContext,
arguments: serde_json::Value,
) -> Result<serde_json::Value, ToolError> {
let args_json = serde_json::to_vec(&arguments)
.map_err(|e| ToolError::InvalidArguments(e.to_string()))?;
self.call_execute(&args_json)
.map_err(|e| ToolError::ExecutionFailed(e.to_string()))
}
}