use lex_bytecode::vm::{EffectHandler, Vm};
use lex_bytecode::{Program, Value};
use std::cell::RefCell;
use std::path::PathBuf;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Mutex, OnceLock};
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use crate::builtins::try_pure_builtin;
use crate::policy::Policy;
pub trait IoSink: Send {
fn print_line(&mut self, s: &str);
}
pub struct StdoutSink;
impl IoSink for StdoutSink {
fn print_line(&mut self, s: &str) {
println!("{s}");
}
}
#[derive(Default)]
pub struct CapturedSink { pub lines: Vec<String> }
impl IoSink for CapturedSink {
fn print_line(&mut self, s: &str) { self.lines.push(s.to_string()); }
}
pub struct DefaultHandler {
policy: Policy,
pub sink: Box<dyn IoSink>,
pub read_root: Option<PathBuf>,
pub budget_used: RefCell<u64>,
pub program: Option<Arc<Program>>,
pub chat_registry: Option<Arc<crate::ws::ChatRegistry>>,
}
impl DefaultHandler {
pub fn new(policy: Policy) -> Self {
Self {
policy,
sink: Box::new(StdoutSink),
read_root: None,
budget_used: RefCell::new(0),
program: None,
chat_registry: None,
}
}
pub fn with_program(mut self, program: Arc<Program>) -> Self {
self.program = Some(program); self
}
pub fn with_chat_registry(mut self, registry: Arc<crate::ws::ChatRegistry>) -> Self {
self.chat_registry = Some(registry); self
}
pub fn with_sink(mut self, sink: Box<dyn IoSink>) -> Self {
self.sink = sink; self
}
pub fn with_read_root(mut self, root: PathBuf) -> Self {
self.read_root = Some(root); self
}
fn ensure_kind_allowed(&self, kind: &str) -> Result<(), String> {
if self.policy.allow_effects.contains(kind) {
Ok(())
} else {
Err(format!("effect `{kind}` not in --allow-effects"))
}
}
fn resolve_read_path(&self, p: &str) -> PathBuf {
match &self.read_root {
Some(root) => root.join(p.trim_start_matches('/')),
None => PathBuf::from(p),
}
}
fn dispatch_log(&mut self, op: &str, args: Vec<Value>) -> Result<Value, String> {
match op {
"debug" | "info" | "warn" | "error" => {
let msg = expect_str(args.first())?;
let level = match op {
"debug" => LogLevel::Debug,
"info" => LogLevel::Info,
"warn" => LogLevel::Warn,
_ => LogLevel::Error,
};
emit_log(level, msg);
Ok(Value::Unit)
}
"set_level" => {
let s = expect_str(args.first())?;
match parse_log_level(s) {
Some(l) => {
log_state().lock().unwrap().level = l;
Ok(ok(Value::Unit))
}
None => Ok(err(Value::Str(format!(
"log.set_level: unknown level `{s}`; expected debug|info|warn|error")))),
}
}
"set_format" => {
let s = expect_str(args.first())?;
let fmt = match s {
"text" => LogFormat::Text,
"json" => LogFormat::Json,
other => return Ok(err(Value::Str(format!(
"log.set_format: unknown format `{other}`; expected text|json")))),
};
log_state().lock().unwrap().format = fmt;
Ok(ok(Value::Unit))
}
"set_sink" => {
let path = expect_str(args.first())?;
if path == "-" {
log_state().lock().unwrap().sink = LogSink::Stderr;
return Ok(ok(Value::Unit));
}
if let Err(e) = self.ensure_fs_write_path(path) {
return Ok(err(Value::Str(e)));
}
match std::fs::OpenOptions::new()
.create(true).append(true).open(path)
{
Ok(f) => {
log_state().lock().unwrap().sink = LogSink::File(std::sync::Arc::new(Mutex::new(f)));
Ok(ok(Value::Unit))
}
Err(e) => Ok(err(Value::Str(format!("log.set_sink `{path}`: {e}")))),
}
}
other => Err(format!("unsupported log.{other}")),
}
}
fn dispatch_process(&mut self, op: &str, args: Vec<Value>) -> Result<Value, String> {
match op {
"spawn" => {
let cmd = expect_str(args.first())?.to_string();
let raw_args = match args.get(1) {
Some(Value::List(items)) => items.clone(),
_ => return Err("process.spawn: args must be List[Str]".into()),
};
let str_args: Result<Vec<String>, String> = raw_args.iter().map(|v| match v {
Value::Str(s) => Ok(s.clone()),
other => Err(format!("process.spawn: arg must be Str, got {other:?}")),
}).collect();
let str_args = str_args?;
let opts = match args.get(2) {
Some(Value::Record(r)) => r.clone(),
_ => return Err("process.spawn: missing or invalid opts record".into()),
};
if !self.policy.allow_proc.is_empty() {
let basename = std::path::Path::new(&cmd)
.file_name()
.and_then(|s| s.to_str())
.unwrap_or(&cmd);
if !self.policy.allow_proc.iter().any(|a| a == basename) {
return Ok(err(Value::Str(format!(
"process.spawn: `{cmd}` not in --allow-proc {:?}",
self.policy.allow_proc
))));
}
}
let mut command = std::process::Command::new(&cmd);
command.args(&str_args);
command.stdin(std::process::Stdio::piped());
command.stdout(std::process::Stdio::piped());
command.stderr(std::process::Stdio::piped());
if let Some(Value::Variant { name, args: vargs }) = opts.get("cwd") {
if name == "Some" {
if let Some(Value::Str(s)) = vargs.first() {
command.current_dir(s);
}
}
}
if let Some(Value::Map(env)) = opts.get("env") {
for (k, v) in env {
if let (lex_bytecode::MapKey::Str(ks), Value::Str(vs)) = (k, v) {
command.env(ks, vs);
}
}
}
let stdin_payload: Option<Vec<u8>> = match opts.get("stdin") {
Some(Value::Variant { name, args: vargs }) if name == "Some" => {
match vargs.first() {
Some(Value::Bytes(b)) => Some(b.clone()),
_ => None,
}
}
_ => None,
};
let mut child = match command.spawn() {
Ok(c) => c,
Err(e) => return Ok(err(Value::Str(format!("process.spawn `{cmd}`: {e}")))),
};
if let Some(payload) = stdin_payload {
if let Some(mut stdin) = child.stdin.take() {
use std::io::Write;
let _ = stdin.write_all(&payload);
}
}
let stdout = child.stdout.take().map(std::io::BufReader::new);
let stderr = child.stderr.take().map(std::io::BufReader::new);
let handle = next_process_handle();
process_registry().lock().unwrap().insert(handle, ProcessState {
child,
stdout,
stderr,
});
Ok(ok(Value::Int(handle as i64)))
}
"read_stdout_line" => Self::read_line_op(args, true),
"read_stderr_line" => Self::read_line_op(args, false),
"wait" => {
let h = expect_process_handle(args.first())?;
let arc = process_registry().lock().unwrap()
.touch_get(h)
.ok_or_else(|| "process.wait: closed or unknown ProcessHandle".to_string())?;
let status = {
let mut state = arc.lock().unwrap();
state.child.wait().map_err(|e| format!("process.wait: {e}"))?
};
process_registry().lock().unwrap().remove(h);
let mut rec = indexmap::IndexMap::new();
rec.insert("code".into(), Value::Int(status.code().unwrap_or(-1) as i64));
#[cfg(unix)]
{
use std::os::unix::process::ExitStatusExt;
rec.insert("signaled".into(), Value::Bool(status.signal().is_some()));
}
#[cfg(not(unix))]
{
rec.insert("signaled".into(), Value::Bool(false));
}
Ok(Value::Record(rec))
}
"kill" => {
let h = expect_process_handle(args.first())?;
let _signal = expect_str(args.get(1))?;
let arc = process_registry().lock().unwrap()
.touch_get(h)
.ok_or_else(|| "process.kill: closed or unknown ProcessHandle".to_string())?;
let mut state = arc.lock().unwrap();
match state.child.kill() {
Ok(_) => Ok(ok(Value::Unit)),
Err(e) => Ok(err(Value::Str(format!("process.kill: {e}")))),
}
}
"run" => {
let cmd = expect_str(args.first())?.to_string();
let raw_args = match args.get(1) {
Some(Value::List(items)) => items.clone(),
_ => return Err("process.run: args must be List[Str]".into()),
};
let str_args: Result<Vec<String>, String> = raw_args.iter().map(|v| match v {
Value::Str(s) => Ok(s.clone()),
other => Err(format!("process.run: arg must be Str, got {other:?}")),
}).collect();
let str_args = str_args?;
if !self.policy.allow_proc.is_empty() {
let basename = std::path::Path::new(&cmd)
.file_name()
.and_then(|s| s.to_str())
.unwrap_or(&cmd);
if !self.policy.allow_proc.iter().any(|a| a == basename) {
return Ok(err(Value::Str(format!(
"process.run: `{cmd}` not in --allow-proc {:?}",
self.policy.allow_proc
))));
}
}
match std::process::Command::new(&cmd).args(&str_args).output() {
Ok(o) => {
let mut rec = indexmap::IndexMap::new();
rec.insert("stdout".into(), Value::Str(
String::from_utf8_lossy(&o.stdout).to_string()));
rec.insert("stderr".into(), Value::Str(
String::from_utf8_lossy(&o.stderr).to_string()));
rec.insert("exit_code".into(), Value::Int(
o.status.code().unwrap_or(-1) as i64));
Ok(ok(Value::Record(rec)))
}
Err(e) => Ok(err(Value::Str(format!("process.run `{cmd}`: {e}")))),
}
}
other => Err(format!("unsupported process.{other}")),
}
}
fn read_line_op(args: Vec<Value>, is_stdout: bool) -> Result<Value, String> {
let h = expect_process_handle(args.first())?;
let arc = process_registry().lock().unwrap()
.touch_get(h)
.ok_or_else(|| format!(
"process.read_{}_line: closed or unknown ProcessHandle",
if is_stdout { "stdout" } else { "stderr" }))?;
let mut state = arc.lock().unwrap();
let reader_opt = if is_stdout {
state.stdout.as_mut().map(|r| -> &mut dyn std::io::BufRead { r })
} else {
state.stderr.as_mut().map(|r| -> &mut dyn std::io::BufRead { r })
};
let reader = match reader_opt {
Some(r) => r,
None => return Ok(none()),
};
let mut line = String::new();
match reader.read_line(&mut line) {
Ok(0) => Ok(none()),
Ok(_) => {
if line.ends_with('\n') { line.pop(); }
if line.ends_with('\r') { line.pop(); }
Ok(some(Value::Str(line)))
}
Err(e) => Err(format!("process.read_*_line: {e}")),
}
}
fn dispatch_fs(&mut self, op: &str, args: Vec<Value>) -> Result<Value, String> {
match op {
"exists" => {
let path = expect_str(args.first())?.to_string();
if let Err(e) = self.ensure_fs_walk_path(&path) {
return Ok(err(Value::Str(e)));
}
Ok(Value::Bool(std::path::Path::new(&path).exists()))
}
"is_file" => {
let path = expect_str(args.first())?.to_string();
if let Err(e) = self.ensure_fs_walk_path(&path) {
return Ok(err(Value::Str(e)));
}
Ok(Value::Bool(std::path::Path::new(&path).is_file()))
}
"is_dir" => {
let path = expect_str(args.first())?.to_string();
if let Err(e) = self.ensure_fs_walk_path(&path) {
return Ok(err(Value::Str(e)));
}
Ok(Value::Bool(std::path::Path::new(&path).is_dir()))
}
"stat" => {
let path = expect_str(args.first())?.to_string();
if let Err(e) = self.ensure_fs_walk_path(&path) {
return Ok(err(Value::Str(e)));
}
match std::fs::metadata(&path) {
Ok(md) => {
let mtime = md.modified()
.ok()
.and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
.map(|d| d.as_secs() as i64)
.unwrap_or(0);
let mut rec = indexmap::IndexMap::new();
rec.insert("size".into(), Value::Int(md.len() as i64));
rec.insert("mtime".into(), Value::Int(mtime));
rec.insert("is_dir".into(), Value::Bool(md.is_dir()));
rec.insert("is_file".into(), Value::Bool(md.is_file()));
Ok(ok(Value::Record(rec)))
}
Err(e) => Ok(err(Value::Str(format!("fs.stat `{path}`: {e}")))),
}
}
"list_dir" => {
let path = expect_str(args.first())?.to_string();
if let Err(e) = self.ensure_fs_walk_path(&path) {
return Ok(err(Value::Str(e)));
}
match std::fs::read_dir(&path) {
Ok(rd) => {
let mut entries: Vec<Value> = Vec::new();
for ent in rd {
match ent {
Ok(e) => {
let p = e.path();
entries.push(Value::Str(p.to_string_lossy().into_owned()));
}
Err(e) => return Ok(err(Value::Str(format!("fs.list_dir: {e}")))),
}
}
Ok(ok(Value::List(entries)))
}
Err(e) => Ok(err(Value::Str(format!("fs.list_dir `{path}`: {e}")))),
}
}
"walk" => {
let path = expect_str(args.first())?.to_string();
if let Err(e) = self.ensure_fs_walk_path(&path) {
return Ok(err(Value::Str(e)));
}
let mut paths: Vec<Value> = Vec::new();
for ent in walkdir::WalkDir::new(&path) {
match ent {
Ok(e) => paths.push(Value::Str(
e.path().to_string_lossy().into_owned())),
Err(e) => return Ok(err(Value::Str(format!("fs.walk: {e}")))),
}
}
Ok(ok(Value::List(paths)))
}
"glob" => {
let pattern = expect_str(args.first())?.to_string();
let entries = match glob::glob(&pattern) {
Ok(e) => e,
Err(e) => return Ok(err(Value::Str(format!("fs.glob: {e}")))),
};
let mut paths: Vec<Value> = Vec::new();
for ent in entries {
match ent {
Ok(p) => {
let s = p.to_string_lossy().into_owned();
if self.policy.allow_fs_read.is_empty()
|| self.policy.allow_fs_read.iter().any(|root| p.starts_with(root))
{
paths.push(Value::Str(s));
}
}
Err(e) => return Ok(err(Value::Str(format!("fs.glob: {e}")))),
}
}
Ok(ok(Value::List(paths)))
}
"mkdir_p" => {
let path = expect_str(args.first())?.to_string();
if let Err(e) = self.ensure_fs_write_path(&path) {
return Ok(err(Value::Str(e)));
}
match std::fs::create_dir_all(&path) {
Ok(_) => Ok(ok(Value::Unit)),
Err(e) => Ok(err(Value::Str(format!("fs.mkdir_p `{path}`: {e}")))),
}
}
"remove" => {
let path = expect_str(args.first())?.to_string();
if let Err(e) = self.ensure_fs_write_path(&path) {
return Ok(err(Value::Str(e)));
}
let p = std::path::Path::new(&path);
let result = if p.is_dir() {
std::fs::remove_dir_all(p)
} else {
std::fs::remove_file(p)
};
match result {
Ok(_) => Ok(ok(Value::Unit)),
Err(e) => Ok(err(Value::Str(format!("fs.remove `{path}`: {e}")))),
}
}
"copy" => {
let src = expect_str(args.first())?.to_string();
let dst = expect_str(args.get(1))?.to_string();
if let Err(e) = self.ensure_fs_walk_path(&src) {
return Ok(err(Value::Str(e)));
}
if let Err(e) = self.ensure_fs_write_path(&dst) {
return Ok(err(Value::Str(e)));
}
match std::fs::copy(&src, &dst) {
Ok(_) => Ok(ok(Value::Unit)),
Err(e) => Ok(err(Value::Str(format!("fs.copy {src} -> {dst}: {e}")))),
}
}
other => Err(format!("unsupported fs.{other}")),
}
}
fn ensure_fs_walk_path(&self, path: &str) -> Result<(), String> {
if self.policy.allow_fs_read.is_empty() {
return Ok(());
}
let p = std::path::Path::new(path);
if self.policy.allow_fs_read.iter().any(|a| p.starts_with(a)) {
Ok(())
} else {
Err(format!("fs path `{path}` outside --allow-fs-read"))
}
}
fn ensure_fs_write_path(&self, path: &str) -> Result<(), String> {
if self.policy.allow_fs_write.is_empty() {
return Ok(());
}
let p = std::path::Path::new(path);
if self.policy.allow_fs_write.iter().any(|a| p.starts_with(a)) {
Ok(())
} else {
Err(format!("fs path `{path}` outside --allow-fs-write"))
}
}
fn ensure_host_allowed(&self, url: &str) -> Result<(), String> {
if self.policy.allow_net_host.is_empty() { return Ok(()); }
let host = extract_host(url).unwrap_or("");
if self.policy.allow_net_host.iter().any(|h| host == h) {
Ok(())
} else {
Err(format!(
"net call to host `{host}` not in --allow-net-host {:?}",
self.policy.allow_net_host,
))
}
}
}
fn extract_host(url: &str) -> Option<&str> {
let rest = url.strip_prefix("http://").or_else(|| url.strip_prefix("https://"))?;
let host_port = match rest.find('/') {
Some(i) => &rest[..i],
None => rest,
};
Some(match host_port.rsplit_once(':') {
Some((h, _)) => h,
None => host_port,
})
}
impl EffectHandler for DefaultHandler {
fn dispatch(&mut self, kind: &str, op: &str, args: Vec<Value>) -> Result<Value, String> {
if let Some(r) = try_pure_builtin(kind, op, &args) {
return r;
}
if kind == "process" {
self.ensure_kind_allowed("proc")?;
return self.dispatch_process(op, args);
}
if kind == "log" {
let effect_kind = match op {
"debug" | "info" | "warn" | "error" => "log",
"set_level" | "set_format" => "io",
"set_sink" => {
self.ensure_kind_allowed("io")?;
self.ensure_kind_allowed("fs_write")?;
return self.dispatch_log(op, args);
}
other => return Err(format!("unsupported log.{other}")),
};
self.ensure_kind_allowed(effect_kind)?;
return self.dispatch_log(op, args);
}
if kind == "fs" {
let effect_kind = match op {
"exists" | "is_file" | "is_dir" | "stat"
| "list_dir" | "walk" | "glob" => "fs_walk",
"mkdir_p" | "remove" => "fs_write",
"copy" => {
self.ensure_kind_allowed("fs_walk")?;
self.ensure_kind_allowed("fs_write")?;
return self.dispatch_fs(op, args);
}
other => return Err(format!("unsupported fs.{other}")),
};
self.ensure_kind_allowed(effect_kind)?;
return self.dispatch_fs(op, args);
}
if kind == "datetime" && op == "now" {
self.ensure_kind_allowed("time")?;
let now = chrono::Utc::now();
let nanos = now.timestamp_nanos_opt().unwrap_or(i64::MAX);
return Ok(Value::Int(nanos));
}
if kind == "crypto" && op == "random" {
self.ensure_kind_allowed("random")?;
let n = expect_int(args.first())?;
if !(0..=1_048_576).contains(&n) {
return Err("crypto.random: n must be in 0..=1048576".into());
}
use rand::{rngs::OsRng, TryRngCore};
let mut buf = vec![0u8; n as usize];
OsRng.try_fill_bytes(&mut buf)
.map_err(|e| format!("crypto.random: OS RNG: {e}"))?;
return Ok(Value::Bytes(buf));
}
if kind == "agent" {
let effect_kind = match op {
"local_complete" => "llm_local",
"cloud_complete" => "llm_cloud",
"send_a2a" => "a2a",
"call_mcp" => "mcp",
other => return Err(format!("unsupported agent.{other}")),
};
self.ensure_kind_allowed(effect_kind)?;
if op == "call_mcp" {
return Ok(dispatch_call_mcp(args));
}
return Ok(ok(Value::Str(format!("<{effect_kind} stub>"))));
}
if kind == "http" && matches!(op, "send" | "get" | "post") {
self.ensure_kind_allowed("net")?;
return match op {
"send" => {
let req = expect_record(args.first())?;
Ok(http_send_record(self, req))
}
"get" => {
let url = expect_str(args.first())?.to_string();
self.ensure_host_allowed(&url)?;
Ok(http_send_simple("GET", &url, None, "", None))
}
"post" => {
let url = expect_str(args.first())?.to_string();
let body = expect_bytes(args.get(1))?.clone();
let content_type = expect_str(args.get(2))?.to_string();
self.ensure_host_allowed(&url)?;
Ok(http_send_simple("POST", &url, Some(body), &content_type, None))
}
_ => unreachable!(),
};
}
self.ensure_kind_allowed(kind)?;
match (kind, op) {
("io", "print") => {
let line = expect_str(args.first())?;
self.sink.print_line(line);
Ok(Value::Unit)
}
("io", "read") => {
let path = expect_str(args.first())?.to_string();
let resolved = self.resolve_read_path(&path);
if !self.policy.allow_fs_read.is_empty()
&& !self.policy.allow_fs_read.iter().any(|a| resolved.starts_with(a))
{
return Err(format!("read of `{path}` outside --allow-fs-read"));
}
match std::fs::read_to_string(&resolved) {
Ok(s) => Ok(ok(Value::Str(s))),
Err(e) => Ok(err(Value::Str(format!("{e}")))),
}
}
("io", "write") => {
let path = expect_str(args.first())?.to_string();
let contents = expect_str(args.get(1))?.to_string();
if !self.policy.allow_fs_write.is_empty() {
let p = std::path::Path::new(&path);
if !self.policy.allow_fs_write.iter().any(|a| p.starts_with(a)) {
return Err(format!("write to `{path}` outside --allow-fs-write"));
}
}
match std::fs::write(&path, contents) {
Ok(_) => Ok(ok(Value::Unit)),
Err(e) => Ok(err(Value::Str(format!("{e}")))),
}
}
("time", "now") => {
let secs = SystemTime::now().duration_since(UNIX_EPOCH)
.map_err(|e| format!("time: {e}"))?.as_secs();
Ok(Value::Int(secs as i64))
}
("rand", "int_in") => {
let lo = expect_int(args.first())?;
let hi = expect_int(args.get(1))?;
Ok(Value::Int((lo + hi) / 2))
}
("budget", _) => {
Ok(Value::Unit)
}
("net", "get") => {
let url = expect_str(args.first())?.to_string();
self.ensure_host_allowed(&url)?;
Ok(http_request("GET", &url, None))
}
("net", "post") => {
let url = expect_str(args.first())?.to_string();
let body = expect_str(args.get(1))?.to_string();
self.ensure_host_allowed(&url)?;
Ok(http_request("POST", &url, Some(&body)))
}
("net", "serve") => {
let port = match args.first() {
Some(Value::Int(n)) if (0..=65535).contains(n) => *n as u16,
_ => return Err("net.serve(port, handler): port must be Int 0..=65535".into()),
};
let handler_name = expect_str(args.get(1))?.to_string();
let program = self.program.clone()
.ok_or_else(|| "net.serve requires a Program reference; use DefaultHandler::with_program".to_string())?;
let policy = self.policy.clone();
serve_http(port, handler_name, program, policy, None)
}
("net", "serve_tls") => {
let port = match args.first() {
Some(Value::Int(n)) if (0..=65535).contains(n) => *n as u16,
_ => return Err("net.serve_tls(port, cert, key, handler): port must be Int 0..=65535".into()),
};
let cert_path = expect_str(args.get(1))?.to_string();
let key_path = expect_str(args.get(2))?.to_string();
let handler_name = expect_str(args.get(3))?.to_string();
let program = self.program.clone()
.ok_or_else(|| "net.serve_tls requires a Program reference".to_string())?;
let policy = self.policy.clone();
let cert = std::fs::read(&cert_path)
.map_err(|e| format!("net.serve_tls: read cert {cert_path}: {e}"))?;
let key = std::fs::read(&key_path)
.map_err(|e| format!("net.serve_tls: read key {key_path}: {e}"))?;
serve_http(port, handler_name, program, policy, Some(TlsConfig { cert, key }))
}
("net", "serve_ws") => {
let port = match args.first() {
Some(Value::Int(n)) if (0..=65535).contains(n) => *n as u16,
_ => return Err("net.serve_ws(port, on_message): port must be Int 0..=65535".into()),
};
let handler_name = expect_str(args.get(1))?.to_string();
let program = self.program.clone()
.ok_or_else(|| "net.serve_ws requires a Program reference".to_string())?;
let policy = self.policy.clone();
let registry = Arc::new(crate::ws::ChatRegistry::default());
crate::ws::serve_ws(port, handler_name, program, policy, registry)
}
("chat", "broadcast") => {
let registry = self.chat_registry.as_ref()
.ok_or_else(|| "chat.broadcast called outside a net.serve_ws handler".to_string())?;
let room = expect_str(args.first())?;
let body = expect_str(args.get(1))?;
crate::ws::chat_broadcast(registry, room, body);
Ok(Value::Unit)
}
("chat", "send") => {
let registry = self.chat_registry.as_ref()
.ok_or_else(|| "chat.send called outside a net.serve_ws handler".to_string())?;
let conn_id = match args.first() {
Some(Value::Int(n)) if *n >= 0 => *n as u64,
_ => return Err("chat.send: conn_id must be non-negative Int".into()),
};
let body = expect_str(args.get(1))?;
Ok(Value::Bool(crate::ws::chat_send(registry, conn_id, body)))
}
("kv", "open") => {
let path = expect_str(args.first())?.to_string();
if !self.policy.allow_fs_write.is_empty() {
let p = std::path::Path::new(&path);
if !self.policy.allow_fs_write.iter().any(|a| p.starts_with(a)) {
return Ok(err(Value::Str(format!(
"kv.open: `{path}` outside --allow-fs-write"))));
}
}
match sled::open(&path) {
Ok(db) => {
let handle = next_kv_handle();
kv_registry().lock().unwrap().insert(handle, db);
Ok(ok(Value::Int(handle as i64)))
}
Err(e) => Ok(err(Value::Str(format!("kv.open: {e}")))),
}
}
("kv", "close") => {
let h = expect_kv_handle(args.first())?;
kv_registry().lock().unwrap().remove(h);
Ok(Value::Unit)
}
("kv", "get") => {
let h = expect_kv_handle(args.first())?;
let key = expect_str(args.get(1))?;
let mut reg = kv_registry().lock().unwrap();
let db = reg.touch_get(h).ok_or_else(|| "kv.get: closed or unknown Kv handle".to_string())?;
match db.get(key.as_bytes()) {
Ok(Some(ivec)) => Ok(some(Value::Bytes(ivec.to_vec()))),
Ok(None) => Ok(none()),
Err(e) => Err(format!("kv.get: {e}")),
}
}
("kv", "put") => {
let h = expect_kv_handle(args.first())?;
let key = expect_str(args.get(1))?.to_string();
let val = expect_bytes(args.get(2))?.clone();
let mut reg = kv_registry().lock().unwrap();
let db = reg.touch_get(h).ok_or_else(|| "kv.put: closed or unknown Kv handle".to_string())?;
match db.insert(key.as_bytes(), val) {
Ok(_) => Ok(ok(Value::Unit)),
Err(e) => Ok(err(Value::Str(format!("kv.put: {e}")))),
}
}
("kv", "delete") => {
let h = expect_kv_handle(args.first())?;
let key = expect_str(args.get(1))?;
let mut reg = kv_registry().lock().unwrap();
let db = reg.touch_get(h).ok_or_else(|| "kv.delete: closed or unknown Kv handle".to_string())?;
match db.remove(key.as_bytes()) {
Ok(_) => Ok(ok(Value::Unit)),
Err(e) => Ok(err(Value::Str(format!("kv.delete: {e}")))),
}
}
("kv", "contains") => {
let h = expect_kv_handle(args.first())?;
let key = expect_str(args.get(1))?;
let mut reg = kv_registry().lock().unwrap();
let db = reg.touch_get(h).ok_or_else(|| "kv.contains: closed or unknown Kv handle".to_string())?;
match db.contains_key(key.as_bytes()) {
Ok(present) => Ok(Value::Bool(present)),
Err(e) => Err(format!("kv.contains: {e}")),
}
}
("kv", "list_prefix") => {
let h = expect_kv_handle(args.first())?;
let prefix = expect_str(args.get(1))?;
let mut reg = kv_registry().lock().unwrap();
let db = reg.touch_get(h).ok_or_else(|| "kv.list_prefix: closed or unknown Kv handle".to_string())?;
let mut keys: Vec<Value> = Vec::new();
for kv in db.scan_prefix(prefix.as_bytes()) {
let (k, _) = kv.map_err(|e| format!("kv.list_prefix: {e}"))?;
let s = String::from_utf8_lossy(&k).to_string();
keys.push(Value::Str(s));
}
Ok(Value::List(keys))
}
("sql", "open") => {
let path = expect_str(args.first())?.to_string();
if path != ":memory:" && !self.policy.allow_fs_write.is_empty() {
let p = std::path::Path::new(&path);
if !self.policy.allow_fs_write.iter().any(|a| p.starts_with(a)) {
return Ok(err(Value::Str(format!(
"sql.open: `{path}` outside --allow-fs-write"))));
}
}
match rusqlite::Connection::open(&path) {
Ok(conn) => {
let handle = next_sql_handle();
sql_registry().lock().unwrap().insert(handle, conn);
Ok(ok(Value::Int(handle as i64)))
}
Err(e) => Ok(err(Value::Str(format!("sql.open: {e}")))),
}
}
("sql", "close") => {
let h = expect_sql_handle(args.first())?;
sql_registry().lock().unwrap().remove(h);
Ok(Value::Unit)
}
("sql", "exec") => {
let h = expect_sql_handle(args.first())?;
let stmt = expect_str(args.get(1))?.to_string();
let params = expect_str_list(args.get(2))?;
let arc = sql_registry().lock().unwrap()
.touch_get(h)
.ok_or_else(|| "sql.exec: closed or unknown Db handle".to_string())?;
let conn = arc.lock().unwrap();
let bind: Vec<&dyn rusqlite::ToSql> = params.iter()
.map(|s| s as &dyn rusqlite::ToSql)
.collect();
match conn.execute(&stmt, rusqlite::params_from_iter(bind.iter())) {
Ok(n) => Ok(ok(Value::Int(n as i64))),
Err(e) => Ok(err(Value::Str(format!("sql.exec: {e}")))),
}
}
("sql", "query") => {
let h = expect_sql_handle(args.first())?;
let stmt_str = expect_str(args.get(1))?.to_string();
let params = expect_str_list(args.get(2))?;
let arc = sql_registry().lock().unwrap()
.touch_get(h)
.ok_or_else(|| "sql.query: closed or unknown Db handle".to_string())?;
let conn = arc.lock().unwrap();
Ok(sql_run_query(&conn, &stmt_str, ¶ms))
}
("proc", "spawn") => {
let cmd = expect_str(args.first())?.to_string();
let raw_args = match args.get(1) {
Some(Value::List(items)) => items,
Some(other) => return Err(format!(
"proc.spawn: args must be List[Str], got {other:?}")),
None => return Err("proc.spawn: missing args list".into()),
};
let str_args: Vec<String> = raw_args.iter().map(|v| match v {
Value::Str(s) => Ok(s.clone()),
other => Err(format!("proc.spawn: arg must be Str, got {other:?}")),
}).collect::<Result<Vec<_>, _>>()?;
if !self.policy.allow_proc.is_empty() {
let basename = std::path::Path::new(&cmd)
.file_name()
.and_then(|s| s.to_str())
.unwrap_or(&cmd);
if !self.policy.allow_proc.iter().any(|a| a == basename) {
return Ok(err(Value::Str(format!(
"proc.spawn: `{cmd}` not in --allow-proc {:?}",
self.policy.allow_proc
))));
}
}
if str_args.len() > 1024 {
return Ok(err(Value::Str(
"proc.spawn: arg-count exceeds 1024".into())));
}
if str_args.iter().any(|a| a.len() > 65_536) {
return Ok(err(Value::Str(
"proc.spawn: per-arg length exceeds 64 KiB".into())));
}
let output = std::process::Command::new(&cmd)
.args(&str_args)
.output();
match output {
Ok(o) => {
let mut rec = indexmap::IndexMap::new();
rec.insert("stdout".into(), Value::Str(
String::from_utf8_lossy(&o.stdout).to_string()));
rec.insert("stderr".into(), Value::Str(
String::from_utf8_lossy(&o.stderr).to_string()));
rec.insert("exit_code".into(), Value::Int(
o.status.code().unwrap_or(-1) as i64));
Ok(ok(Value::Record(rec)))
}
Err(e) => Ok(err(Value::Str(format!("spawn `{cmd}`: {e}")))),
}
}
other => Err(format!("unsupported effect {}.{}", other.0, other.1)),
}
}
}
pub struct TlsConfig {
pub cert: Vec<u8>,
pub key: Vec<u8>,
}
fn serve_http(
port: u16,
handler_name: String,
program: Arc<Program>,
policy: Policy,
tls: Option<TlsConfig>,
) -> Result<Value, String> {
let (server, scheme) = match tls {
None => (
tiny_http::Server::http(("127.0.0.1", port))
.map_err(|e| format!("net.serve bind {port}: {e}"))?,
"http",
),
Some(cfg) => {
let ssl = tiny_http::SslConfig {
certificate: cfg.cert,
private_key: cfg.key,
};
(
tiny_http::Server::https(("127.0.0.1", port), ssl)
.map_err(|e| format!("net.serve_tls bind {port}: {e}"))?,
"https",
)
}
};
eprintln!("net.serve: listening on {scheme}://127.0.0.1:{port}");
for req in server.incoming_requests() {
let program = Arc::clone(&program);
let policy = policy.clone();
let handler_name = handler_name.clone();
std::thread::spawn(move || handle_request(req, program, policy, handler_name));
}
Ok(Value::Unit)
}
fn handle_request(
mut req: tiny_http::Request,
program: Arc<Program>,
policy: Policy,
handler_name: String,
) {
let lex_req = build_request_value(&mut req);
let handler = DefaultHandler::new(policy).with_program(Arc::clone(&program));
let mut vm = Vm::with_handler(&program, Box::new(handler));
match vm.call(&handler_name, vec![lex_req]) {
Ok(resp) => {
let (status, body) = unpack_response(&resp);
let response = tiny_http::Response::from_string(body).with_status_code(status);
let _ = req.respond(response);
}
Err(e) => {
let response = tiny_http::Response::from_string(format!("internal error: {e}"))
.with_status_code(500);
let _ = req.respond(response);
}
}
}
fn build_request_value(req: &mut tiny_http::Request) -> Value {
let method = format!("{:?}", req.method()).to_uppercase();
let url = req.url().to_string();
let (path, query) = match url.split_once('?') {
Some((p, q)) => (p.to_string(), q.to_string()),
None => (url, String::new()),
};
let mut body = String::new();
let _ = req.as_reader().read_to_string(&mut body);
let mut rec = indexmap::IndexMap::new();
rec.insert("method".into(), Value::Str(method));
rec.insert("path".into(), Value::Str(path));
rec.insert("query".into(), Value::Str(query));
rec.insert("body".into(), Value::Str(body));
Value::Record(rec)
}
fn unpack_response(v: &Value) -> (u16, String) {
if let Value::Record(rec) = v {
let status = rec.get("status").and_then(|s| match s {
Value::Int(n) => Some(*n as u16),
_ => None,
}).unwrap_or(200);
let body = rec.get("body").and_then(|b| match b {
Value::Str(s) => Some(s.clone()),
_ => None,
}).unwrap_or_default();
return (status, body);
}
(500, format!("handler returned non-record: {v:?}"))
}
fn http_request(method: &str, url: &str, body: Option<&str>) -> Value {
use std::time::Duration;
let agent: ureq::Agent = ureq::Agent::config_builder()
.timeout_connect(Some(Duration::from_secs(10)))
.timeout_recv_body(Some(Duration::from_secs(30)))
.timeout_send_body(Some(Duration::from_secs(10)))
.http_status_as_error(false)
.build()
.into();
let resp = match (method, body) {
("GET", _) => agent.get(url).call(),
("POST", Some(b)) => agent.post(url).send(b),
("POST", None) => agent.post(url).send(""),
(m, _) => return err_value(format!("unsupported method: {m}")),
};
match resp {
Ok(mut r) => {
let status = r.status().as_u16();
let body = r.body_mut().read_to_string().unwrap_or_default();
if (200..300).contains(&status) {
Value::Variant { name: "Ok".into(), args: vec![Value::Str(body)] }
} else {
err_value(format!("status {status}: {body}"))
}
}
Err(e) => err_value(format!("transport: {e}")),
}
}
fn http_agent(timeout_ms: Option<u64>) -> ureq::Agent {
use std::time::Duration;
let mut b = ureq::Agent::config_builder()
.timeout_connect(Some(Duration::from_secs(10)))
.timeout_recv_body(Some(Duration::from_secs(30)))
.timeout_send_body(Some(Duration::from_secs(10)))
.http_status_as_error(false);
if let Some(ms) = timeout_ms {
let d = Duration::from_millis(ms);
b = b.timeout_global(Some(d));
}
b.build().into()
}
fn http_error_value(e: ureq::Error) -> Value {
let (ctor, payload): (&str, Option<String>) = match &e {
ureq::Error::Timeout(_) => ("TimeoutError", None),
ureq::Error::Tls(s) => ("TlsError", Some((*s).into())),
ureq::Error::Pem(p) => ("TlsError", Some(format!("{p}"))),
ureq::Error::Rustls(r) => ("TlsError", Some(format!("{r}"))),
_ => ("NetworkError", Some(format!("{e}"))),
};
let args = match payload { Some(s) => vec![Value::Str(s)], None => vec![] };
let inner = Value::Variant { name: ctor.into(), args };
Value::Variant { name: "Err".into(), args: vec![inner] }
}
fn http_decode_err(msg: String) -> Value {
let inner = Value::Variant {
name: "DecodeError".into(),
args: vec![Value::Str(msg)],
};
Value::Variant { name: "Err".into(), args: vec![inner] }
}
fn http_send_simple(
method: &str,
url: &str,
body: Option<Vec<u8>>,
content_type: &str,
timeout_ms: Option<u64>,
) -> Value {
http_send_full(method, url, body, content_type, &[], timeout_ms)
}
fn http_send_full(
method: &str,
url: &str,
body: Option<Vec<u8>>,
content_type: &str,
headers: &[(String, String)],
timeout_ms: Option<u64>,
) -> Value {
let agent = http_agent(timeout_ms);
let resp = match method {
"GET" => {
let mut req = agent.get(url);
if !content_type.is_empty() { req = req.header("content-type", content_type); }
for (k, v) in headers { req = req.header(k.as_str(), v.as_str()); }
req.call()
}
"POST" => {
let body = body.unwrap_or_default();
let mut req = agent.post(url);
if !content_type.is_empty() { req = req.header("content-type", content_type); }
for (k, v) in headers { req = req.header(k.as_str(), v.as_str()); }
req.send(&body[..])
}
m => {
return http_decode_err(format!("unsupported method: {m}"));
}
};
match resp {
Ok(mut r) => {
let status = r.status().as_u16() as i64;
let headers_map = collect_response_headers(r.headers());
let body_bytes = match r.body_mut().with_config().limit(10 * 1024 * 1024).read_to_vec() {
Ok(b) => b,
Err(e) => return http_decode_err(format!("body read: {e}")),
};
let mut rec = indexmap::IndexMap::new();
rec.insert("status".into(), Value::Int(status));
rec.insert("headers".into(), Value::Map(headers_map));
rec.insert("body".into(), Value::Bytes(body_bytes));
Value::Variant { name: "Ok".into(), args: vec![Value::Record(rec)] }
}
Err(e) => http_error_value(e),
}
}
fn collect_response_headers(
headers: &ureq::http::HeaderMap,
) -> std::collections::BTreeMap<lex_bytecode::MapKey, Value> {
let mut out = std::collections::BTreeMap::new();
for (name, value) in headers.iter() {
let v = value.to_str().unwrap_or("").to_string();
out.insert(lex_bytecode::MapKey::Str(name.as_str().to_string()), Value::Str(v));
}
out
}
fn http_send_record(handler: &DefaultHandler, req: &indexmap::IndexMap<String, Value>) -> Value {
let method = match req.get("method") {
Some(Value::Str(s)) => s.clone(),
_ => return http_decode_err("HttpRequest.method must be Str".into()),
};
let url = match req.get("url") {
Some(Value::Str(s)) => s.clone(),
_ => return http_decode_err("HttpRequest.url must be Str".into()),
};
if let Err(e) = handler.ensure_host_allowed(&url) {
return http_decode_err(e);
}
let body = match req.get("body") {
Some(Value::Variant { name, args }) if name == "None" => None,
Some(Value::Variant { name, args }) if name == "Some" => match args.as_slice() {
[Value::Bytes(b)] => Some(b.clone()),
_ => return http_decode_err("HttpRequest.body Some payload must be Bytes".into()),
},
_ => return http_decode_err("HttpRequest.body must be Option[Bytes]".into()),
};
let timeout_ms = match req.get("timeout_ms") {
Some(Value::Variant { name, .. }) if name == "None" => None,
Some(Value::Variant { name, args }) if name == "Some" => match args.as_slice() {
[Value::Int(n)] if *n >= 0 => Some(*n as u64),
_ => return http_decode_err(
"HttpRequest.timeout_ms Some payload must be a non-negative Int".into()),
},
_ => return http_decode_err("HttpRequest.timeout_ms must be Option[Int]".into()),
};
let headers: Vec<(String, String)> = match req.get("headers") {
Some(Value::Map(m)) => m.iter().filter_map(|(k, v)| {
let kk = match k { lex_bytecode::MapKey::Str(s) => s.clone(), _ => return None };
let vv = match v { Value::Str(s) => s.clone(), _ => return None };
Some((kk, vv))
}).collect(),
_ => return http_decode_err("HttpRequest.headers must be Map[Str, Str]".into()),
};
http_send_full(&method, &url, body, "", &headers, timeout_ms)
}
fn expect_record(v: Option<&Value>) -> Result<&indexmap::IndexMap<String, Value>, String> {
match v {
Some(Value::Record(r)) => Ok(r),
Some(other) => Err(format!("expected Record, got {other:?}")),
None => Err("missing Record argument".into()),
}
}
fn err_value(msg: String) -> Value {
Value::Variant { name: "Err".into(), args: vec![Value::Str(msg)] }
}
fn expect_str(v: Option<&Value>) -> Result<&str, String> {
match v {
Some(Value::Str(s)) => Ok(s),
Some(other) => Err(format!("expected Str arg, got {other:?}")),
None => Err("missing argument".into()),
}
}
fn expect_int(v: Option<&Value>) -> Result<i64, String> {
match v {
Some(Value::Int(n)) => Ok(*n),
Some(other) => Err(format!("expected Int arg, got {other:?}")),
None => Err("missing argument".into()),
}
}
fn ok(v: Value) -> Value {
Value::Variant { name: "Ok".into(), args: vec![v] }
}
fn err(v: Value) -> Value {
Value::Variant { name: "Err".into(), args: vec![v] }
}
fn dispatch_call_mcp(args: Vec<Value>) -> Value {
let server = match args.first() {
Some(Value::Str(s)) => s.clone(),
_ => return err(Value::Str(
"agent.call_mcp(server, tool, args_json): server must be Str".into())),
};
let tool = match args.get(1) {
Some(Value::Str(s)) => s.clone(),
_ => return err(Value::Str(
"agent.call_mcp(server, tool, args_json): tool must be Str".into())),
};
let args_json = match args.get(2) {
Some(Value::Str(s)) => s.clone(),
_ => return err(Value::Str(
"agent.call_mcp(server, tool, args_json): args_json must be Str".into())),
};
let parsed: serde_json::Value = match serde_json::from_str(&args_json) {
Ok(v) => v,
Err(e) => return err(Value::Str(format!(
"agent.call_mcp: args_json is not valid JSON: {e}"))),
};
let mut client = match crate::mcp_client::McpClient::spawn(&server) {
Ok(c) => c,
Err(e) => return err(Value::Str(e)),
};
match client.call_tool(&tool, parsed) {
Ok(result) => ok(Value::Str(
serde_json::to_string(&result).unwrap_or_else(|_| "null".into()))),
Err(e) => err(Value::Str(e)),
}
}
fn some(v: Value) -> Value {
Value::Variant { name: "Some".into(), args: vec![v] }
}
fn none() -> Value {
Value::Variant { name: "None".into(), args: vec![] }
}
fn expect_bytes(v: Option<&Value>) -> Result<&Vec<u8>, String> {
match v {
Some(Value::Bytes(b)) => Ok(b),
Some(other) => Err(format!("expected Bytes arg, got {other:?}")),
None => Err("missing argument".into()),
}
}
fn expect_kv_handle(v: Option<&Value>) -> Result<u64, String> {
match v {
Some(Value::Int(n)) if *n >= 0 => Ok(*n as u64),
Some(other) => Err(format!("expected Kv handle (Int), got {other:?}")),
None => Err("missing Kv argument".into()),
}
}
fn expect_sql_handle(v: Option<&Value>) -> Result<u64, String> {
match v {
Some(Value::Int(n)) if *n >= 0 => Ok(*n as u64),
Some(other) => Err(format!("expected Db handle (Int), got {other:?}")),
None => Err("missing Db argument".into()),
}
}
fn expect_str_list(v: Option<&Value>) -> Result<Vec<String>, String> {
match v {
Some(Value::List(items)) => items.iter().map(|x| match x {
Value::Str(s) => Ok(s.clone()),
other => Err(format!("expected List[Str] element, got {other:?}")),
}).collect(),
Some(other) => Err(format!("expected List[Str], got {other:?}")),
None => Err("missing List[Str] argument".into()),
}
}
fn sql_run_query(
conn: &rusqlite::Connection,
stmt_str: &str,
params: &[String],
) -> Value {
let mut stmt = match conn.prepare(stmt_str) {
Ok(s) => s,
Err(e) => return err(Value::Str(format!("sql.query: {e}"))),
};
let column_count = stmt.column_count();
let column_names: Vec<String> = (0..column_count)
.map(|i| stmt.column_name(i).unwrap_or("").to_string())
.collect();
let bind: Vec<&dyn rusqlite::ToSql> = params.iter()
.map(|s| s as &dyn rusqlite::ToSql)
.collect();
let mut rows = match stmt.query(rusqlite::params_from_iter(bind.iter())) {
Ok(r) => r,
Err(e) => return err(Value::Str(format!("sql.query: {e}"))),
};
let mut out: Vec<Value> = Vec::new();
loop {
let row = match rows.next() {
Ok(Some(r)) => r,
Ok(None) => break,
Err(e) => return err(Value::Str(format!("sql.query: {e}"))),
};
let mut rec = indexmap::IndexMap::new();
for (i, name) in column_names.iter().enumerate() {
let cell = match row.get_ref(i) {
Ok(c) => sql_value_ref_to_lex(c),
Err(e) => return err(Value::Str(format!("sql.query: column {i}: {e}"))),
};
rec.insert(name.clone(), cell);
}
out.push(Value::Record(rec));
}
ok(Value::List(out))
}
fn sql_value_ref_to_lex(v: rusqlite::types::ValueRef<'_>) -> Value {
use rusqlite::types::ValueRef;
match v {
ValueRef::Null => Value::Unit,
ValueRef::Integer(n) => Value::Int(n),
ValueRef::Real(f) => Value::Float(f),
ValueRef::Text(s) => Value::Str(String::from_utf8_lossy(s).into_owned()),
ValueRef::Blob(b) => Value::Bytes(b.to_vec()),
}
}
#[derive(Clone, Copy, PartialEq, PartialOrd)]
enum LogLevel { Debug, Info, Warn, Error }
#[derive(Clone, Copy, PartialEq)]
enum LogFormat { Text, Json }
#[derive(Clone)]
enum LogSink {
Stderr,
File(std::sync::Arc<Mutex<std::fs::File>>),
}
struct LogState {
level: LogLevel,
format: LogFormat,
sink: LogSink,
}
fn log_state() -> &'static Mutex<LogState> {
static STATE: OnceLock<Mutex<LogState>> = OnceLock::new();
STATE.get_or_init(|| Mutex::new(LogState {
level: LogLevel::Info,
format: LogFormat::Text,
sink: LogSink::Stderr,
}))
}
fn parse_log_level(s: &str) -> Option<LogLevel> {
match s {
"debug" => Some(LogLevel::Debug),
"info" => Some(LogLevel::Info),
"warn" => Some(LogLevel::Warn),
"error" => Some(LogLevel::Error),
_ => None,
}
}
fn level_label(l: LogLevel) -> &'static str {
match l {
LogLevel::Debug => "debug",
LogLevel::Info => "info",
LogLevel::Warn => "warn",
LogLevel::Error => "error",
}
}
fn emit_log(level: LogLevel, msg: &str) {
let state = log_state().lock().unwrap();
if level < state.level {
return;
}
let ts = chrono::Utc::now().to_rfc3339();
let line = match state.format {
LogFormat::Text => format!("[{}] {}: {}\n", ts, level_label(level), msg),
LogFormat::Json => {
let escaped = msg
.replace('\\', "\\\\")
.replace('"', "\\\"")
.replace('\n', "\\n")
.replace('\r', "\\r");
format!(
"{{\"ts\":\"{ts}\",\"level\":\"{}\",\"msg\":\"{escaped}\"}}\n",
level_label(level),
)
}
};
let sink = state.sink.clone();
drop(state);
match sink {
LogSink::Stderr => {
use std::io::Write;
let _ = std::io::stderr().write_all(line.as_bytes());
}
LogSink::File(f) => {
use std::io::Write;
if let Ok(mut g) = f.lock() {
let _ = g.write_all(line.as_bytes());
}
}
}
}
pub(crate) struct ProcessState {
child: std::process::Child,
stdout: Option<std::io::BufReader<std::process::ChildStdout>>,
stderr: Option<std::io::BufReader<std::process::ChildStderr>>,
}
fn process_registry() -> &'static Mutex<ProcessRegistry> {
static REGISTRY: OnceLock<Mutex<ProcessRegistry>> = OnceLock::new();
REGISTRY.get_or_init(|| Mutex::new(ProcessRegistry::with_capacity(MAX_PROCESS_HANDLES)))
}
const MAX_PROCESS_HANDLES: usize = 256;
type SharedProcessState = Arc<Mutex<ProcessState>>;
pub(crate) struct ProcessRegistry {
entries: indexmap::IndexMap<u64, SharedProcessState>,
cap: usize,
}
impl ProcessRegistry {
pub(crate) fn with_capacity(cap: usize) -> Self {
Self { entries: indexmap::IndexMap::new(), cap }
}
pub(crate) fn insert(&mut self, handle: u64, state: ProcessState) {
if self.entries.len() >= self.cap {
self.entries.shift_remove_index(0);
}
self.entries.insert(handle, Arc::new(Mutex::new(state)));
}
pub(crate) fn touch_get(&mut self, handle: u64) -> Option<SharedProcessState> {
let idx = self.entries.get_index_of(&handle)?;
self.entries.move_index(idx, self.entries.len() - 1);
self.entries.get(&handle).cloned()
}
pub(crate) fn remove(&mut self, handle: u64) {
self.entries.shift_remove(&handle);
}
#[cfg(test)]
pub(crate) fn len(&self) -> usize { self.entries.len() }
}
fn next_process_handle() -> u64 {
static COUNTER: AtomicU64 = AtomicU64::new(1);
COUNTER.fetch_add(1, Ordering::SeqCst)
}
#[cfg(all(test, unix))]
mod process_registry_tests {
use super::{ProcessRegistry, ProcessState};
fn fresh_state() -> ProcessState {
let child = std::process::Command::new("true")
.stdout(std::process::Stdio::null())
.stderr(std::process::Stdio::null())
.spawn()
.expect("spawn `true`");
ProcessState { child, stdout: None, stderr: None }
}
#[test]
fn insert_and_get_round_trip() {
let mut r = ProcessRegistry::with_capacity(4);
r.insert(1, fresh_state());
assert!(r.touch_get(1).is_some());
assert!(r.touch_get(2).is_none());
}
#[test]
fn touch_get_returns_distinct_arcs_for_distinct_handles() {
let mut r = ProcessRegistry::with_capacity(4);
r.insert(1, fresh_state());
r.insert(2, fresh_state());
let a = r.touch_get(1).unwrap();
let b = r.touch_get(2).unwrap();
assert!(!std::sync::Arc::ptr_eq(&a, &b));
}
#[test]
fn cap_evicts_lru_on_overflow() {
let mut r = ProcessRegistry::with_capacity(2);
r.insert(1, fresh_state());
r.insert(2, fresh_state());
let _ = r.touch_get(1);
r.insert(3, fresh_state());
assert!(r.touch_get(1).is_some(), "1 was MRU, should survive");
assert!(r.touch_get(2).is_none(), "2 was LRU, should be evicted");
assert!(r.touch_get(3).is_some(), "3 just inserted, should survive");
assert_eq!(r.len(), 2);
}
#[test]
fn cap_with_no_touches_evicts_in_insertion_order() {
let mut r = ProcessRegistry::with_capacity(2);
r.insert(10, fresh_state());
r.insert(20, fresh_state());
r.insert(30, fresh_state());
assert!(r.touch_get(10).is_none());
assert!(r.touch_get(20).is_some());
assert!(r.touch_get(30).is_some());
}
#[test]
fn remove_drops_entry() {
let mut r = ProcessRegistry::with_capacity(4);
r.insert(1, fresh_state());
r.remove(1);
assert!(r.touch_get(1).is_none());
assert_eq!(r.len(), 0);
}
#[test]
fn many_inserts_stay_bounded_at_cap() {
let cap = 8;
let mut r = ProcessRegistry::with_capacity(cap);
for i in 0..(cap as u64 * 3) {
r.insert(i, fresh_state());
assert!(r.len() <= cap);
}
assert_eq!(r.len(), cap);
}
#[test]
fn outstanding_arc_outlives_remove() {
let mut r = ProcessRegistry::with_capacity(4);
r.insert(1, fresh_state());
let arc = r.touch_get(1).expect("entry exists");
r.remove(1);
assert!(r.touch_get(1).is_none());
let _state = arc.lock().unwrap();
}
}
fn expect_process_handle(v: Option<&Value>) -> Result<u64, String> {
match v {
Some(Value::Int(n)) if *n >= 0 => Ok(*n as u64),
Some(other) => Err(format!("expected ProcessHandle (Int), got {other:?}")),
None => Err("missing ProcessHandle argument".into()),
}
}
fn kv_registry() -> &'static Mutex<KvRegistry> {
static REGISTRY: OnceLock<Mutex<KvRegistry>> = OnceLock::new();
REGISTRY.get_or_init(|| Mutex::new(KvRegistry::with_capacity(MAX_KV_HANDLES)))
}
const MAX_KV_HANDLES: usize = 256;
pub(crate) struct KvRegistry {
entries: indexmap::IndexMap<u64, sled::Db>,
cap: usize,
}
impl KvRegistry {
pub(crate) fn with_capacity(cap: usize) -> Self {
Self { entries: indexmap::IndexMap::new(), cap }
}
pub(crate) fn insert(&mut self, handle: u64, db: sled::Db) {
if self.entries.len() >= self.cap {
self.entries.shift_remove_index(0);
}
self.entries.insert(handle, db);
}
pub(crate) fn touch_get(&mut self, handle: u64) -> Option<&sled::Db> {
let idx = self.entries.get_index_of(&handle)?;
self.entries.move_index(idx, self.entries.len() - 1);
self.entries.get(&handle)
}
pub(crate) fn remove(&mut self, handle: u64) {
self.entries.shift_remove(&handle);
}
#[cfg(test)]
pub(crate) fn len(&self) -> usize { self.entries.len() }
}
fn next_kv_handle() -> u64 {
static COUNTER: AtomicU64 = AtomicU64::new(1);
COUNTER.fetch_add(1, Ordering::SeqCst)
}
fn sql_registry() -> &'static Mutex<SqlRegistry> {
static REGISTRY: OnceLock<Mutex<SqlRegistry>> = OnceLock::new();
REGISTRY.get_or_init(|| Mutex::new(SqlRegistry::with_capacity(MAX_SQL_HANDLES)))
}
const MAX_SQL_HANDLES: usize = 256;
type SharedConn = Arc<Mutex<rusqlite::Connection>>;
pub(crate) struct SqlRegistry {
entries: indexmap::IndexMap<u64, SharedConn>,
cap: usize,
}
impl SqlRegistry {
pub(crate) fn with_capacity(cap: usize) -> Self {
Self { entries: indexmap::IndexMap::new(), cap }
}
pub(crate) fn insert(&mut self, handle: u64, conn: rusqlite::Connection) {
if self.entries.len() >= self.cap {
self.entries.shift_remove_index(0);
}
self.entries.insert(handle, Arc::new(Mutex::new(conn)));
}
pub(crate) fn touch_get(&mut self, handle: u64) -> Option<SharedConn> {
let idx = self.entries.get_index_of(&handle)?;
self.entries.move_index(idx, self.entries.len() - 1);
self.entries.get(&handle).cloned()
}
pub(crate) fn remove(&mut self, handle: u64) {
self.entries.shift_remove(&handle);
}
#[cfg(test)]
pub(crate) fn len(&self) -> usize { self.entries.len() }
}
fn next_sql_handle() -> u64 {
static COUNTER: AtomicU64 = AtomicU64::new(1);
COUNTER.fetch_add(1, Ordering::SeqCst)
}
#[cfg(test)]
mod sql_registry_tests {
use super::SqlRegistry;
fn fresh() -> rusqlite::Connection {
rusqlite::Connection::open_in_memory().expect("open in-memory sqlite")
}
#[test]
fn insert_and_get_round_trip() {
let mut r = SqlRegistry::with_capacity(4);
r.insert(1, fresh());
assert!(r.touch_get(1).is_some());
assert!(r.touch_get(2).is_none());
}
#[test]
fn cap_evicts_lru_on_overflow() {
let mut r = SqlRegistry::with_capacity(2);
r.insert(1, fresh());
r.insert(2, fresh());
let _ = r.touch_get(1);
r.insert(3, fresh());
assert!(r.touch_get(1).is_some(), "1 was MRU, should survive");
assert!(r.touch_get(2).is_none(), "2 was LRU, should be evicted");
assert!(r.touch_get(3).is_some(), "3 just inserted");
assert_eq!(r.len(), 2);
}
#[test]
fn remove_drops_entry() {
let mut r = SqlRegistry::with_capacity(4);
r.insert(1, fresh());
r.remove(1);
assert!(r.touch_get(1).is_none());
assert_eq!(r.len(), 0);
}
#[test]
fn many_inserts_stay_bounded_at_cap() {
let cap = 8;
let mut r = SqlRegistry::with_capacity(cap);
for i in 0..(cap as u64 * 3) {
r.insert(i, fresh());
assert!(r.len() <= cap);
}
assert_eq!(r.len(), cap);
}
}
#[cfg(test)]
mod kv_registry_tests {
use super::KvRegistry;
fn fresh_db(tag: &str) -> sled::Db {
let dir = std::env::temp_dir().join(format!(
"lex-kv-reg-{}-{}-{}",
std::process::id(),
tag,
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos()
));
sled::open(&dir).expect("sled open")
}
#[test]
fn insert_and_get_round_trip() {
let mut r = KvRegistry::with_capacity(4);
r.insert(1, fresh_db("a"));
assert!(r.touch_get(1).is_some());
assert!(r.touch_get(2).is_none());
}
#[test]
fn cap_evicts_lru_on_overflow() {
let mut r = KvRegistry::with_capacity(2);
r.insert(1, fresh_db("c1"));
r.insert(2, fresh_db("c2"));
let _ = r.touch_get(1);
r.insert(3, fresh_db("c3"));
assert!(r.touch_get(1).is_some(), "1 was MRU, should survive");
assert!(r.touch_get(2).is_none(), "2 was LRU, should be evicted");
assert!(r.touch_get(3).is_some(), "3 just inserted, should survive");
assert_eq!(r.len(), 2);
}
#[test]
fn cap_with_no_touches_evicts_in_insertion_order() {
let mut r = KvRegistry::with_capacity(2);
r.insert(10, fresh_db("f1"));
r.insert(20, fresh_db("f2"));
r.insert(30, fresh_db("f3"));
assert!(r.touch_get(10).is_none());
assert!(r.touch_get(20).is_some());
assert!(r.touch_get(30).is_some());
}
#[test]
fn remove_drops_entry() {
let mut r = KvRegistry::with_capacity(4);
r.insert(1, fresh_db("r1"));
r.remove(1);
assert!(r.touch_get(1).is_none());
assert_eq!(r.len(), 0);
}
#[test]
fn remove_unknown_handle_is_noop() {
let mut r = KvRegistry::with_capacity(4);
r.insert(1, fresh_db("u1"));
r.remove(999);
assert!(r.touch_get(1).is_some());
}
#[test]
fn many_inserts_stay_bounded_at_cap() {
let cap = 8;
let mut r = KvRegistry::with_capacity(cap);
for i in 0..(cap as u64 * 3) {
r.insert(i, fresh_db(&format!("b{i}")));
assert!(r.len() <= cap);
}
assert_eq!(r.len(), cap);
}
}