use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use tokio::sync::mpsc;
use wasmtime::Store;
use wasmtime::component::ResourceTable;
use wasmtime_wasi::{DirPerms, FilePerms, WasiCtx, WasiCtxBuilder};
use crate::callback::Callback;
use crate::error::Error;
use crate::wasm::{
CallbackRequest, ExecutionOutput, ExecutorState, HostCallbackInfo, MemoryTracker, NetRequest,
PythonExecutor, Sandbox as SandboxBindings, TraceRequest,
};
pub const MAX_SNAPSHOT_SIZE: usize = 10 * 1024 * 1024;
#[derive(Debug, Clone)]
pub struct PythonStateSnapshot {
data: Vec<u8>,
metadata: SnapshotMetadata,
}
#[derive(Debug, Clone)]
pub struct SnapshotMetadata {
pub timestamp_ms: u64,
pub size_bytes: usize,
}
impl PythonStateSnapshot {
fn new(data: Vec<u8>) -> Self {
let timestamp_ms = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_millis() as u64)
.unwrap_or(0);
let size_bytes = data.len();
Self {
data,
metadata: SnapshotMetadata {
timestamp_ms,
size_bytes,
},
}
}
#[must_use]
pub fn data(&self) -> &[u8] {
&self.data
}
#[must_use]
pub fn metadata(&self) -> &SnapshotMetadata {
&self.metadata
}
#[must_use]
pub fn size(&self) -> usize {
self.data.len()
}
#[must_use]
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(8 + self.data.len());
bytes.extend_from_slice(&self.metadata.timestamp_ms.to_le_bytes());
bytes.extend_from_slice(&self.data);
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, Error> {
if bytes.len() < 8 {
return Err(Error::Snapshot("Snapshot data too short".to_string()));
}
let timestamp_ms = u64::from_le_bytes(
bytes[..8]
.try_into()
.map_err(|_| Error::Snapshot("Invalid timestamp".to_string()))?,
);
let data = bytes[8..].to_vec();
let size_bytes = data.len();
Ok(Self {
data,
metadata: SnapshotMetadata {
timestamp_ms,
size_bytes,
},
})
}
}
pub struct SessionExecuteBuilder<'a> {
session: &'a mut SessionExecutor,
code: String,
callbacks: Vec<Arc<dyn Callback>>,
callback_tx: Option<mpsc::Sender<CallbackRequest>>,
trace_tx: Option<mpsc::UnboundedSender<TraceRequest>>,
net_tx: Option<mpsc::Sender<NetRequest>>,
output_tx: Option<mpsc::UnboundedSender<crate::wasm::OutputRequest>>,
fuel_limit: Option<u64>,
}
impl std::fmt::Debug for SessionExecuteBuilder<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SessionExecuteBuilder")
.field("code_len", &self.code.len())
.field("callbacks_count", &self.callbacks.len())
.field("has_callback_tx", &self.callback_tx.is_some())
.field("has_trace_tx", &self.trace_tx.is_some())
.field("has_net_tx", &self.net_tx.is_some())
.field("has_output_tx", &self.output_tx.is_some())
.field("fuel_limit", &self.fuel_limit)
.finish_non_exhaustive()
}
}
impl<'a> SessionExecuteBuilder<'a> {
fn new(session: &'a mut SessionExecutor, code: impl Into<String>) -> Self {
Self {
session,
code: code.into(),
callbacks: Vec::new(),
callback_tx: None,
trace_tx: None,
net_tx: None,
output_tx: None,
fuel_limit: None,
}
}
#[must_use]
pub fn with_callbacks(
mut self,
callbacks: &[Arc<dyn Callback>],
callback_tx: mpsc::Sender<CallbackRequest>,
) -> Self {
self.callbacks = callbacks.to_vec();
self.callback_tx = Some(callback_tx);
self
}
#[must_use]
pub fn with_tracing(mut self, trace_tx: mpsc::UnboundedSender<TraceRequest>) -> Self {
self.trace_tx = Some(trace_tx);
self
}
#[must_use]
pub fn with_network(mut self, net_tx: mpsc::Sender<NetRequest>) -> Self {
self.net_tx = Some(net_tx);
self
}
#[must_use]
pub fn with_output_streaming(
mut self,
output_tx: mpsc::UnboundedSender<crate::wasm::OutputRequest>,
) -> Self {
self.output_tx = Some(output_tx);
self
}
#[must_use]
pub fn with_fuel_limit(mut self, fuel: u64) -> Self {
self.fuel_limit = Some(fuel);
self
}
#[tracing::instrument(
name = "SessionExecuteBuilder::run",
skip(self),
fields(
code_len = self.code.len(),
callbacks = self.callbacks.len(),
fuel_limit = ?self.fuel_limit,
)
)]
pub async fn run(self) -> Result<ExecutionOutput, Error> {
self.session
.execute_internal(
&self.code,
&self.callbacks,
self.callback_tx,
self.trace_tx,
self.net_tx,
self.output_tx,
self.fuel_limit,
)
.await
}
}
#[cfg(feature = "vfs")]
#[derive(Debug, Clone)]
pub struct VolumeMount {
pub host_path: std::path::PathBuf,
pub guest_path: String,
pub read_only: bool,
}
#[cfg(feature = "vfs")]
impl VolumeMount {
#[must_use]
pub fn new(host_path: impl Into<std::path::PathBuf>, guest_path: impl Into<String>) -> Self {
Self {
host_path: host_path.into(),
guest_path: guest_path.into(),
read_only: false,
}
}
#[must_use]
pub fn read_only(
host_path: impl Into<std::path::PathBuf>,
guest_path: impl Into<String>,
) -> Self {
Self {
host_path: host_path.into(),
guest_path: guest_path.into(),
read_only: true,
}
}
}
#[cfg(feature = "vfs")]
#[derive(Debug, Clone)]
pub struct VfsConfig {
pub mount_path: String,
pub dir_perms: eryx_vfs::DirPerms,
pub file_perms: eryx_vfs::FilePerms,
pub volumes: Vec<VolumeMount>,
}
#[cfg(feature = "vfs")]
impl Default for VfsConfig {
fn default() -> Self {
Self {
mount_path: "/data".to_string(),
dir_perms: eryx_vfs::DirPerms::all(),
file_perms: eryx_vfs::FilePerms::all(),
volumes: Vec::new(),
}
}
}
#[cfg(feature = "vfs")]
impl VfsConfig {
#[must_use]
pub fn new(mount_path: impl Into<String>) -> Self {
Self {
mount_path: mount_path.into(),
..Default::default()
}
}
#[must_use]
pub fn with_dir_perms(mut self, perms: eryx_vfs::DirPerms) -> Self {
self.dir_perms = perms;
self
}
#[must_use]
pub fn with_file_perms(mut self, perms: eryx_vfs::FilePerms) -> Self {
self.file_perms = perms;
self
}
#[must_use]
pub fn with_volume(mut self, volume: VolumeMount) -> Self {
self.volumes.push(volume);
self
}
#[must_use]
pub fn with_volumes(mut self, volumes: impl IntoIterator<Item = VolumeMount>) -> Self {
self.volumes.extend(volumes);
self
}
}
pub struct SessionExecutor {
executor: Arc<PythonExecutor>,
store: Option<Store<ExecutorState>>,
bindings: Option<SandboxBindings>,
execution_count: u32,
execution_timeout: Option<Duration>,
fuel_limit: Option<u64>,
#[cfg(feature = "vfs")]
vfs_storage: Option<eryx_vfs::ArcStorage>,
#[cfg(feature = "vfs")]
vfs_config: Option<VfsConfig>,
}
impl std::fmt::Debug for SessionExecutor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SessionExecutor")
.field("execution_count", &self.execution_count)
.field("has_store", &self.store.is_some())
.field("has_bindings", &self.bindings.is_some())
.field("execution_timeout", &self.execution_timeout)
.field("fuel_limit", &self.fuel_limit)
.finish_non_exhaustive()
}
}
fn build_callback_infos(callbacks: &[Arc<dyn Callback>]) -> Vec<HostCallbackInfo> {
callbacks
.iter()
.map(|cb| HostCallbackInfo {
name: cb.name().to_string(),
description: cb.description().to_string(),
parameters_schema_json: serde_json::to_string(&cb.parameters_schema())
.unwrap_or_else(|_| "{}".to_string()),
})
.collect()
}
fn build_wasi_context(executor: &PythonExecutor) -> Result<WasiCtx, Error> {
let mut wasi_builder = WasiCtxBuilder::new();
wasi_builder.inherit_stdout().inherit_stderr();
let site_packages_paths = executor.python_site_packages_paths();
let mut pythonpath_parts = Vec::new();
if executor.python_stdlib_path().is_some() {
pythonpath_parts.push("/python-stdlib".to_string());
}
for i in 0..site_packages_paths.len() {
pythonpath_parts.push(format!("/site-packages-{i}"));
}
if let Some(stdlib_path) = executor.python_stdlib_path() {
wasi_builder.env("PYTHONHOME", "/python-stdlib");
if !pythonpath_parts.is_empty() {
wasi_builder.env("PYTHONPATH", pythonpath_parts.join(":"));
}
wasi_builder
.preopened_dir(
stdlib_path,
"/python-stdlib",
DirPerms::READ,
FilePerms::READ,
)
.map_err(|e| Error::WasmEngine(format!("Failed to mount Python stdlib: {e}")))?;
}
for (i, site_packages_path) in site_packages_paths.iter().enumerate() {
let mount_path = format!("/site-packages-{i}");
wasi_builder
.preopened_dir(
site_packages_path,
&mount_path,
DirPerms::READ,
FilePerms::READ,
)
.map_err(|e| Error::WasmEngine(format!("Failed to mount {mount_path}: {e}")))?;
}
Ok(wasi_builder.build())
}
#[cfg(feature = "vfs")]
fn build_hybrid_vfs_context(
executor: &PythonExecutor,
vfs_storage: eryx_vfs::ArcStorage,
vfs_config: &VfsConfig,
) -> std::result::Result<eryx_vfs::HybridVfsCtx<eryx_vfs::ArcStorage>, Error> {
let mut ctx = eryx_vfs::HybridVfsCtx::new(vfs_storage);
ctx.add_vfs_preopen(
&vfs_config.mount_path,
vfs_config.dir_perms,
vfs_config.file_perms,
);
if let Some(stdlib_path) = executor.python_stdlib_path()
&& let Ok(real_dir) =
eryx_vfs::RealDir::open_ambient(stdlib_path, DirPerms::READ, FilePerms::READ)
{
ctx.add_real_preopen("/python-stdlib", real_dir);
}
for (i, site_packages_path) in executor.python_site_packages_paths().iter().enumerate() {
let mount_path = format!("/site-packages-{i}");
if let Ok(real_dir) =
eryx_vfs::RealDir::open_ambient(site_packages_path, DirPerms::READ, FilePerms::READ)
{
ctx.add_real_preopen(&mount_path, real_dir);
}
}
for volume in &vfs_config.volumes {
let (dir_perms, file_perms) = if volume.read_only {
(DirPerms::READ, FilePerms::READ)
} else {
(DirPerms::all(), FilePerms::all())
};
let result = if volume.host_path.is_file() {
ctx.add_real_file_preopen_path(
&volume.guest_path,
&volume.host_path,
dir_perms,
file_perms,
)
} else {
ctx.add_real_preopen_path(&volume.guest_path, &volume.host_path, dir_perms, file_perms)
};
result.map_err(|e| {
Error::WasmEngine(format!(
"Failed to mount volume {} -> {}: {e}",
volume.host_path.display(),
volume.guest_path,
))
})?;
}
Ok(ctx)
}
impl SessionExecutor {
#[tracing::instrument(
name = "SessionExecutor::new",
skip(executor, callbacks),
fields(callbacks = callbacks.len())
)]
pub async fn new(
executor: Arc<PythonExecutor>,
callbacks: &[Arc<dyn Callback>],
) -> Result<Self, Error> {
#[cfg(feature = "vfs")]
{
Self::new_internal(executor, callbacks, None, None).await
}
#[cfg(not(feature = "vfs"))]
{
Self::new_internal(executor, callbacks).await
}
}
#[cfg(feature = "vfs")]
pub async fn new_with_vfs(
executor: Arc<PythonExecutor>,
callbacks: &[Arc<dyn Callback>],
vfs_storage: eryx_vfs::ArcStorage,
) -> Result<Self, Error> {
Self::new_internal(executor, callbacks, Some(vfs_storage), None).await
}
#[cfg(feature = "vfs")]
pub async fn new_with_vfs_config(
executor: Arc<PythonExecutor>,
callbacks: &[Arc<dyn Callback>],
vfs_storage: eryx_vfs::ArcStorage,
vfs_config: VfsConfig,
) -> Result<Self, Error> {
Self::new_internal(executor, callbacks, Some(vfs_storage), Some(vfs_config)).await
}
#[cfg(feature = "vfs")]
async fn new_internal(
executor: Arc<PythonExecutor>,
callbacks: &[Arc<dyn Callback>],
vfs_storage: Option<eryx_vfs::ArcStorage>,
vfs_config: Option<VfsConfig>,
) -> Result<Self, Error> {
let callback_infos = build_callback_infos(callbacks);
let wasi = build_wasi_context(&executor)?;
let vfs_storage = vfs_storage.unwrap_or_else(|| {
eryx_vfs::ArcStorage::new(std::sync::Arc::new(eryx_vfs::InMemoryStorage::new()))
});
let vfs_config = vfs_config.unwrap_or_default();
let hybrid_vfs_ctx = Some(build_hybrid_vfs_context(
&executor,
vfs_storage.clone(),
&vfs_config,
)?);
let state = ExecutorState::new(
wasi,
ResourceTable::new(),
None,
None,
callback_infos,
MemoryTracker::new(None),
hybrid_vfs_ctx,
);
let mut store = Store::new(executor.engine(), state);
store.limiter(|state| &mut state.memory_tracker);
store.set_epoch_deadline(u64::MAX / 2);
store
.set_fuel(u64::MAX)
.map_err(|e| Error::WasmEngine(format!("Failed to set fuel: {e}")))?;
let bindings = executor
.instance_pre()
.instantiate_async(&mut store)
.await
.map_err(|e| Error::WasmEngine(format!("Failed to instantiate component: {e}")))?;
Ok(Self {
executor,
store: Some(store),
bindings: Some(bindings),
execution_count: 0,
execution_timeout: None,
fuel_limit: None,
vfs_storage: Some(vfs_storage),
vfs_config: Some(vfs_config),
})
}
#[cfg(not(feature = "vfs"))]
async fn new_internal(
executor: Arc<PythonExecutor>,
callbacks: &[Arc<dyn Callback>],
) -> Result<Self, Error> {
let callback_infos = build_callback_infos(callbacks);
let wasi = build_wasi_context(&executor)?;
let state = ExecutorState::new(
wasi,
ResourceTable::new(),
None,
None,
callback_infos,
MemoryTracker::new(None),
);
let mut store = Store::new(executor.engine(), state);
store.limiter(|state| &mut state.memory_tracker);
store.set_epoch_deadline(u64::MAX / 2);
store
.set_fuel(u64::MAX)
.map_err(|e| Error::WasmEngine(format!("Failed to set fuel: {e}")))?;
let bindings = executor
.instance_pre()
.instantiate_async(&mut store)
.await
.map_err(|e| Error::WasmEngine(format!("Failed to instantiate component: {e}")))?;
Ok(Self {
executor,
store: Some(store),
bindings: Some(bindings),
execution_count: 0,
execution_timeout: None,
fuel_limit: None,
})
}
pub fn set_execution_timeout(&mut self, timeout: Option<Duration>) {
self.execution_timeout = timeout;
}
#[must_use]
pub fn execution_timeout(&self) -> Option<Duration> {
self.execution_timeout
}
pub fn set_fuel_limit(&mut self, limit: Option<u64>) {
self.fuel_limit = limit;
}
#[must_use]
pub fn fuel_limit(&self) -> Option<u64> {
self.fuel_limit
}
#[must_use]
pub fn execute(&mut self, code: impl Into<String>) -> SessionExecuteBuilder<'_> {
SessionExecuteBuilder::new(self, code)
}
#[allow(clippy::too_many_arguments)]
async fn execute_internal(
&mut self,
code: &str,
callbacks: &[Arc<dyn Callback>],
callback_tx: Option<mpsc::Sender<CallbackRequest>>,
trace_tx: Option<mpsc::UnboundedSender<TraceRequest>>,
net_tx: Option<mpsc::Sender<NetRequest>>,
output_tx: Option<mpsc::UnboundedSender<crate::wasm::OutputRequest>>,
per_execute_fuel_limit: Option<u64>,
) -> Result<ExecutionOutput, Error> {
let start = Instant::now();
let mut store = self.store.take().ok_or_else(|| {
Error::Execution("Store not available (concurrent execution?)".to_string())
})?;
let bindings = self
.bindings
.take()
.ok_or_else(|| Error::Execution("Bindings not available".to_string()))?;
let callback_infos: Vec<HostCallbackInfo> = callbacks
.iter()
.map(|cb| HostCallbackInfo {
name: cb.name().to_string(),
description: cb.description().to_string(),
parameters_schema_json: serde_json::to_string(&cb.parameters_schema())
.unwrap_or_else(|_| "{}".to_string()),
})
.collect();
{
let state = store.data_mut();
state.set_callback_tx(callback_tx);
state.set_trace_tx(trace_tx);
state.set_net_tx(net_tx);
state.set_output_tx(output_tx);
state.set_callbacks(callback_infos);
state.reset_memory_tracker();
}
self.execution_count += 1;
let fuel_limit = per_execute_fuel_limit.or(self.fuel_limit);
let initial_fuel = fuel_limit.unwrap_or(u64::MAX);
store
.set_fuel(initial_fuel)
.map_err(|e| Error::Initialization(format!("Failed to set fuel: {e}")))?;
tracing::debug!(
code_len = code.len(),
execution_count = self.execution_count,
fuel_limit = ?fuel_limit,
"SessionExecutor: executing Python code"
);
let execution_timeout = self.execution_timeout;
let epoch_ticker = if let Some(timeout) = execution_timeout {
const EPOCH_TICK_MS: u64 = 10;
let ticks_until_timeout = timeout.as_millis() as u64 / EPOCH_TICK_MS;
let ticks = ticks_until_timeout.max(1);
store.set_epoch_deadline(ticks);
store.epoch_deadline_trap();
let engine = self.executor.engine().clone();
let stop_flag = Arc::new(AtomicBool::new(false));
let stop_flag_clone = Arc::clone(&stop_flag);
std::thread::spawn(move || {
while !stop_flag_clone.load(Ordering::Relaxed) {
std::thread::sleep(Duration::from_millis(EPOCH_TICK_MS));
engine.increment_epoch();
}
});
Some(stop_flag)
} else {
store.set_epoch_deadline(u64::MAX / 2);
store.epoch_deadline_trap();
None::<Arc<AtomicBool>>
};
let code_owned = code.to_string();
let mut async_timeout_elapsed = false;
let result = if let Some(timeout) = execution_timeout {
match tokio::time::timeout(
timeout,
store.run_concurrent(async |accessor| {
bindings.call_execute(accessor, code_owned).await
}),
)
.await
{
Ok(result) => result,
Err(_elapsed) => {
async_timeout_elapsed = true;
Err(wasmtime::Error::msg("async timeout elapsed"))
}
}
} else {
store
.run_concurrent(async |accessor| bindings.call_execute(accessor, code_owned).await)
.await
};
if let Some(stop_flag) = epoch_ticker {
stop_flag.store(true, Ordering::Relaxed);
}
let peak_memory = {
let state = store.data_mut();
state.set_callback_tx(None);
state.set_trace_tx(None);
state.set_net_tx(None);
state.set_output_tx(None);
state.peak_memory_bytes()
};
let remaining_fuel = store.get_fuel().unwrap_or(0);
let fuel_consumed = Some(initial_fuel.saturating_sub(remaining_fuel));
self.store = Some(store);
self.bindings = Some(bindings);
let wasmtime_result = result.map_err(|e| {
if async_timeout_elapsed
|| e.downcast_ref::<wasmtime::Trap>() == Some(&wasmtime::Trap::Interrupt)
{
Error::Timeout(execution_timeout.unwrap_or_default())
} else if e.downcast_ref::<wasmtime::Trap>() == Some(&wasmtime::Trap::OutOfFuel) {
let consumed = initial_fuel.saturating_sub(remaining_fuel);
let limit = fuel_limit.unwrap_or(u64::MAX);
Error::FuelExhausted { consumed, limit }
} else {
Error::Execution(format!("WASM execution error: {e:?}"))
}
})?;
let wit_output = wasmtime_result
.map_err(|e| Error::Execution(format!("WASM execution error: {e:?}")))?
.map_err(Error::Execution)?;
let duration = start.elapsed();
Ok(ExecutionOutput::new(
wit_output.stdout,
wit_output.stderr,
peak_memory,
duration,
0, fuel_consumed,
))
}
#[must_use]
pub fn execution_count(&self) -> u32 {
self.execution_count
}
#[tracing::instrument(
name = "SessionExecutor::reset",
skip(self, callbacks),
fields(
callbacks = callbacks.len(),
execution_count = self.execution_count,
)
)]
pub async fn reset(&mut self, callbacks: &[Arc<dyn Callback>]) -> Result<(), Error> {
let callback_infos = build_callback_infos(callbacks);
let wasi = build_wasi_context(&self.executor)?;
#[cfg(not(feature = "vfs"))]
let state = ExecutorState::new(
wasi,
ResourceTable::new(),
None,
None,
callback_infos,
MemoryTracker::new(None),
);
#[cfg(feature = "vfs")]
let state = {
let vfs_storage = self.vfs_storage.clone().unwrap_or_else(|| {
eryx_vfs::ArcStorage::new(std::sync::Arc::new(eryx_vfs::InMemoryStorage::new()))
});
let vfs_config = self.vfs_config.clone().unwrap_or_default();
ExecutorState::new(
wasi,
ResourceTable::new(),
None,
None,
callback_infos,
MemoryTracker::new(None),
Some(build_hybrid_vfs_context(
&self.executor,
vfs_storage,
&vfs_config,
)?),
)
};
let mut store = Store::new(self.executor.engine(), state);
store.limiter(|state| &mut state.memory_tracker);
store.set_epoch_deadline(u64::MAX / 2);
store
.set_fuel(u64::MAX)
.map_err(|e| Error::WasmEngine(format!("Failed to set fuel: {e}")))?;
let execution_timeout = self.execution_timeout;
let fuel_limit = self.fuel_limit;
let bindings = self
.executor
.instance_pre()
.instantiate_async(&mut store)
.await
.map_err(|e| Error::WasmEngine(format!("Failed to reinstantiate component: {e}")))?;
self.store = Some(store);
self.bindings = Some(bindings);
self.execution_count = 0;
self.execution_timeout = execution_timeout;
self.fuel_limit = fuel_limit;
Ok(())
}
#[must_use]
pub fn store(&self) -> Option<&Store<ExecutorState>> {
self.store.as_ref()
}
#[must_use]
pub fn store_mut(&mut self) -> Option<&mut Store<ExecutorState>> {
self.store.as_mut()
}
pub async fn snapshot_state(&mut self) -> Result<PythonStateSnapshot, Error> {
let mut store = self
.store
.take()
.ok_or_else(|| Error::WasmEngine("Store not available".to_string()))?;
let bindings = self
.bindings
.take()
.ok_or_else(|| Error::WasmEngine("Bindings not available".to_string()))?;
tracing::debug!("SessionExecutor: capturing state snapshot");
let result = store
.run_concurrent(async |accessor| bindings.call_snapshot_state(accessor).await)
.await;
self.store = Some(store);
self.bindings = Some(bindings);
let wasmtime_result =
result.map_err(|e| Error::WasmEngine(format!("WASM snapshot error: {e}")))?;
let inner_result =
wasmtime_result.map_err(|e| Error::WasmEngine(format!("WASM snapshot error: {e}")))?;
let data = inner_result.map_err(Error::Snapshot)?;
if data.len() > MAX_SNAPSHOT_SIZE {
return Err(Error::Snapshot(format!(
"Snapshot too large: {} bytes (max {} bytes)",
data.len(),
MAX_SNAPSHOT_SIZE
)));
}
tracing::debug!(size_bytes = data.len(), "State snapshot captured");
Ok(PythonStateSnapshot::new(data))
}
pub async fn restore_state(&mut self, snapshot: &PythonStateSnapshot) -> Result<(), Error> {
let mut store = self
.store
.take()
.ok_or_else(|| Error::WasmEngine("Store not available".to_string()))?;
let bindings = self
.bindings
.take()
.ok_or_else(|| Error::WasmEngine("Bindings not available".to_string()))?;
tracing::debug!(
size_bytes = snapshot.size(),
"SessionExecutor: restoring state snapshot"
);
let data = snapshot.data().to_vec();
let result = store
.run_concurrent(async |accessor| bindings.call_restore_state(accessor, data).await)
.await;
self.store = Some(store);
self.bindings = Some(bindings);
let wasmtime_result =
result.map_err(|e| Error::WasmEngine(format!("WASM restore error: {e}")))?;
let inner_result =
wasmtime_result.map_err(|e| Error::WasmEngine(format!("WASM restore error: {e}")))?;
inner_result.map_err(Error::Snapshot)?;
tracing::debug!("State snapshot restored");
Ok(())
}
pub async fn clear_state(&mut self) -> Result<(), Error> {
let mut store = self
.store
.take()
.ok_or_else(|| Error::WasmEngine("Store not available".to_string()))?;
let bindings = self
.bindings
.take()
.ok_or_else(|| Error::WasmEngine("Bindings not available".to_string()))?;
tracing::debug!("SessionExecutor: clearing state");
let result = store
.run_concurrent(async |accessor| bindings.call_clear_state(accessor).await)
.await;
self.store = Some(store);
self.bindings = Some(bindings);
let wasmtime_result =
result.map_err(|e| Error::WasmEngine(format!("WASM clear state error: {e}")))?;
wasmtime_result.map_err(|e| Error::WasmEngine(format!("WASM clear state error: {e}")))?;
tracing::debug!("State cleared");
Ok(())
}
}
impl ExecutorState {
#[cfg(not(feature = "vfs"))]
pub(crate) fn new(
wasi: WasiCtx,
table: ResourceTable,
callback_tx: Option<mpsc::Sender<CallbackRequest>>,
trace_tx: Option<mpsc::UnboundedSender<TraceRequest>>,
callbacks: Vec<HostCallbackInfo>,
memory_tracker: MemoryTracker,
) -> Self {
Self {
wasi,
table,
callback_tx,
trace_tx,
callbacks,
memory_tracker,
net_tx: None, output_tx: None, }
}
#[cfg(feature = "vfs")]
pub(crate) fn new(
wasi: WasiCtx,
table: ResourceTable,
callback_tx: Option<mpsc::Sender<CallbackRequest>>,
trace_tx: Option<mpsc::UnboundedSender<TraceRequest>>,
callbacks: Vec<HostCallbackInfo>,
memory_tracker: MemoryTracker,
hybrid_vfs_ctx: Option<eryx_vfs::HybridVfsCtx<eryx_vfs::ArcStorage>>,
) -> Self {
Self {
wasi,
table,
callback_tx,
trace_tx,
callbacks,
memory_tracker,
net_tx: None, output_tx: None, hybrid_vfs_ctx,
}
}
pub(crate) fn set_callback_tx(&mut self, tx: Option<mpsc::Sender<CallbackRequest>>) {
self.callback_tx = tx;
}
pub(crate) fn set_trace_tx(&mut self, tx: Option<mpsc::UnboundedSender<TraceRequest>>) {
self.trace_tx = tx;
}
pub(crate) fn set_net_tx(&mut self, tx: Option<mpsc::Sender<NetRequest>>) {
self.net_tx = tx;
}
pub(crate) fn set_output_tx(
&mut self,
tx: Option<mpsc::UnboundedSender<crate::wasm::OutputRequest>>,
) {
self.output_tx = tx;
}
pub(crate) fn set_callbacks(&mut self, callbacks: Vec<HostCallbackInfo>) {
self.callbacks = callbacks;
}
pub(crate) fn peak_memory_bytes(&self) -> u64 {
self.memory_tracker.peak_memory_bytes()
}
pub(crate) fn reset_memory_tracker(&self) {
self.memory_tracker.reset();
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_session_executor_debug() {
let _fmt = format!("{:?}", "SessionExecutor placeholder");
}
#[test]
fn test_python_state_snapshot_roundtrip() {
let data = vec![1, 2, 3, 4, 5];
let snapshot = PythonStateSnapshot::new(data.clone());
assert_eq!(snapshot.data(), &data);
assert_eq!(snapshot.size(), 5);
assert!(snapshot.metadata().timestamp_ms > 0);
let bytes = snapshot.to_bytes();
let restored = PythonStateSnapshot::from_bytes(&bytes).expect("from_bytes failed");
assert_eq!(restored.data(), &data);
assert_eq!(
restored.metadata().timestamp_ms,
snapshot.metadata().timestamp_ms
);
}
#[test]
fn test_python_state_snapshot_from_bytes_too_short() {
let bytes = vec![1, 2, 3]; let result = PythonStateSnapshot::from_bytes(&bytes);
assert!(result.is_err());
}
#[test]
fn test_snapshot_metadata() {
let snapshot = PythonStateSnapshot::new(vec![0; 100]);
let meta = snapshot.metadata();
assert_eq!(meta.size_bytes, 100);
assert!(meta.timestamp_ms > 0);
}
}