Skip to main content

ralph_workflow/prompts/
rebase.rs

1//! Rebase conflict resolution prompts.
2//!
3//! This module provides prompts for AI agents to resolve merge conflicts
4//! that occur during rebase operations.
5//!
6//! # Design Note
7//!
8//! Per project requirements, AI agents should NOT know that we are in the
9//! middle of a rebase. The prompt frames conflicts as "merge conflicts between
10//! two versions" without mentioning rebase or rebasing.
11
12#![deny(unsafe_code)]
13
14use crate::prompts::template_context::TemplateContext;
15use crate::prompts::template_engine::Template;
16use crate::workspace::Workspace;
17use std::collections::HashMap;
18use std::fmt::Write;
19use std::path::Path;
20
21/// Structure representing a single file conflict.
22#[derive(Debug, Clone)]
23pub struct FileConflict {
24    /// The conflict marker content from the file
25    pub conflict_content: String,
26    /// The current file content with conflict markers
27    pub current_content: String,
28}
29
30/// Build a conflict resolution prompt for the AI agent.
31///
32/// This function generates a prompt that instructs the AI agent to resolve
33/// merge conflicts. The prompt does NOT mention "rebase" - it frames the
34/// task as resolving merge conflicts between two versions.
35///
36/// # Arguments
37///
38/// * `conflicts` - Map of file paths to their conflict information
39/// * `prompt_md_content` - Optional content from PROMPT.md for task context
40/// * `plan_content` - Optional content from PLAN.md for additional context
41///
42/// # Returns
43///
44/// Returns a formatted prompt string for the AI agent.
45#[cfg(test)]
46pub fn build_conflict_resolution_prompt(
47    conflicts: &HashMap<String, FileConflict>,
48    prompt_md_content: Option<&str>,
49    plan_content: Option<&str>,
50) -> String {
51    let template_content = include_str!("templates/conflict_resolution.txt");
52    let template = Template::new(template_content);
53
54    let context = format_context_section(prompt_md_content, plan_content);
55    let conflicts_section = format_conflicts_section(conflicts);
56
57    let variables = HashMap::from([
58        ("CONTEXT", context),
59        ("CONFLICTS", conflicts_section.clone()),
60    ]);
61
62    template.render(&variables).unwrap_or_else(|e| {
63        eprintln!("Warning: Failed to render conflict resolution template: {e}");
64        // Use fallback template
65        let fallback_template_content = include_str!("templates/conflict_resolution_fallback.txt");
66        let fallback_template = Template::new(fallback_template_content);
67        fallback_template.render(&variables).unwrap_or_else(|e| {
68            eprintln!("Critical: Failed to render fallback template: {e}");
69            // Last resort: minimal emergency prompt - conflicts_section is captured from closure
70            format!(
71                "# MERGE CONFLICT RESOLUTION\n\nResolve these conflicts:\n\n{}",
72                &conflicts_section
73            )
74        })
75    })
76}
77
78/// Build a conflict resolution prompt using template registry.
79///
80/// This version uses the template registry which supports user template overrides.
81/// It's the recommended way to generate prompts going forward.
82///
83/// # Arguments
84///
85/// * `context` - Template context containing the template registry
86/// * `conflicts` - Map of file paths to their conflict information
87/// * `prompt_md_content` - Optional content from PROMPT.md for task context
88/// * `plan_content` - Optional content from PLAN.md for additional context
89#[must_use]
90pub fn build_conflict_resolution_prompt_with_context<S: std::hash::BuildHasher>(
91    context: &TemplateContext,
92    conflicts: &HashMap<String, FileConflict, S>,
93    prompt_md_content: Option<&str>,
94    plan_content: Option<&str>,
95) -> String {
96    let template_content = context
97        .registry()
98        .get_template("conflict_resolution")
99        .unwrap_or_else(|_| include_str!("templates/conflict_resolution.txt").to_string());
100    let template = Template::new(&template_content);
101
102    let ctx_section = format_context_section(prompt_md_content, plan_content);
103    let conflicts_section = format_conflicts_section(conflicts);
104
105    let variables = HashMap::from([
106        ("CONTEXT", ctx_section),
107        ("CONFLICTS", conflicts_section.clone()),
108    ]);
109
110    template.render(&variables).unwrap_or_else(|e| {
111        eprintln!("Warning: Failed to render conflict resolution template: {e}");
112        // Use fallback template
113        let fallback_template_content = context
114            .registry()
115            .get_template("conflict_resolution_fallback")
116            .unwrap_or_else(|_| {
117                include_str!("templates/conflict_resolution_fallback.txt").to_string()
118            });
119        let fallback_template = Template::new(&fallback_template_content);
120        fallback_template.render(&variables).unwrap_or_else(|e| {
121            eprintln!("Critical: Failed to render fallback template: {e}");
122            // Last resort: minimal emergency prompt - conflicts_section is captured from closure
123            format!(
124                "# MERGE CONFLICT RESOLUTION\n\nResolve these conflicts:\n\n{}",
125                &conflicts_section
126            )
127        })
128    })
129}
130
131/// Format the context section with PROMPT.md and PLAN.md content.
132///
133/// This helper builds the context section that gets injected into the
134/// {{CONTEXT}} template variable.
135fn format_context_section(prompt_md_content: Option<&str>, plan_content: Option<&str>) -> String {
136    let mut context = String::new();
137
138    // Add task context from PROMPT.md if available
139    if let Some(prompt_md) = prompt_md_content {
140        context.push_str("## Task Context\n\n");
141        context.push_str("The user was working on the following task:\n\n");
142        context.push_str("```\n");
143        context.push_str(prompt_md);
144        context.push_str("\n```\n\n");
145    }
146
147    // Add plan context from PLAN.md if available
148    if let Some(plan) = plan_content {
149        context.push_str("## Implementation Plan\n\n");
150        context.push_str("The following plan was being implemented:\n\n");
151        context.push_str("```\n");
152        context.push_str(plan);
153        context.push_str("\n```\n\n");
154    }
155
156    context
157}
158
159/// Format the conflicts section for all conflicted files.
160///
161/// This helper builds the conflicts section that gets injected into the
162/// {{CONFLICTS}} template variable.
163fn format_conflicts_section<S: std::hash::BuildHasher>(
164    conflicts: &HashMap<String, FileConflict, S>,
165) -> String {
166    let mut section = String::new();
167
168    for (path, conflict) in conflicts {
169        writeln!(section, "### {path}\n\n").unwrap();
170        section.push_str("Current state (with conflict markers):\n\n");
171        section.push_str("```");
172        section.push_str(&get_language_marker(path));
173        section.push('\n');
174        section.push_str(&conflict.current_content);
175        section.push_str("\n```\n\n");
176
177        if !conflict.conflict_content.is_empty() {
178            section.push_str("Conflict sections:\n\n");
179            section.push_str("```\n");
180            section.push_str(&conflict.conflict_content);
181            section.push_str("\n```\n\n");
182        }
183    }
184
185    section
186}
187
188/// Get a language marker for syntax highlighting based on file extension.
189fn get_language_marker(path: &str) -> String {
190    let ext = Path::new(path)
191        .extension()
192        .and_then(|e| e.to_str())
193        .unwrap_or("");
194
195    match ext {
196        "rs" => "rust",
197        "py" => "python",
198        "js" | "jsx" => "javascript",
199        "ts" | "tsx" => "typescript",
200        "go" => "go",
201        "java" => "java",
202        "c" | "h" => "c",
203        "cpp" | "hpp" | "cc" | "cxx" => "cpp",
204        "cs" => "csharp",
205        "php" => "php",
206        "rb" => "ruby",
207        "swift" => "swift",
208        "kt" => "kotlin",
209        "scala" => "scala",
210        "sh" | "bash" | "zsh" => "bash",
211        "fish" => "fish",
212        "yaml" | "yml" => "yaml",
213        "json" => "json",
214        "toml" => "toml",
215        "md" | "markdown" => "markdown",
216        "txt" => "text",
217        "html" => "html",
218        "css" | "scss" | "less" => "css",
219        "xml" => "xml",
220        "sql" => "sql",
221        _ => "",
222    }
223    .to_string()
224}
225
226/// Information about divergent branches for enhanced conflict resolution.
227#[derive(Debug, Clone)]
228pub struct BranchInfo {
229    /// The current branch name
230    pub current_branch: String,
231    /// The upstream/target branch name
232    pub upstream_branch: String,
233    /// Recent commit messages from current branch
234    pub current_commits: Vec<String>,
235    /// Recent commit messages from upstream branch
236    pub upstream_commits: Vec<String>,
237    /// Number of diverging commits
238    pub diverging_count: usize,
239}
240
241/// Build a conflict resolution prompt with enhanced branch context.
242///
243/// This version provides richer context about the branches involved in the conflict,
244/// including recent commit history and divergence information.
245///
246/// # Arguments
247///
248/// * `context` - Template context containing the template registry
249/// * `conflicts` - Map of file paths to their conflict information
250/// * `branch_info` - Optional branch information for enhanced context
251/// * `prompt_md_content` - Optional content from PROMPT.md for task context
252/// * `plan_content` - Optional content from PLAN.md for additional context
253#[must_use]
254pub fn build_enhanced_conflict_resolution_prompt<S: std::hash::BuildHasher>(
255    context: &TemplateContext,
256    conflicts: &HashMap<String, FileConflict, S>,
257    branch_info: Option<&BranchInfo>,
258    prompt_md_content: Option<&str>,
259    plan_content: Option<&str>,
260) -> String {
261    let template_content = context
262        .registry()
263        .get_template("conflict_resolution")
264        .unwrap_or_else(|_| include_str!("templates/conflict_resolution.txt").to_string());
265    let template = Template::new(&template_content);
266
267    let mut ctx_section = format_context_section(prompt_md_content, plan_content);
268
269    // Add branch information if available
270    if let Some(info) = branch_info {
271        ctx_section.push_str(&format_branch_info_section(info));
272    }
273
274    let conflicts_section = format_conflicts_section(conflicts);
275
276    let variables = HashMap::from([
277        ("CONTEXT", ctx_section),
278        ("CONFLICTS", conflicts_section.clone()),
279    ]);
280
281    template.render(&variables).unwrap_or_else(|e| {
282        eprintln!("Warning: Failed to render conflict resolution template: {e}");
283        // Use fallback template
284        let fallback_template_content = context
285            .registry()
286            .get_template("conflict_resolution_fallback")
287            .unwrap_or_else(|_| {
288                include_str!("templates/conflict_resolution_fallback.txt").to_string()
289            });
290        let fallback_template = Template::new(&fallback_template_content);
291        fallback_template.render(&variables).unwrap_or_else(|e| {
292            eprintln!("Critical: Failed to render fallback template: {e}");
293            // Last resort: minimal emergency prompt - conflicts_section is captured from closure
294            format!(
295                "# MERGE CONFLICT RESOLUTION\n\nResolve these conflicts:\n\n{}",
296                &conflicts_section
297            )
298        })
299    })
300}
301
302/// Format branch information for context section.
303///
304/// This helper builds a branch information section that gets injected
305/// into the context for AI conflict resolution.
306fn format_branch_info_section(info: &BranchInfo) -> String {
307    let mut section = String::new();
308
309    section.push_str("## Branch Information\n\n");
310    writeln!(section, "- **Current branch**: `{}`", info.current_branch).unwrap();
311    writeln!(section, "- **Target branch**: `{}`", info.upstream_branch).unwrap();
312    writeln!(
313        section,
314        "- **Diverging commits**: {}\n",
315        info.diverging_count
316    )
317    .unwrap();
318
319    if !info.current_commits.is_empty() {
320        section.push_str("### Recent commits on current branch:\n\n");
321        for (i, msg) in info.current_commits.iter().enumerate().take(5) {
322            writeln!(section, "{}. {}", i + 1, msg).unwrap();
323        }
324        section.push('\n');
325    }
326
327    if !info.upstream_commits.is_empty() {
328        section.push_str("### Recent commits on target branch:\n\n");
329        for (i, msg) in info.upstream_commits.iter().enumerate().take(5) {
330            writeln!(section, "{}. {}", i + 1, msg).unwrap();
331        }
332        section.push('\n');
333    }
334
335    section
336}
337
338/// Collect branch information for conflict resolution.
339///
340/// Queries git to gather information about the branches involved in the conflict.
341///
342/// # Arguments
343///
344/// * `upstream_branch` - The name of the upstream/target branch
345/// * `executor` - Process executor for external process execution
346///
347/// # Returns
348///
349/// Returns `Ok(BranchInfo)` with the gathered information, or an error if git operations fail.
350///
351/// # Errors
352///
353/// Returns error if the operation fails.
354pub fn collect_branch_info(
355    upstream_branch: &str,
356    executor: &dyn crate::executor::ProcessExecutor,
357) -> std::io::Result<BranchInfo> {
358    // Get current branch name
359    let current_branch =
360        executor.execute("git", &["rev-parse", "--abbrev-ref", "HEAD"], &[], None)?;
361
362    let current_branch = current_branch.stdout.trim().to_string();
363
364    // Get recent commits from current branch
365    let current_log = executor.execute("git", &["log", "--oneline", "-10", "HEAD"], &[], None)?;
366
367    let current_commits: Vec<String> = current_log
368        .stdout
369        .lines()
370        .map(std::string::ToString::to_string)
371        .collect();
372
373    // Get recent commits from upstream branch
374    let upstream_log = executor.execute(
375        "git",
376        &["log", "--oneline", "-10", upstream_branch],
377        &[],
378        None,
379    )?;
380
381    let upstream_commits: Vec<String> = upstream_log
382        .stdout
383        .lines()
384        .map(std::string::ToString::to_string)
385        .collect();
386
387    // Count diverging commits
388    let diverging = executor.execute(
389        "git",
390        &[
391            "rev-list",
392            "--count",
393            "--left-right",
394            &format!("HEAD...{upstream_branch}"),
395        ],
396        &[],
397        None,
398    )?;
399
400    let diverging_count = diverging
401        .stdout
402        .split_whitespace()
403        .map(|s| s.parse::<usize>().unwrap_or(0))
404        .sum::<usize>();
405
406    Ok(BranchInfo {
407        current_branch,
408        upstream_branch: upstream_branch.to_string(),
409        current_commits,
410        upstream_commits,
411        diverging_count,
412    })
413}
414
415/// Collect conflict information from all conflicted files.
416///
417/// This function reads all conflicted files and builds a map of
418/// file paths to their conflict information.
419///
420/// # Arguments
421///
422/// * `conflicted_paths` - List of paths to conflicted files
423///
424/// # Returns
425///
426/// Returns `Ok(HashMap)` mapping file paths to conflict information,
427/// or an error if a file cannot be read.
428///
429/// # Errors
430///
431/// Returns error if the operation fails.
432pub fn collect_conflict_info_with_workspace(
433    workspace: &dyn Workspace,
434    conflicted_paths: &[String],
435) -> std::io::Result<HashMap<String, FileConflict>> {
436    let mut conflicts = HashMap::new();
437
438    for path in conflicted_paths {
439        let current_content = workspace.read(Path::new(path))?;
440        let conflict_content = extract_conflict_sections_from_content(&current_content);
441
442        conflicts.insert(
443            path.clone(),
444            FileConflict {
445                conflict_content,
446                current_content,
447            },
448        );
449    }
450
451    Ok(conflicts)
452}
453
454fn extract_conflict_sections_from_content(content: &str) -> String {
455    let mut conflict_sections = Vec::new();
456    let lines: Vec<&str> = content.lines().collect();
457    let mut i = 0;
458
459    while i < lines.len() {
460        if lines[i].trim_start().starts_with("<<<<<<<") {
461            let mut section = Vec::new();
462            section.push(lines[i]);
463
464            i += 1;
465            while i < lines.len() && !lines[i].trim_start().starts_with("=======") {
466                section.push(lines[i]);
467                i += 1;
468            }
469
470            if i < lines.len() {
471                section.push(lines[i]);
472                i += 1;
473            }
474
475            while i < lines.len() && !lines[i].trim_start().starts_with(">>>>>>>") {
476                section.push(lines[i]);
477                i += 1;
478            }
479
480            if i < lines.len() {
481                section.push(lines[i]);
482                i += 1;
483            }
484
485            conflict_sections.push(section.join("\n"));
486        } else {
487            i += 1;
488        }
489    }
490
491    if conflict_sections.is_empty() {
492        String::new()
493    } else {
494        conflict_sections.join("\n\n")
495    }
496}
497
498#[cfg(test)]
499mod tests;