1#![deny(unsafe_code)]
13
14use crate::prompts::template_context::TemplateContext;
15use crate::prompts::template_engine::Template;
16use crate::workspace::Workspace;
17use std::collections::HashMap;
18use std::fmt::Write;
19use std::path::Path;
20
21#[derive(Debug, Clone)]
23pub struct FileConflict {
24 pub conflict_content: String,
26 pub current_content: String,
28}
29
30#[cfg(test)]
46pub fn build_conflict_resolution_prompt(
47 conflicts: &HashMap<String, FileConflict>,
48 prompt_md_content: Option<&str>,
49 plan_content: Option<&str>,
50) -> String {
51 let template_content = include_str!("templates/conflict_resolution.txt");
52 let template = Template::new(template_content);
53
54 let context = format_context_section(prompt_md_content, plan_content);
55 let conflicts_section = format_conflicts_section(conflicts);
56
57 let variables = HashMap::from([
58 ("CONTEXT", context),
59 ("CONFLICTS", conflicts_section.clone()),
60 ]);
61
62 template.render(&variables).unwrap_or_else(|e| {
63 eprintln!("Warning: Failed to render conflict resolution template: {e}");
64 let fallback_template_content = include_str!("templates/conflict_resolution_fallback.txt");
66 let fallback_template = Template::new(fallback_template_content);
67 fallback_template.render(&variables).unwrap_or_else(|e| {
68 eprintln!("Critical: Failed to render fallback template: {e}");
69 format!(
71 "# MERGE CONFLICT RESOLUTION\n\nResolve these conflicts:\n\n{}",
72 &conflicts_section
73 )
74 })
75 })
76}
77
78#[must_use]
90pub fn build_conflict_resolution_prompt_with_context<S: std::hash::BuildHasher>(
91 context: &TemplateContext,
92 conflicts: &HashMap<String, FileConflict, S>,
93 prompt_md_content: Option<&str>,
94 plan_content: Option<&str>,
95) -> String {
96 let template_content = context
97 .registry()
98 .get_template("conflict_resolution")
99 .unwrap_or_else(|_| include_str!("templates/conflict_resolution.txt").to_string());
100 let template = Template::new(&template_content);
101
102 let ctx_section = format_context_section(prompt_md_content, plan_content);
103 let conflicts_section = format_conflicts_section(conflicts);
104
105 let variables = HashMap::from([
106 ("CONTEXT", ctx_section),
107 ("CONFLICTS", conflicts_section.clone()),
108 ]);
109
110 template.render(&variables).unwrap_or_else(|e| {
111 eprintln!("Warning: Failed to render conflict resolution template: {e}");
112 let fallback_template_content = context
114 .registry()
115 .get_template("conflict_resolution_fallback")
116 .unwrap_or_else(|_| {
117 include_str!("templates/conflict_resolution_fallback.txt").to_string()
118 });
119 let fallback_template = Template::new(&fallback_template_content);
120 fallback_template.render(&variables).unwrap_or_else(|e| {
121 eprintln!("Critical: Failed to render fallback template: {e}");
122 format!(
124 "# MERGE CONFLICT RESOLUTION\n\nResolve these conflicts:\n\n{}",
125 &conflicts_section
126 )
127 })
128 })
129}
130
131fn format_context_section(prompt_md_content: Option<&str>, plan_content: Option<&str>) -> String {
136 let mut context = String::new();
137
138 if let Some(prompt_md) = prompt_md_content {
140 context.push_str("## Task Context\n\n");
141 context.push_str("The user was working on the following task:\n\n");
142 context.push_str("```\n");
143 context.push_str(prompt_md);
144 context.push_str("\n```\n\n");
145 }
146
147 if let Some(plan) = plan_content {
149 context.push_str("## Implementation Plan\n\n");
150 context.push_str("The following plan was being implemented:\n\n");
151 context.push_str("```\n");
152 context.push_str(plan);
153 context.push_str("\n```\n\n");
154 }
155
156 context
157}
158
159fn format_conflicts_section<S: std::hash::BuildHasher>(
164 conflicts: &HashMap<String, FileConflict, S>,
165) -> String {
166 let mut section = String::new();
167
168 for (path, conflict) in conflicts {
169 writeln!(section, "### {path}\n\n").unwrap();
170 section.push_str("Current state (with conflict markers):\n\n");
171 section.push_str("```");
172 section.push_str(&get_language_marker(path));
173 section.push('\n');
174 section.push_str(&conflict.current_content);
175 section.push_str("\n```\n\n");
176
177 if !conflict.conflict_content.is_empty() {
178 section.push_str("Conflict sections:\n\n");
179 section.push_str("```\n");
180 section.push_str(&conflict.conflict_content);
181 section.push_str("\n```\n\n");
182 }
183 }
184
185 section
186}
187
188fn get_language_marker(path: &str) -> String {
190 let ext = Path::new(path)
191 .extension()
192 .and_then(|e| e.to_str())
193 .unwrap_or("");
194
195 match ext {
196 "rs" => "rust",
197 "py" => "python",
198 "js" | "jsx" => "javascript",
199 "ts" | "tsx" => "typescript",
200 "go" => "go",
201 "java" => "java",
202 "c" | "h" => "c",
203 "cpp" | "hpp" | "cc" | "cxx" => "cpp",
204 "cs" => "csharp",
205 "php" => "php",
206 "rb" => "ruby",
207 "swift" => "swift",
208 "kt" => "kotlin",
209 "scala" => "scala",
210 "sh" | "bash" | "zsh" => "bash",
211 "fish" => "fish",
212 "yaml" | "yml" => "yaml",
213 "json" => "json",
214 "toml" => "toml",
215 "md" | "markdown" => "markdown",
216 "txt" => "text",
217 "html" => "html",
218 "css" | "scss" | "less" => "css",
219 "xml" => "xml",
220 "sql" => "sql",
221 _ => "",
222 }
223 .to_string()
224}
225
226#[derive(Debug, Clone)]
228pub struct BranchInfo {
229 pub current_branch: String,
231 pub upstream_branch: String,
233 pub current_commits: Vec<String>,
235 pub upstream_commits: Vec<String>,
237 pub diverging_count: usize,
239}
240
241#[must_use]
254pub fn build_enhanced_conflict_resolution_prompt<S: std::hash::BuildHasher>(
255 context: &TemplateContext,
256 conflicts: &HashMap<String, FileConflict, S>,
257 branch_info: Option<&BranchInfo>,
258 prompt_md_content: Option<&str>,
259 plan_content: Option<&str>,
260) -> String {
261 let template_content = context
262 .registry()
263 .get_template("conflict_resolution")
264 .unwrap_or_else(|_| include_str!("templates/conflict_resolution.txt").to_string());
265 let template = Template::new(&template_content);
266
267 let mut ctx_section = format_context_section(prompt_md_content, plan_content);
268
269 if let Some(info) = branch_info {
271 ctx_section.push_str(&format_branch_info_section(info));
272 }
273
274 let conflicts_section = format_conflicts_section(conflicts);
275
276 let variables = HashMap::from([
277 ("CONTEXT", ctx_section),
278 ("CONFLICTS", conflicts_section.clone()),
279 ]);
280
281 template.render(&variables).unwrap_or_else(|e| {
282 eprintln!("Warning: Failed to render conflict resolution template: {e}");
283 let fallback_template_content = context
285 .registry()
286 .get_template("conflict_resolution_fallback")
287 .unwrap_or_else(|_| {
288 include_str!("templates/conflict_resolution_fallback.txt").to_string()
289 });
290 let fallback_template = Template::new(&fallback_template_content);
291 fallback_template.render(&variables).unwrap_or_else(|e| {
292 eprintln!("Critical: Failed to render fallback template: {e}");
293 format!(
295 "# MERGE CONFLICT RESOLUTION\n\nResolve these conflicts:\n\n{}",
296 &conflicts_section
297 )
298 })
299 })
300}
301
302fn format_branch_info_section(info: &BranchInfo) -> String {
307 let mut section = String::new();
308
309 section.push_str("## Branch Information\n\n");
310 writeln!(section, "- **Current branch**: `{}`", info.current_branch).unwrap();
311 writeln!(section, "- **Target branch**: `{}`", info.upstream_branch).unwrap();
312 writeln!(
313 section,
314 "- **Diverging commits**: {}\n",
315 info.diverging_count
316 )
317 .unwrap();
318
319 if !info.current_commits.is_empty() {
320 section.push_str("### Recent commits on current branch:\n\n");
321 for (i, msg) in info.current_commits.iter().enumerate().take(5) {
322 writeln!(section, "{}. {}", i + 1, msg).unwrap();
323 }
324 section.push('\n');
325 }
326
327 if !info.upstream_commits.is_empty() {
328 section.push_str("### Recent commits on target branch:\n\n");
329 for (i, msg) in info.upstream_commits.iter().enumerate().take(5) {
330 writeln!(section, "{}. {}", i + 1, msg).unwrap();
331 }
332 section.push('\n');
333 }
334
335 section
336}
337
338pub fn collect_branch_info(
355 upstream_branch: &str,
356 executor: &dyn crate::executor::ProcessExecutor,
357) -> std::io::Result<BranchInfo> {
358 let current_branch =
360 executor.execute("git", &["rev-parse", "--abbrev-ref", "HEAD"], &[], None)?;
361
362 let current_branch = current_branch.stdout.trim().to_string();
363
364 let current_log = executor.execute("git", &["log", "--oneline", "-10", "HEAD"], &[], None)?;
366
367 let current_commits: Vec<String> = current_log
368 .stdout
369 .lines()
370 .map(std::string::ToString::to_string)
371 .collect();
372
373 let upstream_log = executor.execute(
375 "git",
376 &["log", "--oneline", "-10", upstream_branch],
377 &[],
378 None,
379 )?;
380
381 let upstream_commits: Vec<String> = upstream_log
382 .stdout
383 .lines()
384 .map(std::string::ToString::to_string)
385 .collect();
386
387 let diverging = executor.execute(
389 "git",
390 &[
391 "rev-list",
392 "--count",
393 "--left-right",
394 &format!("HEAD...{upstream_branch}"),
395 ],
396 &[],
397 None,
398 )?;
399
400 let diverging_count = diverging
401 .stdout
402 .split_whitespace()
403 .map(|s| s.parse::<usize>().unwrap_or(0))
404 .sum::<usize>();
405
406 Ok(BranchInfo {
407 current_branch,
408 upstream_branch: upstream_branch.to_string(),
409 current_commits,
410 upstream_commits,
411 diverging_count,
412 })
413}
414
415pub fn collect_conflict_info_with_workspace(
433 workspace: &dyn Workspace,
434 conflicted_paths: &[String],
435) -> std::io::Result<HashMap<String, FileConflict>> {
436 let mut conflicts = HashMap::new();
437
438 for path in conflicted_paths {
439 let current_content = workspace.read(Path::new(path))?;
440 let conflict_content = extract_conflict_sections_from_content(¤t_content);
441
442 conflicts.insert(
443 path.clone(),
444 FileConflict {
445 conflict_content,
446 current_content,
447 },
448 );
449 }
450
451 Ok(conflicts)
452}
453
454fn extract_conflict_sections_from_content(content: &str) -> String {
455 let mut conflict_sections = Vec::new();
456 let lines: Vec<&str> = content.lines().collect();
457 let mut i = 0;
458
459 while i < lines.len() {
460 if lines[i].trim_start().starts_with("<<<<<<<") {
461 let mut section = Vec::new();
462 section.push(lines[i]);
463
464 i += 1;
465 while i < lines.len() && !lines[i].trim_start().starts_with("=======") {
466 section.push(lines[i]);
467 i += 1;
468 }
469
470 if i < lines.len() {
471 section.push(lines[i]);
472 i += 1;
473 }
474
475 while i < lines.len() && !lines[i].trim_start().starts_with(">>>>>>>") {
476 section.push(lines[i]);
477 i += 1;
478 }
479
480 if i < lines.len() {
481 section.push(lines[i]);
482 i += 1;
483 }
484
485 conflict_sections.push(section.join("\n"));
486 } else {
487 i += 1;
488 }
489 }
490
491 if conflict_sections.is_empty() {
492 String::new()
493 } else {
494 conflict_sections.join("\n\n")
495 }
496}
497
498#[cfg(test)]
499mod tests;