1#![deny(unsafe_code)]
13
14use crate::prompts::template_context::TemplateContext;
15use crate::prompts::template_engine::Template;
16use std::collections::HashMap;
17use std::fmt::Write;
18use std::fs;
19use std::path::Path;
20
21#[derive(Debug, Clone)]
23pub struct FileConflict {
24 pub conflict_content: String,
26 pub current_content: String,
28}
29
30#[cfg(test)]
46pub fn build_conflict_resolution_prompt(
47 conflicts: &HashMap<String, FileConflict>,
48 prompt_md_content: Option<&str>,
49 plan_content: Option<&str>,
50) -> String {
51 let template_content = include_str!("templates/conflict_resolution.txt");
52 let template = Template::new(template_content);
53
54 let context = format_context_section(prompt_md_content, plan_content);
55 let conflicts_section = format_conflicts_section(conflicts);
56
57 let variables = HashMap::from([
58 ("CONTEXT", context),
59 ("CONFLICTS", conflicts_section.clone()),
60 ]);
61
62 template.render(&variables).unwrap_or_else(|e| {
63 eprintln!("Warning: Failed to render conflict resolution template: {e}");
64 let fallback_template_content = include_str!("templates/conflict_resolution_fallback.txt");
66 let fallback_template = Template::new(fallback_template_content);
67 fallback_template.render(&variables).unwrap_or_else(|e| {
68 eprintln!("Critical: Failed to render fallback template: {e}");
69 format!(
71 "# MERGE CONFLICT RESOLUTION\n\nResolve these conflicts:\n\n{}",
72 &conflicts_section
73 )
74 })
75 })
76}
77
78pub fn build_conflict_resolution_prompt_with_context(
90 context: &TemplateContext,
91 conflicts: &HashMap<String, FileConflict>,
92 prompt_md_content: Option<&str>,
93 plan_content: Option<&str>,
94) -> String {
95 let template_content = context
96 .registry()
97 .get_template("conflict_resolution")
98 .unwrap_or_else(|_| include_str!("templates/conflict_resolution.txt").to_string());
99 let template = Template::new(&template_content);
100
101 let ctx_section = format_context_section(prompt_md_content, plan_content);
102 let conflicts_section = format_conflicts_section(conflicts);
103
104 let variables = HashMap::from([
105 ("CONTEXT", ctx_section),
106 ("CONFLICTS", conflicts_section.clone()),
107 ]);
108
109 template.render(&variables).unwrap_or_else(|e| {
110 eprintln!("Warning: Failed to render conflict resolution template: {e}");
111 let fallback_template_content = context
113 .registry()
114 .get_template("conflict_resolution_fallback")
115 .unwrap_or_else(|_| {
116 include_str!("templates/conflict_resolution_fallback.txt").to_string()
117 });
118 let fallback_template = Template::new(&fallback_template_content);
119 fallback_template.render(&variables).unwrap_or_else(|e| {
120 eprintln!("Critical: Failed to render fallback template: {e}");
121 format!(
123 "# MERGE CONFLICT RESOLUTION\n\nResolve these conflicts:\n\n{}",
124 &conflicts_section
125 )
126 })
127 })
128}
129
130fn format_context_section(prompt_md_content: Option<&str>, plan_content: Option<&str>) -> String {
135 let mut context = String::new();
136
137 if let Some(prompt_md) = prompt_md_content {
139 context.push_str("## Task Context\n\n");
140 context.push_str("The user was working on the following task:\n\n");
141 context.push_str("```\n");
142 context.push_str(prompt_md);
143 context.push_str("\n```\n\n");
144 }
145
146 if let Some(plan) = plan_content {
148 context.push_str("## Implementation Plan\n\n");
149 context.push_str("The following plan was being implemented:\n\n");
150 context.push_str("```\n");
151 context.push_str(plan);
152 context.push_str("\n```\n\n");
153 }
154
155 context
156}
157
158fn format_conflicts_section(conflicts: &HashMap<String, FileConflict>) -> String {
163 let mut section = String::new();
164
165 for (path, conflict) in conflicts {
166 writeln!(section, "### {path}\n\n").unwrap();
167 section.push_str("Current state (with conflict markers):\n\n");
168 section.push_str("```");
169 section.push_str(&get_language_marker(path));
170 section.push('\n');
171 section.push_str(&conflict.current_content);
172 section.push_str("\n```\n\n");
173
174 if !conflict.conflict_content.is_empty() {
175 section.push_str("Conflict sections:\n\n");
176 section.push_str("```\n");
177 section.push_str(&conflict.conflict_content);
178 section.push_str("\n```\n\n");
179 }
180 }
181
182 section
183}
184
185fn get_language_marker(path: &str) -> String {
187 let ext = Path::new(path)
188 .extension()
189 .and_then(|e| e.to_str())
190 .unwrap_or("");
191
192 match ext {
193 "rs" => "rust",
194 "py" => "python",
195 "js" | "jsx" => "javascript",
196 "ts" | "tsx" => "typescript",
197 "go" => "go",
198 "java" => "java",
199 "c" | "h" => "c",
200 "cpp" | "hpp" | "cc" | "cxx" => "cpp",
201 "cs" => "csharp",
202 "php" => "php",
203 "rb" => "ruby",
204 "swift" => "swift",
205 "kt" => "kotlin",
206 "scala" => "scala",
207 "sh" | "bash" | "zsh" => "bash",
208 "fish" => "fish",
209 "yaml" | "yml" => "yaml",
210 "json" => "json",
211 "toml" => "toml",
212 "md" | "markdown" => "markdown",
213 "txt" => "text",
214 "html" => "html",
215 "css" | "scss" | "less" => "css",
216 "xml" => "xml",
217 "sql" => "sql",
218 _ => "",
219 }
220 .to_string()
221}
222
223#[derive(Debug, Clone)]
225#[cfg(any(test, feature = "test-utils"))]
226pub struct BranchInfo {
227 pub current_branch: String,
229 pub upstream_branch: String,
231 pub current_commits: Vec<String>,
233 pub upstream_commits: Vec<String>,
235 pub diverging_count: usize,
237}
238
239#[cfg(any(test, feature = "test-utils"))]
252pub fn build_enhanced_conflict_resolution_prompt(
253 context: &TemplateContext,
254 conflicts: &HashMap<String, FileConflict>,
255 branch_info: Option<&BranchInfo>,
256 prompt_md_content: Option<&str>,
257 plan_content: Option<&str>,
258) -> String {
259 let template_content = context
260 .registry()
261 .get_template("conflict_resolution")
262 .unwrap_or_else(|_| include_str!("templates/conflict_resolution.txt").to_string());
263 let template = Template::new(&template_content);
264
265 let mut ctx_section = format_context_section(prompt_md_content, plan_content);
266
267 if let Some(info) = branch_info {
269 ctx_section.push_str(&format_branch_info_section(info));
270 }
271
272 let conflicts_section = format_conflicts_section(conflicts);
273
274 let variables = HashMap::from([
275 ("CONTEXT", ctx_section),
276 ("CONFLICTS", conflicts_section.clone()),
277 ]);
278
279 template.render(&variables).unwrap_or_else(|e| {
280 eprintln!("Warning: Failed to render conflict resolution template: {e}");
281 let fallback_template_content = context
283 .registry()
284 .get_template("conflict_resolution_fallback")
285 .unwrap_or_else(|_| {
286 include_str!("templates/conflict_resolution_fallback.txt").to_string()
287 });
288 let fallback_template = Template::new(&fallback_template_content);
289 fallback_template.render(&variables).unwrap_or_else(|e| {
290 eprintln!("Critical: Failed to render fallback template: {e}");
291 format!(
293 "# MERGE CONFLICT RESOLUTION\n\nResolve these conflicts:\n\n{}",
294 &conflicts_section
295 )
296 })
297 })
298}
299
300#[cfg(any(test, feature = "test-utils"))]
305fn format_branch_info_section(info: &BranchInfo) -> String {
306 let mut section = String::new();
307
308 section.push_str("## Branch Information\n\n");
309 section.push_str(&format!(
310 "- **Current branch**: `{}`\n",
311 info.current_branch
312 ));
313 section.push_str(&format!(
314 "- **Target branch**: `{}`\n",
315 info.upstream_branch
316 ));
317 section.push_str(&format!(
318 "- **Diverging commits**: {}\n\n",
319 info.diverging_count
320 ));
321
322 if !info.current_commits.is_empty() {
323 section.push_str("### Recent commits on current branch:\n\n");
324 for (i, msg) in info.current_commits.iter().enumerate().take(5) {
325 section.push_str(&format!("{}. {}\n", i + 1, msg));
326 }
327 section.push('\n');
328 }
329
330 if !info.upstream_commits.is_empty() {
331 section.push_str("### Recent commits on target branch:\n\n");
332 for (i, msg) in info.upstream_commits.iter().enumerate().take(5) {
333 section.push_str(&format!("{}. {}\n", i + 1, msg));
334 }
335 section.push('\n');
336 }
337
338 section
339}
340
341#[cfg(any(test, feature = "test-utils"))]
353pub fn collect_branch_info(upstream_branch: &str) -> std::io::Result<BranchInfo> {
354 use std::process::Command;
355
356 let current_branch = Command::new("git")
358 .args(["rev-parse", "--abbrev-ref", "HEAD"])
359 .output()
360 .map_err(|e| std::io::Error::other(format!("git rev-parse failed: {e}")))?;
361
362 let current_branch = String::from_utf8_lossy(¤t_branch.stdout)
363 .trim()
364 .to_string();
365
366 let current_log = Command::new("git")
368 .args(["log", "--oneline", "-10", "HEAD"])
369 .output()
370 .map_err(|e| std::io::Error::other(format!("git log failed: {e}")))?;
371
372 let current_commits: Vec<String> = String::from_utf8_lossy(¤t_log.stdout)
373 .lines()
374 .map(|s| s.to_string())
375 .collect();
376
377 let upstream_log = Command::new("git")
379 .args(["log", "--oneline", "-10", upstream_branch])
380 .output()
381 .map_err(|e| std::io::Error::other(format!("git log failed: {e}")))?;
382
383 let upstream_commits: Vec<String> = String::from_utf8_lossy(&upstream_log.stdout)
384 .lines()
385 .map(|s| s.to_string())
386 .collect();
387
388 let diverging = Command::new("git")
390 .args([
391 "rev-list",
392 "--count",
393 "--left-right",
394 &format!("HEAD...{upstream_branch}"),
395 ])
396 .output()
397 .map_err(|e| std::io::Error::other(format!("git rev-list failed: {e}")))?;
398
399 let diverging_count = String::from_utf8_lossy(&diverging.stdout)
400 .split_whitespace()
401 .map(|s| s.parse::<usize>().unwrap_or(0))
402 .sum::<usize>();
403
404 Ok(BranchInfo {
405 current_branch,
406 upstream_branch: upstream_branch.to_string(),
407 current_commits,
408 upstream_commits,
409 diverging_count,
410 })
411}
412
413pub fn collect_conflict_info(
427 conflicted_paths: &[String],
428) -> std::io::Result<HashMap<String, FileConflict>> {
429 let mut conflicts = HashMap::new();
430
431 for path in conflicted_paths {
432 let current_content = fs::read_to_string(path)?;
434
435 let conflict_content = crate::git_helpers::get_conflict_markers_for_file(Path::new(path))?;
437
438 conflicts.insert(
439 path.clone(),
440 FileConflict {
441 conflict_content,
442 current_content,
443 },
444 );
445 }
446
447 Ok(conflicts)
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453
454 #[test]
455 fn test_build_conflict_resolution_prompt_no_mentions_rebase() {
456 let conflicts = HashMap::new();
457 let prompt = build_conflict_resolution_prompt(&conflicts, None, None);
458
459 assert!(!prompt.to_lowercase().contains("rebase"));
461 assert!(!prompt.to_lowercase().contains("rebasing"));
462
463 assert!(prompt.to_lowercase().contains("merge conflict"));
465 }
466
467 #[test]
468 fn test_build_conflict_resolution_prompt_with_context() {
469 let mut conflicts = HashMap::new();
470 conflicts.insert(
471 "test.rs".to_string(),
472 FileConflict {
473 conflict_content: "<<<<<<< ours\nfn foo() {}\n=======\nfn bar() {}\n>>>>>>> theirs"
474 .to_string(),
475 current_content: "<<<<<<< ours\nfn foo() {}\n=======\nfn bar() {}\n>>>>>>> theirs"
476 .to_string(),
477 },
478 );
479
480 let prompt_md = "Add a new feature";
481 let plan = "1. Create foo function\n2. Create bar function";
482
483 let prompt = build_conflict_resolution_prompt(&conflicts, Some(prompt_md), Some(plan));
484
485 assert!(prompt.contains("Add a new feature"));
487
488 assert!(prompt.contains("Create foo function"));
490 assert!(prompt.contains("Create bar function"));
491
492 assert!(prompt.contains("test.rs"));
494
495 assert!(!prompt.to_lowercase().contains("rebase"));
497 }
498
499 #[test]
500 fn test_get_language_marker() {
501 assert_eq!(get_language_marker("file.rs"), "rust");
502 assert_eq!(get_language_marker("file.py"), "python");
503 assert_eq!(get_language_marker("file.js"), "javascript");
504 assert_eq!(get_language_marker("file.ts"), "typescript");
505 assert_eq!(get_language_marker("file.go"), "go");
506 assert_eq!(get_language_marker("file.java"), "java");
507 assert_eq!(get_language_marker("file.cpp"), "cpp");
508 assert_eq!(get_language_marker("file.md"), "markdown");
509 assert_eq!(get_language_marker("file.yaml"), "yaml");
510 assert_eq!(get_language_marker("file.unknown"), "");
511 }
512
513 #[test]
514 fn test_format_context_section_with_both() {
515 let prompt_md = "Test prompt";
516 let plan = "Test plan";
517 let context = format_context_section(Some(prompt_md), Some(plan));
518
519 assert!(context.contains("## Task Context"));
520 assert!(context.contains("Test prompt"));
521 assert!(context.contains("## Implementation Plan"));
522 assert!(context.contains("Test plan"));
523 }
524
525 #[test]
526 fn test_format_context_section_with_prompt_only() {
527 let prompt_md = "Test prompt";
528 let context = format_context_section(Some(prompt_md), None);
529
530 assert!(context.contains("## Task Context"));
531 assert!(context.contains("Test prompt"));
532 assert!(!context.contains("## Implementation Plan"));
533 }
534
535 #[test]
536 fn test_format_context_section_with_plan_only() {
537 let plan = "Test plan";
538 let context = format_context_section(None, Some(plan));
539
540 assert!(!context.contains("## Task Context"));
541 assert!(context.contains("## Implementation Plan"));
542 assert!(context.contains("Test plan"));
543 }
544
545 #[test]
546 fn test_format_context_section_empty() {
547 let context = format_context_section(None, None);
548 assert!(context.is_empty());
549 }
550
551 #[test]
552 fn test_format_conflicts_section() {
553 let mut conflicts = HashMap::new();
554 conflicts.insert(
555 "src/test.rs".to_string(),
556 FileConflict {
557 conflict_content: "<<<<<<< ours\nx\n=======\ny\n>>>>>>> theirs".to_string(),
558 current_content: "<<<<<<< ours\nx\n=======\ny\n>>>>>>> theirs".to_string(),
559 },
560 );
561
562 let section = format_conflicts_section(&conflicts);
563
564 assert!(section.contains("### src/test.rs"));
565 assert!(section.contains("Current state (with conflict markers)"));
566 assert!(section.contains("```rust"));
567 assert!(section.contains("<<<<<<< ours"));
568 assert!(section.contains("Conflict sections"));
569 }
570
571 #[test]
572 fn test_template_is_used() {
573 let conflicts = HashMap::new();
575 let prompt = build_conflict_resolution_prompt(&conflicts, None, None);
576
577 assert!(prompt.contains("# MERGE CONFLICT RESOLUTION"));
579 assert!(prompt.contains("## Conflict Resolution Instructions"));
580 assert!(prompt.contains("## Optional JSON Output Format"));
581 assert!(prompt.contains("resolved_files"));
582 }
583
584 #[test]
585 fn test_build_conflict_resolution_prompt_with_registry_context() {
586 let context = TemplateContext::default();
587 let conflicts = HashMap::new();
588 let prompt =
589 build_conflict_resolution_prompt_with_context(&context, &conflicts, None, None);
590
591 assert!(!prompt.to_lowercase().contains("rebase"));
593 assert!(!prompt.to_lowercase().contains("rebasing"));
594
595 assert!(prompt.to_lowercase().contains("merge conflict"));
597 }
598
599 #[test]
600 fn test_build_conflict_resolution_prompt_with_registry_context_and_content() {
601 let context = TemplateContext::default();
602 let mut conflicts = HashMap::new();
603 conflicts.insert(
604 "test.rs".to_string(),
605 FileConflict {
606 conflict_content: "<<<<<<< ours\nfn foo() {}\n=======\nfn bar() {}\n>>>>>>> theirs"
607 .to_string(),
608 current_content: "<<<<<<< ours\nfn foo() {}\n=======\nfn bar() {}\n>>>>>>> theirs"
609 .to_string(),
610 },
611 );
612
613 let prompt_md = "Add a new feature";
614 let plan = "1. Create foo function\n2. Create bar function";
615
616 let prompt = build_conflict_resolution_prompt_with_context(
617 &context,
618 &conflicts,
619 Some(prompt_md),
620 Some(plan),
621 );
622
623 assert!(prompt.contains("Add a new feature"));
625
626 assert!(prompt.contains("Create foo function"));
628 assert!(prompt.contains("Create bar function"));
629
630 assert!(prompt.contains("test.rs"));
632
633 assert!(!prompt.to_lowercase().contains("rebase"));
635 }
636
637 #[test]
638 fn test_registry_context_based_matches_regular() {
639 let context = TemplateContext::default();
640 let mut conflicts = HashMap::new();
641 conflicts.insert(
642 "test.rs".to_string(),
643 FileConflict {
644 conflict_content: "conflict".to_string(),
645 current_content: "current".to_string(),
646 },
647 );
648
649 let regular = build_conflict_resolution_prompt(&conflicts, Some("prompt"), Some("plan"));
650 let with_context = build_conflict_resolution_prompt_with_context(
651 &context,
652 &conflicts,
653 Some("prompt"),
654 Some("plan"),
655 );
656 assert_eq!(regular, with_context);
658 }
659
660 #[test]
661 fn test_branch_info_struct_exists() {
662 let info = BranchInfo {
663 current_branch: "feature".to_string(),
664 upstream_branch: "main".to_string(),
665 current_commits: vec!["abc123 feat: add thing".to_string()],
666 upstream_commits: vec!["def456 fix: bug".to_string()],
667 diverging_count: 5,
668 };
669 assert_eq!(info.current_branch, "feature");
670 assert_eq!(info.diverging_count, 5);
671 }
672
673 #[test]
674 fn test_format_branch_info_section() {
675 let info = BranchInfo {
676 current_branch: "feature".to_string(),
677 upstream_branch: "main".to_string(),
678 current_commits: vec!["abc123 feat: add thing".to_string()],
679 upstream_commits: vec!["def456 fix: bug".to_string()],
680 diverging_count: 5,
681 };
682
683 let section = format_branch_info_section(&info);
684
685 assert!(section.contains("Branch Information"));
686 assert!(section.contains("feature"));
687 assert!(section.contains("main"));
688 assert!(section.contains("5"));
689 assert!(section.contains("abc123"));
690 assert!(section.contains("def456"));
691 }
692
693 #[test]
694 fn test_enhanced_prompt_with_branch_info() {
695 let context = TemplateContext::default();
696 let mut conflicts = HashMap::new();
697 conflicts.insert(
698 "test.rs".to_string(),
699 FileConflict {
700 conflict_content: "conflict".to_string(),
701 current_content: "current".to_string(),
702 },
703 );
704
705 let branch_info = BranchInfo {
706 current_branch: "feature".to_string(),
707 upstream_branch: "main".to_string(),
708 current_commits: vec!["abc123 my change".to_string()],
709 upstream_commits: vec!["def456 their change".to_string()],
710 diverging_count: 3,
711 };
712
713 let prompt = build_enhanced_conflict_resolution_prompt(
714 &context,
715 &conflicts,
716 Some(&branch_info),
717 None,
718 None,
719 );
720
721 assert!(prompt.contains("Branch Information"));
723 assert!(prompt.contains("feature"));
724 assert!(prompt.contains("main"));
725 assert!(prompt.contains("3")); assert!(!prompt.to_lowercase().contains("rebase"));
729 }
730}