use crate::fs::{FileSystem, FsError, LocalFileSystem};
use crate::ipc::{ErrorCode, Request, Response, read_message, write_message};
use crate::path_defense::{is_dangerous_system_path, is_path_inside, paths_for_write_check};
use crate::policy::SandboxPolicy;
use crate::policy_check::is_fully_denied;
use anyhow::Result;
use std::path::{Path, PathBuf};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::{debug, warn};
struct Context {
fs: LocalFileSystem,
writable_root: Option<PathBuf>,
policy: SandboxPolicy,
}
impl Default for Context {
fn default() -> Self {
Self {
fs: LocalFileSystem::new(),
writable_root: None,
policy: SandboxPolicy::default(),
}
}
}
impl Context {
fn with_policy(writable_root: PathBuf, policy: SandboxPolicy) -> Self {
Self {
fs: LocalFileSystem::new(),
writable_root: Some(writable_root),
policy,
}
}
}
pub async fn run<R, W>(reader: &mut R, writer: &mut W) -> Result<()>
where
R: AsyncRead + Unpin + Send,
W: AsyncWrite + Unpin + Send,
{
let ctx = Context::default();
run_with_ctx(&ctx, reader, writer).await
}
pub async fn run_with_policy<R, W>(
writable_root: PathBuf,
policy: SandboxPolicy,
reader: &mut R,
writer: &mut W,
) -> Result<()>
where
R: AsyncRead + Unpin + Send,
W: AsyncWrite + Unpin + Send,
{
let ctx = Context::with_policy(writable_root, policy);
run_with_ctx(&ctx, reader, writer).await
}
async fn run_with_ctx<R, W>(ctx: &Context, reader: &mut R, writer: &mut W) -> Result<()>
where
R: AsyncRead + Unpin + Send,
W: AsyncWrite + Unpin + Send,
{
loop {
let req = match read_message::<R, Request>(reader).await {
Ok(Some(r)) => r,
Ok(None) => {
debug!("worker: peer closed transport, exiting cleanly");
return Ok(());
}
Err(e) => {
warn!("worker: transport read error: {e}");
let _ = write_message(
writer,
&Response::Error {
code: ErrorCode::Protocol,
message: format!("read error: {e}"),
},
)
.await;
return Err(e.into());
}
};
let is_shutdown = matches!(req, Request::Shutdown);
let resp = handle(ctx, req).await;
write_message(writer, &resp).await?;
if is_shutdown {
debug!("worker: shutdown requested, exiting cleanly");
return Ok(());
}
}
}
fn canonicalize_for_check(path: &Path) -> PathBuf {
if let Ok(c) = std::fs::canonicalize(path) {
return c;
}
let mut tail: Vec<std::ffi::OsString> = Vec::new();
let mut dir = path.to_path_buf();
loop {
match std::fs::canonicalize(&dir) {
Ok(c) => {
return tail.iter().rev().fold(c, |acc, seg| acc.join(seg));
}
Err(_) => {
if let Some(name) = dir.file_name() {
tail.push(name.to_os_string());
}
match dir.parent() {
Some(p) => dir = p.to_path_buf(),
None => return path.to_path_buf(), }
}
}
}
}
fn check_write_path(ctx: &Context, path: &Path) -> Result<(), FsError> {
if is_fully_denied(path) {
return Err(FsError::PolicyDenied {
message: format!(
"write denied: '{}' is a fully-protected path",
path.display()
),
});
}
if is_dangerous_system_path(path) {
return Err(FsError::PolicyDenied {
message: format!(
"write denied: '{}' is a critical system path",
path.display()
),
});
}
let Some(ref root) = ctx.writable_root else {
return Ok(()); };
let chain = paths_for_write_check(path, ctx.policy.fs.mandatory_deny_search_depth);
for candidate in &chain {
if is_fully_denied(candidate) || is_dangerous_system_path(candidate) {
return Err(FsError::PolicyDenied {
message: format!(
"write denied: symlink chain reaches protected path '{}'",
candidate.display()
),
});
}
let canonical_candidate = canonicalize_for_check(candidate);
let canonical_root = std::fs::canonicalize(root).unwrap_or_else(|_| root.clone());
if !is_path_inside(&canonical_candidate, &canonical_root) {
let explicitly_allowed = ctx
.policy
.fs
.allow_write
.iter()
.any(|pat| is_path_inside(&canonical_candidate, pat.as_path()));
let explicitly_denied = ctx
.policy
.fs
.deny_write_within_allow
.iter()
.any(|pat| is_path_inside(&canonical_candidate, pat.as_path()));
if !explicitly_allowed || explicitly_denied {
return Err(FsError::PolicyDenied {
message: format!(
"write denied: '{}' is outside writable root '{}'",
path.display(),
root.display()
),
});
}
} else {
let carved_out = ctx
.policy
.fs
.deny_write_within_allow
.iter()
.any(|pat| is_path_inside(&canonical_candidate, pat.as_path()));
if carved_out {
return Err(FsError::PolicyDenied {
message: format!(
"write denied: '{}' is in a write-protected sub-path",
path.display()
),
});
}
}
}
Ok(())
}
async fn handle(ctx: &Context, req: Request) -> Response {
match req {
Request::Ping | Request::Shutdown => Response::Pong,
Request::Read { path, max_bytes } => match ctx.fs.read(&path, max_bytes).await {
Ok(content) => Response::Read { content },
Err(e) => fs_err_to_resp(e),
},
Request::Write { path, content } => {
if let Err(e) = check_write_path(ctx, &path) {
return fs_err_to_resp(e);
}
match ctx.fs.write(&path, &content).await {
Ok(bytes_written) => Response::Write { bytes_written },
Err(e) => fs_err_to_resp(e),
}
}
Request::Edit {
path,
old_string,
new_string,
} => {
if let Err(e) = check_write_path(ctx, &path) {
return fs_err_to_resp(e);
}
match ctx.fs.edit(&path, &old_string, &new_string, false).await {
Ok(replacements) => Response::Edit { replacements },
Err(e) => fs_err_to_resp(e),
}
}
Request::Glob { pattern, root } => match ctx.fs.glob(&pattern, &root).await {
Ok(paths) => Response::Glob { paths },
Err(e) => fs_err_to_resp(e),
},
Request::Grep {
pattern,
root,
include,
} => match ctx.fs.grep(&pattern, &root, include.as_deref()).await {
Ok(matches) => Response::Grep { matches },
Err(e) => fs_err_to_resp(e),
},
Request::Stat { path } => match ctx.fs.stat(&path).await {
Ok(m) => Response::Stat {
size: m.size,
is_dir: m.is_dir,
is_symlink: m.is_symlink,
},
Err(e) => fs_err_to_resp(e),
},
Request::GetEnv { names } => {
let values = names.into_iter().map(|n| std::env::var(&n).ok()).collect();
Response::GetEnv { values }
}
}
}
fn fs_err_to_resp(e: FsError) -> Response {
let (code, message) = match e {
FsError::Io(e) => (ErrorCode::Io, e.to_string()),
FsError::PolicyDenied { message } => (ErrorCode::PolicyDenied, message),
FsError::EditNotFound { path } => (
ErrorCode::Io,
format!("old_string not found in {}", path.display()),
),
FsError::InvalidPattern { message } => (ErrorCode::Protocol, message),
FsError::Transport { message } => (ErrorCode::Internal, message),
};
Response::Error { code, message }
}
pub async fn run_stdio() -> Result<()> {
let mut stdin = tokio::io::stdin();
let mut stdout = tokio::io::stdout();
run(&mut stdin, &mut stdout).await
}
#[cfg(unix)]
pub async fn run_unix_socket(path: &Path) -> Result<()> {
unix_socket_serve(path, Context::default()).await
}
#[cfg(unix)]
pub async fn run_unix_socket_with_policy(
path: &Path,
writable_root: PathBuf,
policy: SandboxPolicy,
) -> Result<()> {
unix_socket_serve(path, Context::with_policy(writable_root, policy)).await
}
#[cfg(unix)]
async fn unix_socket_serve(path: &Path, ctx: Context) -> Result<()> {
use std::io::Write as _;
use tokio::net::UnixListener;
let listener = UnixListener::bind(path)?;
println!("ready");
std::io::stdout().flush()?;
let (stream, _addr) = listener.accept().await?;
let (mut reader, mut writer) = tokio::io::split(stream);
run_with_ctx(&ctx, &mut reader, &mut writer).await
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ipc::{Request, Response};
use tempfile::TempDir;
use tokio::io::duplex;
fn spawn_worker() -> (tokio::io::DuplexStream, tokio::task::JoinHandle<Result<()>>) {
let (host, worker) = duplex(65536);
let (mut wr, mut ww) = tokio::io::split(worker);
let join = tokio::spawn(async move { run(&mut wr, &mut ww).await });
(host, join)
}
#[tokio::test]
async fn ping_returns_pong() {
let (mut host, _join) = spawn_worker();
write_message(&mut host, &Request::Ping).await.unwrap();
let resp: Response = read_message(&mut host).await.unwrap().unwrap();
assert_eq!(resp, Response::Pong);
}
#[tokio::test]
async fn shutdown_acks_then_worker_exits() {
let (mut host, join) = spawn_worker();
write_message(&mut host, &Request::Shutdown).await.unwrap();
let resp: Response = read_message(&mut host).await.unwrap().unwrap();
assert_eq!(resp, Response::Pong);
drop(host);
tokio::time::timeout(std::time::Duration::from_secs(2), join)
.await
.expect("worker must exit within 2s of Shutdown")
.expect("join error")
.expect("worker returned Err");
}
#[tokio::test]
async fn peer_eof_exits_loop_cleanly() {
let (host, join) = spawn_worker();
drop(host);
let result = tokio::time::timeout(std::time::Duration::from_secs(2), join)
.await
.expect("worker must exit within 2s of EOF")
.expect("join error");
assert!(result.is_ok());
}
#[tokio::test]
async fn worker_loops_for_multiple_requests() {
let (mut host, _join) = spawn_worker();
for _ in 0..5 {
write_message(&mut host, &Request::Ping).await.unwrap();
let resp: Response = read_message(&mut host).await.unwrap().unwrap();
assert_eq!(resp, Response::Pong);
}
}
#[tokio::test]
async fn read_handler_returns_file_contents() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("hello.txt");
std::fs::write(&path, b"hello from worker").unwrap();
let (mut host, _join) = spawn_worker();
write_message(
&mut host,
&Request::Read {
path,
max_bytes: None,
},
)
.await
.unwrap();
let resp: Response = read_message(&mut host).await.unwrap().unwrap();
assert_eq!(
resp,
Response::Read {
content: b"hello from worker".to_vec()
}
);
}
#[tokio::test]
async fn write_handler_creates_file() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("out.txt");
let (mut host, _join) = spawn_worker();
write_message(
&mut host,
&Request::Write {
path: path.clone(),
content: b"written".to_vec(),
},
)
.await
.unwrap();
let resp: Response = read_message(&mut host).await.unwrap().unwrap();
assert_eq!(resp, Response::Write { bytes_written: 7 });
assert_eq!(std::fs::read(&path).unwrap(), b"written");
}
#[tokio::test]
async fn edit_handler_replaces_string() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("e.txt");
std::fs::write(&path, b"foo bar baz").unwrap();
let (mut host, _join) = spawn_worker();
write_message(
&mut host,
&Request::Edit {
path: path.clone(),
old_string: "bar".to_string(),
new_string: "BAR".to_string(),
},
)
.await
.unwrap();
let resp: Response = read_message(&mut host).await.unwrap().unwrap();
assert_eq!(resp, Response::Edit { replacements: 1 });
assert_eq!(std::fs::read_to_string(&path).unwrap(), "foo BAR baz");
}
#[tokio::test]
async fn stat_handler_reports_file_metadata() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("s.txt");
std::fs::write(&path, b"123456").unwrap();
let (mut host, _join) = spawn_worker();
write_message(&mut host, &Request::Stat { path })
.await
.unwrap();
let resp: Response = read_message(&mut host).await.unwrap().unwrap();
assert_eq!(
resp,
Response::Stat {
size: 6,
is_dir: false,
is_symlink: false
}
);
}
#[tokio::test]
async fn read_missing_file_returns_io_error() {
let dir = TempDir::new().unwrap();
let (mut host, _join) = spawn_worker();
write_message(
&mut host,
&Request::Read {
path: dir.path().join("nope"),
max_bytes: None,
},
)
.await
.unwrap();
let resp: Response = read_message(&mut host).await.unwrap().unwrap();
assert!(matches!(
resp,
Response::Error {
code: ErrorCode::Io,
..
}
));
}
fn spawn_worker_with_root(
root: PathBuf,
) -> (tokio::io::DuplexStream, tokio::task::JoinHandle<Result<()>>) {
let (host, worker) = duplex(65536);
let (mut wr, mut ww) = tokio::io::split(worker);
let policy = crate::policy::SandboxPolicy::default();
let join =
tokio::spawn(async move { run_with_policy(root, policy, &mut wr, &mut ww).await });
(host, join)
}
#[tokio::test]
async fn write_inside_root_is_allowed() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("allowed.txt");
let (mut host, _join) = spawn_worker_with_root(dir.path().to_path_buf());
write_message(
&mut host,
&Request::Write {
path: path.clone(),
content: b"ok".to_vec(),
},
)
.await
.unwrap();
let resp: Response = read_message(&mut host).await.unwrap().unwrap();
assert!(
matches!(resp, Response::Write { .. }),
"expected Write ok, got {resp:?}"
);
assert_eq!(std::fs::read(&path).unwrap(), b"ok");
}
#[tokio::test]
async fn write_outside_root_is_denied() {
let root = TempDir::new().unwrap();
let outside = TempDir::new().unwrap();
let path = outside.path().join("escape.txt");
let (mut host, _join) = spawn_worker_with_root(root.path().to_path_buf());
write_message(
&mut host,
&Request::Write {
path,
content: b"evil".to_vec(),
},
)
.await
.unwrap();
let resp: Response = read_message(&mut host).await.unwrap().unwrap();
assert!(
matches!(
resp,
Response::Error {
code: ErrorCode::PolicyDenied,
..
}
),
"expected PolicyDenied, got {resp:?}"
);
}
#[tokio::test]
async fn dangerous_system_path_always_denied() {
let dir = TempDir::new().unwrap();
let (mut host, _join) = spawn_worker_with_root(dir.path().to_path_buf());
write_message(
&mut host,
&Request::Write {
path: PathBuf::from("/etc"),
content: b"evil".to_vec(),
},
)
.await
.unwrap();
let resp: Response = read_message(&mut host).await.unwrap().unwrap();
assert!(
matches!(
resp,
Response::Error {
code: ErrorCode::PolicyDenied,
..
}
),
"expected PolicyDenied, got {resp:?}"
);
}
#[cfg(unix)]
#[tokio::test]
async fn symlink_escape_is_denied() {
let root = TempDir::new().unwrap();
let outside = TempDir::new().unwrap();
let link = root.path().join("escape_link");
std::os::unix::fs::symlink(outside.path(), &link).unwrap();
let through_link = link.join("secret.txt");
let (mut host, _join) = spawn_worker_with_root(root.path().to_path_buf());
write_message(
&mut host,
&Request::Write {
path: through_link,
content: b"evil".to_vec(),
},
)
.await
.unwrap();
let resp: Response = read_message(&mut host).await.unwrap().unwrap();
assert!(
matches!(
resp,
Response::Error {
code: ErrorCode::PolicyDenied,
..
}
),
"symlink escape should be denied, got {resp:?}"
);
}
}