use crate::context_builder::AnalysisContext;
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum LlmError {
#[error("Failed to create prompt: {0}")]
PromptCreation(String),
#[error("LLM API error: {0}")]
ApiError(String),
#[error("Failed to parse response: {0}")]
ParseError(String),
#[error("Generated code is invalid: {0}")]
InvalidCode(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeneratedCode {
pub code: String,
pub confidence: f64,
pub reasoning: String,
pub warnings: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct CodegenPrompt {
pub c_source: String,
pub context: AnalysisContext,
pub instructions: String,
}
impl CodegenPrompt {
pub fn new(c_source: &str, context: AnalysisContext) -> Self {
Self {
c_source: c_source.to_string(),
context,
instructions: String::new(),
}
}
pub fn with_instructions(mut self, instructions: &str) -> Self {
self.instructions = instructions.to_string();
self
}
pub fn render(&self) -> String {
let mut prompt = String::new();
prompt.push_str("# C to Rust Transpilation Task\n\n");
prompt.push_str("## Source C Code\n```c\n");
prompt.push_str(&self.c_source);
prompt.push_str("\n```\n\n");
prompt.push_str("## Static Analysis Context\n");
if let Ok(context_json) = serde_json::to_string_pretty(&self.context) {
prompt.push_str("```json\n");
prompt.push_str(&context_json);
prompt.push_str("\n```\n\n");
}
for func in &self.context.functions {
if !func.ownership.is_empty() {
prompt.push_str(&format!("### Function: {}\n", func.name));
prompt.push_str("Ownership analysis:\n");
for (var, info) in &func.ownership {
prompt.push_str(&format!(
"- `{}`: {} (confidence: {:.0}%)\n",
var,
info.kind,
info.confidence * 100.0
));
}
prompt.push('\n');
}
}
if !self.instructions.is_empty() {
prompt.push_str("## Additional Instructions\n");
prompt.push_str(&self.instructions);
prompt.push_str("\n\n");
}
prompt.push_str("## Task\n");
prompt.push_str("Generate idiomatic, safe Rust code that is functionally equivalent to the C code above.\n");
prompt.push_str(
"Use the static analysis context to guide ownership and borrowing decisions.\n",
);
prompt
}
}
#[derive(Debug)]
pub struct LlmCodegen {
model: String,
}
impl LlmCodegen {
pub fn new(model: &str) -> Self {
Self {
model: model.to_string(),
}
}
pub fn generate(&self, _prompt: &CodegenPrompt) -> Result<GeneratedCode, LlmError> {
Err(LlmError::ApiError(format!(
"LLM API not configured for model: {}",
self.model
)))
}
pub fn parse_response(&self, response: &str) -> Result<GeneratedCode, LlmError> {
if let Ok(generated) = serde_json::from_str::<GeneratedCode>(response.trim()) {
return Ok(generated);
}
if let Some(code) = Self::extract_rust_code_block(response) {
let reasoning = Self::extract_reasoning(response);
return Ok(GeneratedCode {
code,
confidence: 0.8, reasoning,
warnings: Vec::new(),
});
}
Err(LlmError::ParseError(
"No valid Rust code found in response".to_string(),
))
}
fn extract_rust_code_block(response: &str) -> Option<String> {
let markers = ["```rust", "```"];
for marker in markers {
if let Some(start) = response.find(marker) {
let code_start = start + marker.len();
let code_start = response[code_start..]
.find('\n')
.map(|i| code_start + i + 1)
.unwrap_or(code_start);
if let Some(end) = response[code_start..].find("```") {
let code = response[code_start..code_start + end].trim();
if !code.is_empty() {
return Some(code.to_string());
}
}
}
}
None
}
fn extract_reasoning(response: &str) -> String {
if let Some(last_fence) = response.rfind("```") {
let after = &response[last_fence + 3..];
let reasoning = after.trim();
if !reasoning.is_empty() {
return reasoning.to_string();
}
}
"Generated from C source".to_string()
}
pub fn validate_code(&self, code: &str) -> Result<(), LlmError> {
let open_braces = code.matches('{').count();
let close_braces = code.matches('}').count();
if open_braces != close_braces {
return Err(LlmError::InvalidCode(format!(
"Unbalanced braces: {} open, {} close",
open_braces, close_braces
)));
}
let open_parens = code.matches('(').count();
let close_parens = code.matches(')').count();
if open_parens != close_parens {
return Err(LlmError::InvalidCode(format!(
"Unbalanced parentheses: {} open, {} close",
open_parens, close_parens
)));
}
if code.contains("fn ") {
return Ok(());
}
if !code.trim().is_empty() {
return Ok(());
}
Err(LlmError::InvalidCode("Empty code".to_string()))
}
}
impl Default for LlmCodegen {
fn default() -> Self {
Self::new("claude-3-sonnet")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_code_block() {
let response = "Here's the code:\n```rust\nfn main() {}\n```\nDone!";
let code = LlmCodegen::extract_rust_code_block(response);
assert!(code.is_some());
assert!(code.unwrap().contains("fn main"));
}
}