Skip to main content

decy_llm/
llm_codegen.rs

1//! LLM-guided Rust code generation (DECY-099).
2//!
3//! Uses LLM to generate idiomatic Rust code guided by static analysis results.
4
5use crate::context_builder::AnalysisContext;
6use serde::{Deserialize, Serialize};
7use thiserror::Error;
8
9/// Errors that can occur during LLM code generation.
10#[derive(Debug, Error)]
11pub enum LlmError {
12    /// Failed to create prompt from context
13    #[error("Failed to create prompt: {0}")]
14    PromptCreation(String),
15    /// LLM API error
16    #[error("LLM API error: {0}")]
17    ApiError(String),
18    /// Failed to parse LLM response
19    #[error("Failed to parse response: {0}")]
20    ParseError(String),
21    /// Generated code is invalid
22    #[error("Generated code is invalid: {0}")]
23    InvalidCode(String),
24}
25
26/// Result of LLM code generation.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct GeneratedCode {
29    /// Generated Rust code
30    pub code: String,
31    /// Confidence score (0.0-1.0)
32    pub confidence: f64,
33    /// Reasoning for the generated code
34    pub reasoning: String,
35    /// Any warnings or suggestions
36    pub warnings: Vec<String>,
37}
38
39/// Prompt template for LLM code generation.
40#[derive(Debug, Clone)]
41pub struct CodegenPrompt {
42    /// C source code
43    pub c_source: String,
44    /// Analysis context (ownership, lifetimes, locks)
45    pub context: AnalysisContext,
46    /// Additional instructions
47    pub instructions: String,
48}
49
50impl CodegenPrompt {
51    /// Create a new codegen prompt.
52    pub fn new(c_source: &str, context: AnalysisContext) -> Self {
53        Self { c_source: c_source.to_string(), context, instructions: String::new() }
54    }
55
56    /// Set additional instructions.
57    pub fn with_instructions(mut self, instructions: &str) -> Self {
58        self.instructions = instructions.to_string();
59        self
60    }
61
62    /// Render the prompt as a string for LLM input.
63    pub fn render(&self) -> String {
64        let mut prompt = String::new();
65
66        prompt.push_str("# C to Rust Transpilation Task\n\n");
67        prompt.push_str("## Source C Code\n```c\n");
68        prompt.push_str(&self.c_source);
69        prompt.push_str("\n```\n\n");
70
71        // Add analysis context
72        prompt.push_str("## Static Analysis Context\n");
73        if let Ok(context_json) = serde_json::to_string_pretty(&self.context) {
74            prompt.push_str("```json\n");
75            prompt.push_str(&context_json);
76            prompt.push_str("\n```\n\n");
77        }
78
79        // Add ownership information summary
80        for func in &self.context.functions {
81            if !func.ownership.is_empty() {
82                prompt.push_str(&format!("### Function: {}\n", func.name));
83                prompt.push_str("Ownership analysis:\n");
84                for (var, info) in &func.ownership {
85                    prompt.push_str(&format!(
86                        "- `{}`: {} (confidence: {:.0}%)\n",
87                        var,
88                        info.kind,
89                        info.confidence * 100.0
90                    ));
91                }
92                prompt.push('\n');
93            }
94        }
95
96        if !self.instructions.is_empty() {
97            prompt.push_str("## Additional Instructions\n");
98            prompt.push_str(&self.instructions);
99            prompt.push_str("\n\n");
100        }
101
102        prompt.push_str("## Task\n");
103        prompt.push_str("Generate idiomatic, safe Rust code that is functionally equivalent to the C code above.\n");
104        prompt.push_str(
105            "Use the static analysis context to guide ownership and borrowing decisions.\n",
106        );
107
108        prompt
109    }
110}
111
112/// LLM code generator.
113#[derive(Debug)]
114pub struct LlmCodegen {
115    /// Model identifier
116    model: String,
117}
118
119impl LlmCodegen {
120    /// Create a new LLM code generator.
121    pub fn new(model: &str) -> Self {
122        Self { model: model.to_string() }
123    }
124
125    /// Generate Rust code from C source with analysis context.
126    ///
127    /// Note: This is a stub for research purposes. Actual LLM integration
128    /// would require API credentials and network access.
129    pub fn generate(&self, _prompt: &CodegenPrompt) -> Result<GeneratedCode, LlmError> {
130        // In a real implementation, this would call the LLM API
131        Err(LlmError::ApiError(format!("LLM API not configured for model: {}", self.model)))
132    }
133
134    /// Parse raw LLM response into generated code.
135    ///
136    /// Supports two formats:
137    /// 1. Markdown code blocks with ```rust ... ```
138    /// 2. JSON with { "code": "...", "confidence": ..., ... }
139    pub fn parse_response(&self, response: &str) -> Result<GeneratedCode, LlmError> {
140        contract_pre_parse!();
141        // Try JSON format first
142        if let Ok(generated) = serde_json::from_str::<GeneratedCode>(response.trim()) {
143            return Ok(generated);
144        }
145
146        // Try to extract code from markdown code blocks
147        if let Some(code) = Self::extract_rust_code_block(response) {
148            // Extract reasoning from text after code block
149            let reasoning = Self::extract_reasoning(response);
150
151            return Ok(GeneratedCode {
152                code,
153                confidence: 0.8, // Default confidence for markdown format
154                reasoning,
155                warnings: Vec::new(),
156            });
157        }
158
159        Err(LlmError::ParseError("No valid Rust code found in response".to_string()))
160    }
161
162    /// Extract Rust code from markdown code block.
163    fn extract_rust_code_block(response: &str) -> Option<String> {
164        // Look for ```rust or just ```
165        let markers = ["```rust", "```"];
166
167        for marker in markers {
168            if let Some(start) = response.find(marker) {
169                let code_start = start + marker.len();
170                // Skip newline after marker
171                let code_start = response[code_start..]
172                    .find('\n')
173                    .map(|i| code_start + i + 1)
174                    .unwrap_or(code_start);
175
176                // Find closing ```
177                if let Some(end) = response[code_start..].find("```") {
178                    let code = response[code_start..code_start + end].trim();
179                    if !code.is_empty() {
180                        return Some(code.to_string());
181                    }
182                }
183            }
184        }
185
186        None
187    }
188
189    /// Extract reasoning text from response (text after code block).
190    fn extract_reasoning(response: &str) -> String {
191        // Find last ``` and get text after it
192        if let Some(last_fence) = response.rfind("```") {
193            let after = &response[last_fence + 3..];
194            let reasoning = after.trim();
195            if !reasoning.is_empty() {
196                return reasoning.to_string();
197            }
198        }
199        "Generated from C source".to_string()
200    }
201
202    /// Validate generated code (basic syntax check).
203    ///
204    /// Performs a basic syntactic validation:
205    /// - Checks for balanced braces
206    /// - Checks for fn keyword
207    /// - Checks for basic syntax patterns
208    pub fn validate_code(&self, code: &str) -> Result<(), LlmError> {
209        // Check for balanced braces
210        let open_braces = code.matches('{').count();
211        let close_braces = code.matches('}').count();
212
213        if open_braces != close_braces {
214            return Err(LlmError::InvalidCode(format!(
215                "Unbalanced braces: {} open, {} close",
216                open_braces, close_braces
217            )));
218        }
219
220        // Check for balanced parentheses
221        let open_parens = code.matches('(').count();
222        let close_parens = code.matches(')').count();
223
224        if open_parens != close_parens {
225            return Err(LlmError::InvalidCode(format!(
226                "Unbalanced parentheses: {} open, {} close",
227                open_parens, close_parens
228            )));
229        }
230
231        // Check for basic function structure
232        if code.contains("fn ") {
233            // Looks like it has a function - basic check passed
234            return Ok(());
235        }
236
237        // Allow simple expressions/statements too
238        if !code.trim().is_empty() {
239            return Ok(());
240        }
241
242        Err(LlmError::InvalidCode("Empty code".to_string()))
243    }
244}
245
246impl Default for LlmCodegen {
247    fn default() -> Self {
248        Self::new("claude-3-sonnet")
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn test_extract_code_block() {
258        let response = "Here's the code:\n```rust\nfn main() {}\n```\nDone!";
259        let code = LlmCodegen::extract_rust_code_block(response);
260        assert!(code.is_some());
261        assert!(code.unwrap().contains("fn main"));
262    }
263
264    // ========================================================================
265    // CodegenPrompt tests
266    // ========================================================================
267
268    #[test]
269    fn prompt_new_default_instructions_empty() {
270        let ctx = AnalysisContext { functions: vec![] };
271        let prompt = CodegenPrompt::new("int x = 5;", ctx);
272        assert_eq!(prompt.c_source, "int x = 5;");
273        assert!(prompt.instructions.is_empty());
274    }
275
276    #[test]
277    fn prompt_with_instructions() {
278        let ctx = AnalysisContext { functions: vec![] };
279        let prompt = CodegenPrompt::new("int x;", ctx).with_instructions("Use safe Rust only");
280        assert_eq!(prompt.instructions, "Use safe Rust only");
281    }
282
283    #[test]
284    fn prompt_render_contains_c_source() {
285        let ctx = AnalysisContext { functions: vec![] };
286        let prompt = CodegenPrompt::new("int main() { return 0; }", ctx);
287        let rendered = prompt.render();
288        assert!(rendered.contains("int main() { return 0; }"));
289        assert!(rendered.contains("# C to Rust Transpilation Task"));
290        assert!(rendered.contains("## Source C Code"));
291    }
292
293    #[test]
294    fn prompt_render_contains_instructions_when_set() {
295        let ctx = AnalysisContext { functions: vec![] };
296        let prompt =
297            CodegenPrompt::new("void f();", ctx).with_instructions("Prefer Box over raw ptrs");
298        let rendered = prompt.render();
299        assert!(rendered.contains("## Additional Instructions"));
300        assert!(rendered.contains("Prefer Box over raw ptrs"));
301    }
302
303    #[test]
304    fn prompt_render_no_instructions_section_when_empty() {
305        let ctx = AnalysisContext { functions: vec![] };
306        let prompt = CodegenPrompt::new("void f();", ctx);
307        let rendered = prompt.render();
308        assert!(!rendered.contains("## Additional Instructions"));
309    }
310
311    #[test]
312    fn prompt_render_includes_ownership_info() {
313        use crate::context_builder::{FunctionContext, OwnershipInfo};
314        use std::collections::HashMap;
315
316        let mut ownership = HashMap::new();
317        ownership.insert(
318            "ptr".to_string(),
319            OwnershipInfo {
320                kind: "owning".to_string(),
321                confidence: 0.95,
322                reason: "malloc detected".to_string(),
323            },
324        );
325
326        let ctx = AnalysisContext {
327            functions: vec![FunctionContext {
328                name: "alloc_data".to_string(),
329                c_signature: "void* alloc_data()".to_string(),
330                ownership,
331                lifetimes: vec![],
332                lock_mappings: HashMap::new(),
333            }],
334        };
335        let prompt = CodegenPrompt::new("void* alloc_data() { return malloc(8); }", ctx);
336        let rendered = prompt.render();
337        assert!(rendered.contains("### Function: alloc_data"));
338        assert!(rendered.contains("`ptr`: owning"));
339        assert!(rendered.contains("95%"));
340    }
341
342    #[test]
343    fn prompt_render_skips_functions_with_no_ownership() {
344        use crate::context_builder::FunctionContext;
345        use std::collections::HashMap;
346
347        let ctx = AnalysisContext {
348            functions: vec![FunctionContext {
349                name: "simple".to_string(),
350                c_signature: "int simple()".to_string(),
351                ownership: HashMap::new(),
352                lifetimes: vec![],
353                lock_mappings: HashMap::new(),
354            }],
355        };
356        let prompt = CodegenPrompt::new("int simple() { return 0; }", ctx);
357        let rendered = prompt.render();
358        assert!(!rendered.contains("### Function: simple"));
359    }
360
361    #[test]
362    fn prompt_render_contains_task_section() {
363        let ctx = AnalysisContext { functions: vec![] };
364        let prompt = CodegenPrompt::new("int x;", ctx);
365        let rendered = prompt.render();
366        assert!(rendered.contains("## Task"));
367        assert!(rendered.contains("Generate idiomatic, safe Rust code"));
368    }
369
370    // ========================================================================
371    // LlmCodegen tests
372    // ========================================================================
373
374    #[test]
375    fn llm_codegen_new() {
376        let codegen = LlmCodegen::new("test-model");
377        let debug = format!("{:?}", codegen);
378        assert!(debug.contains("test-model"));
379    }
380
381    #[test]
382    fn llm_codegen_default() {
383        let codegen = LlmCodegen::default();
384        let debug = format!("{:?}", codegen);
385        assert!(debug.contains("claude-3-sonnet"));
386    }
387
388    #[test]
389    fn llm_codegen_generate_returns_api_error() {
390        let codegen = LlmCodegen::new("gpt-4");
391        let ctx = AnalysisContext { functions: vec![] };
392        let prompt = CodegenPrompt::new("int x;", ctx);
393        let result = codegen.generate(&prompt);
394        assert!(result.is_err());
395        let err = result.unwrap_err();
396        assert!(matches!(err, LlmError::ApiError(_)));
397        assert!(err.to_string().contains("gpt-4"));
398    }
399
400    // ========================================================================
401    // parse_response tests
402    // ========================================================================
403
404    #[test]
405    fn parse_response_json_format() {
406        let codegen = LlmCodegen::new("test");
407        let json = r#"{"code": "fn main() {}", "confidence": 0.95, "reasoning": "simple", "warnings": []}"#;
408        let result = codegen.parse_response(json).unwrap();
409        assert_eq!(result.code, "fn main() {}");
410        assert!((result.confidence - 0.95).abs() < 0.01);
411        assert_eq!(result.reasoning, "simple");
412    }
413
414    #[test]
415    fn parse_response_markdown_rust_block() {
416        let codegen = LlmCodegen::new("test");
417        let response = "Here is the code:\n```rust\nfn add(a: i32, b: i32) -> i32 { a + b }\n```\nThis adds two numbers.";
418        let result = codegen.parse_response(response).unwrap();
419        assert!(result.code.contains("fn add"));
420        assert!((result.confidence - 0.8).abs() < 0.01);
421        assert!(result.reasoning.contains("adds two numbers"));
422    }
423
424    #[test]
425    fn parse_response_markdown_plain_block() {
426        let codegen = LlmCodegen::new("test");
427        let response = "Code:\n```\nlet x: i32 = 42;\n```\n";
428        let result = codegen.parse_response(response).unwrap();
429        assert!(result.code.contains("let x: i32 = 42"));
430    }
431
432    #[test]
433    fn parse_response_no_code_returns_error() {
434        let codegen = LlmCodegen::new("test");
435        let response = "I cannot generate code for this.";
436        let result = codegen.parse_response(response);
437        assert!(result.is_err());
438        assert!(matches!(result.unwrap_err(), LlmError::ParseError(_)));
439    }
440
441    #[test]
442    fn parse_response_empty_code_block_returns_error() {
443        let codegen = LlmCodegen::new("test");
444        let response = "```rust\n\n```";
445        let result = codegen.parse_response(response);
446        assert!(result.is_err());
447    }
448
449    #[test]
450    fn parse_response_no_reasoning_after_block() {
451        let codegen = LlmCodegen::new("test");
452        let response = "```rust\nfn main() {}\n```";
453        let result = codegen.parse_response(response).unwrap();
454        assert_eq!(result.reasoning, "Generated from C source");
455    }
456
457    // ========================================================================
458    // validate_code tests
459    // ========================================================================
460
461    #[test]
462    fn validate_code_balanced_with_fn() {
463        let codegen = LlmCodegen::new("test");
464        assert!(codegen.validate_code("fn main() { let x = 1; }").is_ok());
465    }
466
467    #[test]
468    fn validate_code_unbalanced_braces() {
469        let codegen = LlmCodegen::new("test");
470        let result = codegen.validate_code("fn main() {");
471        assert!(result.is_err());
472        let err = result.unwrap_err();
473        assert!(err.to_string().contains("braces"));
474    }
475
476    #[test]
477    fn validate_code_unbalanced_parens() {
478        let codegen = LlmCodegen::new("test");
479        let result = codegen.validate_code("fn main(");
480        assert!(result.is_err());
481        let err = result.unwrap_err();
482        assert!(err.to_string().contains("parentheses"));
483    }
484
485    #[test]
486    fn validate_code_empty() {
487        let codegen = LlmCodegen::new("test");
488        let result = codegen.validate_code("");
489        assert!(result.is_err());
490        let err = result.unwrap_err();
491        assert!(err.to_string().contains("Empty"));
492    }
493
494    #[test]
495    fn validate_code_whitespace_only() {
496        let codegen = LlmCodegen::new("test");
497        let result = codegen.validate_code("   \n  \t  ");
498        assert!(result.is_err());
499    }
500
501    #[test]
502    fn validate_code_expression_no_fn() {
503        let codegen = LlmCodegen::new("test");
504        // Non-empty, balanced, but no fn keyword — still passes (allows expressions)
505        assert!(codegen.validate_code("let x = 42;").is_ok());
506    }
507
508    // ========================================================================
509    // extract_reasoning tests
510    // ========================================================================
511
512    #[test]
513    fn extract_reasoning_with_text_after_fence() {
514        let response = "```rust\nfn main() {}\n```\nThis is a simple main function.";
515        let reasoning = LlmCodegen::extract_reasoning(response);
516        assert!(reasoning.contains("simple main function"));
517    }
518
519    #[test]
520    fn extract_reasoning_no_text_after_fence() {
521        let response = "```rust\nfn main() {}\n```";
522        let reasoning = LlmCodegen::extract_reasoning(response);
523        assert_eq!(reasoning, "Generated from C source");
524    }
525
526    #[test]
527    fn extract_reasoning_no_fences() {
528        let response = "Just some text without code blocks.";
529        let reasoning = LlmCodegen::extract_reasoning(response);
530        assert_eq!(reasoning, "Generated from C source");
531    }
532
533    // ========================================================================
534    // LlmError Display tests
535    // ========================================================================
536
537    #[test]
538    fn llm_error_display_variants() {
539        let e1 = LlmError::PromptCreation("bad prompt".to_string());
540        assert!(e1.to_string().contains("bad prompt"));
541
542        let e2 = LlmError::ApiError("timeout".to_string());
543        assert!(e2.to_string().contains("timeout"));
544
545        let e3 = LlmError::ParseError("invalid json".to_string());
546        assert!(e3.to_string().contains("invalid json"));
547
548        let e4 = LlmError::InvalidCode("no braces".to_string());
549        assert!(e4.to_string().contains("no braces"));
550    }
551
552    // ========================================================================
553    // GeneratedCode serde tests
554    // ========================================================================
555
556    #[test]
557    fn generated_code_serde_roundtrip() {
558        let code = GeneratedCode {
559            code: "fn main() {}".to_string(),
560            confidence: 0.9,
561            reasoning: "test".to_string(),
562            warnings: vec!["warn1".to_string()],
563        };
564        let json = serde_json::to_string(&code).unwrap();
565        let parsed: GeneratedCode = serde_json::from_str(&json).unwrap();
566        assert_eq!(parsed.code, "fn main() {}");
567        assert_eq!(parsed.warnings.len(), 1);
568    }
569
570    #[test]
571    fn generated_code_clone() {
572        let code = GeneratedCode {
573            code: "let x = 5;".to_string(),
574            confidence: 0.8,
575            reasoning: "simple".to_string(),
576            warnings: vec![],
577        };
578        let cloned = code.clone();
579        assert_eq!(code.code, cloned.code);
580        assert_eq!(code.confidence, cloned.confidence);
581    }
582}