use std::sync::Arc;
use std::time::Instant;
use async_trait::async_trait;
use tokio::sync::mpsc;
use crate::callback::Callback;
use crate::callback_handler::{run_callback_handler, run_output_collector, run_trace_collector};
use crate::error::Error;
use crate::sandbox::{ExecuteResult, ExecuteStats, Sandbox};
use crate::wasm::{CallbackRequest, OutputRequest, TraceRequest};
use super::Session;
use super::executor::{PythonStateSnapshot, SessionExecutor};
pub struct InProcessSession<'a> {
sandbox: &'a Sandbox,
executor: SessionExecutor,
preamble_executed: bool,
}
impl std::fmt::Debug for InProcessSession<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("InProcessSession")
.field("execution_count", &self.executor.execution_count())
.field("preamble_executed", &self.preamble_executed)
.finish_non_exhaustive()
}
}
impl<'a> InProcessSession<'a> {
#[tracing::instrument(
name = "InProcessSession::new",
skip(sandbox),
fields(
callbacks = sandbox.callbacks().len(),
has_preamble = !sandbox.preamble().is_empty(),
)
)]
pub async fn new(sandbox: &'a Sandbox) -> Result<Self, Error> {
let callbacks: Vec<Arc<dyn Callback>> = sandbox.callbacks().values().cloned().collect();
let mut executor = SessionExecutor::new(sandbox.executor().clone(), &callbacks).await?;
executor.set_execution_timeout(sandbox.resource_limits().execution_timeout);
Ok(Self {
sandbox,
executor,
preamble_executed: false,
})
}
#[tracing::instrument(
name = "InProcessSession::execute",
skip(self, code),
fields(
code_len = code.len(),
execution_count = self.executor.execution_count(),
)
)]
async fn execute_internal(&mut self, code: &str) -> Result<ExecuteResult, Error> {
let start = Instant::now();
let full_code = if !self.preamble_executed && !self.sandbox.preamble().is_empty() {
self.preamble_executed = true;
format!("{}\n\n# User code\n{}", self.sandbox.preamble(), code)
} else {
code.to_string()
};
let (callback_tx, callback_rx) = mpsc::channel::<CallbackRequest>(32);
let (trace_tx, trace_rx) = mpsc::unbounded_channel::<TraceRequest>();
let callbacks_arc = self.sandbox.callbacks_arc();
let resource_limits = self.sandbox.resource_limits().clone();
let secrets_arc = std::sync::Arc::new(self.sandbox.secrets().clone());
let callback_secrets = std::sync::Arc::clone(&secrets_arc);
let callback_handler = tokio::spawn(async move {
run_callback_handler(
callback_rx,
callbacks_arc,
resource_limits,
callback_secrets,
)
.await
});
let trace_handler = self.sandbox.trace_handler().clone();
let trace_secrets = self.sandbox.secrets().clone();
let trace_collector = tokio::spawn(async move {
run_trace_collector(trace_rx, trace_handler, trace_secrets).await
});
let (output_tx, output_rx) = mpsc::unbounded_channel::<OutputRequest>();
let output_handler_ref = self.sandbox.output_handler().clone();
let output_secrets = self.sandbox.secrets().clone();
let scrub_stdout = self.sandbox.scrub_stdout();
let scrub_stderr = self.sandbox.scrub_stderr();
let output_collector = tokio::spawn(async move {
run_output_collector(
output_rx,
output_handler_ref,
output_secrets,
scrub_stdout,
scrub_stderr,
)
.await
});
let callbacks: Vec<Arc<dyn Callback>> =
self.sandbox.callbacks().values().cloned().collect();
let execution_result = self
.executor
.execute(&full_code)
.with_callbacks(&callbacks, callback_tx)
.with_tracing(trace_tx)
.with_output_streaming(output_tx)
.run()
.await;
let callback_invocations = callback_handler.await.unwrap_or(0);
let trace_events = trace_collector.await.unwrap_or_default();
let _ = output_collector.await;
let duration = start.elapsed();
match execution_result {
Ok(output) => {
tracing::info!(
duration_ms = duration.as_millis() as u64,
callback_invocations,
peak_memory_bytes = output.peak_memory_bytes,
fuel_consumed = ?output.fuel_consumed,
"Session execution completed"
);
Ok(ExecuteResult {
stdout: output.stdout,
stderr: output.stderr,
trace: trace_events,
stats: ExecuteStats {
duration,
callback_invocations,
peak_memory_bytes: Some(output.peak_memory_bytes),
fuel_consumed: output.fuel_consumed,
},
})
}
Err(error) => Err(error),
}
}
#[must_use]
pub fn execution_count(&self) -> u32 {
self.executor.execution_count()
}
pub async fn snapshot_state(&mut self) -> Result<PythonStateSnapshot, Error> {
self.executor.snapshot_state().await
}
pub async fn restore_state(&mut self, snapshot: &PythonStateSnapshot) -> Result<(), Error> {
self.executor.restore_state(snapshot).await
}
pub async fn clear_state(&mut self) -> Result<(), Error> {
self.executor.clear_state().await
}
}
#[async_trait]
impl Session for InProcessSession<'_> {
async fn execute(&mut self, code: &str) -> Result<ExecuteResult, Error> {
self.execute_internal(code).await
}
async fn reset(&mut self) -> Result<(), Error> {
let callbacks: Vec<Arc<dyn Callback>> =
self.sandbox.callbacks().values().cloned().collect();
self.executor.reset(&callbacks).await?;
self.preamble_executed = false;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_in_process_session_size() {
assert!(std::mem::size_of::<InProcessSession<'_>>() > 0);
}
}