use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use tokio::sync::{oneshot, Semaphore};
pub mod js;
pub mod python;
pub mod sandbox;
#[derive(Debug, Default)]
pub struct ScriptUsage {
pub scripts_run: AtomicU64,
pub scripts_timed_out: AtomicU64,
pub scripts_failed: AtomicU64,
pub fetch_calls: AtomicU64,
pub fetch_errors: AtomicU64,
pub fetch_bytes_in: AtomicU64,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ScriptUsageSnapshot {
pub scripts_run: u64,
pub scripts_timed_out: u64,
pub scripts_failed: u64,
pub fetch_calls: u64,
pub fetch_errors: u64,
pub fetch_bytes_in: u64,
}
impl ScriptUsage {
pub(crate) fn snapshot(&self) -> ScriptUsageSnapshot {
ScriptUsageSnapshot {
scripts_run: self.scripts_run.load(Ordering::Relaxed),
scripts_timed_out: self.scripts_timed_out.load(Ordering::Relaxed),
scripts_failed: self.scripts_failed.load(Ordering::Relaxed),
fetch_calls: self.fetch_calls.load(Ordering::Relaxed),
fetch_errors: self.fetch_errors.load(Ordering::Relaxed),
fetch_bytes_in: self.fetch_bytes_in.load(Ordering::Relaxed),
}
}
}
pub(crate) fn fallback_http_client() -> Option<&'static reqwest::Client> {
use std::sync::OnceLock;
static CLIENT: OnceLock<Option<reqwest::Client>> = OnceLock::new();
CLIENT
.get_or_init(|| {
reqwest::Client::builder()
.user_agent(concat!("spider_agent_script/", env!("CARGO_PKG_VERSION")))
.build()
.ok()
})
.as_ref()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ScriptLanguage {
Python,
JavaScript,
}
impl ScriptLanguage {
pub fn as_str(&self) -> &'static str {
match self {
Self::Python => "python",
Self::JavaScript => "javascript",
}
}
}
#[derive(Debug, Clone)]
pub struct ScriptConfig {
pub enabled: bool,
pub num_workers: usize,
pub queue_capacity: usize,
pub max_concurrent: usize,
pub default_timeout: Duration,
pub permit_acquire_timeout: Duration,
pub max_output_bytes: usize,
pub allow_network: bool,
pub allow_filesystem: bool,
pub inject_page_html: bool,
pub html_max_bytes: usize,
}
impl Default for ScriptConfig {
fn default() -> Self {
Self {
enabled: false,
num_workers: 4,
queue_capacity: 64,
max_concurrent: 4,
default_timeout: Duration::from_secs(5),
permit_acquire_timeout: Duration::from_secs(30),
max_output_bytes: 64 * 1024,
allow_network: false,
allow_filesystem: true,
inject_page_html: true,
html_max_bytes: 32 * 1024,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ScriptContext {
pub url: Option<String>,
pub title: Option<String>,
pub html: Option<String>,
pub memory_json: Option<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ScriptResult {
pub language: String,
pub success: bool,
pub stdout: String,
pub stderr: String,
pub value: Option<serde_json::Value>,
pub elapsed_ms: u64,
pub timed_out: bool,
}
impl ScriptResult {
pub(crate) fn error(language: ScriptLanguage, msg: impl Into<String>, elapsed_ms: u64) -> Self {
Self {
language: language.as_str().to_string(),
success: false,
stdout: String::new(),
stderr: msg.into(),
value: None,
elapsed_ms,
timed_out: false,
}
}
pub(crate) fn timeout(language: ScriptLanguage, elapsed_ms: u64) -> Self {
Self {
language: language.as_str().to_string(),
success: false,
stdout: String::new(),
stderr: "script timed out".into(),
value: None,
elapsed_ms,
timed_out: true,
}
}
pub(crate) fn truncate_output(&mut self, max_output_bytes: usize) {
truncate_utf8(&mut self.stdout, max_output_bytes);
truncate_utf8(&mut self.stderr, max_output_bytes);
}
}
fn truncate_utf8(s: &mut String, max: usize) {
if s.len() <= max {
return;
}
let mut cut = max;
while cut > 0 && !s.is_char_boundary(cut) {
cut -= 1;
}
s.truncate(cut);
s.push_str("\n…[output truncated]");
}
pub(crate) struct Job {
pub language: ScriptLanguage,
pub code: String,
pub context: ScriptContext,
pub config: Arc<ScriptConfig>,
pub interrupt: Arc<AtomicBool>,
pub started_at: std::time::Instant,
pub runtime: tokio::runtime::Handle,
pub reply: oneshot::Sender<ScriptResult>,
pub client: reqwest::Client,
pub usage: Arc<ScriptUsage>,
}
#[derive(Clone)]
pub struct ScriptEngine {
config: Arc<ScriptConfig>,
tx: flume::Sender<Job>,
permits: Arc<Semaphore>,
default_client: reqwest::Client,
usage: Arc<ScriptUsage>,
}
impl std::fmt::Debug for ScriptEngine {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ScriptEngine")
.field("enabled", &self.config.enabled)
.field("num_workers", &self.config.num_workers)
.field("queue_capacity", &self.config.queue_capacity)
.field("max_concurrent", &self.config.max_concurrent)
.finish()
}
}
impl ScriptEngine {
pub fn new(config: ScriptConfig) -> Self {
let config = Arc::new(config);
let (tx, rx) = flume::bounded::<Job>(config.queue_capacity.max(1));
let permits = Arc::new(Semaphore::new(config.max_concurrent.max(1)));
for i in 0..config.num_workers.max(1) {
let rx = rx.clone();
let name = format!("spider-agent-script-{i}");
let spawn_result = std::thread::Builder::new()
.name(name.clone())
.stack_size(2 * 1024 * 1024)
.spawn(move || worker_loop(rx));
if let Err(e) = spawn_result {
log::error!("failed to spawn script worker {name}: {e}");
}
}
let default_client = fallback_http_client()
.cloned()
.unwrap_or_else(reqwest::Client::new);
Self {
config,
tx,
permits,
default_client,
usage: Arc::new(ScriptUsage::default()),
}
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
pub fn config(&self) -> &ScriptConfig {
&self.config
}
pub fn with_client(mut self, client: reqwest::Client) -> Self {
self.default_client = client;
self
}
pub fn usage_snapshot(&self) -> ScriptUsageSnapshot {
self.usage.snapshot()
}
pub fn usage_handle(&self) -> Arc<ScriptUsage> {
self.usage.clone()
}
pub async fn run_python(
&self,
code: String,
context: ScriptContext,
timeout_override: Option<Duration>,
) -> ScriptResult {
self.run(
ScriptLanguage::Python,
code,
context,
timeout_override,
None,
)
.await
}
pub async fn run_python_with_client(
&self,
code: String,
context: ScriptContext,
timeout_override: Option<Duration>,
client: reqwest::Client,
) -> ScriptResult {
self.run(
ScriptLanguage::Python,
code,
context,
timeout_override,
Some(client),
)
.await
}
pub async fn run_javascript(
&self,
code: String,
context: ScriptContext,
timeout_override: Option<Duration>,
) -> ScriptResult {
self.run(
ScriptLanguage::JavaScript,
code,
context,
timeout_override,
None,
)
.await
}
pub async fn run_javascript_with_client(
&self,
code: String,
context: ScriptContext,
timeout_override: Option<Duration>,
client: reqwest::Client,
) -> ScriptResult {
self.run(
ScriptLanguage::JavaScript,
code,
context,
timeout_override,
Some(client),
)
.await
}
async fn run(
&self,
language: ScriptLanguage,
code: String,
context: ScriptContext,
timeout_override: Option<Duration>,
client_override: Option<reqwest::Client>,
) -> ScriptResult {
let start = std::time::Instant::now();
if !self.config.enabled {
return ScriptResult::error(language, "scripting engine is disabled", 0);
}
let _permit = match tokio::time::timeout(
self.config.permit_acquire_timeout,
self.permits.clone().acquire_owned(),
)
.await
{
Ok(Ok(p)) => p,
Ok(Err(_)) => return ScriptResult::error(language, "permit acquire failed", 0),
Err(_) => {
return ScriptResult::error(
language,
"permit acquire timed out — workers may be stuck",
elapsed(start),
);
}
};
let (reply_tx, reply_rx) = oneshot::channel();
let interrupt = Arc::new(AtomicBool::new(false));
let runtime = tokio::runtime::Handle::current();
let job = Job {
language,
code,
context,
config: self.config.clone(),
interrupt: interrupt.clone(),
started_at: start,
runtime,
reply: reply_tx,
client: client_override.unwrap_or_else(|| self.default_client.clone()),
usage: self.usage.clone(),
};
self.usage.scripts_run.fetch_add(1, Ordering::Relaxed);
if self.tx.send_async(job).await.is_err() {
return ScriptResult::error(
language,
"script worker pool is shut down",
elapsed(start),
);
}
let deadline = timeout_override.unwrap_or(self.config.default_timeout);
match tokio::time::timeout(deadline, reply_rx).await {
Ok(Ok(mut result)) => {
result.truncate_output(self.config.max_output_bytes);
if !result.success {
self.usage.scripts_failed.fetch_add(1, Ordering::Relaxed);
if result.timed_out {
self.usage.scripts_timed_out.fetch_add(1, Ordering::Relaxed);
}
}
result
}
Ok(Err(_)) => {
self.usage.scripts_failed.fetch_add(1, Ordering::Relaxed);
ScriptResult::error(
language,
"script worker dropped reply channel",
elapsed(start),
)
}
Err(_) => {
interrupt.store(true, Ordering::Relaxed);
self.usage.scripts_failed.fetch_add(1, Ordering::Relaxed);
self.usage.scripts_timed_out.fetch_add(1, Ordering::Relaxed);
ScriptResult::timeout(language, elapsed(start))
}
}
}
}
fn elapsed(start: std::time::Instant) -> u64 {
start.elapsed().as_millis() as u64
}
fn worker_loop(rx: flume::Receiver<Job>) {
log::debug!(
"script worker started on thread {:?}",
std::thread::current().name()
);
while let Ok(job) = rx.recv() {
let started_at = job.started_at;
let language = job.language;
let outcome = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| match language {
ScriptLanguage::Python => python::run(&job),
ScriptLanguage::JavaScript => js::run(&job),
}));
let mut result = match outcome {
Ok(Ok(r)) => r,
Ok(Err(err)) => ScriptResult::error(
language,
format!("internal error: {err}"),
elapsed(started_at),
),
Err(panic_payload) => {
let msg = panic_message(panic_payload);
log::error!("script worker caught panic: {msg}");
ScriptResult::error(
language,
format!("interpreter panic: {msg}"),
elapsed(started_at),
)
}
};
if result.elapsed_ms == 0 {
result.elapsed_ms = elapsed(started_at);
}
let _ = job.reply.send(result);
}
python::cleanup_thread_local();
js::cleanup_thread_local();
log::debug!(
"script worker stopped on thread {:?}",
std::thread::current().name()
);
}
fn panic_message(payload: Box<dyn std::any::Any + Send>) -> String {
if let Some(s) = payload.downcast_ref::<&'static str>() {
(*s).to_string()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
"unknown panic".to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn truncate_utf8_safe() {
let mut s = "héllo, wörld".to_string();
truncate_utf8(&mut s, 5);
assert!(s.is_char_boundary(s.len() - "\n…[output truncated]".len()));
assert!(s.starts_with("h"));
}
#[test]
fn engine_disabled_by_default() {
let cfg = ScriptConfig::default();
assert!(!cfg.enabled);
}
}