use crate::error::{Result, ToriiError};
use serde::{Deserialize, Serialize};
pub(crate) use super::azure::pr::{parse_azure_url, split_azure_owner};
use super::azure::AzurePrClient;
use super::bitbucket::BitbucketPrClient;
pub use super::gitea::pr::{gitea_base_url, resolve_gitea_token};
use super::gitea::GiteaPrClient;
use super::github::GitHubPrClient;
use super::gitlab::GitLabPrClient;
use super::radicle::RadiclePrClient;
use super::sourcehut::SourcehutPrClient;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PullRequest {
pub number: u64,
pub title: String,
pub body: Option<String>,
pub state: String,
pub head: String,
pub base: String,
pub author: String,
pub url: String,
pub draft: bool,
pub mergeable: Option<bool>,
pub created_at: String,
}
#[derive(Debug, Clone)]
pub struct CreatePrOptions {
pub title: String,
pub body: Option<String>,
pub head: String,
pub base: String,
pub draft: bool,
}
#[derive(Debug, Clone)]
pub enum MergeMethod {
Merge,
Squash,
Rebase,
}
impl std::fmt::Display for MergeMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MergeMethod::Merge => write!(f, "merge"),
MergeMethod::Squash => write!(f, "squash"),
MergeMethod::Rebase => write!(f, "rebase"),
}
}
}
pub struct UpdatePrOptions {
pub title: Option<String>,
pub body: Option<String>,
pub base: Option<String>,
}
#[allow(dead_code)]
pub trait PrClient: Send {
fn create(&self, owner: &str, repo: &str, opts: CreatePrOptions) -> Result<PullRequest>;
fn list(&self, owner: &str, repo: &str, state: &str) -> Result<Vec<PullRequest>>;
fn get(&self, owner: &str, repo: &str, number: u64) -> Result<PullRequest>;
fn merge(&self, owner: &str, repo: &str, number: u64, method: MergeMethod) -> Result<()>;
fn close(&self, owner: &str, repo: &str, number: u64) -> Result<()>;
fn update(&self, owner: &str, repo: &str, number: u64, opts: UpdatePrOptions) -> Result<()>;
fn delete_branch(&self, owner: &str, repo: &str, branch: &str) -> Result<()>;
fn checkout_branch(&self, pr: &PullRequest) -> String;
}
pub fn get_pr_client(platform: &str) -> Result<Box<dyn PrClient>> {
match platform.to_lowercase().as_str() {
"github" => Ok(Box::new(GitHubPrClient::new()?)),
"gitlab" => Ok(Box::new(GitLabPrClient::new()?)),
"gitea" => Ok(Box::new(GiteaPrClient::new()?)),
"sourcehut" => Ok(Box::new(SourcehutPrClient::new()?)),
"radicle" => Ok(Box::new(RadiclePrClient::new()?)),
"bitbucket" => Ok(Box::new(BitbucketPrClient::new()?)),
"azure" => Ok(Box::new(AzurePrClient::new()?)),
other => Err(ToriiError::Unsupported(format!("Unsupported platform: {}. Supported: github, gitlab, gitea, sourcehut, radicle, bitbucket, azure", other))),
}
}
pub fn detect_platform_from_remote(repo_path: &str) -> Option<(String, String, String)> {
detect_platform_from_remote_named(repo_path, "origin")
}
pub fn detect_platform_full(
repo_path: &str,
remote_name: &str,
) -> Option<(String, String, String, String)> {
let (platform, owner, repo) = detect_platform_from_remote_named(repo_path, remote_name)?;
let api_base_url = resolve_api_base_url(repo_path, remote_name, &platform);
Some((platform, owner, repo, api_base_url))
}
fn resolve_api_base_url(repo_path: &str, remote_name: &str, platform: &str) -> String {
if let Ok(repo) = git2::Repository::discover(repo_path) {
if let Ok(rem) = repo.find_remote(remote_name) {
if let Some(url) = rem.url() {
if let Some(host) = extract_host(url) {
if let Some(entry) = crate::platforms_registry::find_by_host(repo_path, &host) {
return entry.api_base_url;
}
}
}
}
}
match platform {
"github" => "https://api.github.com".to_string(),
"gitlab" => "https://gitlab.com/api/v4".to_string(),
"gitea" => "https://codeberg.org/api/v1".to_string(),
"bitbucket" => "https://api.bitbucket.org/2.0".to_string(),
_ => String::new(),
}
}
fn extract_host(url: &str) -> Option<String> {
if let Some(rest) = url
.strip_prefix("https://")
.or_else(|| url.strip_prefix("http://"))
{
let host = rest.split(['/', ':']).next()?;
return Some(host.to_string());
}
if let Some(rest) = url.strip_prefix("ssh://") {
let after_user = rest.split('@').last()?;
let host = after_user.split([':', '/']).next()?;
return Some(host.to_string());
}
if let Some(at) = url.find('@') {
if let Some(colon) = url[at + 1..].find(':') {
return Some(url[at + 1..at + 1 + colon].to_string());
}
}
None
}
fn extract_owner_repo(url: &str) -> Option<(String, String)> {
let path_part: String = if let Some(at) = url.find('@') {
url[at + 1..].splitn(2, ':').nth(1)?.to_string()
} else if let Some(after_scheme) = url.split("://").nth(1) {
after_scheme.splitn(2, '/').nth(1)?.to_string()
} else {
url.to_string()
};
let cleaned = path_part.trim_end_matches('/').trim_end_matches(".git");
let segments: Vec<&str> = cleaned.split('/').filter(|s| !s.is_empty()).collect();
if segments.len() < 2 {
return None;
}
let repo = segments.last()?.to_string();
let owner_segments = &segments[..segments.len() - 1];
let owner = owner_segments.join("/");
Some((owner, repo))
}
pub fn detect_platform_from_remote_named(
repo_path: &str,
remote_name: &str,
) -> Option<(String, String, String)> {
let repo = git2::Repository::discover(repo_path).ok()?;
let remote = repo.find_remote(remote_name).ok()?;
let url = remote.url()?.to_string();
let platform = if url.contains("github.com") {
"github"
} else if url.contains("gitlab.com") {
"gitlab"
} else if url.contains("codeberg.org") {
"gitea"
} else if url.contains("git.sr.ht") {
"sourcehut"
} else if url.starts_with("rad://") || url.starts_with("rad@") {
"radicle"
} else if url.contains("bitbucket.org") {
"bitbucket"
} else if url.contains("dev.azure.com") || url.contains(".visualstudio.com") {
"azure"
} else {
if let Some(host) = extract_host(&url) {
if let Some(entry) = crate::platforms_registry::find_by_host(repo_path, &host) {
let mapped: &str = match entry.kind.as_str() {
"gitlab" => "gitlab",
"github" | "github_enterprise" => "github",
"gitea" | "forgejo" | "codeberg" => "gitea",
"bitbucket" | "bitbucket_data_center" => "bitbucket",
other => other,
};
let static_kind: &'static str = match mapped {
"gitlab" => "gitlab",
"github" => "github",
"gitea" => "gitea",
"bitbucket" => "bitbucket",
_ => return None,
};
let owner_repo = extract_owner_repo(&url)?;
return Some((static_kind.to_string(), owner_repo.0, owner_repo.1));
}
}
return None;
};
if platform == "radicle" {
let rid = url
.trim_start_matches("rad://")
.trim_start_matches("rad@")
.split('/')
.last()?
.trim_end_matches(".git")
.to_string();
return Some((platform.to_string(), rid, String::new()));
}
if platform == "azure" {
let (org, project, repo_name) = parse_azure_url(&url)?;
return Some((
platform.to_string(),
format!("{}/{}", org, project),
repo_name,
));
}
let path = if url.contains('@') {
url.splitn(2, ':').nth(1)?
} else {
url.trim_start_matches("https://")
.trim_start_matches("http://")
.splitn(2, '/')
.nth(1)?
};
let path = path.trim_end_matches(".git");
let mut parts = path.splitn(2, '/');
let owner = parts.next()?.to_string();
let repo_name = parts.next()?.to_string();
Some((platform.to_string(), owner, repo_name))
}