1use intent_codegen::Language;
7use intent_gen::{ApiError, LlmClient, Message};
8use intent_parser::ast;
9
10use crate::ImplementOptions;
11use crate::context;
12use crate::prompt;
13
14#[derive(Debug, thiserror::Error)]
16pub enum ImplementError {
17 #[error("API error: {0}")]
18 Api(#[from] ApiError),
19 #[error("validation failed after {retries} retries:\n{errors}")]
20 ValidationFailed { retries: u32, errors: String },
21}
22
23pub fn implement_with_retry(
25 client: &LlmClient,
26 file: &ast::File,
27 options: &ImplementOptions,
28) -> Result<String, ImplementError> {
29 let ctx = context::build_context(file, options.language);
30 let system = prompt::system_prompt(options.language);
31 let user_msg = prompt::user_message(&ctx, options.language);
32
33 let mut messages = vec![
34 Message {
35 role: "system",
36 content: system,
37 },
38 Message {
39 role: "user",
40 content: user_msg,
41 },
42 ];
43
44 let mut last_errors = Vec::new();
45
46 for attempt in 0..=options.max_retries {
47 if attempt == 0 {
48 eprintln!("Generating implementation from LLM...");
49 } else {
50 eprintln!(
51 "Retry {}/{}: feeding errors back to LLM...",
52 attempt, options.max_retries
53 );
54 }
55
56 let raw = client.chat(&messages)?;
57 if options.debug {
58 eprintln!("--- RAW LLM RESPONSE ---");
59 eprintln!("{raw}");
60 eprintln!("--- END RAW RESPONSE ---");
61 }
62
63 let code = strip_fences(&raw);
64 eprintln!("Validating generated code...");
65
66 match validate_output(&code, file, options.language) {
67 Ok(()) => {
68 eprintln!("Validation passed.");
69 return Ok(code);
70 }
71 Err(errors) => {
72 for e in &errors {
73 eprintln!(" {e}");
74 }
75 last_errors.clone_from(&errors);
76
77 if attempt < options.max_retries {
78 messages.push(Message {
79 role: "assistant",
80 content: raw,
81 });
82 messages.push(Message {
83 role: "user",
84 content: prompt::retry_message(&code, &errors, options.language),
85 });
86 }
87 }
88 }
89 }
90
91 Err(ImplementError::ValidationFailed {
92 retries: options.max_retries,
93 errors: last_errors.join("\n"),
94 })
95}
96
97pub fn validate_output(code: &str, file: &ast::File, lang: Language) -> Result<(), Vec<String>> {
105 let mut errors = Vec::new();
106
107 let expected = expected_names(file, lang);
109 for name in &expected {
110 if !code.contains(name.as_str()) {
111 errors.push(format!("missing expected identifier: {name}"));
112 }
113 }
114
115 if let Err(e) = check_balanced(code, lang) {
117 errors.push(e);
118 }
119
120 let stubs = leftover_stubs(code, lang);
122 for stub in stubs {
123 errors.push(format!("leftover stub found: {stub}"));
124 }
125
126 let test_names = intent_codegen::test_harness::expected_test_names(file);
128 for name in &test_names {
129 if !code.contains(name.as_str()) {
130 errors.push(format!("missing contract test: {name}"));
131 }
132 }
133
134 if errors.is_empty() {
135 Ok(())
136 } else {
137 Err(errors)
138 }
139}
140
141fn expected_names(file: &ast::File, lang: Language) -> Vec<String> {
143 let mut names = Vec::new();
144
145 for item in &file.items {
146 match item {
147 ast::TopLevelItem::Entity(e) => {
148 names.push(e.name.clone());
149 }
150 ast::TopLevelItem::Action(a) => {
151 let fn_name = match lang {
153 Language::Rust | Language::Python | Language::Go => {
154 intent_codegen::to_snake_case(&a.name)
155 }
156 Language::TypeScript | Language::Java | Language::Swift => {
157 intent_codegen::to_camel_case(&a.name)
158 }
159 Language::CSharp => a.name.clone(), };
161 names.push(fn_name);
162 }
163 _ => {}
164 }
165 }
166
167 names
168}
169
170fn check_balanced(code: &str, lang: Language) -> Result<(), String> {
172 let (braces, parens, brackets) = count_delimiters(code, lang);
173
174 if braces != 0 {
175 return Err(format!(
176 "unbalanced braces: {} more {} than {}",
177 braces.unsigned_abs(),
178 if braces > 0 { "opening" } else { "closing" },
179 if braces > 0 { "closing" } else { "opening" }
180 ));
181 }
182 if parens != 0 {
183 return Err(format!(
184 "unbalanced parentheses: {} more {} than {}",
185 parens.unsigned_abs(),
186 if parens > 0 { "opening" } else { "closing" },
187 if parens > 0 { "closing" } else { "opening" }
188 ));
189 }
190 if brackets != 0 {
191 return Err(format!(
192 "unbalanced brackets: {} more {} than {}",
193 brackets.unsigned_abs(),
194 if brackets > 0 { "opening" } else { "closing" },
195 if brackets > 0 { "closing" } else { "opening" }
196 ));
197 }
198 Ok(())
199}
200
201fn count_delimiters(code: &str, lang: Language) -> (i32, i32, i32) {
203 let mut braces = 0i32;
204 let mut parens = 0i32;
205 let mut brackets = 0i32;
206
207 for line in code.lines() {
208 let line = strip_comment(line, lang);
209 let mut in_string = false;
210 let mut escape_next = false;
211
212 for ch in line.chars() {
213 if escape_next {
214 escape_next = false;
215 continue;
216 }
217 if ch == '\\' && in_string {
218 escape_next = true;
219 continue;
220 }
221 if ch == '"' {
222 in_string = !in_string;
223 continue;
224 }
225 if ch == '\''
231 && matches!(
232 lang,
233 Language::Python | Language::TypeScript | Language::Swift
234 )
235 {
236 in_string = !in_string;
237 continue;
238 }
239 if in_string {
240 continue;
241 }
242
243 match ch {
244 '{' => braces += 1,
245 '}' => braces -= 1,
246 '(' => parens += 1,
247 ')' => parens -= 1,
248 '[' => brackets += 1,
249 ']' => brackets -= 1,
250 _ => {}
251 }
252 }
253 }
254
255 (braces, parens, brackets)
256}
257
258fn strip_comment(line: &str, lang: Language) -> &str {
260 match lang {
261 Language::Rust
262 | Language::TypeScript
263 | Language::Go
264 | Language::Java
265 | Language::CSharp
266 | Language::Swift => {
267 let mut in_string = false;
269 let mut prev = '\0';
270 for (i, ch) in line.char_indices() {
271 if ch == '"' && prev != '\\' {
272 in_string = !in_string;
273 }
274 if !in_string && ch == '/' && prev == '/' {
275 return &line[..i - 1];
276 }
277 prev = ch;
278 }
279 line
280 }
281 Language::Python => {
282 let mut in_string = false;
283 let mut prev = '\0';
284 for (i, ch) in line.char_indices() {
285 if (ch == '"' || ch == '\'') && prev != '\\' {
286 in_string = !in_string;
287 }
288 if !in_string && ch == '#' {
289 return &line[..i];
290 }
291 prev = ch;
292 }
293 line
294 }
295 }
296}
297
298fn leftover_stubs(code: &str, lang: Language) -> Vec<String> {
300 let mut stubs = Vec::new();
301
302 match lang {
303 Language::Rust => {
304 if code.contains("todo!()") {
305 stubs.push("todo!()".to_string());
306 }
307 if code.contains("unimplemented!()") {
308 stubs.push("unimplemented!()".to_string());
309 }
310 }
311 Language::TypeScript => {
312 if code.contains("throw new Error(\"not implemented\")")
313 || code.contains("throw new Error(\"Not implemented\")")
314 {
315 stubs.push("throw new Error(\"not implemented\")".to_string());
316 }
317 }
318 Language::Python => {
319 if code.contains("raise NotImplementedError") {
320 stubs.push("raise NotImplementedError".to_string());
321 }
322 }
323 Language::Go => {
324 if code.contains("panic(\"not implemented\")") || code.contains("panic(\"TODO\")") {
325 stubs.push("panic(\"not implemented\")".to_string());
326 }
327 }
328 Language::Java => {
329 if code.contains("throw new UnsupportedOperationException") {
330 stubs.push("throw new UnsupportedOperationException".to_string());
331 }
332 }
333 Language::CSharp => {
334 if code.contains("throw new NotImplementedException") {
335 stubs.push("throw new NotImplementedException".to_string());
336 }
337 }
338 Language::Swift => {
339 if code.contains("fatalError(\"TODO") {
340 stubs.push("fatalError(\"TODO: ...\")".to_string());
341 }
342 }
343 }
344
345 stubs
346}
347
348pub fn strip_fences(s: &str) -> String {
350 let trimmed = s.trim();
351
352 if let Some(rest) = trimmed.strip_prefix("```") {
353 let rest = if let Some(idx) = rest.find('\n') {
355 &rest[idx + 1..]
356 } else {
357 rest
358 };
359 if let Some(content) = rest.strip_suffix("```") {
360 return content.trim().to_string();
361 }
362 }
363
364 trimmed.to_string()
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370
371 fn parse(src: &str) -> ast::File {
372 intent_parser::parse_file(src).expect("parse failed")
373 }
374
375 #[test]
378 fn test_strip_fences_no_fences() {
379 let input = "fn main() {}";
380 assert_eq!(strip_fences(input), input);
381 }
382
383 #[test]
384 fn test_strip_fences_with_lang() {
385 let input = "```rust\nfn main() {}\n```";
386 assert_eq!(strip_fences(input), "fn main() {}");
387 }
388
389 #[test]
390 fn test_strip_fences_without_lang() {
391 let input = "```\nfn main() {}\n```";
392 assert_eq!(strip_fences(input), "fn main() {}");
393 }
394
395 #[test]
398 fn test_validate_valid_rust() {
399 let src =
400 "module Test\n\nentity Foo {\n id: UUID\n}\n\naction CreateFoo {\n name: String\n}\n";
401 let ast = parse(src);
402 let code = "struct Foo { id: String }\n\nfn create_foo(name: &str) -> Foo {\n Foo { id: name.to_string() }\n}\n";
403 assert!(validate_output(code, &ast, Language::Rust).is_ok());
404 }
405
406 #[test]
407 fn test_validate_missing_name() {
408 let src =
409 "module Test\n\nentity Foo {\n id: UUID\n}\n\naction CreateFoo {\n name: String\n}\n";
410 let ast = parse(src);
411 let code = "struct Foo { id: String }\n// function not defined\n";
412 let err = validate_output(code, &ast, Language::Rust).unwrap_err();
413 assert!(err.iter().any(|e| e.contains("create_foo")));
414 }
415
416 #[test]
417 fn test_validate_leftover_todo() {
418 let src =
419 "module Test\n\nentity Foo {\n id: UUID\n}\n\naction CreateFoo {\n name: String\n}\n";
420 let ast = parse(src);
421 let code =
422 "struct Foo { id: String }\n\nfn create_foo(name: &str) -> Foo {\n todo!()\n}\n";
423 let err = validate_output(code, &ast, Language::Rust).unwrap_err();
424 assert!(err.iter().any(|e| e.contains("todo!()")));
425 }
426
427 #[test]
428 fn test_validate_unbalanced_braces() {
429 let src = "module Test\n\nentity Foo {\n id: UUID\n}\n";
430 let ast = parse(src);
431 let code = "struct Foo { id: String\n";
432 let err = validate_output(code, &ast, Language::Rust).unwrap_err();
433 assert!(err.iter().any(|e| e.contains("unbalanced")));
434 }
435
436 #[test]
439 fn test_validate_valid_typescript() {
440 let src =
441 "module Test\n\nentity Foo {\n id: UUID\n}\n\naction CreateFoo {\n name: String\n}\n";
442 let ast = parse(src);
443 let code = "interface Foo { id: string; }\n\nfunction createFoo(name: string): Foo {\n return { id: name };\n}\n";
444 assert!(validate_output(code, &ast, Language::TypeScript).is_ok());
445 }
446
447 #[test]
450 fn test_validate_valid_python() {
451 let src =
452 "module Test\n\nentity Foo {\n id: UUID\n}\n\naction CreateFoo {\n name: String\n}\n";
453 let ast = parse(src);
454 let code = "from dataclasses import dataclass\n\n@dataclass\nclass Foo:\n id: str\n\ndef create_foo(name: str) -> Foo:\n return Foo(id=name)\n";
455 assert!(validate_output(code, &ast, Language::Python).is_ok());
456 }
457
458 #[test]
459 fn test_validate_python_leftover_raise() {
460 let src =
461 "module Test\n\nentity Foo {\n id: UUID\n}\n\naction CreateFoo {\n name: String\n}\n";
462 let ast = parse(src);
463 let code = "class Foo:\n id: str\n\ndef create_foo(name: str) -> Foo:\n raise NotImplementedError\n";
464 let err = validate_output(code, &ast, Language::Python).unwrap_err();
465 assert!(err.iter().any(|e| e.contains("NotImplementedError")));
466 }
467
468 #[test]
471 fn test_expected_names_rust() {
472 let src = "module Test\n\nentity Account {\n id: UUID\n}\n\naction FreezeAccount {\n id: UUID\n}\n";
473 let ast = parse(src);
474 let names = expected_names(&ast, Language::Rust);
475 assert!(names.contains(&"Account".to_string()));
476 assert!(names.contains(&"freeze_account".to_string()));
477 }
478
479 #[test]
480 fn test_expected_names_typescript() {
481 let src = "module Test\n\nentity Account {\n id: UUID\n}\n\naction FreezeAccount {\n id: UUID\n}\n";
482 let ast = parse(src);
483 let names = expected_names(&ast, Language::TypeScript);
484 assert!(names.contains(&"Account".to_string()));
485 assert!(names.contains(&"freezeAccount".to_string()));
486 }
487
488 #[test]
491 fn test_balanced_delimiters() {
492 let code = "fn foo() { let x = (1 + 2); let arr = [1, 2, 3]; }";
493 assert!(check_balanced(code, Language::Rust).is_ok());
494 }
495
496 #[test]
497 fn test_delimiters_in_strings_ignored() {
498 let code = "let s = \"({[\"; let t = \"]})\";";
499 let (b, p, br) = count_delimiters(code, Language::Rust);
501 assert_eq!(b, 0);
502 assert_eq!(p, 0);
503 assert_eq!(br, 0);
504 }
505
506 #[test]
507 fn test_rust_lifetimes_not_treated_as_strings() {
508 let code = "impl Foo {\n fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {\n Ok(())\n }\n}";
511 assert!(check_balanced(code, Language::Rust).is_ok());
512 }
513
514 #[test]
515 fn test_python_single_quote_strings() {
516 let code = "x = '({['\ny = ']})'\n";
517 let (b, p, br) = count_delimiters(code, Language::Python);
518 assert_eq!(b, 0);
519 assert_eq!(p, 0);
520 assert_eq!(br, 0);
521 }
522
523 #[test]
526 fn test_validate_missing_contract_test() {
527 let src = r#"module Test
528
529entity Foo { id: UUID }
530
531action Bar { x: Int }
532
533test "happy path" {
534 given { x = 42 }
535 when Bar { x: x }
536 then { x == 42 }
537}
538"#;
539 let ast = parse(src);
540 let code = "struct Foo { id: String }\n\nfn bar(x: i64) -> Result<(), String> { Ok(()) }\n";
542 let err = validate_output(code, &ast, Language::Rust).unwrap_err();
543 assert!(
544 err.iter()
545 .any(|e| e.contains("missing contract test: test_happy_path"))
546 );
547 }
548
549 #[test]
550 fn test_validate_with_contract_test_present() {
551 let src = r#"module Test
552
553entity Foo { id: UUID }
554
555action Bar { x: Int }
556
557test "happy path" {
558 given { x = 42 }
559 when Bar { x: x }
560 then { x == 42 }
561}
562"#;
563 let ast = parse(src);
564 let code = "struct Foo { id: String }\n\nfn bar(x: i64) -> Result<(), String> { Ok(()) }\n\n#[cfg(test)]\nmod contract_tests {\n use super::*;\n #[test]\n fn test_happy_path() { assert!(true); }\n}\n";
565 assert!(validate_output(code, &ast, Language::Rust).is_ok());
566 }
567}