1#![deny(unsafe_code)]
13
14use crate::prompts::template_context::TemplateContext;
15use crate::prompts::template_engine::Template;
16use std::collections::HashMap;
17use std::fmt::Write;
18use std::fs;
19use std::path::Path;
20
21#[derive(Debug, Clone)]
23pub struct FileConflict {
24 pub conflict_content: String,
26 pub current_content: String,
28}
29
30#[cfg(test)]
46pub fn build_conflict_resolution_prompt(
47 conflicts: &HashMap<String, FileConflict>,
48 prompt_md_content: Option<&str>,
49 plan_content: Option<&str>,
50) -> String {
51 let template_content = include_str!("templates/conflict_resolution.txt");
52 let template = Template::new(template_content);
53
54 let context = format_context_section(prompt_md_content, plan_content);
55 let conflicts_section = format_conflicts_section(conflicts);
56
57 let variables = HashMap::from([
58 ("CONTEXT", context),
59 ("CONFLICTS", conflicts_section.clone()),
60 ]);
61
62 template.render(&variables).unwrap_or_else(|e| {
63 eprintln!("Warning: Failed to render conflict resolution template: {e}");
64 let fallback_template_content = include_str!("templates/conflict_resolution_fallback.txt");
66 let fallback_template = Template::new(fallback_template_content);
67 fallback_template.render(&variables).unwrap_or_else(|e| {
68 eprintln!("Critical: Failed to render fallback template: {e}");
69 format!(
71 "# MERGE CONFLICT RESOLUTION\n\nResolve these conflicts:\n\n{}",
72 &conflicts_section
73 )
74 })
75 })
76}
77
78pub fn build_conflict_resolution_prompt_with_context(
90 context: &TemplateContext,
91 conflicts: &HashMap<String, FileConflict>,
92 prompt_md_content: Option<&str>,
93 plan_content: Option<&str>,
94) -> String {
95 let template_content = context
96 .registry()
97 .get_template("conflict_resolution")
98 .unwrap_or_else(|_| include_str!("templates/conflict_resolution.txt").to_string());
99 let template = Template::new(&template_content);
100
101 let ctx_section = format_context_section(prompt_md_content, plan_content);
102 let conflicts_section = format_conflicts_section(conflicts);
103
104 let variables = HashMap::from([
105 ("CONTEXT", ctx_section),
106 ("CONFLICTS", conflicts_section.clone()),
107 ]);
108
109 template.render(&variables).unwrap_or_else(|e| {
110 eprintln!("Warning: Failed to render conflict resolution template: {e}");
111 let fallback_template_content = context
113 .registry()
114 .get_template("conflict_resolution_fallback")
115 .unwrap_or_else(|_| {
116 include_str!("templates/conflict_resolution_fallback.txt").to_string()
117 });
118 let fallback_template = Template::new(&fallback_template_content);
119 fallback_template.render(&variables).unwrap_or_else(|e| {
120 eprintln!("Critical: Failed to render fallback template: {e}");
121 format!(
123 "# MERGE CONFLICT RESOLUTION\n\nResolve these conflicts:\n\n{}",
124 &conflicts_section
125 )
126 })
127 })
128}
129
130fn format_context_section(prompt_md_content: Option<&str>, plan_content: Option<&str>) -> String {
135 let mut context = String::new();
136
137 if let Some(prompt_md) = prompt_md_content {
139 context.push_str("## Task Context\n\n");
140 context.push_str("The user was working on the following task:\n\n");
141 context.push_str("```\n");
142 context.push_str(prompt_md);
143 context.push_str("\n```\n\n");
144 }
145
146 if let Some(plan) = plan_content {
148 context.push_str("## Implementation Plan\n\n");
149 context.push_str("The following plan was being implemented:\n\n");
150 context.push_str("```\n");
151 context.push_str(plan);
152 context.push_str("\n```\n\n");
153 }
154
155 context
156}
157
158fn format_conflicts_section(conflicts: &HashMap<String, FileConflict>) -> String {
163 let mut section = String::new();
164
165 for (path, conflict) in conflicts {
166 writeln!(section, "### {path}\n\n").unwrap();
167 section.push_str("Current state (with conflict markers):\n\n");
168 section.push_str("```");
169 section.push_str(&get_language_marker(path));
170 section.push('\n');
171 section.push_str(&conflict.current_content);
172 section.push_str("\n```\n\n");
173
174 if !conflict.conflict_content.is_empty() {
175 section.push_str("Conflict sections:\n\n");
176 section.push_str("```\n");
177 section.push_str(&conflict.conflict_content);
178 section.push_str("\n```\n\n");
179 }
180 }
181
182 section
183}
184
185fn get_language_marker(path: &str) -> String {
187 let ext = Path::new(path)
188 .extension()
189 .and_then(|e| e.to_str())
190 .unwrap_or("");
191
192 match ext {
193 "rs" => "rust",
194 "py" => "python",
195 "js" | "jsx" => "javascript",
196 "ts" | "tsx" => "typescript",
197 "go" => "go",
198 "java" => "java",
199 "c" | "h" => "c",
200 "cpp" | "hpp" | "cc" | "cxx" => "cpp",
201 "cs" => "csharp",
202 "php" => "php",
203 "rb" => "ruby",
204 "swift" => "swift",
205 "kt" => "kotlin",
206 "scala" => "scala",
207 "sh" | "bash" | "zsh" => "bash",
208 "fish" => "fish",
209 "yaml" | "yml" => "yaml",
210 "json" => "json",
211 "toml" => "toml",
212 "md" | "markdown" => "markdown",
213 "txt" => "text",
214 "html" => "html",
215 "css" | "scss" | "less" => "css",
216 "xml" => "xml",
217 "sql" => "sql",
218 _ => "",
219 }
220 .to_string()
221}
222
223pub fn collect_conflict_info(
237 conflicted_paths: &[String],
238) -> std::io::Result<HashMap<String, FileConflict>> {
239 let mut conflicts = HashMap::new();
240
241 for path in conflicted_paths {
242 let current_content = fs::read_to_string(path)?;
244
245 let conflict_content = crate::git_helpers::get_conflict_markers_for_file(Path::new(path))?;
247
248 conflicts.insert(
249 path.clone(),
250 FileConflict {
251 conflict_content,
252 current_content,
253 },
254 );
255 }
256
257 Ok(conflicts)
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263
264 #[test]
265 fn test_build_conflict_resolution_prompt_no_mentions_rebase() {
266 let conflicts = HashMap::new();
267 let prompt = build_conflict_resolution_prompt(&conflicts, None, None);
268
269 assert!(!prompt.to_lowercase().contains("rebase"));
271 assert!(!prompt.to_lowercase().contains("rebasing"));
272
273 assert!(prompt.to_lowercase().contains("merge conflict"));
275 }
276
277 #[test]
278 fn test_build_conflict_resolution_prompt_with_context() {
279 let mut conflicts = HashMap::new();
280 conflicts.insert(
281 "test.rs".to_string(),
282 FileConflict {
283 conflict_content: "<<<<<<< ours\nfn foo() {}\n=======\nfn bar() {}\n>>>>>>> theirs"
284 .to_string(),
285 current_content: "<<<<<<< ours\nfn foo() {}\n=======\nfn bar() {}\n>>>>>>> theirs"
286 .to_string(),
287 },
288 );
289
290 let prompt_md = "Add a new feature";
291 let plan = "1. Create foo function\n2. Create bar function";
292
293 let prompt = build_conflict_resolution_prompt(&conflicts, Some(prompt_md), Some(plan));
294
295 assert!(prompt.contains("Add a new feature"));
297
298 assert!(prompt.contains("Create foo function"));
300 assert!(prompt.contains("Create bar function"));
301
302 assert!(prompt.contains("test.rs"));
304
305 assert!(!prompt.to_lowercase().contains("rebase"));
307 }
308
309 #[test]
310 fn test_get_language_marker() {
311 assert_eq!(get_language_marker("file.rs"), "rust");
312 assert_eq!(get_language_marker("file.py"), "python");
313 assert_eq!(get_language_marker("file.js"), "javascript");
314 assert_eq!(get_language_marker("file.ts"), "typescript");
315 assert_eq!(get_language_marker("file.go"), "go");
316 assert_eq!(get_language_marker("file.java"), "java");
317 assert_eq!(get_language_marker("file.cpp"), "cpp");
318 assert_eq!(get_language_marker("file.md"), "markdown");
319 assert_eq!(get_language_marker("file.yaml"), "yaml");
320 assert_eq!(get_language_marker("file.unknown"), "");
321 }
322
323 #[test]
324 fn test_format_context_section_with_both() {
325 let prompt_md = "Test prompt";
326 let plan = "Test plan";
327 let context = format_context_section(Some(prompt_md), Some(plan));
328
329 assert!(context.contains("## Task Context"));
330 assert!(context.contains("Test prompt"));
331 assert!(context.contains("## Implementation Plan"));
332 assert!(context.contains("Test plan"));
333 }
334
335 #[test]
336 fn test_format_context_section_with_prompt_only() {
337 let prompt_md = "Test prompt";
338 let context = format_context_section(Some(prompt_md), None);
339
340 assert!(context.contains("## Task Context"));
341 assert!(context.contains("Test prompt"));
342 assert!(!context.contains("## Implementation Plan"));
343 }
344
345 #[test]
346 fn test_format_context_section_with_plan_only() {
347 let plan = "Test plan";
348 let context = format_context_section(None, Some(plan));
349
350 assert!(!context.contains("## Task Context"));
351 assert!(context.contains("## Implementation Plan"));
352 assert!(context.contains("Test plan"));
353 }
354
355 #[test]
356 fn test_format_context_section_empty() {
357 let context = format_context_section(None, None);
358 assert!(context.is_empty());
359 }
360
361 #[test]
362 fn test_format_conflicts_section() {
363 let mut conflicts = HashMap::new();
364 conflicts.insert(
365 "src/test.rs".to_string(),
366 FileConflict {
367 conflict_content: "<<<<<<< ours\nx\n=======\ny\n>>>>>>> theirs".to_string(),
368 current_content: "<<<<<<< ours\nx\n=======\ny\n>>>>>>> theirs".to_string(),
369 },
370 );
371
372 let section = format_conflicts_section(&conflicts);
373
374 assert!(section.contains("### src/test.rs"));
375 assert!(section.contains("Current state (with conflict markers)"));
376 assert!(section.contains("```rust"));
377 assert!(section.contains("<<<<<<< ours"));
378 assert!(section.contains("Conflict sections"));
379 }
380
381 #[test]
382 fn test_template_is_used() {
383 let conflicts = HashMap::new();
385 let prompt = build_conflict_resolution_prompt(&conflicts, None, None);
386
387 assert!(prompt.contains("# MERGE CONFLICT RESOLUTION"));
389 assert!(prompt.contains("## Conflict Resolution Instructions"));
390 assert!(prompt.contains("## Optional JSON Output Format"));
391 assert!(prompt.contains("resolved_files"));
392 }
393
394 #[test]
395 fn test_build_conflict_resolution_prompt_with_registry_context() {
396 let context = TemplateContext::default();
397 let conflicts = HashMap::new();
398 let prompt =
399 build_conflict_resolution_prompt_with_context(&context, &conflicts, None, None);
400
401 assert!(!prompt.to_lowercase().contains("rebase"));
403 assert!(!prompt.to_lowercase().contains("rebasing"));
404
405 assert!(prompt.to_lowercase().contains("merge conflict"));
407 }
408
409 #[test]
410 fn test_build_conflict_resolution_prompt_with_registry_context_and_content() {
411 let context = TemplateContext::default();
412 let mut conflicts = HashMap::new();
413 conflicts.insert(
414 "test.rs".to_string(),
415 FileConflict {
416 conflict_content: "<<<<<<< ours\nfn foo() {}\n=======\nfn bar() {}\n>>>>>>> theirs"
417 .to_string(),
418 current_content: "<<<<<<< ours\nfn foo() {}\n=======\nfn bar() {}\n>>>>>>> theirs"
419 .to_string(),
420 },
421 );
422
423 let prompt_md = "Add a new feature";
424 let plan = "1. Create foo function\n2. Create bar function";
425
426 let prompt = build_conflict_resolution_prompt_with_context(
427 &context,
428 &conflicts,
429 Some(prompt_md),
430 Some(plan),
431 );
432
433 assert!(prompt.contains("Add a new feature"));
435
436 assert!(prompt.contains("Create foo function"));
438 assert!(prompt.contains("Create bar function"));
439
440 assert!(prompt.contains("test.rs"));
442
443 assert!(!prompt.to_lowercase().contains("rebase"));
445 }
446
447 #[test]
448 fn test_registry_context_based_matches_regular() {
449 let context = TemplateContext::default();
450 let mut conflicts = HashMap::new();
451 conflicts.insert(
452 "test.rs".to_string(),
453 FileConflict {
454 conflict_content: "conflict".to_string(),
455 current_content: "current".to_string(),
456 },
457 );
458
459 let regular = build_conflict_resolution_prompt(&conflicts, Some("prompt"), Some("plan"));
460 let with_context = build_conflict_resolution_prompt_with_context(
461 &context,
462 &conflicts,
463 Some("prompt"),
464 Some("plan"),
465 );
466 assert_eq!(regular, with_context);
468 }
469}