use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Instant;
use eryx::JsonSchema;
use eryx::{CallbackError, Sandbox, TypedCallback};
use serde::Deserialize;
use serde_json::{Value, json};
#[derive(Deserialize, JsonSchema)]
struct SleepArgs {
ms: u64,
}
struct SleepCallback {
concurrent_count: Arc<AtomicU32>,
peak_concurrent: Arc<AtomicU32>,
}
impl SleepCallback {
fn new() -> Self {
Self {
concurrent_count: Arc::new(AtomicU32::new(0)),
peak_concurrent: Arc::new(AtomicU32::new(0)),
}
}
}
impl TypedCallback for SleepCallback {
type Args = SleepArgs;
fn name(&self) -> &str {
"sleep"
}
fn description(&self) -> &str {
"Sleeps for the specified number of milliseconds"
}
fn invoke_typed(
&self,
args: SleepArgs,
) -> Pin<Box<dyn Future<Output = Result<Value, CallbackError>> + Send + '_>> {
let current = self.concurrent_count.fetch_add(1, Ordering::SeqCst) + 1;
self.peak_concurrent.fetch_max(current, Ordering::SeqCst);
let concurrent_count = self.concurrent_count.clone();
Box::pin(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(args.ms)).await;
concurrent_count.fetch_sub(1, Ordering::SeqCst);
Ok(json!({ "slept_ms": args.ms }))
})
}
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
println!("=== Parallel Execution Test ===\n");
let sleep_callback = SleepCallback::new();
let peak_concurrent = sleep_callback.peak_concurrent.clone();
let sandbox = Sandbox::embedded().with_callback(sleep_callback).build()?;
println!("Test 1: Sequential execution (3 x 100ms callbacks)");
let start = Instant::now();
let result = sandbox
.execute(
r#"
result1 = await sleep(ms=100)
result2 = await sleep(ms=100)
result3 = await sleep(ms=100)
print(f"Results: {result1}, {result2}, {result3}")
"#,
)
.await?;
let sequential_duration = start.elapsed();
println!(" Duration: {sequential_duration:?}");
println!(" Output: {}", result.stdout);
println!(" Callbacks invoked: {}", result.stats.callback_invocations);
println!();
println!("Test 2: Parallel execution (3 x 100ms callbacks via asyncio.gather)");
let start = Instant::now();
let result = sandbox
.execute(
r#"
import asyncio
results = await asyncio.gather(
sleep(ms=100),
sleep(ms=100),
sleep(ms=100),
)
print(f"Results: {results}")
"#,
)
.await?;
let parallel_duration = start.elapsed();
println!(" Duration: {parallel_duration:?}");
println!(" Output: {}", result.stdout);
println!(" Callbacks invoked: {}", result.stats.callback_invocations);
println!(
" Peak concurrent callbacks: {}",
peak_concurrent.load(Ordering::SeqCst)
);
println!();
println!("=== Results ===");
println!("Sequential duration: {sequential_duration:?}");
println!("Parallel duration: {parallel_duration:?}");
let speedup = sequential_duration.as_millis() as f64 / parallel_duration.as_millis() as f64;
println!("Speedup: {speedup:.2}x");
let sequential_ms = sequential_duration.as_millis();
let parallel_ms = parallel_duration.as_millis();
println!();
if parallel_ms < 200 && sequential_ms > 250 {
println!("✅ PASS: Parallel execution is working!");
println!(" Sequential took {sequential_ms}ms (expected ~300ms)");
println!(" Parallel took {parallel_ms}ms (expected ~100ms)");
} else if parallel_ms >= sequential_ms {
println!("❌ FAIL: Parallel execution is NOT faster than sequential!");
println!(" This suggests callbacks are running sequentially.");
} else {
println!("⚠️ INCONCLUSIVE: Results are close, may need longer sleep times");
println!(" Sequential: {sequential_ms}ms, Parallel: {parallel_ms}ms");
}
let peak = peak_concurrent.load(Ordering::SeqCst);
println!();
if peak >= 3 {
println!("✅ PASS: Peak concurrent callbacks = {peak} (3 ran simultaneously)");
} else {
println!("⚠️ Peak concurrent callbacks = {peak} (expected 3 for full parallelism)");
}
Ok(())
}