use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use tokio::sync::{oneshot, Semaphore};
pub mod js;
pub mod python;
pub mod sandbox;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ScriptLanguage {
Python,
JavaScript,
}
impl ScriptLanguage {
pub fn as_str(&self) -> &'static str {
match self {
Self::Python => "python",
Self::JavaScript => "javascript",
}
}
}
#[derive(Debug, Clone)]
pub struct ScriptConfig {
pub enabled: bool,
pub num_workers: usize,
pub queue_capacity: usize,
pub max_concurrent: usize,
pub default_timeout: Duration,
pub max_output_bytes: usize,
pub allow_network: bool,
pub allow_filesystem: bool,
pub inject_page_html: bool,
pub html_max_bytes: usize,
}
impl Default for ScriptConfig {
fn default() -> Self {
Self {
enabled: false,
num_workers: 4,
queue_capacity: 64,
max_concurrent: 4,
default_timeout: Duration::from_secs(5),
max_output_bytes: 64 * 1024,
allow_network: false,
allow_filesystem: true,
inject_page_html: true,
html_max_bytes: 32 * 1024,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ScriptContext {
pub url: Option<String>,
pub title: Option<String>,
pub html: Option<String>,
pub memory_json: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ScriptResult {
pub language: String,
pub success: bool,
pub stdout: String,
pub stderr: String,
pub value: Option<serde_json::Value>,
pub elapsed_ms: u64,
pub timed_out: bool,
}
impl ScriptResult {
pub(crate) fn error(language: ScriptLanguage, msg: impl Into<String>, elapsed_ms: u64) -> Self {
Self {
language: language.as_str().to_string(),
success: false,
stdout: String::new(),
stderr: msg.into(),
value: None,
elapsed_ms,
timed_out: false,
}
}
pub(crate) fn timeout(language: ScriptLanguage, elapsed_ms: u64) -> Self {
Self {
language: language.as_str().to_string(),
success: false,
stdout: String::new(),
stderr: "script timed out".into(),
value: None,
elapsed_ms,
timed_out: true,
}
}
pub(crate) fn truncate_output(&mut self, max_output_bytes: usize) {
truncate_utf8(&mut self.stdout, max_output_bytes);
truncate_utf8(&mut self.stderr, max_output_bytes);
}
}
fn truncate_utf8(s: &mut String, max: usize) {
if s.len() <= max {
return;
}
let mut cut = max;
while cut > 0 && !s.is_char_boundary(cut) {
cut -= 1;
}
s.truncate(cut);
s.push_str("\n…[output truncated]");
}
pub(crate) struct Job {
pub language: ScriptLanguage,
pub code: String,
pub context: ScriptContext,
pub config: Arc<ScriptConfig>,
pub interrupt: Arc<AtomicBool>,
pub started_at: std::time::Instant,
pub runtime: tokio::runtime::Handle,
pub reply: oneshot::Sender<ScriptResult>,
}
#[derive(Clone)]
pub struct ScriptEngine {
config: Arc<ScriptConfig>,
tx: flume::Sender<Job>,
permits: Arc<Semaphore>,
}
impl std::fmt::Debug for ScriptEngine {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ScriptEngine")
.field("enabled", &self.config.enabled)
.field("num_workers", &self.config.num_workers)
.field("queue_capacity", &self.config.queue_capacity)
.field("max_concurrent", &self.config.max_concurrent)
.finish()
}
}
impl ScriptEngine {
pub fn new(config: ScriptConfig) -> Self {
let config = Arc::new(config);
let (tx, rx) = flume::bounded::<Job>(config.queue_capacity.max(1));
let permits = Arc::new(Semaphore::new(config.max_concurrent.max(1)));
for i in 0..config.num_workers.max(1) {
let rx = rx.clone();
let name = format!("spider-agent-script-{i}");
let spawn_result = std::thread::Builder::new()
.name(name.clone())
.stack_size(2 * 1024 * 1024)
.spawn(move || worker_loop(rx));
if let Err(e) = spawn_result {
log::error!("failed to spawn script worker {name}: {e}");
}
}
Self {
config,
tx,
permits,
}
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
pub fn config(&self) -> &ScriptConfig {
&self.config
}
pub async fn run_python(
&self,
code: String,
context: ScriptContext,
timeout_override: Option<Duration>,
) -> ScriptResult {
self.run(ScriptLanguage::Python, code, context, timeout_override)
.await
}
pub async fn run_javascript(
&self,
code: String,
context: ScriptContext,
timeout_override: Option<Duration>,
) -> ScriptResult {
self.run(ScriptLanguage::JavaScript, code, context, timeout_override)
.await
}
async fn run(
&self,
language: ScriptLanguage,
code: String,
context: ScriptContext,
timeout_override: Option<Duration>,
) -> ScriptResult {
let start = std::time::Instant::now();
if !self.config.enabled {
return ScriptResult::error(language, "scripting engine is disabled", 0);
}
let _permit = match self.permits.clone().acquire_owned().await {
Ok(p) => p,
Err(_) => return ScriptResult::error(language, "permit acquire failed", 0),
};
let (reply_tx, reply_rx) = oneshot::channel();
let interrupt = Arc::new(AtomicBool::new(false));
let runtime = tokio::runtime::Handle::current();
let job = Job {
language,
code,
context,
config: self.config.clone(),
interrupt: interrupt.clone(),
started_at: start,
runtime,
reply: reply_tx,
};
if self.tx.send_async(job).await.is_err() {
return ScriptResult::error(
language,
"script worker pool is shut down",
elapsed(start),
);
}
let deadline = timeout_override.unwrap_or(self.config.default_timeout);
match tokio::time::timeout(deadline, reply_rx).await {
Ok(Ok(mut result)) => {
result.truncate_output(self.config.max_output_bytes);
result
}
Ok(Err(_)) => {
ScriptResult::error(
language,
"script worker dropped reply channel",
elapsed(start),
)
}
Err(_) => {
interrupt.store(true, Ordering::Relaxed);
ScriptResult::timeout(language, elapsed(start))
}
}
}
}
fn elapsed(start: std::time::Instant) -> u64 {
start.elapsed().as_millis() as u64
}
fn worker_loop(rx: flume::Receiver<Job>) {
log::debug!(
"script worker started on thread {:?}",
std::thread::current().name()
);
while let Ok(job) = rx.recv() {
let started_at = job.started_at;
let language = job.language;
let outcome = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| match language {
ScriptLanguage::Python => python::run(&job),
ScriptLanguage::JavaScript => js::run(&job),
}));
let mut result = match outcome {
Ok(Ok(r)) => r,
Ok(Err(err)) => ScriptResult::error(
language,
format!("internal error: {err}"),
elapsed(started_at),
),
Err(panic_payload) => {
let msg = panic_message(panic_payload);
log::error!("script worker caught panic: {msg}");
ScriptResult::error(
language,
format!("interpreter panic: {msg}"),
elapsed(started_at),
)
}
};
if result.elapsed_ms == 0 {
result.elapsed_ms = elapsed(started_at);
}
let _ = job.reply.send(result);
}
log::debug!(
"script worker stopped on thread {:?}",
std::thread::current().name()
);
}
fn panic_message(payload: Box<dyn std::any::Any + Send>) -> String {
if let Some(s) = payload.downcast_ref::<&'static str>() {
(*s).to_string()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
"unknown panic".to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn truncate_utf8_safe() {
let mut s = "héllo, wörld".to_string();
truncate_utf8(&mut s, 5);
assert!(s.is_char_boundary(s.len() - "\n…[output truncated]".len()));
assert!(s.starts_with("h"));
}
#[test]
fn engine_disabled_by_default() {
let cfg = ScriptConfig::default();
assert!(!cfg.enabled);
}
}