use crate::plugins::error::PluginError;
pub fn is_local_path(s: &str) -> bool {
s.starts_with("./") || s.starts_with("../") || s.starts_with('/') || s == "." || s == ".."
}
pub fn is_git_url(s: &str) -> bool {
s.starts_with("git+")
|| s.starts_with("git@")
|| s.starts_with("ssh://")
|| s.starts_with("https://")
|| s.starts_with("http://")
}
fn normalize_git_url(url: &str) -> String {
if url.starts_with("git+") {
return url.to_string();
}
if let Some(rest) = url.strip_prefix("git@") {
let normalized = rest.replacen(':', "/", 1);
return format!("git+ssh://git@{normalized}");
}
format!("git+{url}")
}
fn strip_git_ref(url: &str) -> &str {
if let Some(rest) = url.strip_prefix("git+") {
if let Some(pos) = rest.rfind('@') {
if let Some(protocol_end) = rest.find("://") {
if pos > protocol_end {
return &url[..url.len() - (rest.len() - pos)];
}
}
}
return url;
}
if let Some(rest) = url.strip_prefix("git@") {
if let Some(pos) = rest.rfind('@') {
if pos > 0 {
return &url[.."git@".len() + pos];
}
}
return url;
}
if let Some(protocol_end) = url.find("://") {
let after_protocol = protocol_end + 3;
if let Some(path_start) = url[after_protocol..].find('/') {
let path_start = after_protocol + path_start;
if let Some(pos) = url[path_start..].rfind('@') {
return &url[..path_start + pos];
}
}
return url;
}
url.split('@').next().unwrap_or(url)
}
pub fn extract_package_name(package: &str) -> Result<String, PluginError> {
let pkg = strip_git_ref(package);
if is_git_url(pkg) {
let normalized = if let Some(rest) = pkg.strip_prefix("git@") {
rest.replacen(':', "/", 1)
} else {
pkg.to_string()
};
Ok(normalized
.split('/')
.next_back()
.unwrap_or(pkg)
.trim_end_matches(".git")
.to_string())
} else if pkg.contains('/') || pkg.contains('\\') || is_local_path(pkg) {
extract_name_from_pyproject(pkg).ok_or_else(|| {
PluginError::PackageSpec(format!("Failed to extract package name from {}", package))
})
} else {
Ok(pkg.to_string())
}
}
pub fn build_package_spec(
package: &str,
host: Option<String>,
branch: Option<String>,
tag: Option<String>,
commit: Option<String>,
) -> Result<String, PluginError> {
let expanded_package = expand_tilde(package);
let package = expanded_package.as_str();
if is_local_path(package) {
if branch.is_some() || tag.is_some() || commit.is_some() || host.is_some() {
return Err(PluginError::PackageSpec(
"Cannot use git flags with local paths".to_string(),
));
}
return Ok(package.to_string());
}
if is_git_url(package) {
let url = normalize_git_url(package);
return Ok(add_git_ref(&url, branch, tag, commit));
}
if package.contains('/')
&& !package.contains('\\')
&& (host.is_some() || branch.is_some() || tag.is_some() || commit.is_some())
{
let git_host = host.as_deref().unwrap_or("github.com");
let url = format!("git+https://{git_host}/{package}");
return Ok(add_git_ref(&url, branch, tag, commit));
}
if branch.is_some() || tag.is_some() || commit.is_some() || host.is_some() {
return Err(PluginError::PackageSpec(
"Cannot use git flags with PyPI package name".to_string(),
));
}
Ok(package.to_string())
}
fn add_git_ref(
url: &str,
branch: Option<String>,
tag: Option<String>,
commit: Option<String>,
) -> String {
if let Some(b) = branch {
format!("{url}@{b}")
} else if let Some(t) = tag {
format!("{url}@{t}")
} else if let Some(c) = commit {
format!("{url}@{c}")
} else {
url.to_string()
}
}
fn expand_tilde(path: &str) -> String {
if !path.starts_with('~') {
return path.to_string();
}
match dirs::home_dir() {
Some(home) => {
let home_str = home.to_string_lossy();
if path == "~" {
home_str.to_string()
} else if path.starts_with("~/") {
format!("{}{}", home_str, &path[1..])
} else {
path.to_string()
}
}
None => path.to_string(),
}
}
fn extract_name_from_pyproject(path: &str) -> Option<String> {
use std::fs;
use std::path::Path;
let project_path = Path::new(path);
let pyproject_path = project_path.join("pyproject.toml");
if !pyproject_path.exists() {
return None;
}
if let Ok(content) = fs::read_to_string(&pyproject_path) {
for line in content.lines() {
let trimmed = line.trim();
if trimmed.starts_with("name") && trimmed.contains('=') {
if let Some(start_idx) = trimmed.find('"') {
if let Some(end_idx) = trimmed.rfind('"') {
if start_idx < end_idx {
return Some(trimmed[start_idx + 1..end_idx].to_string());
}
}
} else if let Some(start_idx) = trimmed.find('\'') {
if let Some(end_idx) = trimmed.rfind('\'') {
if start_idx < end_idx {
return Some(trimmed[start_idx + 1..end_idx].to_string());
}
}
}
}
}
}
None
}
#[cfg(test)]
mod tests {
use crate::plugins::package_spec::*;
#[test]
fn test_is_git_url_ssh_shorthand() {
assert!(is_git_url("git@github.com:org/repo.git"));
}
#[test]
fn test_is_git_url_ssh_full() {
assert!(is_git_url("ssh://git@github.com/org/repo.git"));
}
#[test]
fn test_is_git_url_https() {
assert!(is_git_url("https://github.com/org/repo.git"));
}
#[test]
fn test_is_git_url_https_no_git_suffix() {
assert!(is_git_url("https://github.com/org/repo"));
}
#[test]
fn test_is_git_url_git_plus() {
assert!(is_git_url("git+https://github.com/org/repo.git"));
}
#[test]
fn test_is_git_url_pypi_is_not() {
assert!(!is_git_url("r2x-reeds"));
}
#[test]
fn test_is_git_url_local_path_is_not() {
assert!(!is_git_url("./packages/r2x-reeds"));
}
#[test]
fn test_normalize_ssh_shorthand() {
assert_eq!(
normalize_git_url("git@github.com:NatLabRockies/R2X.git"),
"git+ssh://git@github.com/NatLabRockies/R2X.git"
);
}
#[test]
fn test_normalize_ssh_full() {
assert_eq!(
normalize_git_url("ssh://git@github.com/org/repo.git"),
"git+ssh://git@github.com/org/repo.git"
);
}
#[test]
fn test_normalize_https() {
assert_eq!(
normalize_git_url("https://github.com/org/repo.git"),
"git+https://github.com/org/repo.git"
);
}
#[test]
fn test_normalize_https_no_git_suffix() {
assert_eq!(
normalize_git_url("https://github.com/org/repo"),
"git+https://github.com/org/repo"
);
}
#[test]
fn test_normalize_already_prefixed() {
let url = "git+https://github.com/org/repo.git";
assert_eq!(normalize_git_url(url), url);
}
#[test]
fn test_extract_name_pypi() {
assert!(extract_package_name("r2x-reeds").is_ok_and(|s| s == "r2x-reeds"));
}
#[test]
fn test_extract_name_git_plus_https() {
assert!(
extract_package_name("git+https://github.com/nrel/r2x-reeds@main")
.is_ok_and(|s| s == "r2x-reeds")
);
}
#[test]
fn test_extract_name_https() {
assert!(
extract_package_name("https://github.com/NatLabRockies/R2X.git")
.is_ok_and(|s| s == "R2X")
);
}
#[test]
fn test_extract_name_https_with_ref() {
assert!(
extract_package_name("https://github.com/NatLabRockies/R2X.git@v2.0.0")
.is_ok_and(|s| s == "R2X")
);
}
#[test]
fn test_extract_name_ssh_shorthand() {
assert!(
extract_package_name("git@github.com:NatLabRockies/R2X.git").is_ok_and(|s| s == "R2X")
);
}
#[test]
fn test_extract_name_ssh_shorthand_with_ref() {
assert!(
extract_package_name("git@github.com:NatLabRockies/R2X.git@main")
.is_ok_and(|s| s == "R2X")
);
}
#[test]
fn test_extract_name_ssh_full() {
assert!(extract_package_name("ssh://git@github.com/org/R2X.git").is_ok_and(|s| s == "R2X"));
}
#[test]
fn test_extract_name_local_path() {
let result = extract_package_name("./packages/r2x-reeds");
assert!(result.is_ok() || result.is_err());
}
#[test]
fn test_spec_pypi() {
let result = build_package_spec("r2x-reeds", None, None, None, None);
assert!(result.is_ok_and(|s| s == "r2x-reeds"));
}
#[test]
fn test_spec_local_path() {
let result = build_package_spec("./packages/r2x-reeds", None, None, None, None);
assert!(result.is_ok_and(|s| s == "./packages/r2x-reeds"));
}
#[test]
fn test_spec_dot() {
assert!(build_package_spec(".", None, None, None, None).is_ok_and(|s| s == "."));
}
#[test]
fn test_spec_dotdot() {
assert!(build_package_spec("..", None, None, None, None).is_ok_and(|s| s == ".."));
}
#[test]
fn test_spec_dot_rejects_git_flags() {
assert!(build_package_spec(".", None, Some("main".to_string()), None, None).is_err());
}
#[test]
fn test_spec_org_repo_with_branch() {
let result = build_package_spec(
"nrel/r2x-reeds",
None,
Some("develop".to_string()),
None,
None,
);
assert!(result.is_ok_and(|s| s == "git+https://github.com/nrel/r2x-reeds@develop"));
}
#[test]
fn test_spec_rejects_git_flags_with_pypi() {
assert!(
build_package_spec("r2x-reeds", None, Some("main".to_string()), None, None).is_err()
);
}
#[test]
fn test_spec_ssh_shorthand_with_branch() {
let result = build_package_spec(
"git@github.com:NatLabRockies/R2X.git",
None,
Some("v2.0.0".to_string()),
None,
None,
);
assert!(result.is_ok_and(|s| s == "git+ssh://git@github.com/NatLabRockies/R2X.git@v2.0.0"));
}
#[test]
fn test_spec_ssh_shorthand_no_branch() {
let result = build_package_spec(
"git@github.com:NatLabRockies/R2X.git",
None,
None,
None,
None,
);
assert!(result.is_ok_and(|s| s == "git+ssh://git@github.com/NatLabRockies/R2X.git"));
}
#[test]
fn test_spec_https_with_branch() {
let result = build_package_spec(
"https://github.com/NatLabRockies/R2X.git",
None,
Some("v2.0.0".to_string()),
None,
None,
);
assert!(result.is_ok_and(|s| s == "git+https://github.com/NatLabRockies/R2X.git@v2.0.0"));
}
#[test]
fn test_spec_https_no_git_suffix() {
let result = build_package_spec(
"https://github.com/NatLabRockies/R2X",
None,
Some("v2.0.0".to_string()),
None,
None,
);
assert!(result.is_ok_and(|s| s == "git+https://github.com/NatLabRockies/R2X@v2.0.0"));
}
#[test]
fn test_spec_ssh_full_with_branch() {
let result = build_package_spec(
"ssh://git@github.com/org/repo.git",
None,
Some("main".to_string()),
None,
None,
);
assert!(result.is_ok_and(|s| s == "git+ssh://git@github.com/org/repo.git@main"));
}
#[test]
fn test_spec_git_plus_passthrough() {
let result = build_package_spec(
"git+https://github.com/org/repo.git",
None,
Some("main".to_string()),
None,
None,
);
assert!(result.is_ok_and(|s| s == "git+https://github.com/org/repo.git@main"));
}
#[test]
fn test_expand_tilde_with_slash() {
let expanded = expand_tilde("~/dev/r2x-reeds");
assert!(expanded.contains("dev/r2x-reeds") || expanded.contains("dev\\r2x-reeds"));
assert!(!expanded.starts_with('~'));
}
#[test]
fn test_expand_tilde_home_only() {
let expanded = expand_tilde("~");
assert!(!expanded.starts_with('~'));
assert!(!expanded.is_empty());
}
#[test]
fn test_expand_tilde_non_tilde_path() {
let path = "/absolute/path";
assert_eq!(expand_tilde(path), path);
}
#[test]
fn test_spec_with_tilde_path() {
let result = build_package_spec("~/some/local/path", None, None, None, None);
assert!(result.is_ok_and(|s| !s.starts_with('~')));
}
}