1use crate::error::GenerationError;
8use crate::models::GeneratedFile;
9use crate::prompt_builder::GeneratedPrompt;
10use ricecoder_providers::models::{ChatRequest, Message};
11use ricecoder_providers::provider::Provider;
12use std::time::Duration;
13use tokio::time::sleep;
14
15#[derive(Debug, Clone)]
17pub struct CodeGeneratorConfig {
18 pub max_retries: usize,
20 pub initial_backoff: Duration,
22 pub max_backoff: Duration,
24 pub backoff_multiplier: f64,
26}
27
28impl Default for CodeGeneratorConfig {
29 fn default() -> Self {
30 Self {
31 max_retries: 3,
32 initial_backoff: Duration::from_millis(100),
33 max_backoff: Duration::from_secs(10),
34 backoff_multiplier: 2.0,
35 }
36 }
37}
38
39#[derive(Debug, Clone)]
41pub struct CodeGenerator {
42 config: CodeGeneratorConfig,
44}
45
46impl CodeGenerator {
47 pub fn new() -> Self {
49 Self {
50 config: CodeGeneratorConfig::default(),
51 }
52 }
53
54 pub fn with_config(config: CodeGeneratorConfig) -> Self {
56 Self { config }
57 }
58
59 pub async fn generate(
74 &self,
75 provider: &dyn Provider,
76 prompt: &GeneratedPrompt,
77 model: &str,
78 temperature: f32,
79 max_tokens: usize,
80 ) -> Result<Vec<GeneratedFile>, GenerationError> {
81 let mut backoff = self.config.initial_backoff;
82 let mut last_error = None;
83
84 for attempt in 0..=self.config.max_retries {
85 match self
86 .generate_internal(provider, prompt, model, temperature, max_tokens)
87 .await
88 {
89 Ok(files) => return Ok(files),
90 Err(e) => {
91 last_error = Some(e);
92
93 if attempt < self.config.max_retries {
95 sleep(backoff).await;
96 backoff = Duration::from_secs_f64(
97 (backoff.as_secs_f64() * self.config.backoff_multiplier)
98 .min(self.config.max_backoff.as_secs_f64()),
99 );
100 }
101 }
102 }
103 }
104
105 Err(last_error
106 .unwrap_or_else(|| GenerationError::GenerationFailed("Unknown error".to_string())))
107 }
108
109 async fn generate_internal(
111 &self,
112 provider: &dyn Provider,
113 prompt: &GeneratedPrompt,
114 model: &str,
115 temperature: f32,
116 max_tokens: usize,
117 ) -> Result<Vec<GeneratedFile>, GenerationError> {
118 let mut messages = vec![Message {
120 role: "system".to_string(),
121 content: prompt.system_prompt.clone(),
122 }];
123
124 messages.push(Message {
125 role: "user".to_string(),
126 content: prompt.user_prompt.clone(),
127 });
128
129 let request = ChatRequest {
130 model: model.to_string(),
131 messages,
132 temperature: Some(temperature),
133 max_tokens: Some(max_tokens),
134 stream: false,
135 };
136
137 let response = provider
139 .chat(request)
140 .await
141 .map_err(|e| GenerationError::GenerationFailed(e.to_string()))?;
142
143 self.parse_generated_code(&response.content)
145 }
146
147 fn parse_generated_code(&self, content: &str) -> Result<Vec<GeneratedFile>, GenerationError> {
154 let mut files = Vec::new();
155
156 if let Ok(parsed_files) = self.parse_markdown_blocks(content) {
158 if !parsed_files.is_empty() {
159 return Ok(parsed_files);
160 }
161 }
162
163 if let Ok(parsed_files) = self.parse_json_files(content) {
165 if !parsed_files.is_empty() {
166 return Ok(parsed_files);
167 }
168 }
169
170 files.push(GeneratedFile {
172 path: "generated.rs".to_string(),
173 content: content.to_string(),
174 language: "rust".to_string(),
175 });
176
177 Ok(files)
178 }
179
180 fn parse_markdown_blocks(&self, content: &str) -> Result<Vec<GeneratedFile>, GenerationError> {
182 let mut files = Vec::new();
183 let mut current_file: Option<GeneratedFile> = None;
184 let mut in_code_block = false;
185 let mut code_buffer = String::new();
186
187 for line in content.lines() {
188 if line.starts_with("```") {
189 if in_code_block {
190 if let Some(mut file) = current_file.take() {
192 file.content = code_buffer.trim().to_string();
193 files.push(file);
194 }
195 code_buffer.clear();
196 in_code_block = false;
197 } else {
198 let header = line.trim_start_matches("```").trim();
200
201 let parts: Vec<&str> = header.split_whitespace().collect();
203 if !parts.is_empty() {
204 let language = parts[0].to_string();
205
206 let file_path = if parts.len() > 1 && parts[1] == "file:" {
208 parts.get(2).map(|s| s.to_string())
209 } else {
210 None
211 };
212
213 current_file = Some(GeneratedFile {
214 path: file_path.unwrap_or_else(|| format!("generated.{}", language)),
215 content: String::new(),
216 language,
217 });
218 }
219
220 in_code_block = true;
221 }
222 } else if in_code_block {
223 code_buffer.push_str(line);
224 code_buffer.push('\n');
225 }
226 }
227
228 Ok(files)
229 }
230
231 fn parse_json_files(&self, content: &str) -> Result<Vec<GeneratedFile>, GenerationError> {
233 if let Ok(json) = serde_json::from_str::<serde_json::Value>(content) {
235 if let Some(files_array) = json.get("files").and_then(|v| v.as_array()) {
236 let mut files = Vec::new();
237
238 for file_obj in files_array {
239 if let (Some(path), Some(file_content), Some(language)) = (
240 file_obj.get("path").and_then(|v| v.as_str()),
241 file_obj.get("content").and_then(|v| v.as_str()),
242 file_obj.get("language").and_then(|v| v.as_str()),
243 ) {
244 files.push(GeneratedFile {
245 path: path.to_string(),
246 content: file_content.to_string(),
247 language: language.to_string(),
248 });
249 }
250 }
251
252 if !files.is_empty() {
253 return Ok(files);
254 }
255 }
256 }
257
258 Ok(Vec::new())
259 }
260
261 pub async fn generate_streaming(
276 &self,
277 provider: &dyn Provider,
278 prompt: &GeneratedPrompt,
279 model: &str,
280 temperature: f32,
281 max_tokens: usize,
282 ) -> Result<String, GenerationError> {
283 let mut messages = vec![Message {
285 role: "system".to_string(),
286 content: prompt.system_prompt.clone(),
287 }];
288
289 messages.push(Message {
290 role: "user".to_string(),
291 content: prompt.user_prompt.clone(),
292 });
293
294 let request = ChatRequest {
295 model: model.to_string(),
296 messages,
297 temperature: Some(temperature),
298 max_tokens: Some(max_tokens),
299 stream: true,
300 };
301
302 let mut stream = provider
304 .chat_stream(request)
305 .await
306 .map_err(|e| GenerationError::GenerationFailed(e.to_string()))?;
307
308 let mut full_content = String::new();
309
310 use futures::StreamExt;
312 while let Some(result) = stream.next().await {
313 match result {
314 Ok(response) => {
315 full_content.push_str(&response.content);
316 }
317 Err(e) => {
318 return Err(GenerationError::GenerationFailed(e.to_string()));
319 }
320 }
321 }
322
323 Ok(full_content)
324 }
325}
326
327impl Default for CodeGenerator {
328 fn default() -> Self {
329 Self::new()
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336
337 #[test]
338 fn test_parse_markdown_blocks_single_file() {
339 let generator = CodeGenerator::new();
340 let content = r#"```rust
341pub fn hello() {
342 println!("Hello, world!");
343}
344```"#;
345
346 let files = generator.parse_markdown_blocks(content).unwrap();
347 assert_eq!(files.len(), 1);
348 assert_eq!(files[0].language, "rust");
349 assert!(files[0].content.contains("hello"));
350 }
351
352 #[test]
353 fn test_parse_markdown_blocks_multiple_files() {
354 let generator = CodeGenerator::new();
355 let content = r#"```rust file: src/main.rs
356pub fn main() {}
357```
358
359```typescript file: src/index.ts
360export function main() {}
361```"#;
362
363 let files = generator.parse_markdown_blocks(content).unwrap();
364 assert_eq!(files.len(), 2);
365 assert_eq!(files[0].path, "src/main.rs");
366 assert_eq!(files[1].path, "src/index.ts");
367 }
368
369 #[test]
370 fn test_parse_json_files() {
371 let generator = CodeGenerator::new();
372 let content = r#"{
373 "files": [
374 {
375 "path": "src/main.rs",
376 "language": "rust",
377 "content": "pub fn main() {}"
378 }
379 ]
380}"#;
381
382 let files = generator.parse_json_files(content).unwrap();
383 assert_eq!(files.len(), 1);
384 assert_eq!(files[0].path, "src/main.rs");
385 }
386
387 #[test]
388 fn test_parse_generated_code_fallback() {
389 let generator = CodeGenerator::new();
390 let content = "pub fn hello() {}";
391
392 let files = generator.parse_generated_code(content).unwrap();
393 assert_eq!(files.len(), 1);
394 assert_eq!(files[0].path, "generated.rs");
395 }
396}