use anyhow::{Context, Result};
use clap::Parser;
use tracing::{debug, error};
use super::info::InfoCommand;
#[derive(Parser)]
pub struct CreatePrCommand {
#[arg(long, value_name = "BRANCH")]
pub base: Option<String>,
#[arg(long)]
pub model: Option<String>,
#[arg(long)]
pub auto_apply: bool,
#[arg(long, value_name = "FILE")]
pub save_only: Option<String>,
#[arg(long, conflicts_with = "draft")]
pub ready: bool,
#[arg(long, conflicts_with = "ready")]
pub draft: bool,
#[arg(long)]
pub context_dir: Option<std::path::PathBuf>,
#[arg(long)]
pub no_push: bool,
}
#[derive(Debug, PartialEq)]
enum PrAction {
CreateNew,
UpdateExisting,
Cancel,
}
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
pub struct PrContent {
pub title: String,
pub description: String,
}
impl CreatePrCommand {
fn should_create_as_draft(&self) -> bool {
use crate::utils::settings::get_env_var;
if self.ready {
return false;
}
if self.draft {
return true;
}
get_env_var("OMNI_DEV_DEFAULT_DRAFT_PR")
.ok()
.and_then(|val| parse_bool_string(&val))
.unwrap_or(true) }
pub async fn execute(self) -> Result<()> {
let ai_info = crate::utils::check_pr_command_prerequisites(self.model.as_deref())?;
println!(
"✓ {} credentials verified (model: {})",
ai_info.provider, ai_info.model
);
println!("✓ GitHub CLI verified");
println!("🔄 Starting pull request creation process...");
let repo_view = self.generate_repository_view()?;
self.validate_branch_state(&repo_view)?;
use crate::claude::context::ProjectDiscovery;
let repo_root = std::path::PathBuf::from(".");
let context_dir = crate::claude::context::resolve_context_dir(self.context_dir.as_deref());
let discovery = ProjectDiscovery::new(repo_root, context_dir);
let project_context = discovery.discover().unwrap_or_default();
self.show_guidance_files_status(&project_context)?;
let claude_client = crate::claude::create_default_claude_client(self.model.clone(), None)?;
self.show_model_info_from_client(&claude_client)?;
self.show_commit_range_info(&repo_view)?;
let context = {
use crate::claude::context::{BranchAnalyzer, FileAnalyzer, WorkPatternAnalyzer};
use crate::data::context::CommitContext;
let mut context = CommitContext::new();
context.project = project_context;
if let Some(branch_info) = &repo_view.branch_info {
context.branch = BranchAnalyzer::analyze(&branch_info.branch).unwrap_or_default();
}
if !repo_view.commits.is_empty() {
context.range = WorkPatternAnalyzer::analyze_commit_range(&repo_view.commits);
context.files = FileAnalyzer::analyze_commits(&repo_view.commits);
}
context
};
self.show_context_summary(&context)?;
debug!("About to generate PR content from AI");
let (pr_content, _claude_client) = self
.generate_pr_content_with_client_internal(&repo_view, claude_client)
.await?;
self.show_context_information(&repo_view).await?;
debug!(
generated_title = %pr_content.title,
generated_description_length = pr_content.description.len(),
generated_description_preview = %pr_content.description.lines().take(3).collect::<Vec<_>>().join("\\n"),
"Generated PR content from AI"
);
if let Some(save_path) = self.save_only {
let pr_yaml = crate::data::to_yaml(&pr_content)
.context("Failed to serialize PR content to YAML")?;
std::fs::write(&save_path, &pr_yaml).context("Failed to save PR details to file")?;
println!("💾 PR details saved to: {save_path}");
return Ok(());
}
debug!("About to serialize PR content to YAML");
let temp_dir = tempfile::tempdir()?;
let pr_file = temp_dir.path().join("pr-details.yaml");
debug!(
pre_serialize_title = %pr_content.title,
pre_serialize_description_length = pr_content.description.len(),
pre_serialize_description_preview = %pr_content.description.lines().take(3).collect::<Vec<_>>().join("\\n"),
"About to serialize PR content with to_yaml"
);
let pr_yaml =
crate::data::to_yaml(&pr_content).context("Failed to serialize PR content to YAML")?;
debug!(
file_path = %pr_file.display(),
yaml_content_length = pr_yaml.len(),
yaml_content = %pr_yaml,
original_title = %pr_content.title,
original_description_length = pr_content.description.len(),
"Writing PR details to temporary YAML file"
);
std::fs::write(&pr_file, &pr_yaml)?;
let pr_action = if self.auto_apply {
if repo_view
.branch_prs
.as_ref()
.is_some_and(|prs| !prs.is_empty())
{
PrAction::UpdateExisting
} else {
PrAction::CreateNew
}
} else {
self.handle_pr_file(&pr_file, &repo_view)?
};
if pr_action == PrAction::Cancel {
println!("❌ PR operation cancelled by user");
return Ok(());
}
let final_pr_yaml =
std::fs::read_to_string(&pr_file).context("Failed to read PR details file")?;
debug!(
yaml_length = final_pr_yaml.len(),
yaml_content = %final_pr_yaml,
"Read PR details YAML from file"
);
let final_pr_content: PrContent = serde_yaml::from_str(&final_pr_yaml)
.context("Failed to parse PR details YAML. Please check the file format.")?;
debug!(
title = %final_pr_content.title,
description_length = final_pr_content.description.len(),
description_preview = %final_pr_content.description.lines().take(3).collect::<Vec<_>>().join("\\n"),
"Parsed PR content from YAML"
);
let is_draft = self.should_create_as_draft();
match pr_action {
PrAction::CreateNew => {
self.create_github_pr(
&repo_view,
&final_pr_content.title,
&final_pr_content.description,
is_draft,
self.base.as_deref(),
)?;
println!("✅ Pull request created successfully!");
}
PrAction::UpdateExisting => {
self.update_github_pr(
&repo_view,
&final_pr_content.title,
&final_pr_content.description,
self.base.as_deref(),
)?;
println!("✅ Pull request updated successfully!");
}
PrAction::Cancel => unreachable!(), }
Ok(())
}
fn generate_repository_view(&self) -> Result<crate::data::RepositoryView> {
use crate::data::{
AiInfo, BranchInfo, FieldExplanation, FileStatusInfo, RepositoryView, VersionInfo,
WorkingDirectoryInfo,
};
use crate::git::{GitRepository, RemoteInfo};
use crate::utils::ai_scratch;
let repo = GitRepository::open()
.context("Failed to open git repository. Make sure you're in a git repository.")?;
let current_branch = repo.get_current_branch().context(
"Failed to get current branch. Make sure you're not in detached HEAD state.",
)?;
let remotes = RemoteInfo::get_all_remotes(repo.repository())?;
let primary_remote = remotes
.iter()
.find(|r| r.name == "origin")
.or_else(|| remotes.first())
.ok_or_else(|| anyhow::anyhow!("No remotes found in repository"))?;
let base_branch = if let Some(branch) = self.base.as_ref() {
let remote_ref = format!("refs/remotes/{branch}");
if repo.repository().find_reference(&remote_ref).is_ok() {
branch.clone()
} else {
let with_remote = format!("{}/{}", primary_remote.name, branch);
let remote_ref = format!("refs/remotes/{with_remote}");
if repo.repository().find_reference(&remote_ref).is_ok() {
with_remote
} else {
anyhow::bail!(
"Remote branch '{branch}' does not exist (also tried '{with_remote}')"
);
}
}
} else {
let main_branch = &primary_remote.main_branch;
if main_branch == "unknown" {
let remote_name = &primary_remote.name;
anyhow::bail!("Could not determine main branch for remote '{remote_name}'");
}
let remote_main = format!("{}/{}", primary_remote.name, main_branch);
let remote_ref = format!("refs/remotes/{remote_main}");
if repo.repository().find_reference(&remote_ref).is_err() {
anyhow::bail!(
"Remote main branch '{remote_main}' does not exist. Try running 'git fetch' first."
);
}
remote_main
};
let commit_range = format!("{base_branch}..HEAD");
let wd_status = repo.get_working_directory_status()?;
let working_directory = WorkingDirectoryInfo {
clean: wd_status.clean,
untracked_changes: wd_status
.untracked_changes
.into_iter()
.map(|fs| FileStatusInfo {
status: fs.status,
file: fs.file,
})
.collect(),
};
let remotes = RemoteInfo::get_all_remotes(repo.repository())?;
let commits = repo.get_commits_in_range(&commit_range)?;
let pr_template_result = InfoCommand::read_pr_template().ok();
let (pr_template, pr_template_location) = match pr_template_result {
Some((content, location)) => (Some(content), Some(location)),
None => (None, None),
};
let branch_prs = InfoCommand::get_branch_prs(¤t_branch)
.ok()
.filter(|prs| !prs.is_empty());
let versions = Some(VersionInfo {
omni_dev: env!("CARGO_PKG_VERSION").to_string(),
});
let ai_scratch_path =
ai_scratch::get_ai_scratch_dir().context("Failed to determine AI scratch directory")?;
let ai_info = AiInfo {
scratch: ai_scratch_path.to_string_lossy().to_string(),
};
let mut repo_view = RepositoryView {
versions,
explanation: FieldExplanation::default(),
working_directory,
remotes,
ai: ai_info,
branch_info: Some(BranchInfo {
branch: current_branch,
}),
pr_template,
pr_template_location,
branch_prs,
commits,
};
repo_view.update_field_presence();
Ok(repo_view)
}
fn validate_branch_state(&self, repo_view: &crate::data::RepositoryView) -> Result<()> {
if !repo_view.working_directory.clean {
anyhow::bail!(
"Working directory has uncommitted changes. Please commit or stash your changes before creating a PR."
);
}
if !repo_view.working_directory.untracked_changes.is_empty() {
let file_list: Vec<&str> = repo_view
.working_directory
.untracked_changes
.iter()
.map(|f| f.file.as_str())
.collect();
anyhow::bail!(
"Working directory has untracked changes: {}. Please commit or stash your changes before creating a PR.",
file_list.join(", ")
);
}
if repo_view.commits.is_empty() {
anyhow::bail!("No commits found to create PR from. Make sure you have commits that are not in the base branch.");
}
if let Some(existing_prs) = &repo_view.branch_prs {
if !existing_prs.is_empty() {
let pr_info: Vec<String> = existing_prs
.iter()
.map(|pr| format!("#{} ({})", pr.number, pr.state))
.collect();
println!(
"📋 Existing PR(s) found for this branch: {}",
pr_info.join(", ")
);
}
}
Ok(())
}
async fn show_context_information(
&self,
_repo_view: &crate::data::RepositoryView,
) -> Result<()> {
Ok(())
}
fn show_commit_range_info(&self, repo_view: &crate::data::RepositoryView) -> Result<()> {
let base_branch = match self.base.as_ref() {
Some(branch) => {
let primary_remote_name = repo_view
.remotes
.iter()
.find(|r| r.name == "origin")
.or_else(|| repo_view.remotes.first())
.map_or("origin", |r| r.name.as_str());
if branch.starts_with(&format!("{primary_remote_name}/")) {
branch.clone()
} else {
format!("{primary_remote_name}/{branch}")
}
}
None => {
repo_view
.remotes
.iter()
.find(|r| r.name == "origin")
.or_else(|| repo_view.remotes.first())
.map_or_else(
|| "unknown".to_string(),
|r| format!("{}/{}", r.name, r.main_branch),
)
}
};
let commit_range = format!("{base_branch}..HEAD");
let commit_count = repo_view.commits.len();
let current_branch = repo_view
.branch_info
.as_ref()
.map_or("unknown", |bi| bi.branch.as_str());
println!("📊 Branch Analysis:");
println!(" 🌿 Current branch: {current_branch}");
println!(" 📏 Commit range: {commit_range}");
println!(" 📝 Commits found: {commit_count} commits");
println!();
Ok(())
}
async fn collect_context(
&self,
repo_view: &crate::data::RepositoryView,
) -> Result<crate::data::context::CommitContext> {
use crate::claude::context::{
BranchAnalyzer, FileAnalyzer, ProjectDiscovery, WorkPatternAnalyzer,
};
use crate::data::context::{CommitContext, ProjectContext};
use crate::git::GitRepository;
let mut context = CommitContext::new();
let context_dir = crate::claude::context::resolve_context_dir(self.context_dir.as_deref());
let repo_root = std::path::PathBuf::from(".");
let discovery = ProjectDiscovery::new(repo_root, context_dir);
match discovery.discover() {
Ok(project_context) => {
context.project = project_context;
}
Err(_e) => {
context.project = ProjectContext::default();
}
}
let repo = GitRepository::open()?;
let current_branch = repo
.get_current_branch()
.unwrap_or_else(|_| "HEAD".to_string());
context.branch = BranchAnalyzer::analyze(¤t_branch).unwrap_or_default();
if !repo_view.commits.is_empty() {
context.range = WorkPatternAnalyzer::analyze_commit_range(&repo_view.commits);
}
if !repo_view.commits.is_empty() {
context.files = FileAnalyzer::analyze_commits(&repo_view.commits);
}
Ok(context)
}
fn show_guidance_files_status(
&self,
project_context: &crate::data::context::ProjectContext,
) -> Result<()> {
use crate::claude::context::{
config_source_label, resolve_context_dir_with_source, ConfigSourceLabel,
};
let (context_dir, dir_source) =
resolve_context_dir_with_source(self.context_dir.as_deref());
println!("📋 Project guidance files status:");
println!(" 📂 Config dir: {} ({dir_source})", context_dir.display());
let pr_guidelines_source = if project_context.pr_guidelines.is_some() {
match config_source_label(&context_dir, "pr-guidelines.md") {
ConfigSourceLabel::NotFound => "✅ (source unknown)".to_string(),
label => format!("✅ {label}"),
}
} else {
"❌ None found".to_string()
};
println!(" 🔀 PR guidelines: {pr_guidelines_source}");
let scopes_count = project_context.valid_scopes.len();
let scopes_source = if scopes_count > 0 {
match config_source_label(&context_dir, "scopes.yaml") {
ConfigSourceLabel::NotFound => {
format!("✅ (source unknown + ecosystem defaults) ({scopes_count} scopes)")
}
label => format!("✅ {label} ({scopes_count} scopes)"),
}
} else {
"❌ None found".to_string()
};
println!(" 🎯 Valid scopes: {scopes_source}");
let pr_template_path = std::path::Path::new(".github/pull_request_template.md");
let pr_template_status = if pr_template_path.exists() {
format!("✅ Project: {}", pr_template_path.display())
} else {
"❌ None found".to_string()
};
println!(" 📋 PR template: {pr_template_status}");
println!();
Ok(())
}
fn show_context_summary(&self, context: &crate::data::context::CommitContext) -> Result<()> {
use crate::data::context::{VerbosityLevel, WorkPattern};
println!("🔍 Context Analysis:");
if !context.project.valid_scopes.is_empty() {
let scope_names: Vec<&str> = context
.project
.valid_scopes
.iter()
.map(|s| s.name.as_str())
.collect();
println!(" 📁 Valid scopes: {}", scope_names.join(", "));
}
if context.branch.is_feature_branch {
println!(
" 🌿 Branch: {} ({})",
context.branch.description, context.branch.work_type
);
if let Some(ref ticket) = context.branch.ticket_id {
println!(" 🎫 Ticket: {ticket}");
}
}
match context.range.work_pattern {
WorkPattern::Sequential => println!(" 🔄 Pattern: Sequential development"),
WorkPattern::Refactoring => println!(" 🧹 Pattern: Refactoring work"),
WorkPattern::BugHunt => println!(" 🐛 Pattern: Bug investigation"),
WorkPattern::Documentation => println!(" 📖 Pattern: Documentation updates"),
WorkPattern::Configuration => println!(" ⚙️ Pattern: Configuration changes"),
WorkPattern::Unknown => {}
}
if let Some(label) = super::formatting::format_file_analysis(&context.files) {
println!(" {label}");
}
match context.suggested_verbosity() {
VerbosityLevel::Comprehensive => {
println!(" 📝 Detail level: Comprehensive (significant changes detected)");
}
VerbosityLevel::Detailed => println!(" 📝 Detail level: Detailed"),
VerbosityLevel::Concise => println!(" 📝 Detail level: Concise"),
}
println!();
Ok(())
}
async fn generate_pr_content_with_client_internal(
&self,
repo_view: &crate::data::RepositoryView,
claude_client: crate::claude::client::ClaudeClient,
) -> Result<(PrContent, crate::claude::client::ClaudeClient)> {
use tracing::debug;
let pr_template = match &repo_view.pr_template {
Some(template) => template.clone(),
None => self.get_default_pr_template(),
};
debug!(
pr_template_length = pr_template.len(),
pr_template_preview = %pr_template.lines().take(5).collect::<Vec<_>>().join("\\n"),
"Using PR template for generation"
);
println!("🤖 Generating AI-powered PR description...");
debug!("Collecting context for PR generation");
let context = self.collect_context(repo_view).await?;
debug!("Context collection completed");
debug!("About to call Claude AI for PR content generation");
match claude_client
.generate_pr_content_with_context(repo_view, &pr_template, &context)
.await
{
Ok(pr_content) => {
debug!(
ai_generated_title = %pr_content.title,
ai_generated_description_length = pr_content.description.len(),
ai_generated_description_preview = %pr_content.description.lines().take(3).collect::<Vec<_>>().join("\\n"),
"AI successfully generated PR content"
);
Ok((pr_content, claude_client))
}
Err(e) => {
debug!(error = %e, "AI PR generation failed, falling back to basic description");
let mut description = pr_template;
self.enhance_description_with_commits(&mut description, repo_view)?;
let title = self.generate_title_from_commits(repo_view);
debug!(
fallback_title = %title,
fallback_description_length = description.len(),
"Created fallback PR content"
);
Ok((PrContent { title, description }, claude_client))
}
}
}
fn get_default_pr_template(&self) -> String {
r#"# Pull Request
## Description
<!-- Provide a brief description of what this PR does -->
## Type of Change
<!-- Mark the relevant option with an "x" -->
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] Documentation update
- [ ] Refactoring (no functional changes)
- [ ] Performance improvement
- [ ] Test coverage improvement
## Changes Made
<!-- List the specific changes made in this PR -->
-
-
-
## Testing
- [ ] All existing tests pass
- [ ] New tests added for new functionality
- [ ] Manual testing performed
## Additional Notes
<!-- Add any additional notes for reviewers -->
"#.to_string()
}
fn enhance_description_with_commits(
&self,
description: &mut String,
repo_view: &crate::data::RepositoryView,
) -> Result<()> {
if repo_view.commits.is_empty() {
return Ok(());
}
description.push_str("\n---\n");
description.push_str("## 📝 Commit Summary\n");
description
.push_str("*This section was automatically generated based on commit analysis*\n\n");
let mut types_found = std::collections::HashSet::new();
let mut scopes_found = std::collections::HashSet::new();
let mut has_breaking_changes = false;
for commit in &repo_view.commits {
let detected_type = &commit.analysis.detected_type;
types_found.insert(detected_type.clone());
if is_breaking_change(detected_type, &commit.original_message) {
has_breaking_changes = true;
}
let detected_scope = &commit.analysis.detected_scope;
if !detected_scope.is_empty() {
scopes_found.insert(detected_scope.clone());
}
}
if types_found.contains("feat") {
check_checkbox(description, "- [ ] New feature");
}
if types_found.contains("fix") {
check_checkbox(description, "- [ ] Bug fix");
}
if types_found.contains("docs") {
check_checkbox(description, "- [ ] Documentation update");
}
if types_found.contains("refactor") {
check_checkbox(description, "- [ ] Refactoring");
}
if has_breaking_changes {
check_checkbox(description, "- [ ] Breaking change");
}
let scopes_list: Vec<_> = scopes_found.into_iter().collect();
let scopes_section = format_scopes_section(&scopes_list);
if !scopes_section.is_empty() {
description.push_str(&scopes_section);
}
let commit_entries: Vec<(&str, &str)> = repo_view
.commits
.iter()
.map(|c| {
let short = &c.hash[..crate::git::SHORT_HASH_LEN];
let first = extract_first_line(&c.original_message);
(short, first)
})
.collect();
description.push_str(&format_commit_list(&commit_entries));
let total_files: usize = repo_view
.commits
.iter()
.map(|c| c.analysis.file_changes.total_files)
.sum();
if total_files > 0 {
description.push_str(&format!("\n**Files changed:** {total_files} files\n"));
}
Ok(())
}
fn handle_pr_file(
&self,
pr_file: &std::path::Path,
repo_view: &crate::data::RepositoryView,
) -> Result<PrAction> {
use std::io::{self, Write};
println!("\n📝 PR details generated.");
println!("💾 Details saved to: {}", pr_file.display());
let is_draft = self.should_create_as_draft();
let (status_icon, status_text) = format_draft_status(is_draft);
println!("{status_icon} PR will be created as: {status_text}");
println!();
let has_existing_prs = repo_view
.branch_prs
.as_ref()
.is_some_and(|prs| !prs.is_empty());
loop {
if has_existing_prs {
print!("❓ [U]pdate existing PR, [N]ew PR anyway, [S]how file, [E]dit file, or [Q]uit? [U/n/s/e/q] ");
} else {
print!(
"❓ [A]ccept and create PR, [S]how file, [E]dit file, or [Q]uit? [A/s/e/q] "
);
}
io::stdout().flush()?;
let mut input = String::new();
io::stdin().read_line(&mut input)?;
match input.trim().to_lowercase().as_str() {
"u" | "update" if has_existing_prs => return Ok(PrAction::UpdateExisting),
"n" | "new" if has_existing_prs => return Ok(PrAction::CreateNew),
"a" | "accept" | "" if !has_existing_prs => return Ok(PrAction::CreateNew),
"s" | "show" => {
self.show_pr_file(pr_file)?;
println!();
}
"e" | "edit" => {
self.edit_pr_file(pr_file)?;
println!();
}
"q" | "quit" => return Ok(PrAction::Cancel),
_ => {
if has_existing_prs {
println!("Invalid choice. Please enter 'u' to update existing PR, 'n' for new PR, 's' to show, 'e' to edit, or 'q' to quit.");
} else {
println!("Invalid choice. Please enter 'a' to accept, 's' to show, 'e' to edit, or 'q' to quit.");
}
}
}
}
}
fn show_pr_file(&self, pr_file: &std::path::Path) -> Result<()> {
use std::fs;
println!("\n📄 PR details file contents:");
println!("─────────────────────────────");
let contents = fs::read_to_string(pr_file).context("Failed to read PR details file")?;
println!("{contents}");
println!("─────────────────────────────");
Ok(())
}
fn edit_pr_file(&self, pr_file: &std::path::Path) -> Result<()> {
use std::env;
use std::io::{self, Write};
use std::process::Command;
let editor = if let Ok(e) = env::var("OMNI_DEV_EDITOR").or_else(|_| env::var("EDITOR")) {
e
} else {
println!("🔧 Neither OMNI_DEV_EDITOR nor EDITOR environment variables are defined.");
print!("Please enter the command to use as your editor: ");
io::stdout().flush().context("Failed to flush stdout")?;
let mut input = String::new();
io::stdin()
.read_line(&mut input)
.context("Failed to read user input")?;
input.trim().to_string()
};
if editor.is_empty() {
println!("❌ No editor specified. Returning to menu.");
return Ok(());
}
println!("📝 Opening PR details file in editor: {editor}");
let (editor_cmd, args) = super::formatting::parse_editor_command(&editor);
let mut command = Command::new(editor_cmd);
command.args(args);
command.arg(pr_file.to_string_lossy().as_ref());
match command.status() {
Ok(status) => {
if status.success() {
println!("✅ Editor session completed.");
} else {
println!(
"⚠️ Editor exited with non-zero status: {:?}",
status.code()
);
}
}
Err(e) => {
println!("❌ Failed to execute editor '{editor}': {e}");
println!(" Please check that the editor command is correct and available in your PATH.");
}
}
Ok(())
}
fn generate_title_from_commits(&self, repo_view: &crate::data::RepositoryView) -> String {
if repo_view.commits.is_empty() {
return "Pull Request".to_string();
}
if repo_view.commits.len() == 1 {
let first = extract_first_line(&repo_view.commits[0].original_message);
let trimmed = first.trim();
return if trimmed.is_empty() {
"Pull Request".to_string()
} else {
trimmed.to_string()
};
}
let branch_name = repo_view
.branch_info
.as_ref()
.map_or("feature", |bi| bi.branch.as_str());
format!("feat: {}", clean_branch_name(branch_name))
}
fn create_github_pr(
&self,
repo_view: &crate::data::RepositoryView,
title: &str,
description: &str,
is_draft: bool,
new_base: Option<&str>,
) -> Result<()> {
use std::process::Command;
let branch_name = repo_view
.branch_info
.as_ref()
.map(|bi| &bi.branch)
.context("Branch info not available")?;
let pr_status = if is_draft {
"draft"
} else {
"ready for review"
};
println!("🚀 Creating pull request ({pr_status})...");
println!(" 📋 Title: {title}");
println!(" 🌿 Branch: {branch_name}");
if let Some(base) = new_base {
println!(" 🎯 Base: {base}");
}
let push_action = if self.no_push {
determine_push_action(true, false)
} else {
debug!("Opening git repository to check branch status");
let git_repo =
crate::git::GitRepository::open().context("Failed to open git repository")?;
debug!(
"Checking if branch '{}' exists on remote 'origin'",
branch_name
);
let branch_on_remote = git_repo.branch_exists_on_remote(branch_name, "origin")?;
let action = determine_push_action(false, branch_on_remote);
debug!("Push action for branch '{}': {:?}", branch_name, action);
println!("📤 Pushing branch to remote...");
git_repo
.push_branch(branch_name, "origin")
.context("Failed to push branch to remote")?;
action
};
if push_action == PushAction::Skip {
debug!("Skipping push (--no-push flag set)");
}
debug!("Creating PR with gh CLI - title: '{}'", title);
debug!("PR description length: {} characters", description.len());
debug!("PR draft status: {}", is_draft);
if let Some(base) = new_base {
debug!("PR base branch: {}", base);
}
let mut args = vec![
"pr",
"create",
"--head",
branch_name,
"--title",
title,
"--body",
description,
];
if let Some(base) = new_base {
args.push("--base");
args.push(base);
}
if is_draft {
args.push("--draft");
}
let pr_result = Command::new("gh")
.args(&args)
.output()
.context("Failed to create pull request")?;
if pr_result.status.success() {
let pr_url = String::from_utf8_lossy(&pr_result.stdout);
let pr_url = pr_url.trim();
debug!("PR created successfully with URL: {}", pr_url);
println!("🎉 Pull request created: {pr_url}");
} else {
let error_msg = String::from_utf8_lossy(&pr_result.stderr);
error!("gh CLI failed to create PR: {}", error_msg);
anyhow::bail!("Failed to create pull request: {error_msg}");
}
Ok(())
}
fn update_github_pr(
&self,
repo_view: &crate::data::RepositoryView,
title: &str,
description: &str,
new_base: Option<&str>,
) -> Result<()> {
use std::io::{self, Write};
use std::process::Command;
let existing_pr = repo_view
.branch_prs
.as_ref()
.and_then(|prs| prs.first())
.context("No existing PR found to update")?;
let pr_number = existing_pr.number;
let current_base = &existing_pr.base;
println!("🚀 Updating pull request #{pr_number}...");
println!(" 📋 Title: {title}");
let change_base = if let Some(base) = new_base {
if !current_base.is_empty() && current_base != base {
print!(" 🎯 Current base: {current_base} → New base: {base}. Change? [y/N]: ");
io::stdout().flush()?;
let mut input = String::new();
io::stdin().read_line(&mut input)?;
let response = input.trim().to_lowercase();
response == "y" || response == "yes"
} else {
false
}
} else {
false
};
debug!(
pr_number = pr_number,
title = %title,
description_length = description.len(),
description_preview = %description.lines().take(3).collect::<Vec<_>>().join("\\n"),
change_base = change_base,
"Updating GitHub PR with title and description"
);
let pr_number_str = pr_number.to_string();
let mut gh_args = vec![
"pr",
"edit",
&pr_number_str,
"--title",
title,
"--body",
description,
];
if change_base {
if let Some(base) = new_base {
gh_args.push("--base");
gh_args.push(base);
}
}
debug!(
args = ?gh_args,
"Executing gh command to update PR"
);
let pr_result = Command::new("gh")
.args(&gh_args)
.output()
.context("Failed to update pull request")?;
if pr_result.status.success() {
println!("🎉 Pull request updated: {}", existing_pr.url);
if change_base {
if let Some(base) = new_base {
println!(" 🎯 Base branch changed to: {base}");
}
}
} else {
let error_msg = String::from_utf8_lossy(&pr_result.stderr);
anyhow::bail!("Failed to update pull request: {error_msg}");
}
Ok(())
}
fn show_model_info_from_client(
&self,
client: &crate::claude::client::ClaudeClient,
) -> Result<()> {
use crate::claude::model_config::get_model_registry;
println!("🤖 AI Model Configuration:");
let metadata = client.get_ai_client_metadata();
let registry = get_model_registry();
if let Some(spec) = registry.get_model_spec(&metadata.model) {
if metadata.model != spec.api_identifier {
println!(
" 📡 Model: {} → \x1b[33m{}\x1b[0m",
metadata.model, spec.api_identifier
);
} else {
println!(" 📡 Model: \x1b[33m{}\x1b[0m", metadata.model);
}
println!(" 🏷️ Provider: {}", spec.provider);
println!(" 📊 Generation: {}", spec.generation);
println!(" ⭐ Tier: {} ({})", spec.tier, {
if let Some(tier_info) = registry.get_tier_info(&spec.provider, &spec.tier) {
&tier_info.description
} else {
"No description available"
}
});
println!(" 📤 Max output tokens: {}", metadata.max_response_length);
println!(" 📥 Input context: {}", metadata.max_context_length);
if let Some((ref key, ref value)) = metadata.active_beta {
println!(" 🔬 Beta header: {key}: {value}");
}
if spec.legacy {
println!(" ⚠️ Legacy model (consider upgrading to newer version)");
}
} else {
println!(" 📡 Model: \x1b[33m{}\x1b[0m", metadata.model);
println!(" 🏷️ Provider: {}", metadata.provider);
println!(" ⚠️ Model not found in registry, using client metadata:");
println!(" 📤 Max output tokens: {}", metadata.max_response_length);
println!(" 📥 Input context: {}", metadata.max_context_length);
}
println!();
Ok(())
}
}
#[derive(Debug, PartialEq)]
enum PushAction {
Skip,
SyncExisting,
PushNew,
}
fn determine_push_action(no_push: bool, branch_on_remote: bool) -> PushAction {
if no_push {
PushAction::Skip
} else if branch_on_remote {
PushAction::SyncExisting
} else {
PushAction::PushNew
}
}
fn parse_bool_string(val: &str) -> Option<bool> {
match val.to_lowercase().as_str() {
"true" | "1" | "yes" => Some(true),
"false" | "0" | "no" => Some(false),
_ => None,
}
}
fn is_breaking_change(detected_type: &str, original_message: &str) -> bool {
detected_type.contains("BREAKING") || original_message.contains("BREAKING CHANGE")
}
fn check_checkbox(description: &mut String, search_text: &str) {
if let Some(pos) = description.find(search_text) {
description.replace_range(pos..pos + 5, "- [x]");
}
}
fn format_scopes_section(scopes: &[String]) -> String {
if scopes.is_empty() {
return String::new();
}
format!("**Affected areas:** {}\n\n", scopes.join(", "))
}
fn format_commit_list(entries: &[(&str, &str)]) -> String {
let mut output = String::from("### Commits in this PR:\n");
for (hash, message) in entries {
output.push_str(&format!("- `{hash}` {message}\n"));
}
output
}
fn clean_branch_name(branch: &str) -> String {
branch.replace(['/', '-', '_'], " ")
}
fn extract_first_line(text: &str) -> &str {
text.lines().next().unwrap_or("").trim()
}
fn format_draft_status(is_draft: bool) -> (&'static str, &'static str) {
if is_draft {
("\u{1f4cb}", "draft")
} else {
("\u{2705}", "ready for review")
}
}
#[derive(Debug, Clone)]
pub struct CreatePrOutcome {
pub title: String,
pub description: String,
pub pr_yaml: String,
}
pub async fn run_create_pr(
model: Option<String>,
base_branch: Option<&str>,
repo_path: Option<&std::path::Path>,
) -> Result<CreatePrOutcome> {
let _cwd_guard = match repo_path {
Some(p) => Some(super::CwdGuard::enter(p).await?),
None => None,
};
crate::utils::check_pr_command_prerequisites(model.as_deref())?;
let cmd = CreatePrCommand {
base: base_branch.map(str::to_string),
model: model.clone(),
auto_apply: true,
save_only: None,
ready: false,
draft: false,
context_dir: None,
no_push: true,
};
let repo_view = cmd.generate_repository_view()?;
let context = cmd.collect_context(&repo_view).await?;
let claude_client = crate::claude::create_default_claude_client(model, None)?;
run_create_pr_with_client(&cmd, &repo_view, &context, &claude_client).await
}
pub(crate) async fn run_create_pr_with_client(
cmd: &CreatePrCommand,
repo_view: &crate::data::RepositoryView,
context: &crate::data::context::CommitContext,
claude_client: &crate::claude::client::ClaudeClient,
) -> Result<CreatePrOutcome> {
let pr_template = match &repo_view.pr_template {
Some(template) => template.clone(),
None => cmd.get_default_pr_template(),
};
let pr_content = match claude_client
.generate_pr_content_with_context(repo_view, &pr_template, context)
.await
{
Ok(content) => content,
Err(_e) => {
let mut description = pr_template;
cmd.enhance_description_with_commits(&mut description, repo_view)?;
let title = cmd.generate_title_from_commits(repo_view);
PrContent { title, description }
}
};
let pr_yaml = crate::data::to_yaml(&pr_content).context("Failed to serialise PrContent")?;
Ok(CreatePrOutcome {
title: pr_content.title,
description: pr_content.description,
pr_yaml,
})
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod run_create_pr_tests {
use super::*;
use crate::claude::client::ClaudeClient;
use crate::claude::test_utils::ConfigurableMockAiClient;
use crate::data::context::CommitContext;
use crate::data::{
AiInfo, BranchInfo, FieldExplanation, RepositoryView, VersionInfo, WorkingDirectoryInfo,
};
use crate::git::commit::FileChanges;
use crate::git::{CommitAnalysis, CommitInfo};
#[tokio::test]
async fn run_create_pr_invalid_repo_path_errors_before_ai() {
let err = run_create_pr(
None,
None,
Some(std::path::Path::new("/no/such/path/exists")),
)
.await
.unwrap_err();
let msg = format!("{err:#}");
assert!(
msg.to_lowercase().contains("set_current_dir")
|| msg.to_lowercase().contains("no such")
|| msg.to_lowercase().contains("directory"),
"expected cwd-related error, got: {msg}"
);
}
fn fresh_cmd() -> CreatePrCommand {
CreatePrCommand {
base: None,
model: None,
auto_apply: true,
save_only: None,
ready: false,
draft: false,
context_dir: None,
no_push: true,
}
}
fn sample_commit(hash: &str, message: &str) -> (CommitInfo, tempfile::NamedTempFile) {
let tmp = tempfile::NamedTempFile::new().unwrap();
let commit = CommitInfo {
hash: hash.to_string(),
author: "Test <test@test.com>".to_string(),
date: chrono::Utc::now().fixed_offset(),
original_message: message.to_string(),
in_main_branches: vec![],
analysis: CommitAnalysis {
detected_type: "feat".to_string(),
detected_scope: String::new(),
proposed_message: message.to_string(),
file_changes: FileChanges {
total_files: 0,
files_added: 0,
files_deleted: 0,
file_list: vec![],
},
diff_summary: String::new(),
diff_file: tmp.path().to_string_lossy().to_string(),
file_diffs: Vec::new(),
},
};
(commit, tmp)
}
fn sample_repo_view(commits: Vec<CommitInfo>, pr_template: Option<String>) -> RepositoryView {
RepositoryView {
versions: Some(VersionInfo {
omni_dev: "0.0.0".to_string(),
}),
explanation: FieldExplanation::default(),
working_directory: WorkingDirectoryInfo {
clean: true,
untracked_changes: vec![],
},
remotes: vec![],
ai: AiInfo {
scratch: String::new(),
},
branch_info: Some(BranchInfo {
branch: "feature/test".to_string(),
}),
pr_template,
pr_template_location: None,
branch_prs: None,
commits,
}
}
#[tokio::test]
async fn run_create_pr_with_client_ai_success_returns_content() {
let (c1, _tmp) = sample_commit("abcdef00", "feat: work");
let repo_view = sample_repo_view(vec![c1], None);
let context = CommitContext::new();
let cmd = fresh_cmd();
let yaml = "title: My PR\ndescription: |\n Body text\n".to_string();
let mock = ConfigurableMockAiClient::new(vec![Ok(yaml)]);
let client = ClaudeClient::new(Box::new(mock));
let outcome = run_create_pr_with_client(&cmd, &repo_view, &context, &client)
.await
.unwrap();
assert_eq!(outcome.title, "My PR");
assert!(outcome.description.contains("Body text"));
assert!(outcome.pr_yaml.contains("title:"));
}
#[tokio::test]
async fn run_create_pr_with_client_ai_failure_falls_back_to_commit_summary() {
let (c1, _tmp) = sample_commit("abcdef00", "feat: single commit subject");
let repo_view = sample_repo_view(vec![c1], None);
let context = CommitContext::new();
let cmd = fresh_cmd();
let mock = ConfigurableMockAiClient::new(vec![]);
let client = ClaudeClient::new(Box::new(mock));
let outcome = run_create_pr_with_client(&cmd, &repo_view, &context, &client)
.await
.unwrap();
assert!(
outcome.title.contains("feat: single commit subject")
|| outcome.title.contains("Pull Request")
|| outcome.title.contains("feature/test"),
"fallback title unexpected: {}",
outcome.title
);
}
#[tokio::test]
async fn run_create_pr_with_client_uses_repo_template_when_present() {
let (c1, _tmp) = sample_commit("abcdef00", "feat: x");
let repo_view = sample_repo_view(vec![c1], Some("# Custom template\n".to_string()));
let context = CommitContext::new();
let cmd = fresh_cmd();
let mock = ConfigurableMockAiClient::new(vec![]);
let client = ClaudeClient::new(Box::new(mock));
let outcome = run_create_pr_with_client(&cmd, &repo_view, &context, &client)
.await
.unwrap();
assert!(
outcome.description.contains("# Custom template"),
"fallback description should include repo template: {}",
outcome.description
);
}
#[test]
fn create_pr_outcome_clone_and_debug() {
let outcome = CreatePrOutcome {
title: "t".to_string(),
description: "d".to_string(),
pr_yaml: "y".to_string(),
};
let cloned = outcome.clone();
assert_eq!(format!("{outcome:?}"), format!("{cloned:?}"));
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_bool_true_variants() {
assert_eq!(parse_bool_string("true"), Some(true));
assert_eq!(parse_bool_string("1"), Some(true));
assert_eq!(parse_bool_string("yes"), Some(true));
}
#[test]
fn parse_bool_false_variants() {
assert_eq!(parse_bool_string("false"), Some(false));
assert_eq!(parse_bool_string("0"), Some(false));
assert_eq!(parse_bool_string("no"), Some(false));
}
#[test]
fn parse_bool_invalid() {
assert_eq!(parse_bool_string("maybe"), None);
assert_eq!(parse_bool_string(""), None);
}
#[test]
fn parse_bool_case_insensitive() {
assert_eq!(parse_bool_string("TRUE"), Some(true));
assert_eq!(parse_bool_string("Yes"), Some(true));
assert_eq!(parse_bool_string("FALSE"), Some(false));
assert_eq!(parse_bool_string("No"), Some(false));
}
#[test]
fn breaking_change_type_contains() {
assert!(is_breaking_change("BREAKING", "normal message"));
}
#[test]
fn breaking_change_message_contains() {
assert!(is_breaking_change("feat", "BREAKING CHANGE: removed API"));
}
#[test]
fn breaking_change_none() {
assert!(!is_breaking_change("feat", "add new feature"));
}
#[test]
fn check_checkbox_found() {
let mut desc = "- [ ] New feature\n- [ ] Bug fix".to_string();
check_checkbox(&mut desc, "- [ ] New feature");
assert!(desc.contains("- [x] New feature"));
assert!(desc.contains("- [ ] Bug fix"));
}
#[test]
fn check_checkbox_not_found() {
let mut desc = "- [ ] Bug fix".to_string();
let original = desc.clone();
check_checkbox(&mut desc, "- [ ] New feature");
assert_eq!(desc, original);
}
#[test]
fn scopes_section_single() {
let scopes = vec!["cli".to_string()];
assert_eq!(
format_scopes_section(&scopes),
"**Affected areas:** cli\n\n"
);
}
#[test]
fn scopes_section_multiple() {
let scopes = vec!["cli".to_string(), "git".to_string()];
let result = format_scopes_section(&scopes);
assert!(result.contains("cli"));
assert!(result.contains("git"));
assert!(result.starts_with("**Affected areas:**"));
}
#[test]
fn scopes_section_empty() {
assert_eq!(format_scopes_section(&[]), "");
}
#[test]
fn commit_list_formatting() {
let entries = vec![
("abc12345", "feat: add feature"),
("def67890", "fix: resolve bug"),
];
let result = format_commit_list(&entries);
assert!(result.contains("### Commits in this PR:"));
assert!(result.contains("- `abc12345` feat: add feature"));
assert!(result.contains("- `def67890` fix: resolve bug"));
}
#[test]
fn clean_branch_simple() {
assert_eq!(clean_branch_name("feat/add-login"), "feat add login");
}
#[test]
fn clean_branch_underscores() {
assert_eq!(clean_branch_name("user_name/fix_bug"), "user name fix bug");
}
#[test]
fn first_line_multiline() {
assert_eq!(extract_first_line("first\nsecond\nthird"), "first");
}
#[test]
fn first_line_single() {
assert_eq!(extract_first_line("only line"), "only line");
}
#[test]
fn first_line_empty() {
assert_eq!(extract_first_line(""), "");
}
#[test]
fn draft_status_true() {
let (icon, text) = format_draft_status(true);
assert_eq!(text, "draft");
assert!(!icon.is_empty());
}
#[test]
fn draft_status_false() {
let (icon, text) = format_draft_status(false);
assert_eq!(text, "ready for review");
assert!(!icon.is_empty());
}
#[test]
fn push_action_skip_when_no_push() {
assert_eq!(determine_push_action(true, false), PushAction::Skip);
assert_eq!(determine_push_action(true, true), PushAction::Skip);
}
#[test]
fn push_action_sync_existing_branch() {
assert_eq!(determine_push_action(false, true), PushAction::SyncExisting);
}
#[test]
fn push_action_push_new_branch() {
assert_eq!(determine_push_action(false, false), PushAction::PushNew);
}
}