Skip to main content

adk_code/
code_tool.rs

1//! [`CodeTool`] — an [`adk_core::Tool`] implementation that dispatches
2//! to language-specific executors.
3//!
4//! Currently supports Rust via [`RustExecutor`]. Other languages return a
5//! descriptive "not yet supported" message. Errors from the executor are
6//! converted to structured JSON responses (never propagated as `ToolError`),
7//! following the same error-as-information pattern as
8//! [`SandboxTool`](adk_sandbox::SandboxTool).
9
10use std::sync::Arc;
11use std::time::Duration;
12
13use async_trait::async_trait;
14use serde_json::{Value, json};
15use tracing::{debug, instrument};
16
17use adk_core::ToolContext;
18
19use crate::error::CodeError;
20use crate::rust_executor::RustExecutor;
21
22/// Default timeout in seconds when `timeout_secs` is not provided.
23const DEFAULT_TIMEOUT_SECS: u64 = 30;
24
25/// Minimum allowed timeout in seconds.
26const MIN_TIMEOUT_SECS: u64 = 1;
27
28/// Maximum allowed timeout in seconds.
29const MAX_TIMEOUT_SECS: u64 = 300;
30
31/// Scopes required to execute this tool.
32const REQUIRED_SCOPES: &[&str] = &["code:execute", "code:execute:rust"];
33
34/// A tool that executes code through language-specific pipelines.
35///
36/// `CodeTool` wraps a [`RustExecutor`] and implements [`adk_core::Tool`],
37/// making code execution with compiler diagnostics available to LLM agents.
38/// Phase 1 supports Rust only; other languages return a descriptive error.
39///
40/// # Error Handling
41///
42/// Executor errors are **never** propagated as `ToolError`. Instead, they are
43/// converted to structured JSON with a `"status"` field. Compile errors include
44/// a `"diagnostics"` array with structured compiler output.
45///
46/// # Example
47///
48/// ```rust,ignore
49/// use adk_code::{CodeTool, RustExecutor, RustExecutorConfig};
50/// use adk_sandbox::ProcessBackend;
51/// use std::sync::Arc;
52///
53/// let backend = Arc::new(ProcessBackend::default());
54/// let executor = RustExecutor::new(backend, RustExecutorConfig::default());
55/// let tool = CodeTool::new(executor);
56/// assert_eq!(tool.name(), "code_exec");
57/// ```
58pub struct CodeTool {
59    executor: RustExecutor,
60}
61
62impl CodeTool {
63    /// Creates a new `CodeTool` wrapping the given Rust executor.
64    pub fn new(executor: RustExecutor) -> Self {
65        Self { executor }
66    }
67}
68
69/// Converts a [`CodeError`] into a structured JSON value.
70///
71/// The returned JSON always contains a `"status"` field so the agent can
72/// distinguish between different failure modes.
73fn code_error_to_json(err: &CodeError) -> Value {
74    match err {
75        CodeError::CompileError { diagnostics, stderr } => {
76            let diag_json: Vec<Value> = diagnostics
77                .iter()
78                .map(|d| {
79                    json!({
80                        "level": d.level,
81                        "message": d.message,
82                        "spans": d.spans.iter().map(|s| json!({
83                            "file_name": s.file_name,
84                            "line_start": s.line_start,
85                            "line_end": s.line_end,
86                            "column_start": s.column_start,
87                            "column_end": s.column_end,
88                        })).collect::<Vec<_>>(),
89                        "code": d.code,
90                    })
91                })
92                .collect();
93            json!({
94                "status": "compile_error",
95                "diagnostics": diag_json,
96                "stderr": stderr,
97            })
98        }
99        CodeError::DependencyNotFound { name, searched } => json!({
100            "status": "error",
101            "stderr": format!("dependency not found: {name} (searched: {searched:?})"),
102        }),
103        CodeError::Sandbox(sandbox_err) => {
104            use adk_sandbox::SandboxError;
105            match sandbox_err {
106                SandboxError::Timeout { timeout } => json!({
107                    "status": "timeout",
108                    "stderr": format!("execution timed out after {timeout:?}"),
109                    "duration_ms": timeout.as_millis() as u64,
110                }),
111                SandboxError::MemoryExceeded { limit_mb } => json!({
112                    "status": "memory_exceeded",
113                    "stderr": format!("memory limit exceeded: {limit_mb} MB"),
114                }),
115                SandboxError::ExecutionFailed(msg) => json!({
116                    "status": "error",
117                    "stderr": msg,
118                }),
119                SandboxError::InvalidRequest(msg) => json!({
120                    "status": "error",
121                    "stderr": msg,
122                }),
123                SandboxError::BackendUnavailable(msg) => json!({
124                    "status": "error",
125                    "stderr": msg,
126                }),
127                SandboxError::EnforcerFailed { enforcer, message } => json!({
128                    "status": "error",
129                    "stderr": format!("sandbox enforcer '{enforcer}' failed: {message}"),
130                }),
131                SandboxError::EnforcerUnavailable { enforcer, message } => json!({
132                    "status": "error",
133                    "stderr": format!("sandbox enforcer '{enforcer}' unavailable: {message}"),
134                }),
135                SandboxError::PolicyViolation(msg) => json!({
136                    "status": "error",
137                    "stderr": msg,
138                }),
139            }
140        }
141        CodeError::InvalidCode(msg) => json!({
142            "status": "error",
143            "stderr": msg,
144        }),
145    }
146}
147
148#[async_trait]
149impl adk_core::Tool for CodeTool {
150    fn name(&self) -> &str {
151        "code_exec"
152    }
153
154    fn description(&self) -> &str {
155        "Execute Rust code through a check → build → execute pipeline. \
156         The code must provide a `fn run(input: serde_json::Value) -> serde_json::Value` \
157         entry point. Compile errors are returned as structured diagnostics."
158    }
159
160    fn required_scopes(&self) -> &[&str] {
161        REQUIRED_SCOPES
162    }
163
164    fn parameters_schema(&self) -> Option<Value> {
165        Some(json!({
166            "type": "object",
167            "properties": {
168                "language": {
169                    "type": "string",
170                    "enum": ["rust"],
171                    "description": "The programming language. Currently only \"rust\" is supported.",
172                    "default": "rust"
173                },
174                "code": {
175                    "type": "string",
176                    "description": "The Rust source code to execute. Must provide `fn run(input: serde_json::Value) -> serde_json::Value`."
177                },
178                "input": {
179                    "type": "object",
180                    "description": "Optional JSON input passed to the `run()` function via stdin."
181                },
182                "timeout_secs": {
183                    "type": "integer",
184                    "description": "Maximum execution time in seconds.",
185                    "default": DEFAULT_TIMEOUT_SECS,
186                    "minimum": MIN_TIMEOUT_SECS,
187                    "maximum": MAX_TIMEOUT_SECS
188                }
189            },
190            "required": ["code"]
191        }))
192    }
193
194    #[instrument(skip_all, fields(tool = "code_exec"))]
195    async fn execute(&self, _ctx: Arc<dyn ToolContext>, args: Value) -> adk_core::Result<Value> {
196        // Parse language (optional, defaults to "rust").
197        let language = args.get("language").and_then(|v| v.as_str()).unwrap_or("rust");
198
199        if language != "rust" {
200            return Ok(json!({
201                "status": "error",
202                "stderr": format!(
203                    "unsupported language \"{language}\". Only \"rust\" is currently supported."
204                ),
205            }));
206        }
207
208        // Parse code (required).
209        let code = match args.get("code").and_then(|v| v.as_str()) {
210            Some(c) => c,
211            None => {
212                return Ok(json!({
213                    "status": "error",
214                    "stderr": "missing required field \"code\"",
215                }));
216            }
217        };
218
219        // Parse input (optional JSON value).
220        let input = args.get("input").cloned();
221
222        // Parse timeout_secs (optional, default 30).
223        let timeout_secs = args
224            .get("timeout_secs")
225            .and_then(|v| v.as_u64())
226            .unwrap_or(DEFAULT_TIMEOUT_SECS)
227            .clamp(MIN_TIMEOUT_SECS, MAX_TIMEOUT_SECS);
228
229        let timeout = Duration::from_secs(timeout_secs);
230
231        debug!(language, timeout_secs, has_input = input.is_some(), "dispatching to RustExecutor");
232
233        match self.executor.execute(code, input.as_ref(), timeout).await {
234            Ok(result) => Ok(json!({
235                "status": "success",
236                "stdout": result.display_stdout,
237                "stderr": result.exec_result.stderr,
238                "exit_code": result.exec_result.exit_code,
239                "duration_ms": result.exec_result.duration.as_millis() as u64,
240                "output": result.output,
241                "diagnostics": result.diagnostics.iter().map(|d| json!({
242                    "level": d.level,
243                    "message": d.message,
244                    "spans": d.spans.iter().map(|s| json!({
245                        "file_name": s.file_name,
246                        "line_start": s.line_start,
247                        "line_end": s.line_end,
248                        "column_start": s.column_start,
249                        "column_end": s.column_end,
250                    })).collect::<Vec<_>>(),
251                    "code": d.code,
252                })).collect::<Vec<_>>(),
253            })),
254            Err(err) => Ok(code_error_to_json(&err)),
255        }
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use crate::diagnostics::RustDiagnostic;
263    use crate::rust_executor::RustExecutorConfig;
264    use adk_core::{CallbackContext, Content, EventActions, ReadonlyContext, Tool};
265    use adk_sandbox::SandboxBackend;
266    use adk_sandbox::backend::{BackendCapabilities, EnforcedLimits};
267    use adk_sandbox::error::SandboxError;
268    use adk_sandbox::types::{ExecRequest, ExecResult, Language};
269    use std::sync::Mutex;
270    use std::time::Duration;
271
272    // -- Mock backend ----------------------------------------------------------
273
274    struct MockBackend {
275        response: Mutex<Option<Result<ExecResult, SandboxError>>>,
276    }
277
278    impl MockBackend {
279        fn success(stdout: &str) -> Self {
280            Self {
281                response: Mutex::new(Some(Ok(ExecResult {
282                    stdout: stdout.to_string(),
283                    stderr: String::new(),
284                    exit_code: 0,
285                    duration: Duration::from_millis(10),
286                }))),
287            }
288        }
289    }
290
291    #[async_trait]
292    impl SandboxBackend for MockBackend {
293        fn name(&self) -> &str {
294            "mock"
295        }
296
297        fn capabilities(&self) -> BackendCapabilities {
298            BackendCapabilities {
299                supported_languages: vec![Language::Command],
300                isolation_class: "mock".to_string(),
301                enforced_limits: EnforcedLimits {
302                    timeout: true,
303                    memory: false,
304                    network_isolation: false,
305                    filesystem_isolation: false,
306                    environment_isolation: false,
307                },
308            }
309        }
310
311        async fn execute(&self, _request: ExecRequest) -> Result<ExecResult, SandboxError> {
312            self.response
313                .lock()
314                .unwrap()
315                .take()
316                .unwrap_or(Err(SandboxError::ExecutionFailed("no canned response".to_string())))
317        }
318    }
319
320    // -- Mock ToolContext -------------------------------------------------------
321
322    struct MockToolContext {
323        content: Content,
324        actions: Mutex<EventActions>,
325    }
326
327    impl MockToolContext {
328        fn new() -> Self {
329            Self { content: Content::new("user"), actions: Mutex::new(EventActions::default()) }
330        }
331    }
332
333    #[async_trait]
334    impl ReadonlyContext for MockToolContext {
335        fn invocation_id(&self) -> &str {
336            "inv-1"
337        }
338        fn agent_name(&self) -> &str {
339            "test-agent"
340        }
341        fn user_id(&self) -> &str {
342            "user"
343        }
344        fn app_name(&self) -> &str {
345            "app"
346        }
347        fn session_id(&self) -> &str {
348            "session"
349        }
350        fn branch(&self) -> &str {
351            ""
352        }
353        fn user_content(&self) -> &Content {
354            &self.content
355        }
356    }
357
358    #[async_trait]
359    impl CallbackContext for MockToolContext {
360        fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
361            None
362        }
363    }
364
365    #[async_trait]
366    impl ToolContext for MockToolContext {
367        fn function_call_id(&self) -> &str {
368            "call-1"
369        }
370        fn actions(&self) -> EventActions {
371            self.actions.lock().unwrap().clone()
372        }
373        fn set_actions(&self, actions: EventActions) {
374            *self.actions.lock().unwrap() = actions;
375        }
376        async fn search_memory(
377            &self,
378            _query: &str,
379        ) -> adk_core::Result<Vec<adk_core::MemoryEntry>> {
380            Ok(vec![])
381        }
382    }
383
384    fn ctx() -> Arc<dyn ToolContext> {
385        Arc::new(MockToolContext::new())
386    }
387
388    fn make_tool() -> CodeTool {
389        let backend = Arc::new(MockBackend::success(""));
390        let executor = RustExecutor::new(backend, RustExecutorConfig::default());
391        CodeTool::new(executor)
392    }
393
394    // -- Tests -----------------------------------------------------------------
395
396    #[test]
397    fn test_name() {
398        let tool = make_tool();
399        assert_eq!(tool.name(), "code_exec");
400    }
401
402    #[test]
403    fn test_description_is_nonempty() {
404        let tool = make_tool();
405        assert!(!tool.description().is_empty());
406    }
407
408    #[test]
409    fn test_required_scopes() {
410        let tool = make_tool();
411        assert_eq!(tool.required_scopes(), &["code:execute", "code:execute:rust"]);
412    }
413
414    #[test]
415    fn test_parameters_schema_is_valid() {
416        let tool = make_tool();
417        let schema = tool.parameters_schema().expect("schema should be Some");
418        assert_eq!(schema["type"], "object");
419        assert!(schema["properties"]["language"].is_object());
420        assert!(schema["properties"]["code"].is_object());
421        assert!(schema["properties"]["input"].is_object());
422        assert!(schema["properties"]["timeout_secs"].is_object());
423
424        let required = schema["required"].as_array().unwrap();
425        let required_strs: Vec<&str> = required.iter().map(|v| v.as_str().unwrap()).collect();
426        assert!(required_strs.contains(&"code"));
427        // language is optional (defaults to "rust")
428        assert!(!required_strs.contains(&"language"));
429    }
430
431    #[tokio::test]
432    async fn test_missing_code_field() {
433        let tool = make_tool();
434        let args = json!({ "language": "rust" });
435        let result = tool.execute(ctx(), args).await.unwrap();
436        assert_eq!(result["status"], "error");
437        assert!(result["stderr"].as_str().unwrap().contains("code"));
438    }
439
440    #[tokio::test]
441    async fn test_unsupported_language() {
442        let tool = make_tool();
443        let args = json!({ "language": "python", "code": "print('hi')" });
444        let result = tool.execute(ctx(), args).await.unwrap();
445        assert_eq!(result["status"], "error");
446        assert!(result["stderr"].as_str().unwrap().contains("python"));
447        assert!(result["stderr"].as_str().unwrap().contains("unsupported"));
448    }
449
450    #[tokio::test]
451    async fn test_missing_language_defaults_to_rust() {
452        // Without a language field, it should default to "rust" and attempt
453        // execution. The mock backend won't actually compile, but we verify
454        // it doesn't return an "unsupported language" error.
455        let tool = make_tool();
456        let args =
457            json!({ "code": "fn run(input: serde_json::Value) -> serde_json::Value { input }" });
458        let result = tool.execute(ctx(), args).await.unwrap();
459        // It will either succeed or fail with a compile/dependency error,
460        // but NOT with "unsupported language".
461        let status = result["status"].as_str().unwrap();
462        assert_ne!(status, "error_unsupported_language");
463        // The status should not mention "unsupported"
464        if status == "error" {
465            let stderr = result["stderr"].as_str().unwrap_or("");
466            assert!(!stderr.contains("unsupported language"));
467        }
468    }
469
470    #[test]
471    fn test_code_error_to_json_compile_error() {
472        let err = CodeError::CompileError {
473            diagnostics: vec![RustDiagnostic {
474                level: "error".to_string(),
475                message: "expected `;`".to_string(),
476                spans: vec![],
477                code: Some("E0308".to_string()),
478            }],
479            stderr: "error: expected `;`".to_string(),
480        };
481        let json = code_error_to_json(&err);
482        assert_eq!(json["status"], "compile_error");
483        assert!(json["diagnostics"].is_array());
484        assert_eq!(json["diagnostics"][0]["level"], "error");
485        assert_eq!(json["diagnostics"][0]["message"], "expected `;`");
486        assert_eq!(json["diagnostics"][0]["code"], "E0308");
487        assert_eq!(json["stderr"], "error: expected `;`");
488    }
489
490    #[test]
491    fn test_code_error_to_json_dependency_not_found() {
492        let err = CodeError::DependencyNotFound {
493            name: "serde_json".to_string(),
494            searched: vec!["config: /fake/path".to_string()],
495        };
496        let json = code_error_to_json(&err);
497        assert_eq!(json["status"], "error");
498        assert!(json["stderr"].as_str().unwrap().contains("serde_json"));
499    }
500
501    #[test]
502    fn test_code_error_to_json_sandbox_timeout() {
503        let err = CodeError::Sandbox(SandboxError::Timeout { timeout: Duration::from_secs(5) });
504        let json = code_error_to_json(&err);
505        assert_eq!(json["status"], "timeout");
506        assert!(json["stderr"].as_str().unwrap().contains("timed out"));
507    }
508
509    #[test]
510    fn test_code_error_to_json_invalid_code() {
511        let err = CodeError::InvalidCode("missing `fn run()` entry point".to_string());
512        let json = code_error_to_json(&err);
513        assert_eq!(json["status"], "error");
514        assert!(json["stderr"].as_str().unwrap().contains("fn run()"));
515    }
516
517    #[test]
518    fn test_code_error_to_json_sandbox_memory() {
519        let err = CodeError::Sandbox(SandboxError::MemoryExceeded { limit_mb: 128 });
520        let json = code_error_to_json(&err);
521        assert_eq!(json["status"], "memory_exceeded");
522    }
523
524    #[test]
525    fn test_code_error_to_json_sandbox_execution_failed() {
526        let err = CodeError::Sandbox(SandboxError::ExecutionFailed("boom".into()));
527        let json = code_error_to_json(&err);
528        assert_eq!(json["status"], "error");
529        assert_eq!(json["stderr"], "boom");
530    }
531}