Skip to main content

adk_sandbox/
tool.rs

1//! [`SandboxTool`] — an [`adk_core::Tool`] implementation that delegates
2//! execution to a configured [`SandboxBackend`].
3//!
4//! The tool exposes code execution to LLM agents via the standard Tool trait.
5//! Errors from the backend are converted to structured JSON responses (never
6//! propagated as `ToolError`), so the agent always receives actionable
7//! information about what happened.
8
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::Duration;
12
13use async_trait::async_trait;
14use serde_json::{Value, json};
15
16use adk_core::ToolContext;
17
18use crate::backend::SandboxBackend;
19use crate::error::SandboxError;
20use crate::types::{ExecRequest, Language};
21
22/// A tool that executes code in an isolated sandbox.
23///
24/// `SandboxTool` wraps a [`SandboxBackend`] and implements [`adk_core::Tool`],
25/// making sandbox execution available to LLM agents. The tool accepts
26/// `language`, `code`, optional `stdin`, and optional `timeout_secs` parameters.
27///
28/// # Error Handling
29///
30/// Backend errors are **never** propagated as `ToolError`. Instead, they are
31/// converted to structured JSON with a `"status"` field (`"timeout"`,
32/// `"memory_exceeded"`, or `"error"`). This lets the agent reason about
33/// failures without triggering exception handling.
34///
35/// # Example
36///
37/// ```rust,ignore
38/// use adk_sandbox::{SandboxTool, ProcessBackend};
39/// use std::sync::Arc;
40///
41/// let backend = Arc::new(ProcessBackend::default());
42/// let tool = SandboxTool::new(backend);
43/// assert_eq!(tool.name(), "sandbox_exec");
44/// ```
45pub struct SandboxTool {
46    backend: Arc<dyn SandboxBackend>,
47}
48
49impl SandboxTool {
50    /// Creates a new `SandboxTool` wrapping the given backend.
51    pub fn new(backend: Arc<dyn SandboxBackend>) -> Self {
52        Self { backend }
53    }
54}
55
56/// Default timeout in seconds when `timeout_secs` is not provided.
57const DEFAULT_TIMEOUT_SECS: u64 = 30;
58
59/// Scopes required to execute this tool.
60const REQUIRED_SCOPES: &[&str] = &["code:execute"];
61
62/// Parses a `Language` from a JSON string value.
63fn parse_language(value: &Value) -> Result<Language, String> {
64    let s = value.as_str().ok_or_else(|| "\"language\" must be a string".to_string())?;
65    match s {
66        "rust" => Ok(Language::Rust),
67        "python" => Ok(Language::Python),
68        "javascript" => Ok(Language::JavaScript),
69        "typescript" => Ok(Language::TypeScript),
70        "wasm" => Ok(Language::Wasm),
71        "command" => Ok(Language::Command),
72        other => Err(format!(
73            "unsupported language \"{other}\". Expected one of: rust, python, javascript, typescript, wasm, command"
74        )),
75    }
76}
77
78/// Converts a [`SandboxError`] into a structured JSON value.
79///
80/// The returned JSON always contains a `"status"` field so the agent can
81/// distinguish between different failure modes.
82fn sandbox_error_to_json(err: &SandboxError) -> Value {
83    match err {
84        SandboxError::Timeout { timeout } => json!({
85            "status": "timeout",
86            "stderr": format!("execution timed out after {timeout:?}"),
87            "duration_ms": timeout.as_millis() as u64,
88        }),
89        SandboxError::MemoryExceeded { limit_mb } => json!({
90            "status": "memory_exceeded",
91            "stderr": format!("memory limit exceeded: {limit_mb} MB"),
92        }),
93        SandboxError::ExecutionFailed(msg) => json!({
94            "status": "error",
95            "stderr": msg,
96        }),
97        SandboxError::InvalidRequest(msg) => json!({
98            "status": "error",
99            "stderr": msg,
100        }),
101        SandboxError::BackendUnavailable(msg) => json!({
102            "status": "error",
103            "stderr": msg,
104        }),
105        SandboxError::EnforcerFailed { enforcer, message } => json!({
106            "status": "error",
107            "stderr": format!("enforcer '{enforcer}' failed: {message}"),
108        }),
109        SandboxError::EnforcerUnavailable { enforcer, message } => json!({
110            "status": "error",
111            "stderr": format!("enforcer '{enforcer}' unavailable: {message}"),
112        }),
113        SandboxError::PolicyViolation(msg) => json!({
114            "status": "error",
115            "stderr": format!("policy violation: {msg}"),
116        }),
117    }
118}
119
120#[async_trait]
121impl adk_core::Tool for SandboxTool {
122    fn name(&self) -> &str {
123        "sandbox_exec"
124    }
125
126    fn description(&self) -> &str {
127        "Execute code in an isolated sandbox. Supports multiple languages \
128         including rust, python, javascript, typescript, wasm, and shell commands."
129    }
130
131    fn required_scopes(&self) -> &[&str] {
132        REQUIRED_SCOPES
133    }
134
135    fn parameters_schema(&self) -> Option<Value> {
136        Some(json!({
137            "type": "object",
138            "properties": {
139                "language": {
140                    "type": "string",
141                    "enum": ["rust", "python", "javascript", "typescript", "wasm", "command"],
142                    "description": "The programming language of the code to execute."
143                },
144                "code": {
145                    "type": "string",
146                    "description": "The source code or command to execute."
147                },
148                "stdin": {
149                    "type": "string",
150                    "description": "Optional standard input to feed to the process."
151                },
152                "timeout_secs": {
153                    "type": "integer",
154                    "description": "Maximum execution time in seconds.",
155                    "default": DEFAULT_TIMEOUT_SECS,
156                    "minimum": 1,
157                    "maximum": 300
158                }
159            },
160            "required": ["language", "code"]
161        }))
162    }
163
164    async fn execute(&self, _ctx: Arc<dyn ToolContext>, args: Value) -> adk_core::Result<Value> {
165        // Parse language (required)
166        let language = match args.get("language") {
167            Some(v) => match parse_language(v) {
168                Ok(lang) => lang,
169                Err(msg) => {
170                    return Ok(json!({ "status": "error", "stderr": msg }));
171                }
172            },
173            None => {
174                return Ok(
175                    json!({ "status": "error", "stderr": "missing required field \"language\"" }),
176                );
177            }
178        };
179
180        // Parse code (required)
181        let code = match args.get("code").and_then(|v| v.as_str()) {
182            Some(c) => c.to_string(),
183            None => {
184                return Ok(
185                    json!({ "status": "error", "stderr": "missing required field \"code\"" }),
186                );
187            }
188        };
189
190        // Parse stdin (optional)
191        let stdin = args.get("stdin").and_then(|v| v.as_str()).map(String::from);
192
193        // Parse timeout_secs (optional, default 30)
194        let timeout_secs =
195            args.get("timeout_secs").and_then(|v| v.as_u64()).unwrap_or(DEFAULT_TIMEOUT_SECS);
196
197        let request = ExecRequest {
198            language,
199            code,
200            stdin,
201            timeout: Duration::from_secs(timeout_secs),
202            memory_limit_mb: None,
203            env: HashMap::new(),
204        };
205
206        match self.backend.execute(request).await {
207            Ok(result) => Ok(json!({
208                "status": "success",
209                "stdout": result.stdout,
210                "stderr": result.stderr,
211                "exit_code": result.exit_code,
212                "duration_ms": result.duration.as_millis() as u64,
213            })),
214            Err(err) => Ok(sandbox_error_to_json(&err)),
215        }
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use crate::backend::{BackendCapabilities, EnforcedLimits};
223    use crate::types::ExecResult;
224    use adk_core::{CallbackContext, Content, EventActions, ReadonlyContext, Tool};
225    use std::sync::Mutex;
226    use std::time::Duration;
227
228    // -- Mock backend ----------------------------------------------------------
229
230    /// A configurable mock backend for testing `SandboxTool`.
231    struct MockBackend {
232        /// When `Some`, `execute()` returns this error.
233        error: Option<SandboxError>,
234        /// When `error` is `None`, `execute()` returns this result.
235        result: ExecResult,
236    }
237
238    impl MockBackend {
239        fn success(stdout: &str, exit_code: i32) -> Self {
240            Self {
241                error: None,
242                result: ExecResult {
243                    stdout: stdout.to_string(),
244                    stderr: String::new(),
245                    exit_code,
246                    duration: Duration::from_millis(42),
247                },
248            }
249        }
250
251        fn failing(err: SandboxError) -> Self {
252            Self {
253                error: Some(err),
254                result: ExecResult {
255                    stdout: String::new(),
256                    stderr: String::new(),
257                    exit_code: 0,
258                    duration: Duration::ZERO,
259                },
260            }
261        }
262    }
263
264    #[async_trait]
265    impl SandboxBackend for MockBackend {
266        fn name(&self) -> &str {
267            "mock"
268        }
269
270        fn capabilities(&self) -> BackendCapabilities {
271            BackendCapabilities {
272                supported_languages: vec![Language::Python],
273                isolation_class: "mock".to_string(),
274                enforced_limits: EnforcedLimits {
275                    timeout: true,
276                    memory: false,
277                    network_isolation: false,
278                    filesystem_isolation: false,
279                    environment_isolation: false,
280                },
281            }
282        }
283
284        async fn execute(&self, _request: ExecRequest) -> Result<ExecResult, SandboxError> {
285            if let Some(ref err) = self.error { Err(err.clone()) } else { Ok(self.result.clone()) }
286        }
287    }
288
289    // -- Mock ToolContext -------------------------------------------------------
290
291    struct MockToolContext {
292        content: Content,
293        actions: Mutex<EventActions>,
294    }
295
296    impl MockToolContext {
297        fn new() -> Self {
298            Self { content: Content::new("user"), actions: Mutex::new(EventActions::default()) }
299        }
300    }
301
302    #[async_trait]
303    impl ReadonlyContext for MockToolContext {
304        fn invocation_id(&self) -> &str {
305            "inv-1"
306        }
307        fn agent_name(&self) -> &str {
308            "test-agent"
309        }
310        fn user_id(&self) -> &str {
311            "user"
312        }
313        fn app_name(&self) -> &str {
314            "app"
315        }
316        fn session_id(&self) -> &str {
317            "session"
318        }
319        fn branch(&self) -> &str {
320            ""
321        }
322        fn user_content(&self) -> &Content {
323            &self.content
324        }
325    }
326
327    #[async_trait]
328    impl CallbackContext for MockToolContext {
329        fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
330            None
331        }
332    }
333
334    #[async_trait]
335    impl ToolContext for MockToolContext {
336        fn function_call_id(&self) -> &str {
337            "call-1"
338        }
339        fn actions(&self) -> EventActions {
340            self.actions.lock().unwrap().clone()
341        }
342        fn set_actions(&self, actions: EventActions) {
343            *self.actions.lock().unwrap() = actions;
344        }
345        async fn search_memory(
346            &self,
347            _query: &str,
348        ) -> adk_core::Result<Vec<adk_core::MemoryEntry>> {
349            Ok(vec![])
350        }
351    }
352
353    fn ctx() -> Arc<dyn ToolContext> {
354        Arc::new(MockToolContext::new())
355    }
356
357    // -- Tests -----------------------------------------------------------------
358
359    #[test]
360    fn test_name() {
361        let tool = SandboxTool::new(Arc::new(MockBackend::success("", 0)));
362        assert_eq!(tool.name(), "sandbox_exec");
363    }
364
365    #[test]
366    fn test_required_scopes() {
367        let tool = SandboxTool::new(Arc::new(MockBackend::success("", 0)));
368        assert_eq!(tool.required_scopes(), &["code:execute"]);
369    }
370
371    #[test]
372    fn test_parameters_schema_is_valid() {
373        let tool = SandboxTool::new(Arc::new(MockBackend::success("", 0)));
374        let schema = tool.parameters_schema().expect("schema should be Some");
375        assert_eq!(schema["type"], "object");
376        assert!(schema["properties"]["language"].is_object());
377        assert!(schema["properties"]["code"].is_object());
378        assert!(schema["properties"]["stdin"].is_object());
379        assert!(schema["properties"]["timeout_secs"].is_object());
380
381        let required = schema["required"].as_array().unwrap();
382        let required_strs: Vec<&str> = required.iter().map(|v| v.as_str().unwrap()).collect();
383        assert!(required_strs.contains(&"language"));
384        assert!(required_strs.contains(&"code"));
385        assert!(!required_strs.contains(&"stdin"));
386        assert!(!required_strs.contains(&"timeout_secs"));
387    }
388
389    #[tokio::test]
390    async fn test_successful_execution() {
391        let backend = Arc::new(MockBackend::success("hello\n", 0));
392        let tool = SandboxTool::new(backend);
393        let args = json!({ "language": "python", "code": "print('hello')" });
394
395        let result = tool.execute(ctx(), args).await.unwrap();
396
397        assert_eq!(result["status"], "success");
398        assert_eq!(result["stdout"], "hello\n");
399        assert_eq!(result["exit_code"], 0);
400        assert!(result["duration_ms"].is_number());
401    }
402
403    #[tokio::test]
404    async fn test_timeout_error_as_information() {
405        let backend = Arc::new(MockBackend::failing(SandboxError::Timeout {
406            timeout: Duration::from_secs(5),
407        }));
408        let tool = SandboxTool::new(backend);
409        let args = json!({ "language": "python", "code": "import time; time.sleep(100)" });
410
411        let result = tool.execute(ctx(), args).await;
412
413        // Must be Ok, not Err
414        assert!(result.is_ok());
415        let val = result.unwrap();
416        assert_eq!(val["status"], "timeout");
417        assert!(val["stderr"].as_str().unwrap().contains("timed out"));
418        assert!(val["duration_ms"].is_number());
419    }
420
421    #[tokio::test]
422    async fn test_memory_exceeded_error_as_information() {
423        let backend = Arc::new(MockBackend::failing(SandboxError::MemoryExceeded { limit_mb: 64 }));
424        let tool = SandboxTool::new(backend);
425        let args = json!({ "language": "wasm", "code": "(module)" });
426
427        let result = tool.execute(ctx(), args).await.unwrap();
428
429        assert_eq!(result["status"], "memory_exceeded");
430        assert!(result["stderr"].as_str().unwrap().contains("64 MB"));
431    }
432
433    #[tokio::test]
434    async fn test_execution_failed_error_as_information() {
435        let backend = Arc::new(MockBackend::failing(SandboxError::ExecutionFailed(
436            "spawn failed".to_string(),
437        )));
438        let tool = SandboxTool::new(backend);
439        let args = json!({ "language": "python", "code": "x" });
440
441        let result = tool.execute(ctx(), args).await.unwrap();
442
443        assert_eq!(result["status"], "error");
444        assert_eq!(result["stderr"], "spawn failed");
445    }
446
447    #[tokio::test]
448    async fn test_missing_language_field() {
449        let backend = Arc::new(MockBackend::success("", 0));
450        let tool = SandboxTool::new(backend);
451        let args = json!({ "code": "print('hi')" });
452
453        let result = tool.execute(ctx(), args).await.unwrap();
454
455        assert_eq!(result["status"], "error");
456        assert!(result["stderr"].as_str().unwrap().contains("language"));
457    }
458
459    #[tokio::test]
460    async fn test_missing_code_field() {
461        let backend = Arc::new(MockBackend::success("", 0));
462        let tool = SandboxTool::new(backend);
463        let args = json!({ "language": "python" });
464
465        let result = tool.execute(ctx(), args).await.unwrap();
466
467        assert_eq!(result["status"], "error");
468        assert!(result["stderr"].as_str().unwrap().contains("code"));
469    }
470
471    #[tokio::test]
472    async fn test_unsupported_language() {
473        let backend = Arc::new(MockBackend::success("", 0));
474        let tool = SandboxTool::new(backend);
475        let args = json!({ "language": "cobol", "code": "DISPLAY 'HI'" });
476
477        let result = tool.execute(ctx(), args).await.unwrap();
478
479        assert_eq!(result["status"], "error");
480        assert!(result["stderr"].as_str().unwrap().contains("cobol"));
481    }
482
483    #[tokio::test]
484    async fn test_custom_timeout() {
485        // Verify that a custom timeout_secs is parsed (we can't easily verify
486        // the Duration passed to the backend without more instrumentation, but
487        // we can at least confirm the call succeeds).
488        let backend = Arc::new(MockBackend::success("ok", 0));
489        let tool = SandboxTool::new(backend);
490        let args = json!({ "language": "python", "code": "print('ok')", "timeout_secs": 60 });
491
492        let result = tool.execute(ctx(), args).await.unwrap();
493        assert_eq!(result["status"], "success");
494    }
495
496    #[tokio::test]
497    async fn test_stdin_passed_through() {
498        let backend = Arc::new(MockBackend::success("echo", 0));
499        let tool = SandboxTool::new(backend);
500        let args = json!({
501            "language": "python",
502            "code": "import sys; print(sys.stdin.read())",
503            "stdin": "hello"
504        });
505
506        let result = tool.execute(ctx(), args).await.unwrap();
507        assert_eq!(result["status"], "success");
508    }
509
510    #[test]
511    fn test_parse_language_all_variants() {
512        assert_eq!(parse_language(&json!("rust")).unwrap(), Language::Rust);
513        assert_eq!(parse_language(&json!("python")).unwrap(), Language::Python);
514        assert_eq!(parse_language(&json!("javascript")).unwrap(), Language::JavaScript);
515        assert_eq!(parse_language(&json!("typescript")).unwrap(), Language::TypeScript);
516        assert_eq!(parse_language(&json!("wasm")).unwrap(), Language::Wasm);
517        assert_eq!(parse_language(&json!("command")).unwrap(), Language::Command);
518        assert!(parse_language(&json!("ruby")).is_err());
519        assert!(parse_language(&json!(42)).is_err());
520    }
521
522    #[test]
523    fn test_sandbox_error_to_json_variants() {
524        let timeout_json =
525            sandbox_error_to_json(&SandboxError::Timeout { timeout: Duration::from_secs(10) });
526        assert_eq!(timeout_json["status"], "timeout");
527
528        let mem_json = sandbox_error_to_json(&SandboxError::MemoryExceeded { limit_mb: 128 });
529        assert_eq!(mem_json["status"], "memory_exceeded");
530
531        let exec_json = sandbox_error_to_json(&SandboxError::ExecutionFailed("boom".into()));
532        assert_eq!(exec_json["status"], "error");
533
534        let invalid_json = sandbox_error_to_json(&SandboxError::InvalidRequest("bad".into()));
535        assert_eq!(invalid_json["status"], "error");
536
537        let unavail_json = sandbox_error_to_json(&SandboxError::BackendUnavailable("gone".into()));
538        assert_eq!(unavail_json["status"], "error");
539    }
540}