Skip to main content

ralph_workflow/prompts/
rebase.rs

1//! Rebase conflict resolution prompts.
2//!
3//! This module provides prompts for AI agents to resolve merge conflicts
4//! that occur during rebase operations.
5//!
6//! # Design Note
7//!
8//! Per project requirements, AI agents should NOT know that we are in the
9//! middle of a rebase. The prompt frames conflicts as "merge conflicts between
10//! two versions" without mentioning rebase or rebasing.
11
12#![deny(unsafe_code)]
13
14use crate::prompts::template_context::TemplateContext;
15use crate::prompts::template_engine::Template;
16use crate::workspace::Workspace;
17use std::collections::HashMap;
18use std::fmt::Write;
19use std::path::Path;
20
21/// Structure representing a single file conflict.
22#[derive(Debug, Clone)]
23pub struct FileConflict {
24    /// The conflict marker content from the file
25    pub conflict_content: String,
26    /// The current file content with conflict markers
27    pub current_content: String,
28}
29
30/// Build a conflict resolution prompt for the AI agent.
31///
32/// This function generates a prompt that instructs the AI agent to resolve
33/// merge conflicts. The prompt does NOT mention "rebase" - it frames the
34/// task as resolving merge conflicts between two versions.
35///
36/// # Arguments
37///
38/// * `conflicts` - Map of file paths to their conflict information
39/// * `prompt_md_content` - Optional content from PROMPT.md for task context
40/// * `plan_content` - Optional content from PLAN.md for additional context
41///
42/// # Returns
43///
44/// Returns a formatted prompt string for the AI agent.
45#[cfg(test)]
46pub fn build_conflict_resolution_prompt(
47    conflicts: &HashMap<String, FileConflict>,
48    prompt_md_content: Option<&str>,
49    plan_content: Option<&str>,
50) -> String {
51    let template_content = include_str!("templates/conflict_resolution.txt");
52    let template = Template::new(template_content);
53
54    let context = format_context_section(prompt_md_content, plan_content);
55    let conflicts_section = format_conflicts_section(conflicts);
56
57    let variables = HashMap::from([
58        ("CONTEXT", context),
59        ("CONFLICTS", conflicts_section.clone()),
60    ]);
61
62    template.render(&variables).unwrap_or_else(|e| {
63        eprintln!("Warning: Failed to render conflict resolution template: {e}");
64        // Use fallback template
65        let fallback_template_content = include_str!("templates/conflict_resolution_fallback.txt");
66        let fallback_template = Template::new(fallback_template_content);
67        fallback_template.render(&variables).unwrap_or_else(|e| {
68            eprintln!("Critical: Failed to render fallback template: {e}");
69            // Last resort: minimal emergency prompt - conflicts_section is captured from closure
70            format!(
71                "# MERGE CONFLICT RESOLUTION\n\nResolve these conflicts:\n\n{}",
72                &conflicts_section
73            )
74        })
75    })
76}
77
78/// Build a conflict resolution prompt using template registry.
79///
80/// This version uses the template registry which supports user template overrides.
81/// It's the recommended way to generate prompts going forward.
82///
83/// # Arguments
84///
85/// * `context` - Template context containing the template registry
86/// * `conflicts` - Map of file paths to their conflict information
87/// * `prompt_md_content` - Optional content from PROMPT.md for task context
88/// * `plan_content` - Optional content from PLAN.md for additional context
89pub fn build_conflict_resolution_prompt_with_context(
90    context: &TemplateContext,
91    conflicts: &HashMap<String, FileConflict>,
92    prompt_md_content: Option<&str>,
93    plan_content: Option<&str>,
94) -> String {
95    let template_content = context
96        .registry()
97        .get_template("conflict_resolution")
98        .unwrap_or_else(|_| include_str!("templates/conflict_resolution.txt").to_string());
99    let template = Template::new(&template_content);
100
101    let ctx_section = format_context_section(prompt_md_content, plan_content);
102    let conflicts_section = format_conflicts_section(conflicts);
103
104    let variables = HashMap::from([
105        ("CONTEXT", ctx_section),
106        ("CONFLICTS", conflicts_section.clone()),
107    ]);
108
109    template.render(&variables).unwrap_or_else(|e| {
110        eprintln!("Warning: Failed to render conflict resolution template: {e}");
111        // Use fallback template
112        let fallback_template_content = context
113            .registry()
114            .get_template("conflict_resolution_fallback")
115            .unwrap_or_else(|_| {
116                include_str!("templates/conflict_resolution_fallback.txt").to_string()
117            });
118        let fallback_template = Template::new(&fallback_template_content);
119        fallback_template.render(&variables).unwrap_or_else(|e| {
120            eprintln!("Critical: Failed to render fallback template: {e}");
121            // Last resort: minimal emergency prompt - conflicts_section is captured from closure
122            format!(
123                "# MERGE CONFLICT RESOLUTION\n\nResolve these conflicts:\n\n{}",
124                &conflicts_section
125            )
126        })
127    })
128}
129
130/// Format the context section with PROMPT.md and PLAN.md content.
131///
132/// This helper builds the context section that gets injected into the
133/// {{CONTEXT}} template variable.
134fn format_context_section(prompt_md_content: Option<&str>, plan_content: Option<&str>) -> String {
135    let mut context = String::new();
136
137    // Add task context from PROMPT.md if available
138    if let Some(prompt_md) = prompt_md_content {
139        context.push_str("## Task Context\n\n");
140        context.push_str("The user was working on the following task:\n\n");
141        context.push_str("```\n");
142        context.push_str(prompt_md);
143        context.push_str("\n```\n\n");
144    }
145
146    // Add plan context from PLAN.md if available
147    if let Some(plan) = plan_content {
148        context.push_str("## Implementation Plan\n\n");
149        context.push_str("The following plan was being implemented:\n\n");
150        context.push_str("```\n");
151        context.push_str(plan);
152        context.push_str("\n```\n\n");
153    }
154
155    context
156}
157
158/// Format the conflicts section for all conflicted files.
159///
160/// This helper builds the conflicts section that gets injected into the
161/// {{CONFLICTS}} template variable.
162fn format_conflicts_section(conflicts: &HashMap<String, FileConflict>) -> String {
163    let mut section = String::new();
164
165    for (path, conflict) in conflicts {
166        writeln!(section, "### {path}\n\n").unwrap();
167        section.push_str("Current state (with conflict markers):\n\n");
168        section.push_str("```");
169        section.push_str(&get_language_marker(path));
170        section.push('\n');
171        section.push_str(&conflict.current_content);
172        section.push_str("\n```\n\n");
173
174        if !conflict.conflict_content.is_empty() {
175            section.push_str("Conflict sections:\n\n");
176            section.push_str("```\n");
177            section.push_str(&conflict.conflict_content);
178            section.push_str("\n```\n\n");
179        }
180    }
181
182    section
183}
184
185/// Get a language marker for syntax highlighting based on file extension.
186fn get_language_marker(path: &str) -> String {
187    let ext = Path::new(path)
188        .extension()
189        .and_then(|e| e.to_str())
190        .unwrap_or("");
191
192    match ext {
193        "rs" => "rust",
194        "py" => "python",
195        "js" | "jsx" => "javascript",
196        "ts" | "tsx" => "typescript",
197        "go" => "go",
198        "java" => "java",
199        "c" | "h" => "c",
200        "cpp" | "hpp" | "cc" | "cxx" => "cpp",
201        "cs" => "csharp",
202        "php" => "php",
203        "rb" => "ruby",
204        "swift" => "swift",
205        "kt" => "kotlin",
206        "scala" => "scala",
207        "sh" | "bash" | "zsh" => "bash",
208        "fish" => "fish",
209        "yaml" | "yml" => "yaml",
210        "json" => "json",
211        "toml" => "toml",
212        "md" | "markdown" => "markdown",
213        "txt" => "text",
214        "html" => "html",
215        "css" | "scss" | "less" => "css",
216        "xml" => "xml",
217        "sql" => "sql",
218        _ => "",
219    }
220    .to_string()
221}
222
223/// Information about divergent branches for enhanced conflict resolution.
224#[derive(Debug, Clone)]
225pub struct BranchInfo {
226    /// The current branch name
227    pub current_branch: String,
228    /// The upstream/target branch name
229    pub upstream_branch: String,
230    /// Recent commit messages from current branch
231    pub current_commits: Vec<String>,
232    /// Recent commit messages from upstream branch
233    pub upstream_commits: Vec<String>,
234    /// Number of diverging commits
235    pub diverging_count: usize,
236}
237
238/// Build a conflict resolution prompt with enhanced branch context.
239///
240/// This version provides richer context about the branches involved in the conflict,
241/// including recent commit history and divergence information.
242///
243/// # Arguments
244///
245/// * `context` - Template context containing the template registry
246/// * `conflicts` - Map of file paths to their conflict information
247/// * `branch_info` - Optional branch information for enhanced context
248/// * `prompt_md_content` - Optional content from PROMPT.md for task context
249/// * `plan_content` - Optional content from PLAN.md for additional context
250pub fn build_enhanced_conflict_resolution_prompt(
251    context: &TemplateContext,
252    conflicts: &HashMap<String, FileConflict>,
253    branch_info: Option<&BranchInfo>,
254    prompt_md_content: Option<&str>,
255    plan_content: Option<&str>,
256) -> String {
257    let template_content = context
258        .registry()
259        .get_template("conflict_resolution")
260        .unwrap_or_else(|_| include_str!("templates/conflict_resolution.txt").to_string());
261    let template = Template::new(&template_content);
262
263    let mut ctx_section = format_context_section(prompt_md_content, plan_content);
264
265    // Add branch information if available
266    if let Some(info) = branch_info {
267        ctx_section.push_str(&format_branch_info_section(info));
268    }
269
270    let conflicts_section = format_conflicts_section(conflicts);
271
272    let variables = HashMap::from([
273        ("CONTEXT", ctx_section),
274        ("CONFLICTS", conflicts_section.clone()),
275    ]);
276
277    template.render(&variables).unwrap_or_else(|e| {
278        eprintln!("Warning: Failed to render conflict resolution template: {e}");
279        // Use fallback template
280        let fallback_template_content = context
281            .registry()
282            .get_template("conflict_resolution_fallback")
283            .unwrap_or_else(|_| {
284                include_str!("templates/conflict_resolution_fallback.txt").to_string()
285            });
286        let fallback_template = Template::new(&fallback_template_content);
287        fallback_template.render(&variables).unwrap_or_else(|e| {
288            eprintln!("Critical: Failed to render fallback template: {e}");
289            // Last resort: minimal emergency prompt - conflicts_section is captured from closure
290            format!(
291                "# MERGE CONFLICT RESOLUTION\n\nResolve these conflicts:\n\n{}",
292                &conflicts_section
293            )
294        })
295    })
296}
297
298/// Format branch information for context section.
299///
300/// This helper builds a branch information section that gets injected
301/// into the context for AI conflict resolution.
302fn format_branch_info_section(info: &BranchInfo) -> String {
303    let mut section = String::new();
304
305    section.push_str("## Branch Information\n\n");
306    section.push_str(&format!(
307        "- **Current branch**: `{}`\n",
308        info.current_branch
309    ));
310    section.push_str(&format!(
311        "- **Target branch**: `{}`\n",
312        info.upstream_branch
313    ));
314    section.push_str(&format!(
315        "- **Diverging commits**: {}\n\n",
316        info.diverging_count
317    ));
318
319    if !info.current_commits.is_empty() {
320        section.push_str("### Recent commits on current branch:\n\n");
321        for (i, msg) in info.current_commits.iter().enumerate().take(5) {
322            section.push_str(&format!("{}. {}\n", i + 1, msg));
323        }
324        section.push('\n');
325    }
326
327    if !info.upstream_commits.is_empty() {
328        section.push_str("### Recent commits on target branch:\n\n");
329        for (i, msg) in info.upstream_commits.iter().enumerate().take(5) {
330            section.push_str(&format!("{}. {}\n", i + 1, msg));
331        }
332        section.push('\n');
333    }
334
335    section
336}
337
338/// Collect branch information for conflict resolution.
339///
340/// Queries git to gather information about the branches involved in the conflict.
341///
342/// # Arguments
343///
344/// * `upstream_branch` - The name of the upstream/target branch
345/// * `executor` - Process executor for external process execution
346///
347/// # Returns
348///
349/// Returns `Ok(BranchInfo)` with the gathered information, or an error if git operations fail.
350pub fn collect_branch_info(
351    upstream_branch: &str,
352    executor: &dyn crate::executor::ProcessExecutor,
353) -> std::io::Result<BranchInfo> {
354    // Get current branch name
355    let current_branch =
356        executor.execute("git", &["rev-parse", "--abbrev-ref", "HEAD"], &[], None)?;
357
358    let current_branch = current_branch.stdout.trim().to_string();
359
360    // Get recent commits from current branch
361    let current_log = executor.execute("git", &["log", "--oneline", "-10", "HEAD"], &[], None)?;
362
363    let current_commits: Vec<String> = current_log.stdout.lines().map(|s| s.to_string()).collect();
364
365    // Get recent commits from upstream branch
366    let upstream_log = executor.execute(
367        "git",
368        &["log", "--oneline", "-10", upstream_branch],
369        &[],
370        None,
371    )?;
372
373    let upstream_commits: Vec<String> =
374        upstream_log.stdout.lines().map(|s| s.to_string()).collect();
375
376    // Count diverging commits
377    let diverging = executor.execute(
378        "git",
379        &[
380            "rev-list",
381            "--count",
382            "--left-right",
383            &format!("HEAD...{upstream_branch}"),
384        ],
385        &[],
386        None,
387    )?;
388
389    let diverging_count = diverging
390        .stdout
391        .split_whitespace()
392        .map(|s| s.parse::<usize>().unwrap_or(0))
393        .sum::<usize>();
394
395    Ok(BranchInfo {
396        current_branch,
397        upstream_branch: upstream_branch.to_string(),
398        current_commits,
399        upstream_commits,
400        diverging_count,
401    })
402}
403
404/// Collect conflict information from all conflicted files.
405///
406/// This function reads all conflicted files and builds a map of
407/// file paths to their conflict information.
408///
409/// # Arguments
410///
411/// * `conflicted_paths` - List of paths to conflicted files
412///
413/// # Returns
414///
415/// Returns `Ok(HashMap)` mapping file paths to conflict information,
416/// or an error if a file cannot be read.
417pub fn collect_conflict_info_with_workspace(
418    workspace: &dyn Workspace,
419    conflicted_paths: &[String],
420) -> std::io::Result<HashMap<String, FileConflict>> {
421    let mut conflicts = HashMap::new();
422
423    for path in conflicted_paths {
424        let current_content = workspace.read(Path::new(path))?;
425        let conflict_content = extract_conflict_sections_from_content(&current_content);
426
427        conflicts.insert(
428            path.clone(),
429            FileConflict {
430                conflict_content,
431                current_content,
432            },
433        );
434    }
435
436    Ok(conflicts)
437}
438
439fn extract_conflict_sections_from_content(content: &str) -> String {
440    let mut conflict_sections = Vec::new();
441    let lines: Vec<&str> = content.lines().collect();
442    let mut i = 0;
443
444    while i < lines.len() {
445        if lines[i].trim_start().starts_with("<<<<<<<") {
446            let mut section = Vec::new();
447            section.push(lines[i]);
448
449            i += 1;
450            while i < lines.len() && !lines[i].trim_start().starts_with("=======") {
451                section.push(lines[i]);
452                i += 1;
453            }
454
455            if i < lines.len() {
456                section.push(lines[i]);
457                i += 1;
458            }
459
460            while i < lines.len() && !lines[i].trim_start().starts_with(">>>>>>>") {
461                section.push(lines[i]);
462                i += 1;
463            }
464
465            if i < lines.len() {
466                section.push(lines[i]);
467                i += 1;
468            }
469
470            conflict_sections.push(section.join("\n"));
471        } else {
472            i += 1;
473        }
474    }
475
476    if conflict_sections.is_empty() {
477        String::new()
478    } else {
479        conflict_sections.join("\n\n")
480    }
481}
482
483#[cfg(test)]
484mod tests;