1#![deny(unsafe_code)]
13
14use crate::prompts::template_context::TemplateContext;
15use crate::prompts::template_engine::Template;
16use crate::workspace::Workspace;
17use std::collections::HashMap;
18use std::fmt::Write;
19use std::path::Path;
20
21#[derive(Debug, Clone)]
23pub struct FileConflict {
24 pub conflict_content: String,
26 pub current_content: String,
28}
29
30#[cfg(test)]
46pub fn build_conflict_resolution_prompt(
47 conflicts: &HashMap<String, FileConflict>,
48 prompt_md_content: Option<&str>,
49 plan_content: Option<&str>,
50) -> String {
51 let template_content = include_str!("templates/conflict_resolution.txt");
52 let template = Template::new(template_content);
53
54 let context = format_context_section(prompt_md_content, plan_content);
55 let conflicts_section = format_conflicts_section(conflicts);
56
57 let variables = HashMap::from([
58 ("CONTEXT", context),
59 ("CONFLICTS", conflicts_section.clone()),
60 ]);
61
62 template.render(&variables).unwrap_or_else(|e| {
63 eprintln!("Warning: Failed to render conflict resolution template: {e}");
64 let fallback_template_content = include_str!("templates/conflict_resolution_fallback.txt");
66 let fallback_template = Template::new(fallback_template_content);
67 fallback_template.render(&variables).unwrap_or_else(|e| {
68 eprintln!("Critical: Failed to render fallback template: {e}");
69 format!(
71 "# MERGE CONFLICT RESOLUTION\n\nResolve these conflicts:\n\n{}",
72 &conflicts_section
73 )
74 })
75 })
76}
77
78pub fn build_conflict_resolution_prompt_with_context(
90 context: &TemplateContext,
91 conflicts: &HashMap<String, FileConflict>,
92 prompt_md_content: Option<&str>,
93 plan_content: Option<&str>,
94) -> String {
95 let template_content = context
96 .registry()
97 .get_template("conflict_resolution")
98 .unwrap_or_else(|_| include_str!("templates/conflict_resolution.txt").to_string());
99 let template = Template::new(&template_content);
100
101 let ctx_section = format_context_section(prompt_md_content, plan_content);
102 let conflicts_section = format_conflicts_section(conflicts);
103
104 let variables = HashMap::from([
105 ("CONTEXT", ctx_section),
106 ("CONFLICTS", conflicts_section.clone()),
107 ]);
108
109 template.render(&variables).unwrap_or_else(|e| {
110 eprintln!("Warning: Failed to render conflict resolution template: {e}");
111 let fallback_template_content = context
113 .registry()
114 .get_template("conflict_resolution_fallback")
115 .unwrap_or_else(|_| {
116 include_str!("templates/conflict_resolution_fallback.txt").to_string()
117 });
118 let fallback_template = Template::new(&fallback_template_content);
119 fallback_template.render(&variables).unwrap_or_else(|e| {
120 eprintln!("Critical: Failed to render fallback template: {e}");
121 format!(
123 "# MERGE CONFLICT RESOLUTION\n\nResolve these conflicts:\n\n{}",
124 &conflicts_section
125 )
126 })
127 })
128}
129
130fn format_context_section(prompt_md_content: Option<&str>, plan_content: Option<&str>) -> String {
135 let mut context = String::new();
136
137 if let Some(prompt_md) = prompt_md_content {
139 context.push_str("## Task Context\n\n");
140 context.push_str("The user was working on the following task:\n\n");
141 context.push_str("```\n");
142 context.push_str(prompt_md);
143 context.push_str("\n```\n\n");
144 }
145
146 if let Some(plan) = plan_content {
148 context.push_str("## Implementation Plan\n\n");
149 context.push_str("The following plan was being implemented:\n\n");
150 context.push_str("```\n");
151 context.push_str(plan);
152 context.push_str("\n```\n\n");
153 }
154
155 context
156}
157
158fn format_conflicts_section(conflicts: &HashMap<String, FileConflict>) -> String {
163 let mut section = String::new();
164
165 for (path, conflict) in conflicts {
166 writeln!(section, "### {path}\n\n").unwrap();
167 section.push_str("Current state (with conflict markers):\n\n");
168 section.push_str("```");
169 section.push_str(&get_language_marker(path));
170 section.push('\n');
171 section.push_str(&conflict.current_content);
172 section.push_str("\n```\n\n");
173
174 if !conflict.conflict_content.is_empty() {
175 section.push_str("Conflict sections:\n\n");
176 section.push_str("```\n");
177 section.push_str(&conflict.conflict_content);
178 section.push_str("\n```\n\n");
179 }
180 }
181
182 section
183}
184
185fn get_language_marker(path: &str) -> String {
187 let ext = Path::new(path)
188 .extension()
189 .and_then(|e| e.to_str())
190 .unwrap_or("");
191
192 match ext {
193 "rs" => "rust",
194 "py" => "python",
195 "js" | "jsx" => "javascript",
196 "ts" | "tsx" => "typescript",
197 "go" => "go",
198 "java" => "java",
199 "c" | "h" => "c",
200 "cpp" | "hpp" | "cc" | "cxx" => "cpp",
201 "cs" => "csharp",
202 "php" => "php",
203 "rb" => "ruby",
204 "swift" => "swift",
205 "kt" => "kotlin",
206 "scala" => "scala",
207 "sh" | "bash" | "zsh" => "bash",
208 "fish" => "fish",
209 "yaml" | "yml" => "yaml",
210 "json" => "json",
211 "toml" => "toml",
212 "md" | "markdown" => "markdown",
213 "txt" => "text",
214 "html" => "html",
215 "css" | "scss" | "less" => "css",
216 "xml" => "xml",
217 "sql" => "sql",
218 _ => "",
219 }
220 .to_string()
221}
222
223#[derive(Debug, Clone)]
225pub struct BranchInfo {
226 pub current_branch: String,
228 pub upstream_branch: String,
230 pub current_commits: Vec<String>,
232 pub upstream_commits: Vec<String>,
234 pub diverging_count: usize,
236}
237
238pub fn build_enhanced_conflict_resolution_prompt(
251 context: &TemplateContext,
252 conflicts: &HashMap<String, FileConflict>,
253 branch_info: Option<&BranchInfo>,
254 prompt_md_content: Option<&str>,
255 plan_content: Option<&str>,
256) -> String {
257 let template_content = context
258 .registry()
259 .get_template("conflict_resolution")
260 .unwrap_or_else(|_| include_str!("templates/conflict_resolution.txt").to_string());
261 let template = Template::new(&template_content);
262
263 let mut ctx_section = format_context_section(prompt_md_content, plan_content);
264
265 if let Some(info) = branch_info {
267 ctx_section.push_str(&format_branch_info_section(info));
268 }
269
270 let conflicts_section = format_conflicts_section(conflicts);
271
272 let variables = HashMap::from([
273 ("CONTEXT", ctx_section),
274 ("CONFLICTS", conflicts_section.clone()),
275 ]);
276
277 template.render(&variables).unwrap_or_else(|e| {
278 eprintln!("Warning: Failed to render conflict resolution template: {e}");
279 let fallback_template_content = context
281 .registry()
282 .get_template("conflict_resolution_fallback")
283 .unwrap_or_else(|_| {
284 include_str!("templates/conflict_resolution_fallback.txt").to_string()
285 });
286 let fallback_template = Template::new(&fallback_template_content);
287 fallback_template.render(&variables).unwrap_or_else(|e| {
288 eprintln!("Critical: Failed to render fallback template: {e}");
289 format!(
291 "# MERGE CONFLICT RESOLUTION\n\nResolve these conflicts:\n\n{}",
292 &conflicts_section
293 )
294 })
295 })
296}
297
298fn format_branch_info_section(info: &BranchInfo) -> String {
303 let mut section = String::new();
304
305 section.push_str("## Branch Information\n\n");
306 section.push_str(&format!(
307 "- **Current branch**: `{}`\n",
308 info.current_branch
309 ));
310 section.push_str(&format!(
311 "- **Target branch**: `{}`\n",
312 info.upstream_branch
313 ));
314 section.push_str(&format!(
315 "- **Diverging commits**: {}\n\n",
316 info.diverging_count
317 ));
318
319 if !info.current_commits.is_empty() {
320 section.push_str("### Recent commits on current branch:\n\n");
321 for (i, msg) in info.current_commits.iter().enumerate().take(5) {
322 section.push_str(&format!("{}. {}\n", i + 1, msg));
323 }
324 section.push('\n');
325 }
326
327 if !info.upstream_commits.is_empty() {
328 section.push_str("### Recent commits on target branch:\n\n");
329 for (i, msg) in info.upstream_commits.iter().enumerate().take(5) {
330 section.push_str(&format!("{}. {}\n", i + 1, msg));
331 }
332 section.push('\n');
333 }
334
335 section
336}
337
338pub fn collect_branch_info(
351 upstream_branch: &str,
352 executor: &dyn crate::executor::ProcessExecutor,
353) -> std::io::Result<BranchInfo> {
354 let current_branch =
356 executor.execute("git", &["rev-parse", "--abbrev-ref", "HEAD"], &[], None)?;
357
358 let current_branch = current_branch.stdout.trim().to_string();
359
360 let current_log = executor.execute("git", &["log", "--oneline", "-10", "HEAD"], &[], None)?;
362
363 let current_commits: Vec<String> = current_log.stdout.lines().map(|s| s.to_string()).collect();
364
365 let upstream_log = executor.execute(
367 "git",
368 &["log", "--oneline", "-10", upstream_branch],
369 &[],
370 None,
371 )?;
372
373 let upstream_commits: Vec<String> =
374 upstream_log.stdout.lines().map(|s| s.to_string()).collect();
375
376 let diverging = executor.execute(
378 "git",
379 &[
380 "rev-list",
381 "--count",
382 "--left-right",
383 &format!("HEAD...{upstream_branch}"),
384 ],
385 &[],
386 None,
387 )?;
388
389 let diverging_count = diverging
390 .stdout
391 .split_whitespace()
392 .map(|s| s.parse::<usize>().unwrap_or(0))
393 .sum::<usize>();
394
395 Ok(BranchInfo {
396 current_branch,
397 upstream_branch: upstream_branch.to_string(),
398 current_commits,
399 upstream_commits,
400 diverging_count,
401 })
402}
403
404pub fn collect_conflict_info_with_workspace(
418 workspace: &dyn Workspace,
419 conflicted_paths: &[String],
420) -> std::io::Result<HashMap<String, FileConflict>> {
421 let mut conflicts = HashMap::new();
422
423 for path in conflicted_paths {
424 let current_content = workspace.read(Path::new(path))?;
425 let conflict_content = extract_conflict_sections_from_content(¤t_content);
426
427 conflicts.insert(
428 path.clone(),
429 FileConflict {
430 conflict_content,
431 current_content,
432 },
433 );
434 }
435
436 Ok(conflicts)
437}
438
439fn extract_conflict_sections_from_content(content: &str) -> String {
440 let mut conflict_sections = Vec::new();
441 let lines: Vec<&str> = content.lines().collect();
442 let mut i = 0;
443
444 while i < lines.len() {
445 if lines[i].trim_start().starts_with("<<<<<<<") {
446 let mut section = Vec::new();
447 section.push(lines[i]);
448
449 i += 1;
450 while i < lines.len() && !lines[i].trim_start().starts_with("=======") {
451 section.push(lines[i]);
452 i += 1;
453 }
454
455 if i < lines.len() {
456 section.push(lines[i]);
457 i += 1;
458 }
459
460 while i < lines.len() && !lines[i].trim_start().starts_with(">>>>>>>") {
461 section.push(lines[i]);
462 i += 1;
463 }
464
465 if i < lines.len() {
466 section.push(lines[i]);
467 i += 1;
468 }
469
470 conflict_sections.push(section.join("\n"));
471 } else {
472 i += 1;
473 }
474 }
475
476 if conflict_sections.is_empty() {
477 String::new()
478 } else {
479 conflict_sections.join("\n\n")
480 }
481}
482
483#[cfg(test)]
484mod tests;