Skip to main content

adk_code/
harness.rs

1//! Harness template and source validation for Rust code execution.
2//!
3//! This module contains the harness template that wraps user code, source
4//! validation functions, and output extraction utilities. These are shared
5//! between `RustSandboxExecutor` (legacy) and `RustExecutor` (new).
6
7use crate::ExecutionError;
8
9/// Patterns that are rejected in user code because they conflict with the
10/// harness or exceed the phase 1 source model.
11///
12/// Each entry is `(pattern, human-readable reason)`. The pattern is matched
13/// against the source after stripping comments and string literals would be
14/// a more robust approach, but for phase 1 a simple token-level scan is
15/// sufficient — `fn main` inside a string literal is an unlikely false positive
16/// and the compile step would catch real conflicts anyway.
17pub const REJECTED_PATTERNS: &[(&str, &str)] = &[
18    ("fn main", "user code must not define `fn main()` — the harness provides it"),
19    ("#![", "crate-level attributes (`#![...]`) are not supported in the harness body"),
20];
21
22/// The harness template that wraps user code.
23///
24/// The user provides `fn run(input: serde_json::Value) -> serde_json::Value`.
25/// The harness reads JSON from stdin, calls `run()`, and prints JSON to stdout.
26///
27/// ## Available to User Code
28///
29/// - `serde_json::Value` (imported at top level)
30/// - All public items from `serde_json` (e.g., `serde_json::json!`, `serde_json::Map`)
31/// - The full Rust standard library
32///
33/// ## Not Available
34///
35/// - External crates other than `serde_json`
36/// - `fn main()` (provided by the harness)
37/// - Crate-level attributes (`#![...]`)
38/// - Multi-file modules
39pub const HARNESS_TEMPLATE: &str = r#"use serde_json::Value;
40
41{user_code}
42
43fn main() {
44    let input: Value = serde_json::from_reader(std::io::stdin()).unwrap_or(Value::Null);
45    let output = run(input);
46    println!("{}", serde_json::to_string(&output).unwrap());
47}
48"#;
49
50/// Validate that user source code fits the phase 1 bounded source model.
51///
52/// The phase 1 model requires self-contained snippets that provide
53/// `fn run(input: Value) -> Value`. The harness supplies `fn main()`,
54/// `use serde_json::Value;`, and links `serde_json`. User code must not
55/// redefine `main` or use crate-level attributes.
56///
57/// Returns `Ok(())` if the code passes validation, or
58/// `Err(ExecutionError::InvalidRequest(...))` with a descriptive message.
59///
60/// # Example
61///
62/// ```rust
63/// use adk_code::validate_rust_source;
64///
65/// // Valid: provides the run() contract
66/// assert!(validate_rust_source(r#"
67///     fn run(input: serde_json::Value) -> serde_json::Value {
68///         input
69///     }
70/// "#).is_ok());
71///
72/// // Invalid: defines fn main()
73/// assert!(validate_rust_source(r#"
74///     fn main() { println!("hello"); }
75/// "#).is_err());
76/// ```
77pub fn validate_rust_source(code: &str) -> Result<(), ExecutionError> {
78    // Strip single-line comments and block comments to reduce false positives.
79    let stripped = strip_comments(code);
80
81    for &(pattern, reason) in REJECTED_PATTERNS {
82        if stripped.contains(pattern) {
83            return Err(ExecutionError::InvalidRequest(reason.to_string()));
84        }
85    }
86
87    Ok(())
88}
89
90/// Strip single-line (`//`) and block (`/* */`) comments from Rust source.
91///
92/// This is a best-effort heuristic for phase 1 validation. It does not handle
93/// string literals containing comment-like sequences, but that is acceptable
94/// for the patterns we check (e.g., `fn main` inside a string literal is
95/// unlikely and would be caught at compile time anyway).
96pub fn strip_comments(code: &str) -> String {
97    let mut result = String::with_capacity(code.len());
98    let mut chars = code.chars().peekable();
99
100    while let Some(c) = chars.next() {
101        if c == '/' {
102            match chars.peek() {
103                Some('/') => {
104                    // Single-line comment — skip to end of line.
105                    chars.next();
106                    for ch in chars.by_ref() {
107                        if ch == '\n' {
108                            result.push('\n');
109                            break;
110                        }
111                    }
112                }
113                Some('*') => {
114                    // Block comment — skip to closing `*/`.
115                    chars.next();
116                    let mut depth = 1u32;
117                    while depth > 0 {
118                        match chars.next() {
119                            Some('/') if chars.peek() == Some(&'*') => {
120                                chars.next();
121                                depth += 1;
122                            }
123                            Some('*') if chars.peek() == Some(&'/') => {
124                                chars.next();
125                                depth -= 1;
126                            }
127                            Some(_) => {}
128                            None => break,
129                        }
130                    }
131                    result.push(' ');
132                }
133                _ => result.push(c),
134            }
135        } else {
136            result.push(c);
137        }
138    }
139
140    result
141}
142
143/// Extract structured JSON output from stdout.
144///
145/// The harness prints the JSON output as the last line of stdout. This function
146/// tries to parse the last non-empty line as JSON. If successful, it returns
147/// the parsed value and the remaining stdout (everything before the last line).
148/// If parsing fails, it returns `None` and the full stdout.
149pub fn extract_structured_output(stdout: &str) -> (Option<serde_json::Value>, String) {
150    let trimmed = stdout.trim_end();
151    if trimmed.is_empty() {
152        return (None, String::new());
153    }
154
155    // Find the last line.
156    if let Some(last_newline_pos) = trimmed.rfind('\n') {
157        let last_line = &trimmed[last_newline_pos + 1..];
158        let before = &trimmed[..last_newline_pos];
159
160        if let Ok(value) = serde_json::from_str::<serde_json::Value>(last_line) {
161            return (Some(value), before.to_string());
162        }
163    } else {
164        // Only one line — try to parse it as JSON.
165        if let Ok(value) = serde_json::from_str::<serde_json::Value>(trimmed) {
166            return (Some(value), String::new());
167        }
168    }
169
170    (None, stdout.to_string())
171}
172
173/// Truncate output to the given byte limit. Returns the (possibly truncated)
174/// string and whether truncation occurred.
175pub fn truncate_output(output: String, max_bytes: usize) -> (String, bool) {
176    if output.len() <= max_bytes {
177        (output, false)
178    } else {
179        // Truncate at a char boundary.
180        let truncated = output
181            .char_indices()
182            .take_while(|(i, _)| *i < max_bytes)
183            .map(|(_, c)| c)
184            .collect::<String>();
185        (truncated, true)
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    // ── Truncation tests ───────────────────────────────────────────
194
195    #[test]
196    fn truncate_output_no_truncation() {
197        let (result, truncated) = truncate_output("hello".to_string(), 100);
198        assert_eq!(result, "hello");
199        assert!(!truncated);
200    }
201
202    #[test]
203    fn truncate_output_at_limit() {
204        let (result, truncated) = truncate_output("hello".to_string(), 5);
205        assert_eq!(result, "hello");
206        assert!(!truncated);
207    }
208
209    #[test]
210    fn truncate_output_over_limit() {
211        let (result, truncated) = truncate_output("hello world".to_string(), 5);
212        assert_eq!(result, "hello");
213        assert!(truncated);
214    }
215
216    #[test]
217    fn truncate_output_respects_char_boundaries() {
218        // Multi-byte character: "é" is 2 bytes in UTF-8.
219        // "café" = 5 bytes: c(0), a(1), f(2), é(3..4).
220        // With limit 4, "é" starts at byte 3 which is < 4, so it's included.
221        let (result, truncated) = truncate_output("café".to_string(), 4);
222        assert_eq!(result, "café");
223        assert!(truncated);
224
225        // With limit 3, "é" starts at byte 3 which is NOT < 3, so it's excluded.
226        let (result, truncated) = truncate_output("café".to_string(), 3);
227        assert_eq!(result, "caf");
228        assert!(truncated);
229    }
230
231    // ── Structured output extraction tests ─────────────────────────
232
233    #[test]
234    fn extract_structured_output_single_json_line() {
235        let (output, display) = extract_structured_output(r#"{"answer":42}"#);
236        assert_eq!(output, Some(serde_json::json!({"answer": 42})));
237        assert_eq!(display, "");
238    }
239
240    #[test]
241    fn extract_structured_output_with_preceding_text() {
242        let stdout = "some debug output\n{\"answer\":42}";
243        let (output, display) = extract_structured_output(stdout);
244        assert_eq!(output, Some(serde_json::json!({"answer": 42})));
245        assert_eq!(display, "some debug output");
246    }
247
248    #[test]
249    fn extract_structured_output_no_json() {
250        let stdout = "just plain text\nmore text";
251        let (output, display) = extract_structured_output(stdout);
252        assert!(output.is_none());
253        assert_eq!(display, stdout);
254    }
255
256    #[test]
257    fn extract_structured_output_empty() {
258        let (output, display) = extract_structured_output("");
259        assert!(output.is_none());
260        assert_eq!(display, "");
261    }
262
263    // ── Source model validation tests ──────────────────────────────
264
265    #[test]
266    fn validate_accepts_valid_run_function() {
267        let code = r#"
268fn run(input: serde_json::Value) -> serde_json::Value {
269    let v = input["x"].as_i64().unwrap_or(0);
270    serde_json::json!({ "result": v * 2 })
271}
272"#;
273        assert!(validate_rust_source(code).is_ok());
274    }
275
276    #[test]
277    fn validate_accepts_helper_functions() {
278        let code = r#"
279fn helper(x: i64) -> i64 { x + 1 }
280
281fn run(input: serde_json::Value) -> serde_json::Value {
282    let v = input["x"].as_i64().unwrap_or(0);
283    serde_json::json!({ "result": helper(v) })
284}
285"#;
286        assert!(validate_rust_source(code).is_ok());
287    }
288
289    #[test]
290    fn validate_rejects_fn_main() {
291        let code = r#"
292fn main() {
293    println!("hello");
294}
295"#;
296        let err = validate_rust_source(code).unwrap_err();
297        assert!(matches!(err, ExecutionError::InvalidRequest(_)));
298        assert!(err.to_string().contains("fn main()"));
299    }
300
301    #[test]
302    fn validate_rejects_crate_level_attributes() {
303        let code = r#"
304#![allow(unused)]
305fn run(input: serde_json::Value) -> serde_json::Value { input }
306"#;
307        let err = validate_rust_source(code).unwrap_err();
308        assert!(matches!(err, ExecutionError::InvalidRequest(_)));
309        assert!(err.to_string().contains("crate-level attributes"));
310    }
311
312    #[test]
313    fn validate_ignores_fn_main_in_comments() {
314        let code = r#"
315// fn main() is provided by the harness
316fn run(input: serde_json::Value) -> serde_json::Value { input }
317"#;
318        assert!(validate_rust_source(code).is_ok());
319    }
320
321    #[test]
322    fn validate_ignores_fn_main_in_block_comments() {
323        let code = r#"
324/* fn main() { } */
325fn run(input: serde_json::Value) -> serde_json::Value { input }
326"#;
327        assert!(validate_rust_source(code).is_ok());
328    }
329
330    #[test]
331    fn validate_ignores_crate_attr_in_comments() {
332        let code = r#"
333// #![allow(unused)]
334fn run(input: serde_json::Value) -> serde_json::Value { input }
335"#;
336        assert!(validate_rust_source(code).is_ok());
337    }
338
339    #[test]
340    fn validate_accepts_item_level_attributes() {
341        let code = r#"
342#[derive(Debug)]
343struct Foo { x: i64 }
344
345fn run(input: serde_json::Value) -> serde_json::Value { input }
346"#;
347        assert!(validate_rust_source(code).is_ok());
348    }
349
350    #[test]
351    fn validate_accepts_empty_code() {
352        assert!(validate_rust_source("").is_ok());
353    }
354
355    // ── Comment stripping tests ────────────────────────────────────
356
357    #[test]
358    fn strip_comments_removes_single_line() {
359        let code = "let x = 1; // this is a comment\nlet y = 2;";
360        let stripped = strip_comments(code);
361        assert!(!stripped.contains("this is a comment"));
362        assert!(stripped.contains("let x = 1;"));
363        assert!(stripped.contains("let y = 2;"));
364    }
365
366    #[test]
367    fn strip_comments_removes_block_comment() {
368        let code = "let x = /* hidden */ 1;";
369        let stripped = strip_comments(code);
370        assert!(!stripped.contains("hidden"));
371        assert!(stripped.contains("let x ="));
372        assert!(stripped.contains("1;"));
373    }
374
375    #[test]
376    fn strip_comments_handles_nested_block_comments() {
377        let code = "before /* outer /* inner */ still outer */ after";
378        let stripped = strip_comments(code);
379        assert!(!stripped.contains("outer"));
380        assert!(!stripped.contains("inner"));
381        assert!(stripped.contains("before"));
382        assert!(stripped.contains("after"));
383    }
384}