Skip to main content

adk_sandbox/
tool.rs

1//! [`SandboxTool`] — an [`adk_core::Tool`] implementation that delegates
2//! execution to a configured [`SandboxBackend`].
3//!
4//! The tool exposes code execution to LLM agents via the standard Tool trait.
5//! Errors from the backend are converted to structured JSON responses (never
6//! propagated as `ToolError`), so the agent always receives actionable
7//! information about what happened.
8
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::Duration;
12
13use async_trait::async_trait;
14use serde_json::{Value, json};
15
16use adk_core::ToolContext;
17
18use crate::backend::SandboxBackend;
19use crate::error::SandboxError;
20use crate::types::{ExecRequest, Language};
21
22/// A tool that executes code in an isolated sandbox.
23///
24/// `SandboxTool` wraps a [`SandboxBackend`] and implements [`adk_core::Tool`],
25/// making sandbox execution available to LLM agents. The tool accepts
26/// `language`, `code`, optional `stdin`, and optional `timeout_secs` parameters.
27///
28/// # Error Handling
29///
30/// Backend errors are **never** propagated as `ToolError`. Instead, they are
31/// converted to structured JSON with a `"status"` field (`"timeout"`,
32/// `"memory_exceeded"`, or `"error"`). This lets the agent reason about
33/// failures without triggering exception handling.
34///
35/// # Example
36///
37/// ```rust,ignore
38/// use adk_sandbox::{SandboxTool, ProcessBackend};
39/// use std::sync::Arc;
40///
41/// let backend = Arc::new(ProcessBackend::default());
42/// let tool = SandboxTool::new(backend);
43/// assert_eq!(tool.name(), "sandbox_exec");
44/// ```
45pub struct SandboxTool {
46    backend: Arc<dyn SandboxBackend>,
47}
48
49impl SandboxTool {
50    /// Creates a new `SandboxTool` wrapping the given backend.
51    pub fn new(backend: Arc<dyn SandboxBackend>) -> Self {
52        Self { backend }
53    }
54}
55
56/// Default timeout in seconds when `timeout_secs` is not provided.
57const DEFAULT_TIMEOUT_SECS: u64 = 30;
58
59/// Scopes required to execute this tool.
60const REQUIRED_SCOPES: &[&str] = &["code:execute"];
61
62/// Parses a `Language` from a JSON string value.
63fn parse_language(value: &Value) -> Result<Language, String> {
64    let s = value.as_str().ok_or_else(|| "\"language\" must be a string".to_string())?;
65    match s {
66        "rust" => Ok(Language::Rust),
67        "python" => Ok(Language::Python),
68        "javascript" => Ok(Language::JavaScript),
69        "typescript" => Ok(Language::TypeScript),
70        "wasm" => Ok(Language::Wasm),
71        "command" => Ok(Language::Command),
72        other => Err(format!(
73            "unsupported language \"{other}\". Expected one of: rust, python, javascript, typescript, wasm, command"
74        )),
75    }
76}
77
78/// Converts a [`SandboxError`] into a structured JSON value.
79///
80/// The returned JSON always contains a `"status"` field so the agent can
81/// distinguish between different failure modes.
82fn sandbox_error_to_json(err: &SandboxError) -> Value {
83    match err {
84        SandboxError::Timeout { timeout } => json!({
85            "status": "timeout",
86            "stderr": format!("execution timed out after {timeout:?}"),
87            "duration_ms": timeout.as_millis() as u64,
88        }),
89        SandboxError::MemoryExceeded { limit_mb } => json!({
90            "status": "memory_exceeded",
91            "stderr": format!("memory limit exceeded: {limit_mb} MB"),
92        }),
93        SandboxError::ExecutionFailed(msg) => json!({
94            "status": "error",
95            "stderr": msg,
96        }),
97        SandboxError::InvalidRequest(msg) => json!({
98            "status": "error",
99            "stderr": msg,
100        }),
101        SandboxError::BackendUnavailable(msg) => json!({
102            "status": "error",
103            "stderr": msg,
104        }),
105        SandboxError::EnforcerFailed { enforcer, message } => json!({
106            "status": "error",
107            "stderr": format!("enforcer '{enforcer}' failed: {message}"),
108        }),
109        SandboxError::EnforcerUnavailable { enforcer, message } => json!({
110            "status": "error",
111            "stderr": format!("enforcer '{enforcer}' unavailable: {message}"),
112        }),
113        SandboxError::PolicyViolation(msg) => json!({
114            "status": "error",
115            "stderr": format!("policy violation: {msg}"),
116        }),
117        #[cfg(feature = "workspace")]
118        SandboxError::ProvisionFailed { resource, reason, suggestion } => json!({
119            "status": "error",
120            "stderr": format!("provisioning failed for '{resource}': {reason}. {suggestion}"),
121        }),
122        #[cfg(feature = "workspace")]
123        SandboxError::SessionNotFound { handle } => json!({
124            "status": "error",
125            "stderr": format!("session '{handle}' not found. It may have been stopped or expired."),
126        }),
127        #[cfg(feature = "workspace")]
128        SandboxError::SnapshotNotFound { id } => json!({
129            "status": "error",
130            "stderr": format!("snapshot '{id}' not found. It may have been deleted or expired."),
131        }),
132        #[cfg(feature = "workspace")]
133        SandboxError::PathTraversal { path } => json!({
134            "status": "path_traversal",
135            "stderr": format!("path traversal rejected: '{path}' escapes workspace root. Use relative paths only."),
136        }),
137        #[cfg(feature = "workspace")]
138        SandboxError::DockerUnavailable { reason } => json!({
139            "status": "error",
140            "stderr": format!("Docker unavailable: {reason}. Ensure Docker daemon is running and accessible."),
141        }),
142        #[cfg(feature = "workspace")]
143        SandboxError::SessionTimeout { timeout } => json!({
144            "status": "timeout",
145            "stderr": format!("session timed out after {timeout:?}. Consider increasing session_timeout in SandboxConfig."),
146            "duration_ms": timeout.as_millis() as u64,
147        }),
148    }
149}
150
151#[async_trait]
152impl adk_core::Tool for SandboxTool {
153    fn name(&self) -> &str {
154        "sandbox_exec"
155    }
156
157    fn description(&self) -> &str {
158        "Execute code in an isolated sandbox. Supports multiple languages \
159         including rust, python, javascript, typescript, wasm, and shell commands."
160    }
161
162    fn required_scopes(&self) -> &[&str] {
163        REQUIRED_SCOPES
164    }
165
166    fn parameters_schema(&self) -> Option<Value> {
167        Some(json!({
168            "type": "object",
169            "properties": {
170                "language": {
171                    "type": "string",
172                    "enum": ["rust", "python", "javascript", "typescript", "wasm", "command"],
173                    "description": "The programming language of the code to execute."
174                },
175                "code": {
176                    "type": "string",
177                    "description": "The source code or command to execute."
178                },
179                "stdin": {
180                    "type": "string",
181                    "description": "Optional standard input to feed to the process."
182                },
183                "timeout_secs": {
184                    "type": "integer",
185                    "description": "Maximum execution time in seconds.",
186                    "default": DEFAULT_TIMEOUT_SECS,
187                    "minimum": 1,
188                    "maximum": 300
189                }
190            },
191            "required": ["language", "code"]
192        }))
193    }
194
195    async fn execute(&self, _ctx: Arc<dyn ToolContext>, args: Value) -> adk_core::Result<Value> {
196        // Parse language (required)
197        let language = match args.get("language") {
198            Some(v) => match parse_language(v) {
199                Ok(lang) => lang,
200                Err(msg) => {
201                    return Ok(json!({ "status": "error", "stderr": msg }));
202                }
203            },
204            None => {
205                return Ok(
206                    json!({ "status": "error", "stderr": "missing required field \"language\"" }),
207                );
208            }
209        };
210
211        // Parse code (required)
212        let code = match args.get("code").and_then(|v| v.as_str()) {
213            Some(c) => c.to_string(),
214            None => {
215                return Ok(
216                    json!({ "status": "error", "stderr": "missing required field \"code\"" }),
217                );
218            }
219        };
220
221        // Parse stdin (optional)
222        let stdin = args.get("stdin").and_then(|v| v.as_str()).map(String::from);
223
224        // Parse timeout_secs (optional, default 30)
225        let timeout_secs =
226            args.get("timeout_secs").and_then(|v| v.as_u64()).unwrap_or(DEFAULT_TIMEOUT_SECS);
227
228        let request = ExecRequest {
229            language,
230            code,
231            stdin,
232            timeout: Duration::from_secs(timeout_secs),
233            memory_limit_mb: None,
234            env: HashMap::new(),
235        };
236
237        match self.backend.execute(request).await {
238            Ok(result) => Ok(json!({
239                "status": "success",
240                "stdout": result.stdout,
241                "stderr": result.stderr,
242                "exit_code": result.exit_code,
243                "duration_ms": result.duration.as_millis() as u64,
244            })),
245            Err(err) => Ok(sandbox_error_to_json(&err)),
246        }
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use crate::backend::{BackendCapabilities, EnforcedLimits};
254    use crate::types::ExecResult;
255    use adk_core::{CallbackContext, Content, EventActions, ReadonlyContext, Tool};
256    use std::sync::Mutex;
257    use std::time::Duration;
258
259    // -- Mock backend ----------------------------------------------------------
260
261    /// A configurable mock backend for testing `SandboxTool`.
262    struct MockBackend {
263        /// When `Some`, `execute()` returns this error.
264        error: Option<SandboxError>,
265        /// When `error` is `None`, `execute()` returns this result.
266        result: ExecResult,
267    }
268
269    impl MockBackend {
270        fn success(stdout: &str, exit_code: i32) -> Self {
271            Self {
272                error: None,
273                result: ExecResult {
274                    stdout: stdout.to_string(),
275                    stderr: String::new(),
276                    exit_code,
277                    duration: Duration::from_millis(42),
278                },
279            }
280        }
281
282        fn failing(err: SandboxError) -> Self {
283            Self {
284                error: Some(err),
285                result: ExecResult {
286                    stdout: String::new(),
287                    stderr: String::new(),
288                    exit_code: 0,
289                    duration: Duration::ZERO,
290                },
291            }
292        }
293    }
294
295    #[async_trait]
296    impl SandboxBackend for MockBackend {
297        fn name(&self) -> &str {
298            "mock"
299        }
300
301        fn capabilities(&self) -> BackendCapabilities {
302            BackendCapabilities {
303                supported_languages: vec![Language::Python],
304                isolation_class: "mock".to_string(),
305                enforced_limits: EnforcedLimits {
306                    timeout: true,
307                    memory: false,
308                    network_isolation: false,
309                    filesystem_isolation: false,
310                    environment_isolation: false,
311                },
312            }
313        }
314
315        async fn execute(&self, _request: ExecRequest) -> Result<ExecResult, SandboxError> {
316            if let Some(ref err) = self.error { Err(err.clone()) } else { Ok(self.result.clone()) }
317        }
318    }
319
320    // -- Mock ToolContext -------------------------------------------------------
321
322    struct MockToolContext {
323        content: Content,
324        actions: Mutex<EventActions>,
325    }
326
327    impl MockToolContext {
328        fn new() -> Self {
329            Self { content: Content::new("user"), actions: Mutex::new(EventActions::default()) }
330        }
331    }
332
333    #[async_trait]
334    impl ReadonlyContext for MockToolContext {
335        fn invocation_id(&self) -> &str {
336            "inv-1"
337        }
338        fn agent_name(&self) -> &str {
339            "test-agent"
340        }
341        fn user_id(&self) -> &str {
342            "user"
343        }
344        fn app_name(&self) -> &str {
345            "app"
346        }
347        fn session_id(&self) -> &str {
348            "session"
349        }
350        fn branch(&self) -> &str {
351            ""
352        }
353        fn user_content(&self) -> &Content {
354            &self.content
355        }
356    }
357
358    #[async_trait]
359    impl CallbackContext for MockToolContext {
360        fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
361            None
362        }
363    }
364
365    #[async_trait]
366    impl ToolContext for MockToolContext {
367        fn function_call_id(&self) -> &str {
368            "call-1"
369        }
370        fn actions(&self) -> EventActions {
371            self.actions.lock().unwrap().clone()
372        }
373        fn set_actions(&self, actions: EventActions) {
374            *self.actions.lock().unwrap() = actions;
375        }
376        async fn search_memory(
377            &self,
378            _query: &str,
379        ) -> adk_core::Result<Vec<adk_core::MemoryEntry>> {
380            Ok(vec![])
381        }
382    }
383
384    fn ctx() -> Arc<dyn ToolContext> {
385        Arc::new(MockToolContext::new())
386    }
387
388    // -- Tests -----------------------------------------------------------------
389
390    #[test]
391    fn test_name() {
392        let tool = SandboxTool::new(Arc::new(MockBackend::success("", 0)));
393        assert_eq!(tool.name(), "sandbox_exec");
394    }
395
396    #[test]
397    fn test_required_scopes() {
398        let tool = SandboxTool::new(Arc::new(MockBackend::success("", 0)));
399        assert_eq!(tool.required_scopes(), &["code:execute"]);
400    }
401
402    #[test]
403    fn test_parameters_schema_is_valid() {
404        let tool = SandboxTool::new(Arc::new(MockBackend::success("", 0)));
405        let schema = tool.parameters_schema().expect("schema should be Some");
406        assert_eq!(schema["type"], "object");
407        assert!(schema["properties"]["language"].is_object());
408        assert!(schema["properties"]["code"].is_object());
409        assert!(schema["properties"]["stdin"].is_object());
410        assert!(schema["properties"]["timeout_secs"].is_object());
411
412        let required = schema["required"].as_array().unwrap();
413        let required_strs: Vec<&str> = required.iter().map(|v| v.as_str().unwrap()).collect();
414        assert!(required_strs.contains(&"language"));
415        assert!(required_strs.contains(&"code"));
416        assert!(!required_strs.contains(&"stdin"));
417        assert!(!required_strs.contains(&"timeout_secs"));
418    }
419
420    #[tokio::test]
421    async fn test_successful_execution() {
422        let backend = Arc::new(MockBackend::success("hello\n", 0));
423        let tool = SandboxTool::new(backend);
424        let args = json!({ "language": "python", "code": "print('hello')" });
425
426        let result = tool.execute(ctx(), args).await.unwrap();
427
428        assert_eq!(result["status"], "success");
429        assert_eq!(result["stdout"], "hello\n");
430        assert_eq!(result["exit_code"], 0);
431        assert!(result["duration_ms"].is_number());
432    }
433
434    #[tokio::test]
435    async fn test_timeout_error_as_information() {
436        let backend = Arc::new(MockBackend::failing(SandboxError::Timeout {
437            timeout: Duration::from_secs(5),
438        }));
439        let tool = SandboxTool::new(backend);
440        let args = json!({ "language": "python", "code": "import time; time.sleep(100)" });
441
442        let result = tool.execute(ctx(), args).await;
443
444        // Must be Ok, not Err
445        assert!(result.is_ok());
446        let val = result.unwrap();
447        assert_eq!(val["status"], "timeout");
448        assert!(val["stderr"].as_str().unwrap().contains("timed out"));
449        assert!(val["duration_ms"].is_number());
450    }
451
452    #[tokio::test]
453    async fn test_memory_exceeded_error_as_information() {
454        let backend = Arc::new(MockBackend::failing(SandboxError::MemoryExceeded { limit_mb: 64 }));
455        let tool = SandboxTool::new(backend);
456        let args = json!({ "language": "wasm", "code": "(module)" });
457
458        let result = tool.execute(ctx(), args).await.unwrap();
459
460        assert_eq!(result["status"], "memory_exceeded");
461        assert!(result["stderr"].as_str().unwrap().contains("64 MB"));
462    }
463
464    #[tokio::test]
465    async fn test_execution_failed_error_as_information() {
466        let backend = Arc::new(MockBackend::failing(SandboxError::ExecutionFailed(
467            "spawn failed".to_string(),
468        )));
469        let tool = SandboxTool::new(backend);
470        let args = json!({ "language": "python", "code": "x" });
471
472        let result = tool.execute(ctx(), args).await.unwrap();
473
474        assert_eq!(result["status"], "error");
475        assert_eq!(result["stderr"], "spawn failed");
476    }
477
478    #[tokio::test]
479    async fn test_missing_language_field() {
480        let backend = Arc::new(MockBackend::success("", 0));
481        let tool = SandboxTool::new(backend);
482        let args = json!({ "code": "print('hi')" });
483
484        let result = tool.execute(ctx(), args).await.unwrap();
485
486        assert_eq!(result["status"], "error");
487        assert!(result["stderr"].as_str().unwrap().contains("language"));
488    }
489
490    #[tokio::test]
491    async fn test_missing_code_field() {
492        let backend = Arc::new(MockBackend::success("", 0));
493        let tool = SandboxTool::new(backend);
494        let args = json!({ "language": "python" });
495
496        let result = tool.execute(ctx(), args).await.unwrap();
497
498        assert_eq!(result["status"], "error");
499        assert!(result["stderr"].as_str().unwrap().contains("code"));
500    }
501
502    #[tokio::test]
503    async fn test_unsupported_language() {
504        let backend = Arc::new(MockBackend::success("", 0));
505        let tool = SandboxTool::new(backend);
506        let args = json!({ "language": "cobol", "code": "DISPLAY 'HI'" });
507
508        let result = tool.execute(ctx(), args).await.unwrap();
509
510        assert_eq!(result["status"], "error");
511        assert!(result["stderr"].as_str().unwrap().contains("cobol"));
512    }
513
514    #[tokio::test]
515    async fn test_custom_timeout() {
516        // Verify that a custom timeout_secs is parsed (we can't easily verify
517        // the Duration passed to the backend without more instrumentation, but
518        // we can at least confirm the call succeeds).
519        let backend = Arc::new(MockBackend::success("ok", 0));
520        let tool = SandboxTool::new(backend);
521        let args = json!({ "language": "python", "code": "print('ok')", "timeout_secs": 60 });
522
523        let result = tool.execute(ctx(), args).await.unwrap();
524        assert_eq!(result["status"], "success");
525    }
526
527    #[tokio::test]
528    async fn test_stdin_passed_through() {
529        let backend = Arc::new(MockBackend::success("echo", 0));
530        let tool = SandboxTool::new(backend);
531        let args = json!({
532            "language": "python",
533            "code": "import sys; print(sys.stdin.read())",
534            "stdin": "hello"
535        });
536
537        let result = tool.execute(ctx(), args).await.unwrap();
538        assert_eq!(result["status"], "success");
539    }
540
541    #[test]
542    fn test_parse_language_all_variants() {
543        assert_eq!(parse_language(&json!("rust")).unwrap(), Language::Rust);
544        assert_eq!(parse_language(&json!("python")).unwrap(), Language::Python);
545        assert_eq!(parse_language(&json!("javascript")).unwrap(), Language::JavaScript);
546        assert_eq!(parse_language(&json!("typescript")).unwrap(), Language::TypeScript);
547        assert_eq!(parse_language(&json!("wasm")).unwrap(), Language::Wasm);
548        assert_eq!(parse_language(&json!("command")).unwrap(), Language::Command);
549        assert!(parse_language(&json!("ruby")).is_err());
550        assert!(parse_language(&json!(42)).is_err());
551    }
552
553    #[test]
554    fn test_sandbox_error_to_json_variants() {
555        let timeout_json =
556            sandbox_error_to_json(&SandboxError::Timeout { timeout: Duration::from_secs(10) });
557        assert_eq!(timeout_json["status"], "timeout");
558
559        let mem_json = sandbox_error_to_json(&SandboxError::MemoryExceeded { limit_mb: 128 });
560        assert_eq!(mem_json["status"], "memory_exceeded");
561
562        let exec_json = sandbox_error_to_json(&SandboxError::ExecutionFailed("boom".into()));
563        assert_eq!(exec_json["status"], "error");
564
565        let invalid_json = sandbox_error_to_json(&SandboxError::InvalidRequest("bad".into()));
566        assert_eq!(invalid_json["status"], "error");
567
568        let unavail_json = sandbox_error_to_json(&SandboxError::BackendUnavailable("gone".into()));
569        assert_eq!(unavail_json["status"], "error");
570    }
571}