1#![deny(unsafe_code)]
13
14use crate::prompts::template_context::TemplateContext;
15use crate::prompts::template_engine::Template;
16use std::collections::HashMap;
17use std::fmt::Write;
18use std::fs;
19use std::path::Path;
20
21#[derive(Debug, Clone)]
23pub struct FileConflict {
24 pub conflict_content: String,
26 pub current_content: String,
28}
29
30#[cfg(test)]
46pub fn build_conflict_resolution_prompt(
47 conflicts: &HashMap<String, FileConflict>,
48 prompt_md_content: Option<&str>,
49 plan_content: Option<&str>,
50) -> String {
51 let template_content = include_str!("templates/conflict_resolution.txt");
52 let template = Template::new(template_content);
53
54 let context = format_context_section(prompt_md_content, plan_content);
55 let conflicts_section = format_conflicts_section(conflicts);
56
57 let variables = HashMap::from([
58 ("CONTEXT", context),
59 ("CONFLICTS", conflicts_section.clone()),
60 ]);
61
62 template.render(&variables).unwrap_or_else(|e| {
63 eprintln!("Warning: Failed to render conflict resolution template: {e}");
64 let fallback_template_content = include_str!("templates/conflict_resolution_fallback.txt");
66 let fallback_template = Template::new(fallback_template_content);
67 fallback_template.render(&variables).unwrap_or_else(|e| {
68 eprintln!("Critical: Failed to render fallback template: {e}");
69 format!(
71 "# MERGE CONFLICT RESOLUTION\n\nResolve these conflicts:\n\n{}",
72 &conflicts_section
73 )
74 })
75 })
76}
77
78pub fn build_conflict_resolution_prompt_with_context(
90 context: &TemplateContext,
91 conflicts: &HashMap<String, FileConflict>,
92 prompt_md_content: Option<&str>,
93 plan_content: Option<&str>,
94) -> String {
95 let template_content = context
96 .registry()
97 .get_template("conflict_resolution")
98 .unwrap_or_else(|_| include_str!("templates/conflict_resolution.txt").to_string());
99 let template = Template::new(&template_content);
100
101 let ctx_section = format_context_section(prompt_md_content, plan_content);
102 let conflicts_section = format_conflicts_section(conflicts);
103
104 let variables = HashMap::from([
105 ("CONTEXT", ctx_section),
106 ("CONFLICTS", conflicts_section.clone()),
107 ]);
108
109 template.render(&variables).unwrap_or_else(|e| {
110 eprintln!("Warning: Failed to render conflict resolution template: {e}");
111 let fallback_template_content = context
113 .registry()
114 .get_template("conflict_resolution_fallback")
115 .unwrap_or_else(|_| {
116 include_str!("templates/conflict_resolution_fallback.txt").to_string()
117 });
118 let fallback_template = Template::new(&fallback_template_content);
119 fallback_template.render(&variables).unwrap_or_else(|e| {
120 eprintln!("Critical: Failed to render fallback template: {e}");
121 format!(
123 "# MERGE CONFLICT RESOLUTION\n\nResolve these conflicts:\n\n{}",
124 &conflicts_section
125 )
126 })
127 })
128}
129
130fn format_context_section(prompt_md_content: Option<&str>, plan_content: Option<&str>) -> String {
135 let mut context = String::new();
136
137 if let Some(prompt_md) = prompt_md_content {
139 context.push_str("## Task Context\n\n");
140 context.push_str("The user was working on the following task:\n\n");
141 context.push_str("```\n");
142 context.push_str(prompt_md);
143 context.push_str("\n```\n\n");
144 }
145
146 if let Some(plan) = plan_content {
148 context.push_str("## Implementation Plan\n\n");
149 context.push_str("The following plan was being implemented:\n\n");
150 context.push_str("```\n");
151 context.push_str(plan);
152 context.push_str("\n```\n\n");
153 }
154
155 context
156}
157
158fn format_conflicts_section(conflicts: &HashMap<String, FileConflict>) -> String {
163 let mut section = String::new();
164
165 for (path, conflict) in conflicts {
166 writeln!(section, "### {path}\n\n").unwrap();
167 section.push_str("Current state (with conflict markers):\n\n");
168 section.push_str("```");
169 section.push_str(&get_language_marker(path));
170 section.push('\n');
171 section.push_str(&conflict.current_content);
172 section.push_str("\n```\n\n");
173
174 if !conflict.conflict_content.is_empty() {
175 section.push_str("Conflict sections:\n\n");
176 section.push_str("```\n");
177 section.push_str(&conflict.conflict_content);
178 section.push_str("\n```\n\n");
179 }
180 }
181
182 section
183}
184
185fn get_language_marker(path: &str) -> String {
187 let ext = Path::new(path)
188 .extension()
189 .and_then(|e| e.to_str())
190 .unwrap_or("");
191
192 match ext {
193 "rs" => "rust",
194 "py" => "python",
195 "js" | "jsx" => "javascript",
196 "ts" | "tsx" => "typescript",
197 "go" => "go",
198 "java" => "java",
199 "c" | "h" => "c",
200 "cpp" | "hpp" | "cc" | "cxx" => "cpp",
201 "cs" => "csharp",
202 "php" => "php",
203 "rb" => "ruby",
204 "swift" => "swift",
205 "kt" => "kotlin",
206 "scala" => "scala",
207 "sh" | "bash" | "zsh" => "bash",
208 "fish" => "fish",
209 "yaml" | "yml" => "yaml",
210 "json" => "json",
211 "toml" => "toml",
212 "md" | "markdown" => "markdown",
213 "txt" => "text",
214 "html" => "html",
215 "css" | "scss" | "less" => "css",
216 "xml" => "xml",
217 "sql" => "sql",
218 _ => "",
219 }
220 .to_string()
221}
222
223#[derive(Debug, Clone)]
225#[cfg(any(test, feature = "test-utils"))]
226pub struct BranchInfo {
227 pub current_branch: String,
229 pub upstream_branch: String,
231 pub current_commits: Vec<String>,
233 pub upstream_commits: Vec<String>,
235 pub diverging_count: usize,
237}
238
239#[cfg(any(test, feature = "test-utils"))]
252pub fn build_enhanced_conflict_resolution_prompt(
253 context: &TemplateContext,
254 conflicts: &HashMap<String, FileConflict>,
255 branch_info: Option<&BranchInfo>,
256 prompt_md_content: Option<&str>,
257 plan_content: Option<&str>,
258) -> String {
259 let template_content = context
260 .registry()
261 .get_template("conflict_resolution")
262 .unwrap_or_else(|_| include_str!("templates/conflict_resolution.txt").to_string());
263 let template = Template::new(&template_content);
264
265 let mut ctx_section = format_context_section(prompt_md_content, plan_content);
266
267 if let Some(info) = branch_info {
269 ctx_section.push_str(&format_branch_info_section(info));
270 }
271
272 let conflicts_section = format_conflicts_section(conflicts);
273
274 let variables = HashMap::from([
275 ("CONTEXT", ctx_section),
276 ("CONFLICTS", conflicts_section.clone()),
277 ]);
278
279 template.render(&variables).unwrap_or_else(|e| {
280 eprintln!("Warning: Failed to render conflict resolution template: {e}");
281 let fallback_template_content = context
283 .registry()
284 .get_template("conflict_resolution_fallback")
285 .unwrap_or_else(|_| {
286 include_str!("templates/conflict_resolution_fallback.txt").to_string()
287 });
288 let fallback_template = Template::new(&fallback_template_content);
289 fallback_template.render(&variables).unwrap_or_else(|e| {
290 eprintln!("Critical: Failed to render fallback template: {e}");
291 format!(
293 "# MERGE CONFLICT RESOLUTION\n\nResolve these conflicts:\n\n{}",
294 &conflicts_section
295 )
296 })
297 })
298}
299
300#[cfg(any(test, feature = "test-utils"))]
305fn format_branch_info_section(info: &BranchInfo) -> String {
306 let mut section = String::new();
307
308 section.push_str("## Branch Information\n\n");
309 section.push_str(&format!(
310 "- **Current branch**: `{}`\n",
311 info.current_branch
312 ));
313 section.push_str(&format!(
314 "- **Target branch**: `{}`\n",
315 info.upstream_branch
316 ));
317 section.push_str(&format!(
318 "- **Diverging commits**: {}\n\n",
319 info.diverging_count
320 ));
321
322 if !info.current_commits.is_empty() {
323 section.push_str("### Recent commits on current branch:\n\n");
324 for (i, msg) in info.current_commits.iter().enumerate().take(5) {
325 section.push_str(&format!("{}. {}\n", i + 1, msg));
326 }
327 section.push('\n');
328 }
329
330 if !info.upstream_commits.is_empty() {
331 section.push_str("### Recent commits on target branch:\n\n");
332 for (i, msg) in info.upstream_commits.iter().enumerate().take(5) {
333 section.push_str(&format!("{}. {}\n", i + 1, msg));
334 }
335 section.push('\n');
336 }
337
338 section
339}
340
341#[cfg(any(test, feature = "test-utils"))]
354pub fn collect_branch_info(
355 upstream_branch: &str,
356 executor: &dyn crate::executor::ProcessExecutor,
357) -> std::io::Result<BranchInfo> {
358 let current_branch =
360 executor.execute("git", &["rev-parse", "--abbrev-ref", "HEAD"], &[], None)?;
361
362 let current_branch = current_branch.stdout.trim().to_string();
363
364 let current_log = executor.execute("git", &["log", "--oneline", "-10", "HEAD"], &[], None)?;
366
367 let current_commits: Vec<String> = current_log.stdout.lines().map(|s| s.to_string()).collect();
368
369 let upstream_log = executor.execute(
371 "git",
372 &["log", "--oneline", "-10", upstream_branch],
373 &[],
374 None,
375 )?;
376
377 let upstream_commits: Vec<String> =
378 upstream_log.stdout.lines().map(|s| s.to_string()).collect();
379
380 let diverging = executor.execute(
382 "git",
383 &[
384 "rev-list",
385 "--count",
386 "--left-right",
387 &format!("HEAD...{upstream_branch}"),
388 ],
389 &[],
390 None,
391 )?;
392
393 let diverging_count = diverging
394 .stdout
395 .split_whitespace()
396 .map(|s| s.parse::<usize>().unwrap_or(0))
397 .sum::<usize>();
398
399 Ok(BranchInfo {
400 current_branch,
401 upstream_branch: upstream_branch.to_string(),
402 current_commits,
403 upstream_commits,
404 diverging_count,
405 })
406}
407
408pub fn collect_conflict_info(
422 conflicted_paths: &[String],
423) -> std::io::Result<HashMap<String, FileConflict>> {
424 let mut conflicts = HashMap::new();
425
426 for path in conflicted_paths {
427 let current_content = fs::read_to_string(path)?;
429
430 let conflict_content = crate::git_helpers::get_conflict_markers_for_file(Path::new(path))?;
432
433 conflicts.insert(
434 path.clone(),
435 FileConflict {
436 conflict_content,
437 current_content,
438 },
439 );
440 }
441
442 Ok(conflicts)
443}
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448
449 #[test]
450 fn test_build_conflict_resolution_prompt_no_mentions_rebase() {
451 let conflicts = HashMap::new();
452 let prompt = build_conflict_resolution_prompt(&conflicts, None, None);
453
454 assert!(!prompt.to_lowercase().contains("rebase"));
456 assert!(!prompt.to_lowercase().contains("rebasing"));
457
458 assert!(prompt.to_lowercase().contains("merge conflict"));
460 }
461
462 #[test]
463 fn test_build_conflict_resolution_prompt_with_context() {
464 let mut conflicts = HashMap::new();
465 conflicts.insert(
466 "test.rs".to_string(),
467 FileConflict {
468 conflict_content: "<<<<<<< ours\nfn foo() {}\n=======\nfn bar() {}\n>>>>>>> theirs"
469 .to_string(),
470 current_content: "<<<<<<< ours\nfn foo() {}\n=======\nfn bar() {}\n>>>>>>> theirs"
471 .to_string(),
472 },
473 );
474
475 let prompt_md = "Add a new feature";
476 let plan = "1. Create foo function\n2. Create bar function";
477
478 let prompt = build_conflict_resolution_prompt(&conflicts, Some(prompt_md), Some(plan));
479
480 assert!(prompt.contains("Add a new feature"));
482
483 assert!(prompt.contains("Create foo function"));
485 assert!(prompt.contains("Create bar function"));
486
487 assert!(prompt.contains("test.rs"));
489
490 assert!(!prompt.to_lowercase().contains("rebase"));
492 }
493
494 #[test]
495 fn test_get_language_marker() {
496 assert_eq!(get_language_marker("file.rs"), "rust");
497 assert_eq!(get_language_marker("file.py"), "python");
498 assert_eq!(get_language_marker("file.js"), "javascript");
499 assert_eq!(get_language_marker("file.ts"), "typescript");
500 assert_eq!(get_language_marker("file.go"), "go");
501 assert_eq!(get_language_marker("file.java"), "java");
502 assert_eq!(get_language_marker("file.cpp"), "cpp");
503 assert_eq!(get_language_marker("file.md"), "markdown");
504 assert_eq!(get_language_marker("file.yaml"), "yaml");
505 assert_eq!(get_language_marker("file.unknown"), "");
506 }
507
508 #[test]
509 fn test_format_context_section_with_both() {
510 let prompt_md = "Test prompt";
511 let plan = "Test plan";
512 let context = format_context_section(Some(prompt_md), Some(plan));
513
514 assert!(context.contains("## Task Context"));
515 assert!(context.contains("Test prompt"));
516 assert!(context.contains("## Implementation Plan"));
517 assert!(context.contains("Test plan"));
518 }
519
520 #[test]
521 fn test_format_context_section_with_prompt_only() {
522 let prompt_md = "Test prompt";
523 let context = format_context_section(Some(prompt_md), None);
524
525 assert!(context.contains("## Task Context"));
526 assert!(context.contains("Test prompt"));
527 assert!(!context.contains("## Implementation Plan"));
528 }
529
530 #[test]
531 fn test_format_context_section_with_plan_only() {
532 let plan = "Test plan";
533 let context = format_context_section(None, Some(plan));
534
535 assert!(!context.contains("## Task Context"));
536 assert!(context.contains("## Implementation Plan"));
537 assert!(context.contains("Test plan"));
538 }
539
540 #[test]
541 fn test_format_context_section_empty() {
542 let context = format_context_section(None, None);
543 assert!(context.is_empty());
544 }
545
546 #[test]
547 fn test_format_conflicts_section() {
548 let mut conflicts = HashMap::new();
549 conflicts.insert(
550 "src/test.rs".to_string(),
551 FileConflict {
552 conflict_content: "<<<<<<< ours\nx\n=======\ny\n>>>>>>> theirs".to_string(),
553 current_content: "<<<<<<< ours\nx\n=======\ny\n>>>>>>> theirs".to_string(),
554 },
555 );
556
557 let section = format_conflicts_section(&conflicts);
558
559 assert!(section.contains("### src/test.rs"));
560 assert!(section.contains("Current state (with conflict markers)"));
561 assert!(section.contains("```rust"));
562 assert!(section.contains("<<<<<<< ours"));
563 assert!(section.contains("Conflict sections"));
564 }
565
566 #[test]
567 fn test_template_is_used() {
568 let conflicts = HashMap::new();
570 let prompt = build_conflict_resolution_prompt(&conflicts, None, None);
571
572 assert!(prompt.contains("# MERGE CONFLICT RESOLUTION"));
574 assert!(prompt.contains("## Conflict Resolution Instructions"));
575 assert!(prompt.contains("## Optional JSON Output Format"));
576 assert!(prompt.contains("resolved_files"));
577 }
578
579 #[test]
580 fn test_build_conflict_resolution_prompt_with_registry_context() {
581 let context = TemplateContext::default();
582 let conflicts = HashMap::new();
583 let prompt =
584 build_conflict_resolution_prompt_with_context(&context, &conflicts, None, None);
585
586 assert!(!prompt.to_lowercase().contains("rebase"));
588 assert!(!prompt.to_lowercase().contains("rebasing"));
589
590 assert!(prompt.to_lowercase().contains("merge conflict"));
592 }
593
594 #[test]
595 fn test_build_conflict_resolution_prompt_with_registry_context_and_content() {
596 let context = TemplateContext::default();
597 let mut conflicts = HashMap::new();
598 conflicts.insert(
599 "test.rs".to_string(),
600 FileConflict {
601 conflict_content: "<<<<<<< ours\nfn foo() {}\n=======\nfn bar() {}\n>>>>>>> theirs"
602 .to_string(),
603 current_content: "<<<<<<< ours\nfn foo() {}\n=======\nfn bar() {}\n>>>>>>> theirs"
604 .to_string(),
605 },
606 );
607
608 let prompt_md = "Add a new feature";
609 let plan = "1. Create foo function\n2. Create bar function";
610
611 let prompt = build_conflict_resolution_prompt_with_context(
612 &context,
613 &conflicts,
614 Some(prompt_md),
615 Some(plan),
616 );
617
618 assert!(prompt.contains("Add a new feature"));
620
621 assert!(prompt.contains("Create foo function"));
623 assert!(prompt.contains("Create bar function"));
624
625 assert!(prompt.contains("test.rs"));
627
628 assert!(!prompt.to_lowercase().contains("rebase"));
630 }
631
632 #[test]
633 fn test_registry_context_based_matches_regular() {
634 let context = TemplateContext::default();
635 let mut conflicts = HashMap::new();
636 conflicts.insert(
637 "test.rs".to_string(),
638 FileConflict {
639 conflict_content: "conflict".to_string(),
640 current_content: "current".to_string(),
641 },
642 );
643
644 let regular = build_conflict_resolution_prompt(&conflicts, Some("prompt"), Some("plan"));
645 let with_context = build_conflict_resolution_prompt_with_context(
646 &context,
647 &conflicts,
648 Some("prompt"),
649 Some("plan"),
650 );
651 assert_eq!(regular, with_context);
653 }
654
655 #[test]
656 fn test_branch_info_struct_exists() {
657 let info = BranchInfo {
658 current_branch: "feature".to_string(),
659 upstream_branch: "main".to_string(),
660 current_commits: vec!["abc123 feat: add thing".to_string()],
661 upstream_commits: vec!["def456 fix: bug".to_string()],
662 diverging_count: 5,
663 };
664 assert_eq!(info.current_branch, "feature");
665 assert_eq!(info.diverging_count, 5);
666 }
667
668 #[test]
669 fn test_format_branch_info_section() {
670 let info = BranchInfo {
671 current_branch: "feature".to_string(),
672 upstream_branch: "main".to_string(),
673 current_commits: vec!["abc123 feat: add thing".to_string()],
674 upstream_commits: vec!["def456 fix: bug".to_string()],
675 diverging_count: 5,
676 };
677
678 let section = format_branch_info_section(&info);
679
680 assert!(section.contains("Branch Information"));
681 assert!(section.contains("feature"));
682 assert!(section.contains("main"));
683 assert!(section.contains("5"));
684 assert!(section.contains("abc123"));
685 assert!(section.contains("def456"));
686 }
687
688 #[test]
689 fn test_enhanced_prompt_with_branch_info() {
690 let context = TemplateContext::default();
691 let mut conflicts = HashMap::new();
692 conflicts.insert(
693 "test.rs".to_string(),
694 FileConflict {
695 conflict_content: "conflict".to_string(),
696 current_content: "current".to_string(),
697 },
698 );
699
700 let branch_info = BranchInfo {
701 current_branch: "feature".to_string(),
702 upstream_branch: "main".to_string(),
703 current_commits: vec!["abc123 my change".to_string()],
704 upstream_commits: vec!["def456 their change".to_string()],
705 diverging_count: 3,
706 };
707
708 let prompt = build_enhanced_conflict_resolution_prompt(
709 &context,
710 &conflicts,
711 Some(&branch_info),
712 None,
713 None,
714 );
715
716 assert!(prompt.contains("Branch Information"));
718 assert!(prompt.contains("feature"));
719 assert!(prompt.contains("main"));
720 assert!(prompt.contains("3")); assert!(!prompt.to_lowercase().contains("rebase"));
724 }
725}