ricecoder_generation/
code_generator.rs

1//! AI-based code generation with streaming support
2//!
3//! Provides CodeGenerator for calling AI providers with built prompts,
4//! handling streaming responses, parsing generated code, and extracting
5//! multiple files from single responses.
6
7use crate::error::GenerationError;
8use crate::models::GeneratedFile;
9use crate::prompt_builder::GeneratedPrompt;
10use ricecoder_providers::models::{ChatRequest, Message};
11use ricecoder_providers::provider::Provider;
12use std::time::Duration;
13use tokio::time::sleep;
14
15/// Configuration for code generation
16#[derive(Debug, Clone)]
17pub struct CodeGeneratorConfig {
18    /// Maximum number of retries on transient failures
19    pub max_retries: usize,
20    /// Initial backoff duration for retries
21    pub initial_backoff: Duration,
22    /// Maximum backoff duration for retries
23    pub max_backoff: Duration,
24    /// Backoff multiplier for exponential backoff
25    pub backoff_multiplier: f64,
26}
27
28impl Default for CodeGeneratorConfig {
29    fn default() -> Self {
30        Self {
31            max_retries: 3,
32            initial_backoff: Duration::from_millis(100),
33            max_backoff: Duration::from_secs(10),
34            backoff_multiplier: 2.0,
35        }
36    }
37}
38
39/// Generates code using AI providers
40#[derive(Debug, Clone)]
41pub struct CodeGenerator {
42    /// Configuration for code generation
43    config: CodeGeneratorConfig,
44}
45
46impl CodeGenerator {
47    /// Creates a new CodeGenerator with default configuration
48    pub fn new() -> Self {
49        Self {
50            config: CodeGeneratorConfig::default(),
51        }
52    }
53
54    /// Creates a new CodeGenerator with custom configuration
55    pub fn with_config(config: CodeGeneratorConfig) -> Self {
56        Self { config }
57    }
58
59    /// Generates code from a prompt using the provided provider
60    ///
61    /// # Arguments
62    /// * `provider` - The AI provider to use for generation
63    /// * `prompt` - The generated prompt containing system and user messages
64    /// * `model` - The model to use for generation
65    /// * `temperature` - Temperature for sampling (0.0 to 2.0)
66    /// * `max_tokens` - Maximum tokens to generate
67    ///
68    /// # Returns
69    /// A vector of generated files
70    ///
71    /// # Errors
72    /// Returns `GenerationError` if generation fails after all retries
73    pub async fn generate(
74        &self,
75        provider: &dyn Provider,
76        prompt: &GeneratedPrompt,
77        model: &str,
78        temperature: f32,
79        max_tokens: usize,
80    ) -> Result<Vec<GeneratedFile>, GenerationError> {
81        let mut backoff = self.config.initial_backoff;
82        let mut last_error = None;
83
84        for attempt in 0..=self.config.max_retries {
85            match self
86                .generate_internal(provider, prompt, model, temperature, max_tokens)
87                .await
88            {
89                Ok(files) => return Ok(files),
90                Err(e) => {
91                    last_error = Some(e);
92
93                    // Don't retry on the last attempt
94                    if attempt < self.config.max_retries {
95                        sleep(backoff).await;
96                        backoff = Duration::from_secs_f64(
97                            (backoff.as_secs_f64() * self.config.backoff_multiplier)
98                                .min(self.config.max_backoff.as_secs_f64()),
99                        );
100                    }
101                }
102            }
103        }
104
105        Err(last_error
106            .unwrap_or_else(|| GenerationError::GenerationFailed("Unknown error".to_string())))
107    }
108
109    /// Internal implementation of code generation
110    async fn generate_internal(
111        &self,
112        provider: &dyn Provider,
113        prompt: &GeneratedPrompt,
114        model: &str,
115        temperature: f32,
116        max_tokens: usize,
117    ) -> Result<Vec<GeneratedFile>, GenerationError> {
118        // Build the chat request
119        let mut messages = vec![Message {
120            role: "system".to_string(),
121            content: prompt.system_prompt.clone(),
122        }];
123
124        messages.push(Message {
125            role: "user".to_string(),
126            content: prompt.user_prompt.clone(),
127        });
128
129        let request = ChatRequest {
130            model: model.to_string(),
131            messages,
132            temperature: Some(temperature),
133            max_tokens: Some(max_tokens),
134            stream: false,
135        };
136
137        // Call the provider
138        let response = provider
139            .chat(request)
140            .await
141            .map_err(|e| GenerationError::GenerationFailed(e.to_string()))?;
142
143        // Parse the response into files
144        self.parse_generated_code(&response.content)
145    }
146
147    /// Parses generated code into individual files
148    ///
149    /// Supports multiple file formats:
150    /// - Markdown code blocks with file paths: ```rust\n// file: src/main.rs\n...```
151    /// - JSON format with file list
152    /// - Plain code (single file)
153    fn parse_generated_code(&self, content: &str) -> Result<Vec<GeneratedFile>, GenerationError> {
154        let mut files = Vec::new();
155
156        // Try to parse as markdown code blocks first
157        if let Ok(parsed_files) = self.parse_markdown_blocks(content) {
158            if !parsed_files.is_empty() {
159                return Ok(parsed_files);
160            }
161        }
162
163        // Try to parse as JSON
164        if let Ok(parsed_files) = self.parse_json_files(content) {
165            if !parsed_files.is_empty() {
166                return Ok(parsed_files);
167            }
168        }
169
170        // Fall back to treating entire content as a single file
171        files.push(GeneratedFile {
172            path: "generated.rs".to_string(),
173            content: content.to_string(),
174            language: "rust".to_string(),
175        });
176
177        Ok(files)
178    }
179
180    /// Parses markdown code blocks with file paths
181    fn parse_markdown_blocks(&self, content: &str) -> Result<Vec<GeneratedFile>, GenerationError> {
182        let mut files = Vec::new();
183        let mut current_file: Option<GeneratedFile> = None;
184        let mut in_code_block = false;
185        let mut code_buffer = String::new();
186
187        for line in content.lines() {
188            if line.starts_with("```") {
189                if in_code_block {
190                    // End of code block
191                    if let Some(mut file) = current_file.take() {
192                        file.content = code_buffer.trim().to_string();
193                        files.push(file);
194                    }
195                    code_buffer.clear();
196                    in_code_block = false;
197                } else {
198                    // Start of code block
199                    let header = line.trim_start_matches("```").trim();
200
201                    // Extract language and optional file path
202                    let parts: Vec<&str> = header.split_whitespace().collect();
203                    if !parts.is_empty() {
204                        let language = parts[0].to_string();
205
206                        // Check for file path in comment
207                        let file_path = if parts.len() > 1 && parts[1] == "file:" {
208                            parts.get(2).map(|s| s.to_string())
209                        } else {
210                            None
211                        };
212
213                        current_file = Some(GeneratedFile {
214                            path: file_path.unwrap_or_else(|| format!("generated.{}", language)),
215                            content: String::new(),
216                            language,
217                        });
218                    }
219
220                    in_code_block = true;
221                }
222            } else if in_code_block {
223                code_buffer.push_str(line);
224                code_buffer.push('\n');
225            }
226        }
227
228        Ok(files)
229    }
230
231    /// Parses JSON format with file list
232    fn parse_json_files(&self, content: &str) -> Result<Vec<GeneratedFile>, GenerationError> {
233        // Try to find JSON structure in the content
234        if let Ok(json) = serde_json::from_str::<serde_json::Value>(content) {
235            if let Some(files_array) = json.get("files").and_then(|v| v.as_array()) {
236                let mut files = Vec::new();
237
238                for file_obj in files_array {
239                    if let (Some(path), Some(file_content), Some(language)) = (
240                        file_obj.get("path").and_then(|v| v.as_str()),
241                        file_obj.get("content").and_then(|v| v.as_str()),
242                        file_obj.get("language").and_then(|v| v.as_str()),
243                    ) {
244                        files.push(GeneratedFile {
245                            path: path.to_string(),
246                            content: file_content.to_string(),
247                            language: language.to_string(),
248                        });
249                    }
250                }
251
252                if !files.is_empty() {
253                    return Ok(files);
254                }
255            }
256        }
257
258        Ok(Vec::new())
259    }
260
261    /// Generates code with streaming support
262    ///
263    /// # Arguments
264    /// * `provider` - The AI provider to use for generation
265    /// * `prompt` - The generated prompt containing system and user messages
266    /// * `model` - The model to use for generation
267    /// * `temperature` - Temperature for sampling (0.0 to 2.0)
268    /// * `max_tokens` - Maximum tokens to generate
269    ///
270    /// # Returns
271    /// A stream of generated content chunks
272    ///
273    /// # Errors
274    /// Returns `GenerationError` if streaming fails
275    pub async fn generate_streaming(
276        &self,
277        provider: &dyn Provider,
278        prompt: &GeneratedPrompt,
279        model: &str,
280        temperature: f32,
281        max_tokens: usize,
282    ) -> Result<String, GenerationError> {
283        // Build the chat request
284        let mut messages = vec![Message {
285            role: "system".to_string(),
286            content: prompt.system_prompt.clone(),
287        }];
288
289        messages.push(Message {
290            role: "user".to_string(),
291            content: prompt.user_prompt.clone(),
292        });
293
294        let request = ChatRequest {
295            model: model.to_string(),
296            messages,
297            temperature: Some(temperature),
298            max_tokens: Some(max_tokens),
299            stream: true,
300        };
301
302        // Call the provider with streaming
303        let mut stream = provider
304            .chat_stream(request)
305            .await
306            .map_err(|e| GenerationError::GenerationFailed(e.to_string()))?;
307
308        let mut full_content = String::new();
309
310        // Collect all streamed responses
311        use futures::StreamExt;
312        while let Some(result) = stream.next().await {
313            match result {
314                Ok(response) => {
315                    full_content.push_str(&response.content);
316                }
317                Err(e) => {
318                    return Err(GenerationError::GenerationFailed(e.to_string()));
319                }
320            }
321        }
322
323        Ok(full_content)
324    }
325}
326
327impl Default for CodeGenerator {
328    fn default() -> Self {
329        Self::new()
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    #[test]
338    fn test_parse_markdown_blocks_single_file() {
339        let generator = CodeGenerator::new();
340        let content = r#"```rust
341pub fn hello() {
342    println!("Hello, world!");
343}
344```"#;
345
346        let files = generator.parse_markdown_blocks(content).unwrap();
347        assert_eq!(files.len(), 1);
348        assert_eq!(files[0].language, "rust");
349        assert!(files[0].content.contains("hello"));
350    }
351
352    #[test]
353    fn test_parse_markdown_blocks_multiple_files() {
354        let generator = CodeGenerator::new();
355        let content = r#"```rust file: src/main.rs
356pub fn main() {}
357```
358
359```typescript file: src/index.ts
360export function main() {}
361```"#;
362
363        let files = generator.parse_markdown_blocks(content).unwrap();
364        assert_eq!(files.len(), 2);
365        assert_eq!(files[0].path, "src/main.rs");
366        assert_eq!(files[1].path, "src/index.ts");
367    }
368
369    #[test]
370    fn test_parse_json_files() {
371        let generator = CodeGenerator::new();
372        let content = r#"{
373  "files": [
374    {
375      "path": "src/main.rs",
376      "language": "rust",
377      "content": "pub fn main() {}"
378    }
379  ]
380}"#;
381
382        let files = generator.parse_json_files(content).unwrap();
383        assert_eq!(files.len(), 1);
384        assert_eq!(files[0].path, "src/main.rs");
385    }
386
387    #[test]
388    fn test_parse_generated_code_fallback() {
389        let generator = CodeGenerator::new();
390        let content = "pub fn hello() {}";
391
392        let files = generator.parse_generated_code(content).unwrap();
393        assert_eq!(files.len(), 1);
394        assert_eq!(files[0].path, "generated.rs");
395    }
396}