1use crate::context_builder::AnalysisContext;
6use serde::{Deserialize, Serialize};
7use thiserror::Error;
8
9#[derive(Debug, Error)]
11pub enum LlmError {
12 #[error("Failed to create prompt: {0}")]
14 PromptCreation(String),
15 #[error("LLM API error: {0}")]
17 ApiError(String),
18 #[error("Failed to parse response: {0}")]
20 ParseError(String),
21 #[error("Generated code is invalid: {0}")]
23 InvalidCode(String),
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct GeneratedCode {
29 pub code: String,
31 pub confidence: f64,
33 pub reasoning: String,
35 pub warnings: Vec<String>,
37}
38
39#[derive(Debug, Clone)]
41pub struct CodegenPrompt {
42 pub c_source: String,
44 pub context: AnalysisContext,
46 pub instructions: String,
48}
49
50impl CodegenPrompt {
51 pub fn new(c_source: &str, context: AnalysisContext) -> Self {
53 Self { c_source: c_source.to_string(), context, instructions: String::new() }
54 }
55
56 pub fn with_instructions(mut self, instructions: &str) -> Self {
58 self.instructions = instructions.to_string();
59 self
60 }
61
62 pub fn render(&self) -> String {
64 let mut prompt = String::new();
65
66 prompt.push_str("# C to Rust Transpilation Task\n\n");
67 prompt.push_str("## Source C Code\n```c\n");
68 prompt.push_str(&self.c_source);
69 prompt.push_str("\n```\n\n");
70
71 prompt.push_str("## Static Analysis Context\n");
73 if let Ok(context_json) = serde_json::to_string_pretty(&self.context) {
74 prompt.push_str("```json\n");
75 prompt.push_str(&context_json);
76 prompt.push_str("\n```\n\n");
77 }
78
79 for func in &self.context.functions {
81 if !func.ownership.is_empty() {
82 prompt.push_str(&format!("### Function: {}\n", func.name));
83 prompt.push_str("Ownership analysis:\n");
84 for (var, info) in &func.ownership {
85 prompt.push_str(&format!(
86 "- `{}`: {} (confidence: {:.0}%)\n",
87 var,
88 info.kind,
89 info.confidence * 100.0
90 ));
91 }
92 prompt.push('\n');
93 }
94 }
95
96 if !self.instructions.is_empty() {
97 prompt.push_str("## Additional Instructions\n");
98 prompt.push_str(&self.instructions);
99 prompt.push_str("\n\n");
100 }
101
102 prompt.push_str("## Task\n");
103 prompt.push_str("Generate idiomatic, safe Rust code that is functionally equivalent to the C code above.\n");
104 prompt.push_str(
105 "Use the static analysis context to guide ownership and borrowing decisions.\n",
106 );
107
108 prompt
109 }
110}
111
112#[derive(Debug)]
114pub struct LlmCodegen {
115 model: String,
117}
118
119impl LlmCodegen {
120 pub fn new(model: &str) -> Self {
122 Self { model: model.to_string() }
123 }
124
125 pub fn generate(&self, _prompt: &CodegenPrompt) -> Result<GeneratedCode, LlmError> {
130 Err(LlmError::ApiError(format!("LLM API not configured for model: {}", self.model)))
132 }
133
134 pub fn parse_response(&self, response: &str) -> Result<GeneratedCode, LlmError> {
140 contract_pre_parse!();
141 if let Ok(generated) = serde_json::from_str::<GeneratedCode>(response.trim()) {
143 return Ok(generated);
144 }
145
146 if let Some(code) = Self::extract_rust_code_block(response) {
148 let reasoning = Self::extract_reasoning(response);
150
151 return Ok(GeneratedCode {
152 code,
153 confidence: 0.8, reasoning,
155 warnings: Vec::new(),
156 });
157 }
158
159 Err(LlmError::ParseError("No valid Rust code found in response".to_string()))
160 }
161
162 fn extract_rust_code_block(response: &str) -> Option<String> {
164 let markers = ["```rust", "```"];
166
167 for marker in markers {
168 if let Some(start) = response.find(marker) {
169 let code_start = start + marker.len();
170 let code_start = response[code_start..]
172 .find('\n')
173 .map(|i| code_start + i + 1)
174 .unwrap_or(code_start);
175
176 if let Some(end) = response[code_start..].find("```") {
178 let code = response[code_start..code_start + end].trim();
179 if !code.is_empty() {
180 return Some(code.to_string());
181 }
182 }
183 }
184 }
185
186 None
187 }
188
189 fn extract_reasoning(response: &str) -> String {
191 if let Some(last_fence) = response.rfind("```") {
193 let after = &response[last_fence + 3..];
194 let reasoning = after.trim();
195 if !reasoning.is_empty() {
196 return reasoning.to_string();
197 }
198 }
199 "Generated from C source".to_string()
200 }
201
202 pub fn validate_code(&self, code: &str) -> Result<(), LlmError> {
209 let open_braces = code.matches('{').count();
211 let close_braces = code.matches('}').count();
212
213 if open_braces != close_braces {
214 return Err(LlmError::InvalidCode(format!(
215 "Unbalanced braces: {} open, {} close",
216 open_braces, close_braces
217 )));
218 }
219
220 let open_parens = code.matches('(').count();
222 let close_parens = code.matches(')').count();
223
224 if open_parens != close_parens {
225 return Err(LlmError::InvalidCode(format!(
226 "Unbalanced parentheses: {} open, {} close",
227 open_parens, close_parens
228 )));
229 }
230
231 if code.contains("fn ") {
233 return Ok(());
235 }
236
237 if !code.trim().is_empty() {
239 return Ok(());
240 }
241
242 Err(LlmError::InvalidCode("Empty code".to_string()))
243 }
244}
245
246impl Default for LlmCodegen {
247 fn default() -> Self {
248 Self::new("claude-3-sonnet")
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[test]
257 fn test_extract_code_block() {
258 let response = "Here's the code:\n```rust\nfn main() {}\n```\nDone!";
259 let code = LlmCodegen::extract_rust_code_block(response);
260 assert!(code.is_some());
261 assert!(code.unwrap().contains("fn main"));
262 }
263
264 #[test]
269 fn prompt_new_default_instructions_empty() {
270 let ctx = AnalysisContext { functions: vec![] };
271 let prompt = CodegenPrompt::new("int x = 5;", ctx);
272 assert_eq!(prompt.c_source, "int x = 5;");
273 assert!(prompt.instructions.is_empty());
274 }
275
276 #[test]
277 fn prompt_with_instructions() {
278 let ctx = AnalysisContext { functions: vec![] };
279 let prompt = CodegenPrompt::new("int x;", ctx).with_instructions("Use safe Rust only");
280 assert_eq!(prompt.instructions, "Use safe Rust only");
281 }
282
283 #[test]
284 fn prompt_render_contains_c_source() {
285 let ctx = AnalysisContext { functions: vec![] };
286 let prompt = CodegenPrompt::new("int main() { return 0; }", ctx);
287 let rendered = prompt.render();
288 assert!(rendered.contains("int main() { return 0; }"));
289 assert!(rendered.contains("# C to Rust Transpilation Task"));
290 assert!(rendered.contains("## Source C Code"));
291 }
292
293 #[test]
294 fn prompt_render_contains_instructions_when_set() {
295 let ctx = AnalysisContext { functions: vec![] };
296 let prompt =
297 CodegenPrompt::new("void f();", ctx).with_instructions("Prefer Box over raw ptrs");
298 let rendered = prompt.render();
299 assert!(rendered.contains("## Additional Instructions"));
300 assert!(rendered.contains("Prefer Box over raw ptrs"));
301 }
302
303 #[test]
304 fn prompt_render_no_instructions_section_when_empty() {
305 let ctx = AnalysisContext { functions: vec![] };
306 let prompt = CodegenPrompt::new("void f();", ctx);
307 let rendered = prompt.render();
308 assert!(!rendered.contains("## Additional Instructions"));
309 }
310
311 #[test]
312 fn prompt_render_includes_ownership_info() {
313 use crate::context_builder::{FunctionContext, OwnershipInfo};
314 use std::collections::HashMap;
315
316 let mut ownership = HashMap::new();
317 ownership.insert(
318 "ptr".to_string(),
319 OwnershipInfo {
320 kind: "owning".to_string(),
321 confidence: 0.95,
322 reason: "malloc detected".to_string(),
323 },
324 );
325
326 let ctx = AnalysisContext {
327 functions: vec![FunctionContext {
328 name: "alloc_data".to_string(),
329 c_signature: "void* alloc_data()".to_string(),
330 ownership,
331 lifetimes: vec![],
332 lock_mappings: HashMap::new(),
333 }],
334 };
335 let prompt = CodegenPrompt::new("void* alloc_data() { return malloc(8); }", ctx);
336 let rendered = prompt.render();
337 assert!(rendered.contains("### Function: alloc_data"));
338 assert!(rendered.contains("`ptr`: owning"));
339 assert!(rendered.contains("95%"));
340 }
341
342 #[test]
343 fn prompt_render_skips_functions_with_no_ownership() {
344 use crate::context_builder::FunctionContext;
345 use std::collections::HashMap;
346
347 let ctx = AnalysisContext {
348 functions: vec![FunctionContext {
349 name: "simple".to_string(),
350 c_signature: "int simple()".to_string(),
351 ownership: HashMap::new(),
352 lifetimes: vec![],
353 lock_mappings: HashMap::new(),
354 }],
355 };
356 let prompt = CodegenPrompt::new("int simple() { return 0; }", ctx);
357 let rendered = prompt.render();
358 assert!(!rendered.contains("### Function: simple"));
359 }
360
361 #[test]
362 fn prompt_render_contains_task_section() {
363 let ctx = AnalysisContext { functions: vec![] };
364 let prompt = CodegenPrompt::new("int x;", ctx);
365 let rendered = prompt.render();
366 assert!(rendered.contains("## Task"));
367 assert!(rendered.contains("Generate idiomatic, safe Rust code"));
368 }
369
370 #[test]
375 fn llm_codegen_new() {
376 let codegen = LlmCodegen::new("test-model");
377 let debug = format!("{:?}", codegen);
378 assert!(debug.contains("test-model"));
379 }
380
381 #[test]
382 fn llm_codegen_default() {
383 let codegen = LlmCodegen::default();
384 let debug = format!("{:?}", codegen);
385 assert!(debug.contains("claude-3-sonnet"));
386 }
387
388 #[test]
389 fn llm_codegen_generate_returns_api_error() {
390 let codegen = LlmCodegen::new("gpt-4");
391 let ctx = AnalysisContext { functions: vec![] };
392 let prompt = CodegenPrompt::new("int x;", ctx);
393 let result = codegen.generate(&prompt);
394 assert!(result.is_err());
395 let err = result.unwrap_err();
396 assert!(matches!(err, LlmError::ApiError(_)));
397 assert!(err.to_string().contains("gpt-4"));
398 }
399
400 #[test]
405 fn parse_response_json_format() {
406 let codegen = LlmCodegen::new("test");
407 let json = r#"{"code": "fn main() {}", "confidence": 0.95, "reasoning": "simple", "warnings": []}"#;
408 let result = codegen.parse_response(json).unwrap();
409 assert_eq!(result.code, "fn main() {}");
410 assert!((result.confidence - 0.95).abs() < 0.01);
411 assert_eq!(result.reasoning, "simple");
412 }
413
414 #[test]
415 fn parse_response_markdown_rust_block() {
416 let codegen = LlmCodegen::new("test");
417 let response = "Here is the code:\n```rust\nfn add(a: i32, b: i32) -> i32 { a + b }\n```\nThis adds two numbers.";
418 let result = codegen.parse_response(response).unwrap();
419 assert!(result.code.contains("fn add"));
420 assert!((result.confidence - 0.8).abs() < 0.01);
421 assert!(result.reasoning.contains("adds two numbers"));
422 }
423
424 #[test]
425 fn parse_response_markdown_plain_block() {
426 let codegen = LlmCodegen::new("test");
427 let response = "Code:\n```\nlet x: i32 = 42;\n```\n";
428 let result = codegen.parse_response(response).unwrap();
429 assert!(result.code.contains("let x: i32 = 42"));
430 }
431
432 #[test]
433 fn parse_response_no_code_returns_error() {
434 let codegen = LlmCodegen::new("test");
435 let response = "I cannot generate code for this.";
436 let result = codegen.parse_response(response);
437 assert!(result.is_err());
438 assert!(matches!(result.unwrap_err(), LlmError::ParseError(_)));
439 }
440
441 #[test]
442 fn parse_response_empty_code_block_returns_error() {
443 let codegen = LlmCodegen::new("test");
444 let response = "```rust\n\n```";
445 let result = codegen.parse_response(response);
446 assert!(result.is_err());
447 }
448
449 #[test]
450 fn parse_response_no_reasoning_after_block() {
451 let codegen = LlmCodegen::new("test");
452 let response = "```rust\nfn main() {}\n```";
453 let result = codegen.parse_response(response).unwrap();
454 assert_eq!(result.reasoning, "Generated from C source");
455 }
456
457 #[test]
462 fn validate_code_balanced_with_fn() {
463 let codegen = LlmCodegen::new("test");
464 assert!(codegen.validate_code("fn main() { let x = 1; }").is_ok());
465 }
466
467 #[test]
468 fn validate_code_unbalanced_braces() {
469 let codegen = LlmCodegen::new("test");
470 let result = codegen.validate_code("fn main() {");
471 assert!(result.is_err());
472 let err = result.unwrap_err();
473 assert!(err.to_string().contains("braces"));
474 }
475
476 #[test]
477 fn validate_code_unbalanced_parens() {
478 let codegen = LlmCodegen::new("test");
479 let result = codegen.validate_code("fn main(");
480 assert!(result.is_err());
481 let err = result.unwrap_err();
482 assert!(err.to_string().contains("parentheses"));
483 }
484
485 #[test]
486 fn validate_code_empty() {
487 let codegen = LlmCodegen::new("test");
488 let result = codegen.validate_code("");
489 assert!(result.is_err());
490 let err = result.unwrap_err();
491 assert!(err.to_string().contains("Empty"));
492 }
493
494 #[test]
495 fn validate_code_whitespace_only() {
496 let codegen = LlmCodegen::new("test");
497 let result = codegen.validate_code(" \n \t ");
498 assert!(result.is_err());
499 }
500
501 #[test]
502 fn validate_code_expression_no_fn() {
503 let codegen = LlmCodegen::new("test");
504 assert!(codegen.validate_code("let x = 42;").is_ok());
506 }
507
508 #[test]
513 fn extract_reasoning_with_text_after_fence() {
514 let response = "```rust\nfn main() {}\n```\nThis is a simple main function.";
515 let reasoning = LlmCodegen::extract_reasoning(response);
516 assert!(reasoning.contains("simple main function"));
517 }
518
519 #[test]
520 fn extract_reasoning_no_text_after_fence() {
521 let response = "```rust\nfn main() {}\n```";
522 let reasoning = LlmCodegen::extract_reasoning(response);
523 assert_eq!(reasoning, "Generated from C source");
524 }
525
526 #[test]
527 fn extract_reasoning_no_fences() {
528 let response = "Just some text without code blocks.";
529 let reasoning = LlmCodegen::extract_reasoning(response);
530 assert_eq!(reasoning, "Generated from C source");
531 }
532
533 #[test]
538 fn llm_error_display_variants() {
539 let e1 = LlmError::PromptCreation("bad prompt".to_string());
540 assert!(e1.to_string().contains("bad prompt"));
541
542 let e2 = LlmError::ApiError("timeout".to_string());
543 assert!(e2.to_string().contains("timeout"));
544
545 let e3 = LlmError::ParseError("invalid json".to_string());
546 assert!(e3.to_string().contains("invalid json"));
547
548 let e4 = LlmError::InvalidCode("no braces".to_string());
549 assert!(e4.to_string().contains("no braces"));
550 }
551
552 #[test]
557 fn generated_code_serde_roundtrip() {
558 let code = GeneratedCode {
559 code: "fn main() {}".to_string(),
560 confidence: 0.9,
561 reasoning: "test".to_string(),
562 warnings: vec!["warn1".to_string()],
563 };
564 let json = serde_json::to_string(&code).unwrap();
565 let parsed: GeneratedCode = serde_json::from_str(&json).unwrap();
566 assert_eq!(parsed.code, "fn main() {}");
567 assert_eq!(parsed.warnings.len(), 1);
568 }
569
570 #[test]
571 fn generated_code_clone() {
572 let code = GeneratedCode {
573 code: "let x = 5;".to_string(),
574 confidence: 0.8,
575 reasoning: "simple".to_string(),
576 warnings: vec![],
577 };
578 let cloned = code.clone();
579 assert_eq!(code.code, cloned.code);
580 assert_eq!(code.confidence, cloned.confidence);
581 }
582}