1use std::sync::Arc;
11use std::time::Duration;
12
13use async_trait::async_trait;
14use serde_json::{Value, json};
15use tracing::{debug, instrument};
16
17use adk_core::ToolContext;
18
19use crate::error::CodeError;
20use crate::rust_executor::RustExecutor;
21
22const DEFAULT_TIMEOUT_SECS: u64 = 30;
24
25const MIN_TIMEOUT_SECS: u64 = 1;
27
28const MAX_TIMEOUT_SECS: u64 = 300;
30
31const REQUIRED_SCOPES: &[&str] = &["code:execute", "code:execute:rust"];
33
34pub struct CodeTool {
59 executor: RustExecutor,
60}
61
62impl CodeTool {
63 pub fn new(executor: RustExecutor) -> Self {
65 Self { executor }
66 }
67}
68
69fn code_error_to_json(err: &CodeError) -> Value {
74 match err {
75 CodeError::CompileError { diagnostics, stderr } => {
76 let diag_json: Vec<Value> = diagnostics
77 .iter()
78 .map(|d| {
79 json!({
80 "level": d.level,
81 "message": d.message,
82 "spans": d.spans.iter().map(|s| json!({
83 "file_name": s.file_name,
84 "line_start": s.line_start,
85 "line_end": s.line_end,
86 "column_start": s.column_start,
87 "column_end": s.column_end,
88 })).collect::<Vec<_>>(),
89 "code": d.code,
90 })
91 })
92 .collect();
93 json!({
94 "status": "compile_error",
95 "diagnostics": diag_json,
96 "stderr": stderr,
97 })
98 }
99 CodeError::DependencyNotFound { name, searched } => json!({
100 "status": "error",
101 "stderr": format!("dependency not found: {name} (searched: {searched:?})"),
102 }),
103 CodeError::Sandbox(sandbox_err) => {
104 use adk_sandbox::SandboxError;
105 match sandbox_err {
106 SandboxError::Timeout { timeout } => json!({
107 "status": "timeout",
108 "stderr": format!("execution timed out after {timeout:?}"),
109 "duration_ms": timeout.as_millis() as u64,
110 }),
111 SandboxError::MemoryExceeded { limit_mb } => json!({
112 "status": "memory_exceeded",
113 "stderr": format!("memory limit exceeded: {limit_mb} MB"),
114 }),
115 SandboxError::ExecutionFailed(msg) => json!({
116 "status": "error",
117 "stderr": msg,
118 }),
119 SandboxError::InvalidRequest(msg) => json!({
120 "status": "error",
121 "stderr": msg,
122 }),
123 SandboxError::BackendUnavailable(msg) => json!({
124 "status": "error",
125 "stderr": msg,
126 }),
127 }
128 }
129 CodeError::InvalidCode(msg) => json!({
130 "status": "error",
131 "stderr": msg,
132 }),
133 }
134}
135
136#[async_trait]
137impl adk_core::Tool for CodeTool {
138 fn name(&self) -> &str {
139 "code_exec"
140 }
141
142 fn description(&self) -> &str {
143 "Execute Rust code through a check → build → execute pipeline. \
144 The code must provide a `fn run(input: serde_json::Value) -> serde_json::Value` \
145 entry point. Compile errors are returned as structured diagnostics."
146 }
147
148 fn required_scopes(&self) -> &[&str] {
149 REQUIRED_SCOPES
150 }
151
152 fn parameters_schema(&self) -> Option<Value> {
153 Some(json!({
154 "type": "object",
155 "properties": {
156 "language": {
157 "type": "string",
158 "enum": ["rust"],
159 "description": "The programming language. Currently only \"rust\" is supported.",
160 "default": "rust"
161 },
162 "code": {
163 "type": "string",
164 "description": "The Rust source code to execute. Must provide `fn run(input: serde_json::Value) -> serde_json::Value`."
165 },
166 "input": {
167 "type": "object",
168 "description": "Optional JSON input passed to the `run()` function via stdin."
169 },
170 "timeout_secs": {
171 "type": "integer",
172 "description": "Maximum execution time in seconds.",
173 "default": DEFAULT_TIMEOUT_SECS,
174 "minimum": MIN_TIMEOUT_SECS,
175 "maximum": MAX_TIMEOUT_SECS
176 }
177 },
178 "required": ["code"]
179 }))
180 }
181
182 #[instrument(skip_all, fields(tool = "code_exec"))]
183 async fn execute(&self, _ctx: Arc<dyn ToolContext>, args: Value) -> adk_core::Result<Value> {
184 let language = args.get("language").and_then(|v| v.as_str()).unwrap_or("rust");
186
187 if language != "rust" {
188 return Ok(json!({
189 "status": "error",
190 "stderr": format!(
191 "unsupported language \"{language}\". Only \"rust\" is currently supported."
192 ),
193 }));
194 }
195
196 let code = match args.get("code").and_then(|v| v.as_str()) {
198 Some(c) => c,
199 None => {
200 return Ok(json!({
201 "status": "error",
202 "stderr": "missing required field \"code\"",
203 }));
204 }
205 };
206
207 let input = args.get("input").cloned();
209
210 let timeout_secs = args
212 .get("timeout_secs")
213 .and_then(|v| v.as_u64())
214 .unwrap_or(DEFAULT_TIMEOUT_SECS)
215 .clamp(MIN_TIMEOUT_SECS, MAX_TIMEOUT_SECS);
216
217 let timeout = Duration::from_secs(timeout_secs);
218
219 debug!(language, timeout_secs, has_input = input.is_some(), "dispatching to RustExecutor");
220
221 match self.executor.execute(code, input.as_ref(), timeout).await {
222 Ok(result) => Ok(json!({
223 "status": "success",
224 "stdout": result.display_stdout,
225 "stderr": result.exec_result.stderr,
226 "exit_code": result.exec_result.exit_code,
227 "duration_ms": result.exec_result.duration.as_millis() as u64,
228 "output": result.output,
229 "diagnostics": result.diagnostics.iter().map(|d| json!({
230 "level": d.level,
231 "message": d.message,
232 "spans": d.spans.iter().map(|s| json!({
233 "file_name": s.file_name,
234 "line_start": s.line_start,
235 "line_end": s.line_end,
236 "column_start": s.column_start,
237 "column_end": s.column_end,
238 })).collect::<Vec<_>>(),
239 "code": d.code,
240 })).collect::<Vec<_>>(),
241 })),
242 Err(err) => Ok(code_error_to_json(&err)),
243 }
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250 use crate::diagnostics::RustDiagnostic;
251 use crate::rust_executor::RustExecutorConfig;
252 use adk_core::{CallbackContext, Content, EventActions, ReadonlyContext, Tool};
253 use adk_sandbox::SandboxBackend;
254 use adk_sandbox::backend::{BackendCapabilities, EnforcedLimits};
255 use adk_sandbox::error::SandboxError;
256 use adk_sandbox::types::{ExecRequest, ExecResult, Language};
257 use std::sync::Mutex;
258 use std::time::Duration;
259
260 struct MockBackend {
263 response: Mutex<Option<Result<ExecResult, SandboxError>>>,
264 }
265
266 impl MockBackend {
267 fn success(stdout: &str) -> Self {
268 Self {
269 response: Mutex::new(Some(Ok(ExecResult {
270 stdout: stdout.to_string(),
271 stderr: String::new(),
272 exit_code: 0,
273 duration: Duration::from_millis(10),
274 }))),
275 }
276 }
277 }
278
279 #[async_trait]
280 impl SandboxBackend for MockBackend {
281 fn name(&self) -> &str {
282 "mock"
283 }
284
285 fn capabilities(&self) -> BackendCapabilities {
286 BackendCapabilities {
287 supported_languages: vec![Language::Command],
288 isolation_class: "mock".to_string(),
289 enforced_limits: EnforcedLimits {
290 timeout: true,
291 memory: false,
292 network_isolation: false,
293 filesystem_isolation: false,
294 environment_isolation: false,
295 },
296 }
297 }
298
299 async fn execute(&self, _request: ExecRequest) -> Result<ExecResult, SandboxError> {
300 self.response
301 .lock()
302 .unwrap()
303 .take()
304 .unwrap_or(Err(SandboxError::ExecutionFailed("no canned response".to_string())))
305 }
306 }
307
308 struct MockToolContext {
311 content: Content,
312 actions: Mutex<EventActions>,
313 }
314
315 impl MockToolContext {
316 fn new() -> Self {
317 Self { content: Content::new("user"), actions: Mutex::new(EventActions::default()) }
318 }
319 }
320
321 #[async_trait]
322 impl ReadonlyContext for MockToolContext {
323 fn invocation_id(&self) -> &str {
324 "inv-1"
325 }
326 fn agent_name(&self) -> &str {
327 "test-agent"
328 }
329 fn user_id(&self) -> &str {
330 "user"
331 }
332 fn app_name(&self) -> &str {
333 "app"
334 }
335 fn session_id(&self) -> &str {
336 "session"
337 }
338 fn branch(&self) -> &str {
339 ""
340 }
341 fn user_content(&self) -> &Content {
342 &self.content
343 }
344 }
345
346 #[async_trait]
347 impl CallbackContext for MockToolContext {
348 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
349 None
350 }
351 }
352
353 #[async_trait]
354 impl ToolContext for MockToolContext {
355 fn function_call_id(&self) -> &str {
356 "call-1"
357 }
358 fn actions(&self) -> EventActions {
359 self.actions.lock().unwrap().clone()
360 }
361 fn set_actions(&self, actions: EventActions) {
362 *self.actions.lock().unwrap() = actions;
363 }
364 async fn search_memory(
365 &self,
366 _query: &str,
367 ) -> adk_core::Result<Vec<adk_core::MemoryEntry>> {
368 Ok(vec![])
369 }
370 }
371
372 fn ctx() -> Arc<dyn ToolContext> {
373 Arc::new(MockToolContext::new())
374 }
375
376 fn make_tool() -> CodeTool {
377 let backend = Arc::new(MockBackend::success(""));
378 let executor = RustExecutor::new(backend, RustExecutorConfig::default());
379 CodeTool::new(executor)
380 }
381
382 #[test]
385 fn test_name() {
386 let tool = make_tool();
387 assert_eq!(tool.name(), "code_exec");
388 }
389
390 #[test]
391 fn test_description_is_nonempty() {
392 let tool = make_tool();
393 assert!(!tool.description().is_empty());
394 }
395
396 #[test]
397 fn test_required_scopes() {
398 let tool = make_tool();
399 assert_eq!(tool.required_scopes(), &["code:execute", "code:execute:rust"]);
400 }
401
402 #[test]
403 fn test_parameters_schema_is_valid() {
404 let tool = make_tool();
405 let schema = tool.parameters_schema().expect("schema should be Some");
406 assert_eq!(schema["type"], "object");
407 assert!(schema["properties"]["language"].is_object());
408 assert!(schema["properties"]["code"].is_object());
409 assert!(schema["properties"]["input"].is_object());
410 assert!(schema["properties"]["timeout_secs"].is_object());
411
412 let required = schema["required"].as_array().unwrap();
413 let required_strs: Vec<&str> = required.iter().map(|v| v.as_str().unwrap()).collect();
414 assert!(required_strs.contains(&"code"));
415 assert!(!required_strs.contains(&"language"));
417 }
418
419 #[tokio::test]
420 async fn test_missing_code_field() {
421 let tool = make_tool();
422 let args = json!({ "language": "rust" });
423 let result = tool.execute(ctx(), args).await.unwrap();
424 assert_eq!(result["status"], "error");
425 assert!(result["stderr"].as_str().unwrap().contains("code"));
426 }
427
428 #[tokio::test]
429 async fn test_unsupported_language() {
430 let tool = make_tool();
431 let args = json!({ "language": "python", "code": "print('hi')" });
432 let result = tool.execute(ctx(), args).await.unwrap();
433 assert_eq!(result["status"], "error");
434 assert!(result["stderr"].as_str().unwrap().contains("python"));
435 assert!(result["stderr"].as_str().unwrap().contains("unsupported"));
436 }
437
438 #[tokio::test]
439 async fn test_missing_language_defaults_to_rust() {
440 let tool = make_tool();
444 let args =
445 json!({ "code": "fn run(input: serde_json::Value) -> serde_json::Value { input }" });
446 let result = tool.execute(ctx(), args).await.unwrap();
447 let status = result["status"].as_str().unwrap();
450 assert_ne!(status, "error_unsupported_language");
451 if status == "error" {
453 let stderr = result["stderr"].as_str().unwrap_or("");
454 assert!(!stderr.contains("unsupported language"));
455 }
456 }
457
458 #[test]
459 fn test_code_error_to_json_compile_error() {
460 let err = CodeError::CompileError {
461 diagnostics: vec![RustDiagnostic {
462 level: "error".to_string(),
463 message: "expected `;`".to_string(),
464 spans: vec![],
465 code: Some("E0308".to_string()),
466 }],
467 stderr: "error: expected `;`".to_string(),
468 };
469 let json = code_error_to_json(&err);
470 assert_eq!(json["status"], "compile_error");
471 assert!(json["diagnostics"].is_array());
472 assert_eq!(json["diagnostics"][0]["level"], "error");
473 assert_eq!(json["diagnostics"][0]["message"], "expected `;`");
474 assert_eq!(json["diagnostics"][0]["code"], "E0308");
475 assert_eq!(json["stderr"], "error: expected `;`");
476 }
477
478 #[test]
479 fn test_code_error_to_json_dependency_not_found() {
480 let err = CodeError::DependencyNotFound {
481 name: "serde_json".to_string(),
482 searched: vec!["config: /fake/path".to_string()],
483 };
484 let json = code_error_to_json(&err);
485 assert_eq!(json["status"], "error");
486 assert!(json["stderr"].as_str().unwrap().contains("serde_json"));
487 }
488
489 #[test]
490 fn test_code_error_to_json_sandbox_timeout() {
491 let err = CodeError::Sandbox(SandboxError::Timeout { timeout: Duration::from_secs(5) });
492 let json = code_error_to_json(&err);
493 assert_eq!(json["status"], "timeout");
494 assert!(json["stderr"].as_str().unwrap().contains("timed out"));
495 }
496
497 #[test]
498 fn test_code_error_to_json_invalid_code() {
499 let err = CodeError::InvalidCode("missing `fn run()` entry point".to_string());
500 let json = code_error_to_json(&err);
501 assert_eq!(json["status"], "error");
502 assert!(json["stderr"].as_str().unwrap().contains("fn run()"));
503 }
504
505 #[test]
506 fn test_code_error_to_json_sandbox_memory() {
507 let err = CodeError::Sandbox(SandboxError::MemoryExceeded { limit_mb: 128 });
508 let json = code_error_to_json(&err);
509 assert_eq!(json["status"], "memory_exceeded");
510 }
511
512 #[test]
513 fn test_code_error_to_json_sandbox_execution_failed() {
514 let err = CodeError::Sandbox(SandboxError::ExecutionFailed("boom".into()));
515 let json = code_error_to_json(&err);
516 assert_eq!(json["status"], "error");
517 assert_eq!(json["stderr"], "boom");
518 }
519}