1use std::sync::Arc;
11use std::time::Duration;
12
13use async_trait::async_trait;
14use serde_json::{Value, json};
15use tracing::{debug, instrument};
16
17use adk_core::ToolContext;
18
19use crate::error::CodeError;
20use crate::rust_executor::RustExecutor;
21
22const DEFAULT_TIMEOUT_SECS: u64 = 30;
24
25const MIN_TIMEOUT_SECS: u64 = 1;
27
28const MAX_TIMEOUT_SECS: u64 = 300;
30
31const REQUIRED_SCOPES: &[&str] = &["code:execute", "code:execute:rust"];
33
34pub struct CodeTool {
59 executor: RustExecutor,
60}
61
62impl CodeTool {
63 pub fn new(executor: RustExecutor) -> Self {
65 Self { executor }
66 }
67}
68
69fn code_error_to_json(err: &CodeError) -> Value {
74 match err {
75 CodeError::CompileError { diagnostics, stderr } => {
76 let diag_json: Vec<Value> = diagnostics
77 .iter()
78 .map(|d| {
79 json!({
80 "level": d.level,
81 "message": d.message,
82 "spans": d.spans.iter().map(|s| json!({
83 "file_name": s.file_name,
84 "line_start": s.line_start,
85 "line_end": s.line_end,
86 "column_start": s.column_start,
87 "column_end": s.column_end,
88 })).collect::<Vec<_>>(),
89 "code": d.code,
90 })
91 })
92 .collect();
93 json!({
94 "status": "compile_error",
95 "diagnostics": diag_json,
96 "stderr": stderr,
97 })
98 }
99 CodeError::DependencyNotFound { name, searched } => json!({
100 "status": "error",
101 "stderr": format!("dependency not found: {name} (searched: {searched:?})"),
102 }),
103 CodeError::Sandbox(sandbox_err) => {
104 use adk_sandbox::SandboxError;
105 match sandbox_err {
106 SandboxError::Timeout { timeout } => json!({
107 "status": "timeout",
108 "stderr": format!("execution timed out after {timeout:?}"),
109 "duration_ms": timeout.as_millis() as u64,
110 }),
111 SandboxError::MemoryExceeded { limit_mb } => json!({
112 "status": "memory_exceeded",
113 "stderr": format!("memory limit exceeded: {limit_mb} MB"),
114 }),
115 SandboxError::ExecutionFailed(msg) => json!({
116 "status": "error",
117 "stderr": msg,
118 }),
119 SandboxError::InvalidRequest(msg) => json!({
120 "status": "error",
121 "stderr": msg,
122 }),
123 SandboxError::BackendUnavailable(msg) => json!({
124 "status": "error",
125 "stderr": msg,
126 }),
127 SandboxError::EnforcerFailed { enforcer, message } => json!({
128 "status": "error",
129 "stderr": format!("sandbox enforcer '{enforcer}' failed: {message}"),
130 }),
131 SandboxError::EnforcerUnavailable { enforcer, message } => json!({
132 "status": "error",
133 "stderr": format!("sandbox enforcer '{enforcer}' unavailable: {message}"),
134 }),
135 SandboxError::PolicyViolation(msg) => json!({
136 "status": "error",
137 "stderr": msg,
138 }),
139 }
140 }
141 CodeError::InvalidCode(msg) => json!({
142 "status": "error",
143 "stderr": msg,
144 }),
145 }
146}
147
148#[async_trait]
149impl adk_core::Tool for CodeTool {
150 fn name(&self) -> &str {
151 "code_exec"
152 }
153
154 fn description(&self) -> &str {
155 "Execute Rust code through a check → build → execute pipeline. \
156 The code must provide a `fn run(input: serde_json::Value) -> serde_json::Value` \
157 entry point. Compile errors are returned as structured diagnostics."
158 }
159
160 fn required_scopes(&self) -> &[&str] {
161 REQUIRED_SCOPES
162 }
163
164 fn parameters_schema(&self) -> Option<Value> {
165 Some(json!({
166 "type": "object",
167 "properties": {
168 "language": {
169 "type": "string",
170 "enum": ["rust"],
171 "description": "The programming language. Currently only \"rust\" is supported.",
172 "default": "rust"
173 },
174 "code": {
175 "type": "string",
176 "description": "The Rust source code to execute. Must provide `fn run(input: serde_json::Value) -> serde_json::Value`."
177 },
178 "input": {
179 "type": "object",
180 "description": "Optional JSON input passed to the `run()` function via stdin."
181 },
182 "timeout_secs": {
183 "type": "integer",
184 "description": "Maximum execution time in seconds.",
185 "default": DEFAULT_TIMEOUT_SECS,
186 "minimum": MIN_TIMEOUT_SECS,
187 "maximum": MAX_TIMEOUT_SECS
188 }
189 },
190 "required": ["code"]
191 }))
192 }
193
194 #[instrument(skip_all, fields(tool = "code_exec"))]
195 async fn execute(&self, _ctx: Arc<dyn ToolContext>, args: Value) -> adk_core::Result<Value> {
196 let language = args.get("language").and_then(|v| v.as_str()).unwrap_or("rust");
198
199 if language != "rust" {
200 return Ok(json!({
201 "status": "error",
202 "stderr": format!(
203 "unsupported language \"{language}\". Only \"rust\" is currently supported."
204 ),
205 }));
206 }
207
208 let code = match args.get("code").and_then(|v| v.as_str()) {
210 Some(c) => c,
211 None => {
212 return Ok(json!({
213 "status": "error",
214 "stderr": "missing required field \"code\"",
215 }));
216 }
217 };
218
219 let input = args.get("input").cloned();
221
222 let timeout_secs = args
224 .get("timeout_secs")
225 .and_then(|v| v.as_u64())
226 .unwrap_or(DEFAULT_TIMEOUT_SECS)
227 .clamp(MIN_TIMEOUT_SECS, MAX_TIMEOUT_SECS);
228
229 let timeout = Duration::from_secs(timeout_secs);
230
231 debug!(language, timeout_secs, has_input = input.is_some(), "dispatching to RustExecutor");
232
233 match self.executor.execute(code, input.as_ref(), timeout).await {
234 Ok(result) => Ok(json!({
235 "status": "success",
236 "stdout": result.display_stdout,
237 "stderr": result.exec_result.stderr,
238 "exit_code": result.exec_result.exit_code,
239 "duration_ms": result.exec_result.duration.as_millis() as u64,
240 "output": result.output,
241 "diagnostics": result.diagnostics.iter().map(|d| json!({
242 "level": d.level,
243 "message": d.message,
244 "spans": d.spans.iter().map(|s| json!({
245 "file_name": s.file_name,
246 "line_start": s.line_start,
247 "line_end": s.line_end,
248 "column_start": s.column_start,
249 "column_end": s.column_end,
250 })).collect::<Vec<_>>(),
251 "code": d.code,
252 })).collect::<Vec<_>>(),
253 })),
254 Err(err) => Ok(code_error_to_json(&err)),
255 }
256 }
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use crate::diagnostics::RustDiagnostic;
263 use crate::rust_executor::RustExecutorConfig;
264 use adk_core::{CallbackContext, Content, EventActions, ReadonlyContext, Tool};
265 use adk_sandbox::SandboxBackend;
266 use adk_sandbox::backend::{BackendCapabilities, EnforcedLimits};
267 use adk_sandbox::error::SandboxError;
268 use adk_sandbox::types::{ExecRequest, ExecResult, Language};
269 use std::sync::Mutex;
270 use std::time::Duration;
271
272 struct MockBackend {
275 response: Mutex<Option<Result<ExecResult, SandboxError>>>,
276 }
277
278 impl MockBackend {
279 fn success(stdout: &str) -> Self {
280 Self {
281 response: Mutex::new(Some(Ok(ExecResult {
282 stdout: stdout.to_string(),
283 stderr: String::new(),
284 exit_code: 0,
285 duration: Duration::from_millis(10),
286 }))),
287 }
288 }
289 }
290
291 #[async_trait]
292 impl SandboxBackend for MockBackend {
293 fn name(&self) -> &str {
294 "mock"
295 }
296
297 fn capabilities(&self) -> BackendCapabilities {
298 BackendCapabilities {
299 supported_languages: vec![Language::Command],
300 isolation_class: "mock".to_string(),
301 enforced_limits: EnforcedLimits {
302 timeout: true,
303 memory: false,
304 network_isolation: false,
305 filesystem_isolation: false,
306 environment_isolation: false,
307 },
308 }
309 }
310
311 async fn execute(&self, _request: ExecRequest) -> Result<ExecResult, SandboxError> {
312 self.response
313 .lock()
314 .unwrap()
315 .take()
316 .unwrap_or(Err(SandboxError::ExecutionFailed("no canned response".to_string())))
317 }
318 }
319
320 struct MockToolContext {
323 content: Content,
324 actions: Mutex<EventActions>,
325 }
326
327 impl MockToolContext {
328 fn new() -> Self {
329 Self { content: Content::new("user"), actions: Mutex::new(EventActions::default()) }
330 }
331 }
332
333 #[async_trait]
334 impl ReadonlyContext for MockToolContext {
335 fn invocation_id(&self) -> &str {
336 "inv-1"
337 }
338 fn agent_name(&self) -> &str {
339 "test-agent"
340 }
341 fn user_id(&self) -> &str {
342 "user"
343 }
344 fn app_name(&self) -> &str {
345 "app"
346 }
347 fn session_id(&self) -> &str {
348 "session"
349 }
350 fn branch(&self) -> &str {
351 ""
352 }
353 fn user_content(&self) -> &Content {
354 &self.content
355 }
356 }
357
358 #[async_trait]
359 impl CallbackContext for MockToolContext {
360 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
361 None
362 }
363 }
364
365 #[async_trait]
366 impl ToolContext for MockToolContext {
367 fn function_call_id(&self) -> &str {
368 "call-1"
369 }
370 fn actions(&self) -> EventActions {
371 self.actions.lock().unwrap().clone()
372 }
373 fn set_actions(&self, actions: EventActions) {
374 *self.actions.lock().unwrap() = actions;
375 }
376 async fn search_memory(
377 &self,
378 _query: &str,
379 ) -> adk_core::Result<Vec<adk_core::MemoryEntry>> {
380 Ok(vec![])
381 }
382 }
383
384 fn ctx() -> Arc<dyn ToolContext> {
385 Arc::new(MockToolContext::new())
386 }
387
388 fn make_tool() -> CodeTool {
389 let backend = Arc::new(MockBackend::success(""));
390 let executor = RustExecutor::new(backend, RustExecutorConfig::default());
391 CodeTool::new(executor)
392 }
393
394 #[test]
397 fn test_name() {
398 let tool = make_tool();
399 assert_eq!(tool.name(), "code_exec");
400 }
401
402 #[test]
403 fn test_description_is_nonempty() {
404 let tool = make_tool();
405 assert!(!tool.description().is_empty());
406 }
407
408 #[test]
409 fn test_required_scopes() {
410 let tool = make_tool();
411 assert_eq!(tool.required_scopes(), &["code:execute", "code:execute:rust"]);
412 }
413
414 #[test]
415 fn test_parameters_schema_is_valid() {
416 let tool = make_tool();
417 let schema = tool.parameters_schema().expect("schema should be Some");
418 assert_eq!(schema["type"], "object");
419 assert!(schema["properties"]["language"].is_object());
420 assert!(schema["properties"]["code"].is_object());
421 assert!(schema["properties"]["input"].is_object());
422 assert!(schema["properties"]["timeout_secs"].is_object());
423
424 let required = schema["required"].as_array().unwrap();
425 let required_strs: Vec<&str> = required.iter().map(|v| v.as_str().unwrap()).collect();
426 assert!(required_strs.contains(&"code"));
427 assert!(!required_strs.contains(&"language"));
429 }
430
431 #[tokio::test]
432 async fn test_missing_code_field() {
433 let tool = make_tool();
434 let args = json!({ "language": "rust" });
435 let result = tool.execute(ctx(), args).await.unwrap();
436 assert_eq!(result["status"], "error");
437 assert!(result["stderr"].as_str().unwrap().contains("code"));
438 }
439
440 #[tokio::test]
441 async fn test_unsupported_language() {
442 let tool = make_tool();
443 let args = json!({ "language": "python", "code": "print('hi')" });
444 let result = tool.execute(ctx(), args).await.unwrap();
445 assert_eq!(result["status"], "error");
446 assert!(result["stderr"].as_str().unwrap().contains("python"));
447 assert!(result["stderr"].as_str().unwrap().contains("unsupported"));
448 }
449
450 #[tokio::test]
451 async fn test_missing_language_defaults_to_rust() {
452 let tool = make_tool();
456 let args =
457 json!({ "code": "fn run(input: serde_json::Value) -> serde_json::Value { input }" });
458 let result = tool.execute(ctx(), args).await.unwrap();
459 let status = result["status"].as_str().unwrap();
462 assert_ne!(status, "error_unsupported_language");
463 if status == "error" {
465 let stderr = result["stderr"].as_str().unwrap_or("");
466 assert!(!stderr.contains("unsupported language"));
467 }
468 }
469
470 #[test]
471 fn test_code_error_to_json_compile_error() {
472 let err = CodeError::CompileError {
473 diagnostics: vec![RustDiagnostic {
474 level: "error".to_string(),
475 message: "expected `;`".to_string(),
476 spans: vec![],
477 code: Some("E0308".to_string()),
478 }],
479 stderr: "error: expected `;`".to_string(),
480 };
481 let json = code_error_to_json(&err);
482 assert_eq!(json["status"], "compile_error");
483 assert!(json["diagnostics"].is_array());
484 assert_eq!(json["diagnostics"][0]["level"], "error");
485 assert_eq!(json["diagnostics"][0]["message"], "expected `;`");
486 assert_eq!(json["diagnostics"][0]["code"], "E0308");
487 assert_eq!(json["stderr"], "error: expected `;`");
488 }
489
490 #[test]
491 fn test_code_error_to_json_dependency_not_found() {
492 let err = CodeError::DependencyNotFound {
493 name: "serde_json".to_string(),
494 searched: vec!["config: /fake/path".to_string()],
495 };
496 let json = code_error_to_json(&err);
497 assert_eq!(json["status"], "error");
498 assert!(json["stderr"].as_str().unwrap().contains("serde_json"));
499 }
500
501 #[test]
502 fn test_code_error_to_json_sandbox_timeout() {
503 let err = CodeError::Sandbox(SandboxError::Timeout { timeout: Duration::from_secs(5) });
504 let json = code_error_to_json(&err);
505 assert_eq!(json["status"], "timeout");
506 assert!(json["stderr"].as_str().unwrap().contains("timed out"));
507 }
508
509 #[test]
510 fn test_code_error_to_json_invalid_code() {
511 let err = CodeError::InvalidCode("missing `fn run()` entry point".to_string());
512 let json = code_error_to_json(&err);
513 assert_eq!(json["status"], "error");
514 assert!(json["stderr"].as_str().unwrap().contains("fn run()"));
515 }
516
517 #[test]
518 fn test_code_error_to_json_sandbox_memory() {
519 let err = CodeError::Sandbox(SandboxError::MemoryExceeded { limit_mb: 128 });
520 let json = code_error_to_json(&err);
521 assert_eq!(json["status"], "memory_exceeded");
522 }
523
524 #[test]
525 fn test_code_error_to_json_sandbox_execution_failed() {
526 let err = CodeError::Sandbox(SandboxError::ExecutionFailed("boom".into()));
527 let json = code_error_to_json(&err);
528 assert_eq!(json["status"], "error");
529 assert_eq!(json["stderr"], "boom");
530 }
531}