Skip to main content

ralph_workflow/prompts/
rebase.rs

1//! Rebase conflict resolution prompts.
2//!
3//! This module provides prompts for AI agents to resolve merge conflicts
4//! that occur during rebase operations.
5//!
6//! # Design Note
7//!
8//! Per project requirements, AI agents should NOT know that we are in the
9//! middle of a rebase. The prompt frames conflicts as "merge conflicts between
10//! two versions" without mentioning rebase or rebasing.
11
12#![deny(unsafe_code)]
13
14use crate::prompts::template_context::TemplateContext;
15use crate::prompts::template_engine::Template;
16use crate::workspace::Workspace;
17use std::collections::HashMap;
18use std::path::Path;
19
20/// Structure representing a single file conflict.
21#[derive(Debug, Clone)]
22pub struct FileConflict {
23    /// The conflict marker content from the file
24    pub conflict_content: String,
25    /// The current file content with conflict markers
26    pub current_content: String,
27}
28
29/// Build a conflict resolution prompt for the AI agent.
30///
31/// This function generates a prompt that instructs the AI agent to resolve
32/// merge conflicts. The prompt does NOT mention "rebase" - it frames the
33/// task as resolving merge conflicts between two versions.
34///
35/// # Arguments
36///
37/// * `conflicts` - Map of file paths to their conflict information
38/// * `prompt_md_content` - Optional content from PROMPT.md for task context
39/// * `plan_content` - Optional content from PLAN.md for additional context
40///
41/// # Returns
42///
43/// Returns a formatted prompt string for the AI agent.
44#[cfg(test)]
45#[expect(clippy::print_stderr, reason = "test-only error logging")]
46pub fn build_conflict_resolution_prompt(
47    conflicts: &HashMap<String, FileConflict>,
48    prompt_md_content: Option<&str>,
49    plan_content: Option<&str>,
50) -> String {
51    let template_content = include_str!("templates/conflict_resolution.txt");
52    let template = Template::new(template_content);
53
54    let context = format_context_section(prompt_md_content, plan_content);
55    let conflicts_section = format_conflicts_section(conflicts);
56
57    let variables = HashMap::from([
58        ("CONTEXT", context),
59        ("CONFLICTS", conflicts_section.clone()),
60    ]);
61
62    template.render(&variables).unwrap_or_else(|e| {
63        eprintln!("Warning: Failed to render conflict resolution template: {e}");
64        let fallback_template_content = include_str!("templates/conflict_resolution_fallback.txt");
65        let fallback_template = Template::new(fallback_template_content);
66        fallback_template.render(&variables).unwrap_or_else(|e| {
67            eprintln!("Critical: Failed to render fallback template: {e}");
68            format!(
69                "# MERGE CONFLICT RESOLUTION\n\nResolve these conflicts:\n\n{}",
70                &conflicts_section
71            )
72        })
73    })
74}
75
76/// Build a conflict resolution prompt using template registry.
77///
78/// This version uses the template registry which supports user template overrides.
79/// It's the recommended way to generate prompts going forward.
80///
81/// # Arguments
82///
83/// * `context` - Template context containing the template registry
84/// * `conflicts` - Map of file paths to their conflict information
85/// * `prompt_md_content` - Optional content from PROMPT.md for task context
86/// * `plan_content` - Optional content from PLAN.md for additional context
87#[must_use]
88#[expect(
89    clippy::print_stderr,
90    reason = "error logging for template rendering failures"
91)]
92pub fn build_conflict_resolution_prompt_with_context<S: std::hash::BuildHasher>(
93    context: &TemplateContext,
94    conflicts: &HashMap<String, FileConflict, S>,
95    prompt_md_content: Option<&str>,
96    plan_content: Option<&str>,
97) -> String {
98    let template_content = context
99        .registry()
100        .get_template("conflict_resolution")
101        .unwrap_or_else(|_| include_str!("templates/conflict_resolution.txt").to_string());
102    let template = Template::new(&template_content);
103
104    let ctx_section = format_context_section(prompt_md_content, plan_content);
105    let conflicts_section = format_conflicts_section(conflicts);
106
107    let variables = HashMap::from([
108        ("CONTEXT", ctx_section),
109        ("CONFLICTS", conflicts_section.clone()),
110    ]);
111
112    template.render(&variables).unwrap_or_else(|e| {
113        eprintln!("Warning: Failed to render conflict resolution template: {e}");
114        // Use fallback template
115        let fallback_template_content = context
116            .registry()
117            .get_template("conflict_resolution_fallback")
118            .unwrap_or_else(|_| {
119                include_str!("templates/conflict_resolution_fallback.txt").to_string()
120            });
121        let fallback_template = Template::new(&fallback_template_content);
122        fallback_template.render(&variables).unwrap_or_else(|e| {
123            eprintln!("Critical: Failed to render fallback template: {e}");
124            // Last resort: minimal emergency prompt - conflicts_section is captured from closure
125            format!(
126                "# MERGE CONFLICT RESOLUTION\n\nResolve these conflicts:\n\n{}",
127                &conflicts_section
128            )
129        })
130    })
131}
132
133/// Format the context section with PROMPT.md and PLAN.md content.
134///
135/// This helper builds the context section that gets injected into the
136/// {{CONTEXT}} template variable.
137fn format_context_section(prompt_md_content: Option<&str>, plan_content: Option<&str>) -> String {
138    let prompt_part = prompt_md_content.map(|prompt_md| {
139        format!(
140            "## Task Context\n\nThe user was working on the following task:\n\n```\n{}\n```\n\n",
141            prompt_md
142        )
143    });
144
145    let plan_part = plan_content.map(|plan| {
146        format!(
147            "## Implementation Plan\n\nThe following plan was being implemented:\n\n```\n{}\n```\n\n",
148            plan
149        )
150    });
151
152    [prompt_part, plan_part]
153        .into_iter()
154        .flatten()
155        .collect::<String>()
156}
157
158/// Format the conflicts section for all conflicted files.
159///
160/// This helper builds the conflicts section that gets injected into the
161/// {{CONFLICTS}} template variable.
162fn format_conflicts_section<S: std::hash::BuildHasher>(
163    conflicts: &HashMap<String, FileConflict, S>,
164) -> String {
165    let sections: Vec<String> = conflicts
166        .iter()
167        .map(|(path, conflict)| {
168            let header = format!("### {path}\n\n");
169            let current = format!(
170                "Current state (with conflict markers):\n\n```{}\n{}\n```\n\n",
171                get_language_marker(path),
172                conflict.current_content
173            );
174            let conflict_part = if conflict.conflict_content.is_empty() {
175                String::new()
176            } else {
177                format!(
178                    "Conflict sections:\n\n```{}\n{}\n```\n\n",
179                    get_language_marker(path),
180                    conflict.conflict_content
181                )
182            };
183            [header, current, conflict_part].join("")
184        })
185        .collect();
186
187    sections.join("")
188}
189
190/// Get a language marker for syntax highlighting based on file extension.
191fn get_language_marker(path: &str) -> String {
192    let ext = Path::new(path)
193        .extension()
194        .and_then(|e| e.to_str())
195        .unwrap_or("");
196
197    match ext {
198        "rs" => "rust",
199        "py" => "python",
200        "js" | "jsx" => "javascript",
201        "ts" | "tsx" => "typescript",
202        "go" => "go",
203        "java" => "java",
204        "c" => "c",
205        "cpp" | "cc" | "cxx" => "cpp",
206        "h" | "hpp" => "cpp",
207        "cs" => "csharp",
208        "rb" => "ruby",
209        "php" => "php",
210        "swift" => "swift",
211        "kt" | "kts" => "kotlin",
212        "scala" => "scala",
213        "sh" | "bash" | "zsh" => "bash",
214        "yml" | "yaml" => "yaml",
215        "json" => "json",
216        "toml" => "toml",
217        "xml" => "xml",
218        "html" | "htm" => "html",
219        "css" => "css",
220        "scss" | "sass" => "scss",
221        "sql" => "sql",
222        "md" | "markdown" => "markdown",
223        _ => "",
224    }
225    .to_string()
226}
227
228/// Branch information for enhanced context.
229#[derive(Debug, Clone)]
230pub struct BranchInfo {
231    /// Current branch name
232    pub current_branch: String,
233    /// Upstream/target branch name
234    pub upstream_branch: String,
235    /// Recent commits on current branch
236    pub current_commits: Vec<String>,
237    /// Recent commits on upstream branch
238    pub upstream_commits: Vec<String>,
239    /// Number of diverging commits
240    pub diverging_count: usize,
241}
242
243/// Build an enhanced conflict resolution prompt with branch information.
244///
245/// This version includes additional context about the branches involved
246/// in the conflict for more informed resolution.
247///
248/// # Arguments
249///
250/// * `context` - Template context containing the template registry
251/// * `conflicts` - Map of file paths to their conflict information
252/// * `branch_info` - Optional branch information for enhanced context
253/// * `prompt_md_content` - Optional content from PROMPT.md for task context
254/// * `plan_content` - Optional content from PLAN.md for additional context
255#[must_use]
256#[expect(
257    clippy::print_stderr,
258    reason = "error logging for template rendering failures"
259)]
260pub fn build_enhanced_conflict_resolution_prompt<S: std::hash::BuildHasher>(
261    context: &TemplateContext,
262    conflicts: &HashMap<String, FileConflict, S>,
263    branch_info: Option<&BranchInfo>,
264    prompt_md_content: Option<&str>,
265    plan_content: Option<&str>,
266) -> String {
267    let template_content = context
268        .registry()
269        .get_template("conflict_resolution")
270        .unwrap_or_else(|_| include_str!("templates/conflict_resolution.txt").to_string());
271    let template = Template::new(&template_content);
272
273    let ctx_section = match branch_info {
274        Some(info) => {
275            format_context_section(prompt_md_content, plan_content)
276                + &format_branch_info_section(info)
277        }
278        None => format_context_section(prompt_md_content, plan_content),
279    };
280
281    let conflicts_section = format_conflicts_section(conflicts);
282
283    let variables = HashMap::from([
284        ("CONTEXT", ctx_section),
285        ("CONFLICTS", conflicts_section.clone()),
286    ]);
287
288    template.render(&variables).unwrap_or_else(|e| {
289        eprintln!("Warning: Failed to render conflict resolution template: {e}");
290        // Use fallback template
291        let fallback_template_content = context
292            .registry()
293            .get_template("conflict_resolution_fallback")
294            .unwrap_or_else(|_| {
295                include_str!("templates/conflict_resolution_fallback.txt").to_string()
296            });
297        let fallback_template = Template::new(&fallback_template_content);
298        fallback_template.render(&variables).unwrap_or_else(|e| {
299            eprintln!("Critical: Failed to render fallback template: {e}");
300            // Last resort: minimal emergency prompt - conflicts_section is captured from closure
301            format!(
302                "# MERGE CONFLICT RESOLUTION\n\nResolve these conflicts:\n\n{}",
303                &conflicts_section
304            )
305        })
306    })
307}
308
309/// Format branch information for context section.
310///
311/// This helper builds a branch information section that gets injected
312/// into the context for AI conflict resolution.
313fn format_branch_info_section(info: &BranchInfo) -> String {
314    let header = format!(
315        "## Branch Information\n\n- **Current branch**: `{}`\n- **Target branch**: `{}`\n- **Diverging commits**: {}\n\n",
316        info.current_branch, info.upstream_branch, info.diverging_count
317    );
318
319    let current_commits_section = if info.current_commits.is_empty() {
320        String::new()
321    } else {
322        let commits: Vec<String> = info
323            .current_commits
324            .iter()
325            .take(5)
326            .enumerate()
327            .map(|(i, msg)| format!("{}. {}", i + 1, msg))
328            .collect();
329        format!(
330            "### Recent commits on current branch:\n\n{}\n\n",
331            commits.join("\n")
332        )
333    };
334
335    let upstream_commits_section = if info.upstream_commits.is_empty() {
336        String::new()
337    } else {
338        let commits: Vec<String> = info
339            .upstream_commits
340            .iter()
341            .take(5)
342            .enumerate()
343            .map(|(i, msg)| format!("{}. {}", i + 1, msg))
344            .collect();
345        format!(
346            "### Recent commits on target branch:\n\n{}\n\n",
347            commits.join("\n")
348        )
349    };
350
351    [header, current_commits_section, upstream_commits_section]
352        .into_iter()
353        .filter(|s| !s.is_empty())
354        .collect()
355}
356
357/// Collect branch information for conflict resolution.
358///
359/// Queries git to gather information about the branches involved in the conflict.
360///
361/// # Arguments
362///
363/// * `upstream_branch` - The name of the upstream/target branch
364/// * `executor` - Process executor for external process execution
365///
366/// # Returns
367///
368/// Returns `Ok(BranchInfo)` with the gathered information, or an error if git operations fail.
369///
370/// # Errors
371///
372/// Returns error if the operation fails.
373pub fn collect_branch_info(
374    upstream_branch: &str,
375    executor: &dyn crate::executor::ProcessExecutor,
376) -> std::io::Result<BranchInfo> {
377    // Get current branch name
378    let current_branch =
379        executor.execute("git", &["rev-parse", "--abbrev-ref", "HEAD"], &[], None)?;
380
381    let current_branch = current_branch.stdout.trim().to_string();
382
383    // Get recent commits from current branch
384    let current_log = executor.execute("git", &["log", "--oneline", "-10", "HEAD"], &[], None)?;
385
386    let current_commits: Vec<String> = current_log
387        .stdout
388        .lines()
389        .map(std::string::ToString::to_string)
390        .collect();
391
392    // Get recent commits from upstream branch
393    let upstream_log = executor.execute(
394        "git",
395        &["log", "--oneline", "-10", upstream_branch],
396        &[],
397        None,
398    )?;
399
400    let upstream_commits: Vec<String> = upstream_log
401        .stdout
402        .lines()
403        .map(std::string::ToString::to_string)
404        .collect();
405
406    // Count diverging commits
407    let diverging = executor.execute(
408        "git",
409        &[
410            "rev-list",
411            "--count",
412            "--left-right",
413            &format!("HEAD...{upstream_branch}"),
414        ],
415        &[],
416        None,
417    )?;
418
419    let diverging_count = diverging
420        .stdout
421        .split_whitespace()
422        .map(|s| s.parse::<usize>().unwrap_or(0))
423        .sum::<usize>();
424
425    Ok(BranchInfo {
426        current_branch,
427        upstream_branch: upstream_branch.to_string(),
428        current_commits,
429        upstream_commits,
430        diverging_count,
431    })
432}
433
434/// Collect conflict information from all conflicted files.
435///
436/// This function reads all conflicted files and builds a map of
437/// file paths to their conflict information.
438///
439/// # Arguments
440///
441/// * `conflicted_paths` - List of paths to conflicted files
442///
443/// # Returns
444///
445/// Returns `Ok(HashMap)` mapping file paths to conflict information,
446/// or an error if a file cannot be read.
447///
448/// # Errors
449///
450/// Returns error if the operation fails.
451pub fn collect_conflict_info_with_workspace(
452    workspace: &dyn Workspace,
453    conflicted_paths: &[String],
454) -> std::io::Result<HashMap<String, FileConflict>> {
455    let conflicts: std::io::Result<Vec<(String, FileConflict)>> = conflicted_paths
456        .iter()
457        .map(|path| {
458            let current_content = workspace.read(Path::new(path))?;
459            let conflict_content = extract_conflict_sections_from_content(&current_content);
460            Ok((
461                path.clone(),
462                FileConflict {
463                    conflict_content,
464                    current_content,
465                },
466            ))
467        })
468        .collect();
469
470    let result: HashMap<String, FileConflict> = conflicts?.into_iter().collect();
471
472    Ok(result)
473}
474
475fn extract_conflict_sections_from_content(content: &str) -> String {
476    let lines: Vec<&str> = content.lines().collect();
477
478    // Find all conflict markers and extract sections between them
479    let conflict_sections: Vec<String> = lines
480        .iter()
481        .enumerate()
482        .filter(|(_, line)| line.trim_start().starts_with("<<<<<<<"))
483        .filter_map(|(start_idx, _)| {
484            // Find the ======= line
485            let equals_idx = lines
486                .get(start_idx + 1..)?
487                .iter()
488                .position(|line| line.trim_start().starts_with("======="))
489                .map(|i| start_idx + 1 + i);
490
491            // Find the >>>>>>> line after =======
492            let end_idx = equals_idx.and_then(|eq_idx| {
493                lines
494                    .get(eq_idx + 1..)?
495                    .iter()
496                    .position(|line| line.trim_start().starts_with(">>>>>>>"))
497                    .map(|i| eq_idx + 1 + i)
498            });
499
500            // Extract the full conflict section
501            let end = end_idx.unwrap_or(lines.len() - 1) + 1;
502            Some(lines.get(start_idx..end)?.join("\n"))
503        })
504        .collect();
505
506    if conflict_sections.is_empty() {
507        String::new()
508    } else {
509        conflict_sections.join("\n\n")
510    }
511}
512
513#[cfg(test)]
514mod tests;
515
516#[cfg(test)]
517mod io_tests;