Skip to main content

intent_implement/
validate.rs

1//! Output validation and generate-check-retry loop.
2//!
3//! Validates LLM-generated code for structural correctness (expected names
4//! present, balanced delimiters, no leftover stubs) and retries on failure.
5
6use intent_codegen::Language;
7use intent_gen::{ApiError, LlmClient, Message};
8use intent_parser::ast;
9
10use crate::ImplementOptions;
11use crate::context;
12use crate::prompt;
13
14/// Errors from the implementation pipeline.
15#[derive(Debug, thiserror::Error)]
16pub enum ImplementError {
17    #[error("API error: {0}")]
18    Api(#[from] ApiError),
19    #[error("validation failed after {retries} retries:\n{errors}")]
20    ValidationFailed { retries: u32, errors: String },
21}
22
23/// Run the implement-check-retry loop.
24pub fn implement_with_retry(
25    client: &LlmClient,
26    file: &ast::File,
27    options: &ImplementOptions,
28) -> Result<String, ImplementError> {
29    let ctx = context::build_context(file, options.language);
30    let system = prompt::system_prompt(options.language);
31    let user_msg = prompt::user_message(&ctx, options.language);
32
33    let mut messages = vec![
34        Message {
35            role: "system",
36            content: system,
37        },
38        Message {
39            role: "user",
40            content: user_msg,
41        },
42    ];
43
44    let mut last_errors = Vec::new();
45
46    for attempt in 0..=options.max_retries {
47        if attempt == 0 {
48            eprintln!("Generating implementation from LLM...");
49        } else {
50            eprintln!(
51                "Retry {}/{}: feeding errors back to LLM...",
52                attempt, options.max_retries
53            );
54        }
55
56        let raw = client.chat(&messages)?;
57        if options.debug {
58            eprintln!("--- RAW LLM RESPONSE ---");
59            eprintln!("{raw}");
60            eprintln!("--- END RAW RESPONSE ---");
61        }
62
63        let code = strip_fences(&raw);
64        eprintln!("Validating generated code...");
65
66        match validate_output(&code, file, options.language) {
67            Ok(()) => {
68                eprintln!("Validation passed.");
69                return Ok(code);
70            }
71            Err(errors) => {
72                for e in &errors {
73                    eprintln!("  {e}");
74                }
75                last_errors.clone_from(&errors);
76
77                if attempt < options.max_retries {
78                    messages.push(Message {
79                        role: "assistant",
80                        content: raw,
81                    });
82                    messages.push(Message {
83                        role: "user",
84                        content: prompt::retry_message(&code, &errors, options.language),
85                    });
86                }
87            }
88        }
89    }
90
91    Err(ImplementError::ValidationFailed {
92        retries: options.max_retries,
93        errors: last_errors.join("\n"),
94    })
95}
96
97/// Validate generated code for structural correctness.
98///
99/// Checks:
100/// 1. Expected entity/action names are present
101/// 2. Delimiters are balanced
102/// 3. No leftover `todo!()` / `throw "not implemented"` / `raise NotImplementedError` stubs
103/// 4. Contract test functions are present (if spec has test blocks)
104pub fn validate_output(code: &str, file: &ast::File, lang: Language) -> Result<(), Vec<String>> {
105    let mut errors = Vec::new();
106
107    // Check expected names
108    let expected = expected_names(file, lang);
109    for name in &expected {
110        if !code.contains(name.as_str()) {
111            errors.push(format!("missing expected identifier: {name}"));
112        }
113    }
114
115    // Check balanced delimiters
116    if let Err(e) = check_balanced(code, lang) {
117        errors.push(e);
118    }
119
120    // Check for leftover stubs
121    let stubs = leftover_stubs(code, lang);
122    for stub in stubs {
123        errors.push(format!("leftover stub found: {stub}"));
124    }
125
126    // Check that contract test functions are present
127    let test_names = intent_codegen::test_harness::expected_test_names(file);
128    for name in &test_names {
129        if !code.contains(name.as_str()) {
130            errors.push(format!("missing contract test: {name}"));
131        }
132    }
133
134    if errors.is_empty() {
135        Ok(())
136    } else {
137        Err(errors)
138    }
139}
140
141/// Extract the names we expect to find in the generated code.
142fn expected_names(file: &ast::File, lang: Language) -> Vec<String> {
143    let mut names = Vec::new();
144
145    for item in &file.items {
146        match item {
147            ast::TopLevelItem::Entity(e) => {
148                names.push(e.name.clone());
149            }
150            ast::TopLevelItem::Action(a) => {
151                // Actions become functions with language-appropriate naming
152                let fn_name = match lang {
153                    Language::Rust | Language::Python | Language::Go => {
154                        intent_codegen::to_snake_case(&a.name)
155                    }
156                    Language::TypeScript | Language::Java | Language::Swift => {
157                        intent_codegen::to_camel_case(&a.name)
158                    }
159                    Language::CSharp => a.name.clone(), // PascalCase
160                };
161                names.push(fn_name);
162            }
163            _ => {}
164        }
165    }
166
167    names
168}
169
170/// Check that delimiters are balanced in the code.
171fn check_balanced(code: &str, lang: Language) -> Result<(), String> {
172    let (braces, parens, brackets) = count_delimiters(code, lang);
173
174    if braces != 0 {
175        return Err(format!(
176            "unbalanced braces: {} more {} than {}",
177            braces.unsigned_abs(),
178            if braces > 0 { "opening" } else { "closing" },
179            if braces > 0 { "closing" } else { "opening" }
180        ));
181    }
182    if parens != 0 {
183        return Err(format!(
184            "unbalanced parentheses: {} more {} than {}",
185            parens.unsigned_abs(),
186            if parens > 0 { "opening" } else { "closing" },
187            if parens > 0 { "closing" } else { "opening" }
188        ));
189    }
190    if brackets != 0 {
191        return Err(format!(
192            "unbalanced brackets: {} more {} than {}",
193            brackets.unsigned_abs(),
194            if brackets > 0 { "opening" } else { "closing" },
195            if brackets > 0 { "closing" } else { "opening" }
196        ));
197    }
198    Ok(())
199}
200
201/// Count net delimiters in code, skipping strings and comments.
202fn count_delimiters(code: &str, lang: Language) -> (i32, i32, i32) {
203    let mut braces = 0i32;
204    let mut parens = 0i32;
205    let mut brackets = 0i32;
206
207    for line in code.lines() {
208        let line = strip_comment(line, lang);
209        let mut in_string = false;
210        let mut escape_next = false;
211
212        for ch in line.chars() {
213            if escape_next {
214                escape_next = false;
215                continue;
216            }
217            if ch == '\\' && in_string {
218                escape_next = true;
219                continue;
220            }
221            if ch == '"' {
222                in_string = !in_string;
223                continue;
224            }
225            // In Rust, ' is for lifetimes ('a, '_, 'static) — not strings.
226            // Treating it as a string toggle causes false positives on lines
227            // like `Formatter<'_>) -> Result {` where the trailing { is missed.
228            // In Go, ' is for rune literals which are short and self-closing.
229            // Only Python and TypeScript use ' as a string delimiter.
230            if ch == '\''
231                && matches!(
232                    lang,
233                    Language::Python | Language::TypeScript | Language::Swift
234                )
235            {
236                in_string = !in_string;
237                continue;
238            }
239            if in_string {
240                continue;
241            }
242
243            match ch {
244                '{' => braces += 1,
245                '}' => braces -= 1,
246                '(' => parens += 1,
247                ')' => parens -= 1,
248                '[' => brackets += 1,
249                ']' => brackets -= 1,
250                _ => {}
251            }
252        }
253    }
254
255    (braces, parens, brackets)
256}
257
258/// Strip single-line comments from a line.
259fn strip_comment(line: &str, lang: Language) -> &str {
260    match lang {
261        Language::Rust
262        | Language::TypeScript
263        | Language::Go
264        | Language::Java
265        | Language::CSharp
266        | Language::Swift => {
267            // Find // outside of strings
268            let mut in_string = false;
269            let mut prev = '\0';
270            for (i, ch) in line.char_indices() {
271                if ch == '"' && prev != '\\' {
272                    in_string = !in_string;
273                }
274                if !in_string && ch == '/' && prev == '/' {
275                    return &line[..i - 1];
276                }
277                prev = ch;
278            }
279            line
280        }
281        Language::Python => {
282            let mut in_string = false;
283            let mut prev = '\0';
284            for (i, ch) in line.char_indices() {
285                if (ch == '"' || ch == '\'') && prev != '\\' {
286                    in_string = !in_string;
287                }
288                if !in_string && ch == '#' {
289                    return &line[..i];
290                }
291                prev = ch;
292            }
293            line
294        }
295    }
296}
297
298/// Check for leftover implementation stubs.
299fn leftover_stubs(code: &str, lang: Language) -> Vec<String> {
300    let mut stubs = Vec::new();
301
302    match lang {
303        Language::Rust => {
304            if code.contains("todo!()") {
305                stubs.push("todo!()".to_string());
306            }
307            if code.contains("unimplemented!()") {
308                stubs.push("unimplemented!()".to_string());
309            }
310        }
311        Language::TypeScript => {
312            if code.contains("throw new Error(\"not implemented\")")
313                || code.contains("throw new Error(\"Not implemented\")")
314            {
315                stubs.push("throw new Error(\"not implemented\")".to_string());
316            }
317        }
318        Language::Python => {
319            if code.contains("raise NotImplementedError") {
320                stubs.push("raise NotImplementedError".to_string());
321            }
322        }
323        Language::Go => {
324            if code.contains("panic(\"not implemented\")") || code.contains("panic(\"TODO\")") {
325                stubs.push("panic(\"not implemented\")".to_string());
326            }
327        }
328        Language::Java => {
329            if code.contains("throw new UnsupportedOperationException") {
330                stubs.push("throw new UnsupportedOperationException".to_string());
331            }
332        }
333        Language::CSharp => {
334            if code.contains("throw new NotImplementedException") {
335                stubs.push("throw new NotImplementedException".to_string());
336            }
337        }
338        Language::Swift => {
339            if code.contains("fatalError(\"TODO") {
340                stubs.push("fatalError(\"TODO: ...\")".to_string());
341            }
342        }
343    }
344
345    stubs
346}
347
348/// Strip markdown code fences if the LLM wraps the output in them.
349pub fn strip_fences(s: &str) -> String {
350    let trimmed = s.trim();
351
352    if let Some(rest) = trimmed.strip_prefix("```") {
353        // Skip optional language tag on the first line
354        let rest = if let Some(idx) = rest.find('\n') {
355            &rest[idx + 1..]
356        } else {
357            rest
358        };
359        if let Some(content) = rest.strip_suffix("```") {
360            return content.trim().to_string();
361        }
362    }
363
364    trimmed.to_string()
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    fn parse(src: &str) -> ast::File {
372        intent_parser::parse_file(src).expect("parse failed")
373    }
374
375    // ── strip_fences ───────────────────────────────────────
376
377    #[test]
378    fn test_strip_fences_no_fences() {
379        let input = "fn main() {}";
380        assert_eq!(strip_fences(input), input);
381    }
382
383    #[test]
384    fn test_strip_fences_with_lang() {
385        let input = "```rust\nfn main() {}\n```";
386        assert_eq!(strip_fences(input), "fn main() {}");
387    }
388
389    #[test]
390    fn test_strip_fences_without_lang() {
391        let input = "```\nfn main() {}\n```";
392        assert_eq!(strip_fences(input), "fn main() {}");
393    }
394
395    // ── validate_output ────────────────────────────────────
396
397    #[test]
398    fn test_validate_valid_rust() {
399        let src =
400            "module Test\n\nentity Foo {\n  id: UUID\n}\n\naction CreateFoo {\n  name: String\n}\n";
401        let ast = parse(src);
402        let code = "struct Foo { id: String }\n\nfn create_foo(name: &str) -> Foo {\n    Foo { id: name.to_string() }\n}\n";
403        assert!(validate_output(code, &ast, Language::Rust).is_ok());
404    }
405
406    #[test]
407    fn test_validate_missing_name() {
408        let src =
409            "module Test\n\nentity Foo {\n  id: UUID\n}\n\naction CreateFoo {\n  name: String\n}\n";
410        let ast = parse(src);
411        let code = "struct Foo { id: String }\n// function not defined\n";
412        let err = validate_output(code, &ast, Language::Rust).unwrap_err();
413        assert!(err.iter().any(|e| e.contains("create_foo")));
414    }
415
416    #[test]
417    fn test_validate_leftover_todo() {
418        let src =
419            "module Test\n\nentity Foo {\n  id: UUID\n}\n\naction CreateFoo {\n  name: String\n}\n";
420        let ast = parse(src);
421        let code =
422            "struct Foo { id: String }\n\nfn create_foo(name: &str) -> Foo {\n    todo!()\n}\n";
423        let err = validate_output(code, &ast, Language::Rust).unwrap_err();
424        assert!(err.iter().any(|e| e.contains("todo!()")));
425    }
426
427    #[test]
428    fn test_validate_unbalanced_braces() {
429        let src = "module Test\n\nentity Foo {\n  id: UUID\n}\n";
430        let ast = parse(src);
431        let code = "struct Foo { id: String\n";
432        let err = validate_output(code, &ast, Language::Rust).unwrap_err();
433        assert!(err.iter().any(|e| e.contains("unbalanced")));
434    }
435
436    // ── TypeScript validation ──────────────────────────────
437
438    #[test]
439    fn test_validate_valid_typescript() {
440        let src =
441            "module Test\n\nentity Foo {\n  id: UUID\n}\n\naction CreateFoo {\n  name: String\n}\n";
442        let ast = parse(src);
443        let code = "interface Foo { id: string; }\n\nfunction createFoo(name: string): Foo {\n    return { id: name };\n}\n";
444        assert!(validate_output(code, &ast, Language::TypeScript).is_ok());
445    }
446
447    // ── Python validation ──────────────────────────────────
448
449    #[test]
450    fn test_validate_valid_python() {
451        let src =
452            "module Test\n\nentity Foo {\n  id: UUID\n}\n\naction CreateFoo {\n  name: String\n}\n";
453        let ast = parse(src);
454        let code = "from dataclasses import dataclass\n\n@dataclass\nclass Foo:\n    id: str\n\ndef create_foo(name: str) -> Foo:\n    return Foo(id=name)\n";
455        assert!(validate_output(code, &ast, Language::Python).is_ok());
456    }
457
458    #[test]
459    fn test_validate_python_leftover_raise() {
460        let src =
461            "module Test\n\nentity Foo {\n  id: UUID\n}\n\naction CreateFoo {\n  name: String\n}\n";
462        let ast = parse(src);
463        let code = "class Foo:\n    id: str\n\ndef create_foo(name: str) -> Foo:\n    raise NotImplementedError\n";
464        let err = validate_output(code, &ast, Language::Python).unwrap_err();
465        assert!(err.iter().any(|e| e.contains("NotImplementedError")));
466    }
467
468    // ── expected_names ─────────────────────────────────────
469
470    #[test]
471    fn test_expected_names_rust() {
472        let src = "module Test\n\nentity Account {\n  id: UUID\n}\n\naction FreezeAccount {\n  id: UUID\n}\n";
473        let ast = parse(src);
474        let names = expected_names(&ast, Language::Rust);
475        assert!(names.contains(&"Account".to_string()));
476        assert!(names.contains(&"freeze_account".to_string()));
477    }
478
479    #[test]
480    fn test_expected_names_typescript() {
481        let src = "module Test\n\nentity Account {\n  id: UUID\n}\n\naction FreezeAccount {\n  id: UUID\n}\n";
482        let ast = parse(src);
483        let names = expected_names(&ast, Language::TypeScript);
484        assert!(names.contains(&"Account".to_string()));
485        assert!(names.contains(&"freezeAccount".to_string()));
486    }
487
488    // ── delimiter counting ─────────────────────────────────
489
490    #[test]
491    fn test_balanced_delimiters() {
492        let code = "fn foo() { let x = (1 + 2); let arr = [1, 2, 3]; }";
493        assert!(check_balanced(code, Language::Rust).is_ok());
494    }
495
496    #[test]
497    fn test_delimiters_in_strings_ignored() {
498        let code = "let s = \"({[\"; let t = \"]})\";";
499        // Strings contain unbalanced delimiters but they should be ignored
500        let (b, p, br) = count_delimiters(code, Language::Rust);
501        assert_eq!(b, 0);
502        assert_eq!(p, 0);
503        assert_eq!(br, 0);
504    }
505
506    #[test]
507    fn test_rust_lifetimes_not_treated_as_strings() {
508        // Rust lifetimes ('a, '_, 'static) must not toggle string mode,
509        // otherwise the { at the end of lines like this gets skipped.
510        let code = "impl Foo {\n    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {\n        Ok(())\n    }\n}";
511        assert!(check_balanced(code, Language::Rust).is_ok());
512    }
513
514    #[test]
515    fn test_python_single_quote_strings() {
516        let code = "x = '({['\ny = ']})'\n";
517        let (b, p, br) = count_delimiters(code, Language::Python);
518        assert_eq!(b, 0);
519        assert_eq!(p, 0);
520        assert_eq!(br, 0);
521    }
522
523    // ── contract test validation ──────────────────────────────
524
525    #[test]
526    fn test_validate_missing_contract_test() {
527        let src = r#"module Test
528
529entity Foo { id: UUID }
530
531action Bar { x: Int }
532
533test "happy path" {
534  given { x = 42 }
535  when Bar { x: x }
536  then { x == 42 }
537}
538"#;
539        let ast = parse(src);
540        // Code has entity + action but missing the contract test function
541        let code = "struct Foo { id: String }\n\nfn bar(x: i64) -> Result<(), String> { Ok(()) }\n";
542        let err = validate_output(code, &ast, Language::Rust).unwrap_err();
543        assert!(
544            err.iter()
545                .any(|e| e.contains("missing contract test: test_happy_path"))
546        );
547    }
548
549    #[test]
550    fn test_validate_with_contract_test_present() {
551        let src = r#"module Test
552
553entity Foo { id: UUID }
554
555action Bar { x: Int }
556
557test "happy path" {
558  given { x = 42 }
559  when Bar { x: x }
560  then { x == 42 }
561}
562"#;
563        let ast = parse(src);
564        let code = "struct Foo { id: String }\n\nfn bar(x: i64) -> Result<(), String> { Ok(()) }\n\n#[cfg(test)]\nmod contract_tests {\n    use super::*;\n    #[test]\n    fn test_happy_path() { assert!(true); }\n}\n";
565        assert!(validate_output(code, &ast, Language::Rust).is_ok());
566    }
567}