use anyhow::{Context, Result};
use stkd_core::Repository;
use stkd_provider_api::{Provider, RepoId, RepositoryProvider};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProviderType {
GitHub,
GitLab,
}
impl std::fmt::Display for ProviderType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ProviderType::GitHub => write!(f, "GitHub"),
ProviderType::GitLab => write!(f, "GitLab"),
}
}
}
pub struct ProviderContext {
pub repo_id: RepoId,
pub provider_type: ProviderType,
provider: Box<dyn Provider>,
}
impl ProviderContext {
pub async fn from_repo(repo: &Repository) -> Result<Self> {
let remote_url = get_remote_url(repo)?;
let (provider_type, host) = detect_provider(&remote_url)?;
let (provider, repo_id): (Box<dyn Provider>, RepoId) = match provider_type {
ProviderType::GitHub => {
let auth_token = stkd_github::auth::load_credentials()?
.ok_or_else(|| anyhow::anyhow!("Not authenticated with GitHub. Run 'gt auth --github' first."))?;
let provider = stkd_github::GitHubProvider::with_auth(auth_token.to_auth())?;
let repo_id = provider.parse_remote_url(&remote_url)
.ok_or_else(|| anyhow::anyhow!("Could not parse GitHub repository from remote URL"))?;
(Box::new(provider), repo_id)
}
ProviderType::GitLab => {
let auth_token = stkd_gitlab::auth::load_credentials(&host)?
.ok_or_else(|| anyhow::anyhow!("Not authenticated with GitLab. Run 'gt auth --gitlab' first."))?;
let provider = if host == "gitlab.com" {
stkd_gitlab::GitLabProvider::new(auth_token.token)?
} else {
stkd_gitlab::GitLabProvider::with_host(auth_token.token, &host)?
};
let repo_id = provider.parse_remote_url(&remote_url)
.ok_or_else(|| anyhow::anyhow!("Could not parse GitLab repository from remote URL"))?;
(Box::new(provider), repo_id)
}
};
Ok(Self {
repo_id,
provider_type,
provider,
})
}
pub fn provider(&self) -> &dyn Provider {
self.provider.as_ref()
}
pub fn full_name(&self) -> String {
self.repo_id.full_name()
}
}
fn get_remote_url(repo: &Repository) -> Result<String> {
let git = repo.git();
if let Ok(remote) = git.find_remote("origin") {
if let Some(url) = remote.url() {
return Ok(url.to_string());
}
}
let remotes = git.remotes().context("Failed to list remotes")?;
for remote_name in remotes.iter().flatten() {
if let Ok(remote) = git.find_remote(remote_name) {
if let Some(url) = remote.url() {
return Ok(url.to_string());
}
}
}
anyhow::bail!("No git remote found. Add a remote with 'git remote add origin <url>'")
}
fn detect_provider(url: &str) -> Result<(ProviderType, String)> {
if url.contains("github.com") {
return Ok((ProviderType::GitHub, "github.com".to_string()));
}
if url.contains("gitlab.com") {
return Ok((ProviderType::GitLab, "gitlab.com".to_string()));
}
if let Some(rest) = url.strip_prefix("git@") {
if let Some(colon_pos) = rest.find(':') {
let host = &rest[..colon_pos];
if host.to_lowercase().contains("gitlab") {
return Ok((ProviderType::GitLab, host.to_string()));
}
return Ok((ProviderType::GitHub, host.to_string()));
}
}
if let Ok(parsed) = url::Url::parse(url) {
if let Some(host) = parsed.host_str() {
if host.to_lowercase().contains("gitlab") {
return Ok((ProviderType::GitLab, host.to_string()));
}
return Ok((ProviderType::GitHub, host.to_string()));
}
}
anyhow::bail!(
"Could not detect provider from remote URL: {}\n\
Supported providers: GitHub, GitLab",
url
)
}
#[allow(dead_code)]
pub fn has_any_credentials() -> bool {
stkd_github::auth::load_credentials().ok().flatten().is_some()
|| stkd_gitlab::auth::load_credentials("gitlab.com").ok().flatten().is_some()
}
pub fn detect_provider_type(repo: &Repository) -> Result<ProviderType> {
let remote_url = get_remote_url(repo)?;
let (provider_type, _) = detect_provider(&remote_url)?;
Ok(provider_type)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_detect_github() {
let (pt, host) = detect_provider("https://github.com/owner/repo.git").unwrap();
assert_eq!(pt, ProviderType::GitHub);
assert_eq!(host, "github.com");
let (pt, host) = detect_provider("git@github.com:owner/repo.git").unwrap();
assert_eq!(pt, ProviderType::GitHub);
assert_eq!(host, "github.com");
}
#[test]
fn test_detect_gitlab() {
let (pt, host) = detect_provider("https://gitlab.com/group/project.git").unwrap();
assert_eq!(pt, ProviderType::GitLab);
assert_eq!(host, "gitlab.com");
let (pt, host) = detect_provider("git@gitlab.com:group/project.git").unwrap();
assert_eq!(pt, ProviderType::GitLab);
assert_eq!(host, "gitlab.com");
}
#[test]
fn test_detect_self_hosted_gitlab() {
let (pt, host) = detect_provider("https://gitlab.mycompany.com/group/project.git").unwrap();
assert_eq!(pt, ProviderType::GitLab);
assert_eq!(host, "gitlab.mycompany.com");
let (pt, host) = detect_provider("git@gitlab.mycompany.com:group/project.git").unwrap();
assert_eq!(pt, ProviderType::GitLab);
assert_eq!(host, "gitlab.mycompany.com");
}
}