use crate::config::Config;
use anyhow::{anyhow, Context, Result};
use directories::ProjectDirs;
use git2::{build::RepoBuilder, Cred, FetchOptions, RemoteCallbacks, Repository, ResetType};
use hex;
use once_cell::sync::Lazy;
use regex::Regex;
use reqwest::blocking::Client;
use reqwest::header::{ACCEPT, AUTHORIZATION, USER_AGENT};
use serde::Deserialize;
use serde_json::Value;
use sha2::{Digest, Sha256};
use std::collections::VecDeque;
use std::env;
use std::{
fs,
path::{Path, PathBuf},
};
use tempfile::Builder as TempDirBuilder;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParsedGitUrl {
pub clone_url: String,
pub branch: String,
pub subdirectory: String,
}
#[derive(Deserialize, Debug)]
struct ContentItem {
path: String,
#[serde(rename = "type")]
item_type: String,
download_url: Option<String>,
}
#[derive(Deserialize, Debug)]
struct RepoInfo {
default_branch: String,
}
fn get_base_cache_dir() -> Result<PathBuf> {
if let Ok(cache_override) = env::var("DIRCAT_TEST_CACHE_DIR") {
return Ok(PathBuf::from(cache_override));
}
let proj_dirs = ProjectDirs::from("com", "romelium", "dircat")
.context("Could not determine project cache directory")?;
let cache_dir = proj_dirs.cache_dir().join("repos");
Ok(cache_dir)
}
fn get_repo_cache_path(base_cache_dir: &Path, url: &str) -> PathBuf {
let mut hasher = Sha256::new();
hasher.update(url.as_bytes());
let hash = hasher.finalize();
let hex_hash = hex::encode(hash);
base_cache_dir.join(hex_hash)
}
fn create_remote_callbacks() -> RemoteCallbacks<'static> {
let mut callbacks = RemoteCallbacks::new();
callbacks.credentials(|_url, username_from_url, _allowed_types| {
let username = username_from_url.unwrap_or("git");
log::debug!("Attempting SSH authentication for user: {}", username);
if let Ok(cred) = Cred::ssh_key_from_agent(username) {
log::debug!("Authenticated via SSH agent");
return Ok(cred);
}
if let Ok(cred) = Cred::ssh_key(
username,
None,
std::env::var("HOME")
.or_else(|_| std::env::var("USERPROFILE"))
.map(std::path::PathBuf::from)
.ok()
.as_deref()
.unwrap_or_else(|| std::path::Path::new(""))
.join(".ssh")
.join("id_rsa")
.as_path(),
None,
) {
log::debug!("Authenticated via default SSH key path");
return Ok(cred);
}
log::warn!("SSH authentication failed: No agent or default keys found.");
Err(git2::Error::from_str(
"Authentication failed: could not connect with SSH agent or default keys",
))
});
let progress_bar = {
let pb = indicatif::ProgressBar::new(0);
pb.set_style(
indicatif::ProgressStyle::default_bar()
.template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({percent}%) {msg}")
.unwrap()
.progress_chars("#>-"),
);
pb
};
let pb_clone = progress_bar.clone();
callbacks.transfer_progress(move |stats| {
if stats.received_objects() == stats.total_objects() {
pb_clone.set_length(stats.total_deltas() as u64);
pb_clone.set_position(stats.indexed_deltas() as u64);
pb_clone.set_message("Resolving deltas...");
} else if stats.total_objects() > 0 {
pb_clone.set_length(stats.total_objects() as u64);
pb_clone.set_position(stats.received_objects() as u64);
pb_clone.set_message("Receiving objects...");
}
true
});
callbacks
}
fn create_fetch_options(config: &Config) -> FetchOptions<'static> {
let mut fetch_options = FetchOptions::new();
fetch_options.remote_callbacks(create_remote_callbacks());
fetch_options.prune(git2::FetchPrune::On);
fetch_options.download_tags(git2::AutotagOption::All);
if let Some(depth) = config.git_depth {
fetch_options.depth(depth as i32);
log::debug!("Set shallow clone depth to: {}", depth);
}
fetch_options
}
fn find_remote_commit<'a>(repo: &'a Repository, config: &Config) -> Result<git2::Commit<'a>> {
if let Some(ref_name) = &config.git_branch {
log::debug!("Using user-specified ref: {}", ref_name);
let branch_ref_name = format!("refs/remotes/origin/{}", ref_name);
if let Ok(reference) = repo.find_reference(&branch_ref_name) {
log::debug!("Resolved '{}' as a remote branch.", ref_name);
return repo
.find_commit(
reference
.target()
.context("Remote branch reference has no target commit")?,
)
.context("Failed to find commit for branch reference");
}
let tag_ref_name = format!("refs/tags/{}", ref_name);
if let Ok(reference) = repo.find_reference(&tag_ref_name) {
log::debug!("Resolved '{}' as a tag.", ref_name);
let object = reference.peel(git2::ObjectType::Commit)?;
return object
.into_commit()
.map_err(|_| anyhow!("Tag '{}' does not point to a commit", ref_name));
}
return Err(anyhow!(
"Could not find remote branch or tag named '{}' after fetch. Does this ref exist on the remote?",
ref_name
));
}
log::debug!("Resolving remote's default branch via origin/HEAD");
let remote_head = repo.find_reference("refs/remotes/origin/HEAD").context(
"Could not find remote's HEAD. The repository might not have a default branch set, or it may be empty. Please specify a branch with --git-branch.",
)?;
let remote_branch_ref_name = remote_head
.symbolic_target()
.context("Remote HEAD is not a symbolic reference; cannot determine default branch.")?
.to_string();
log::debug!("Targeting remote reference: {}", remote_branch_ref_name);
let fetch_head = repo.find_reference(&remote_branch_ref_name).with_context(|| {
format!(
"Could not find remote branch reference '{}' after fetch. Does this branch exist on the remote?",
remote_branch_ref_name
)
})?;
repo.find_commit(
fetch_head
.target()
.context("Remote branch reference has no target commit")?,
)
.context("Failed to find commit for default branch reference")
}
fn update_repo(repo: &Repository, config: &Config) -> Result<()> {
log::info!("Updating cached repository...");
let mut remote = repo.find_remote("origin")?;
let mut fetch_options = create_fetch_options(config);
remote
.fetch(&[] as &[&str], Some(&mut fetch_options), None)
.context("Failed to fetch from remote 'origin'")?;
let target_commit = find_remote_commit(repo, config)?;
repo.set_head_detached(target_commit.id())
.context("Failed to detach HEAD in cached repository")?;
repo.reset(
target_commit.as_object(),
ResetType::Hard,
None, )
.context("Failed to perform hard reset on cached repository")?;
log::info!("Cached repository updated successfully.");
Ok(())
}
pub(crate) fn get_repo_with_base_cache(
base_cache_dir: &Path,
url: &str,
config: &Config,
) -> Result<PathBuf> {
let repo_path = get_repo_cache_path(base_cache_dir, url);
if repo_path.exists() {
log::info!(
"Found cached repository for '{}' at '{}'. Checking for updates...",
url,
repo_path.display()
);
match Repository::open(&repo_path) {
Ok(repo) => {
update_repo(&repo, config)?;
return Ok(repo_path);
}
Err(e) => {
log::warn!(
"Cached repository at '{}' is corrupted or invalid: {}. Re-cloning...",
repo_path.display(),
e
);
if repo_path.is_dir() {
fs::remove_dir_all(&repo_path).with_context(|| {
format!(
"Failed to remove corrupted cache directory at '{}'",
repo_path.display()
)
})?;
} else if repo_path.is_file() {
fs::remove_file(&repo_path).with_context(|| {
format!(
"Failed to remove corrupted cache file at '{}'",
repo_path.display()
)
})?;
}
}
}
}
log::info!(
"Cloning git repository from '{}' into cache at '{}'...",
url,
repo_path.display()
);
fs::create_dir_all(repo_path.parent().unwrap()).context("Failed to create cache directory")?;
let fetch_options = create_fetch_options(config);
let mut repo_builder = RepoBuilder::new();
repo_builder.fetch_options(fetch_options);
if let Some(ref_name) = &config.git_branch {
log::debug!(
"Cloning default branch first, will check out '{}' after.",
ref_name
);
}
let repo = repo_builder
.clone(url, &repo_path)
.context("Failed to clone repository")?;
log::info!("Successfully cloned repository into cache.");
if config.git_branch.is_some() {
log::info!(
"Checking out specified ref: {:?}",
config.git_branch.as_ref().unwrap()
);
let target_commit = find_remote_commit(&repo, config)?;
repo.set_head_detached(target_commit.id())
.context("Failed to detach HEAD in newly cloned repository")?;
repo.reset(target_commit.as_object(), ResetType::Hard, None)
.context("Failed to perform hard reset on newly cloned repository")?;
log::info!("Successfully checked out specified ref.");
}
Ok(repo_path)
}
pub fn get_repo(url: &str, config: &Config) -> Result<PathBuf> {
let base_cache_dir = get_base_cache_dir()?;
get_repo_with_base_cache(&base_cache_dir, url, config)
}
pub fn is_git_url(path_str: &str) -> bool {
path_str.starts_with("https://")
|| path_str.starts_with("http://")
|| path_str.starts_with("git@")
|| path_str.starts_with("file://")
}
static GITHUB_TREE_URL_RE: Lazy<Regex> = Lazy::new(|| {
Regex::new(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)(?:/(.*))?$").unwrap()
});
pub fn parse_github_folder_url(url: &str) -> Option<ParsedGitUrl> {
if let Some(caps) = GITHUB_TREE_URL_RE.captures(url) {
let user = caps.get(1).unwrap().as_str();
let repo = caps.get(2).unwrap().as_str();
let branch = caps.get(3).unwrap().as_str();
let subdirectory = caps.get(4).map_or("", |m| m.as_str()).trim_end_matches('/');
return Some(ParsedGitUrl {
clone_url: format!("https://github.com/{}/{}.git", user, repo),
branch: branch.to_string(),
subdirectory: subdirectory.to_string(),
});
}
let path_part = url.strip_prefix("https://github.com/")?;
let parts: Vec<&str> = path_part.split('/').filter(|s| !s.is_empty()).collect();
if parts.len() < 3 {
return None;
}
let user = parts[0];
let repo = parts[1].trim_end_matches(".git"); let first_segment = parts[2];
let reserved_names = [
"releases", "tags", "pull", "issues", "actions", "projects", "wiki", "security", "pulse",
"graphs", "settings", "blob", "tree", "commit", "blame", "find",
];
if reserved_names.contains(&first_segment) {
return None;
}
let branch = "HEAD";
let subdirectory = parts[2..].join("/");
Some(ParsedGitUrl {
clone_url: format!("https://github.com/{}/{}.git", user, repo),
branch: branch.to_string(),
subdirectory,
})
}
pub fn download_directory_via_api(url_parts: &ParsedGitUrl, config: &Config) -> Result<PathBuf> {
let temp_dir = TempDirBuilder::new().prefix("dircat-git-api-").tempdir()?;
let client = build_reqwest_client()?;
let (owner, repo) = parse_clone_url(&url_parts.clone_url)?;
let branch_to_use = if let Some(cli_branch) = &config.git_branch {
log::debug!("Using branch from --git-branch flag: {}", cli_branch);
cli_branch.clone()
} else if url_parts.branch != "HEAD" {
log::debug!("Using branch from URL: {}", url_parts.branch);
url_parts.branch.clone()
} else {
log::debug!("Fetching default branch for {}/{}", owner, repo);
fetch_default_branch(&owner, &repo, &client)?
};
log::info!("Processing repository on branch: {}", branch_to_use);
let files_to_download =
list_all_files_recursively(&client, &owner, &repo, &branch_to_use, url_parts)?;
if files_to_download.is_empty() {
return Ok(temp_dir.keep());
}
use rayon::prelude::*;
files_to_download
.par_iter()
.map(|file_item| download_and_write_file(&client, file_item, temp_dir.path()))
.collect::<Result<()>>()?;
Ok(temp_dir.keep())
}
fn build_reqwest_client() -> Result<Client> {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(ACCEPT, "application/vnd.github.v3+json".parse()?);
headers.insert(USER_AGENT, "dircat-rust-downloader".parse()?);
if let Ok(token) = env::var("GITHUB_TOKEN") {
headers.insert(AUTHORIZATION, format!("Bearer {}", token).parse()?);
log::debug!("Using GITHUB_TOKEN for authentication.");
}
let client = Client::builder().default_headers(headers).build()?;
Ok(client)
}
fn fetch_default_branch(owner: &str, repo: &str, client: &Client) -> Result<String> {
let api_url = format!("https://api.github.com/repos/{}/{}", owner, repo);
log::debug!("Fetching repo metadata from: {}", api_url);
let response = client.get(&api_url).send()?.error_for_status()?;
let repo_info: RepoInfo = response.json()?;
Ok(repo_info.default_branch)
}
fn list_all_files_recursively(
client: &Client,
owner: &str,
repo: &str,
branch: &str,
url_parts: &ParsedGitUrl,
) -> Result<Vec<ContentItem>> {
let mut files = Vec::new();
let mut queue: VecDeque<String> = VecDeque::new();
queue.push_back(url_parts.subdirectory.clone());
while let Some(path) = queue.pop_front() {
let api_url = format!(
"https://api.github.com/repos/{}/{}/contents/{}?ref={}",
owner, repo, path, branch
);
log::debug!("Fetching directory contents from: {}", api_url);
let response = client.get(&api_url).send()?.error_for_status()?;
let response_text = response.text()?;
let json_value: Value = serde_json::from_str(&response_text)?;
let items: Vec<ContentItem> = if json_value.is_array() {
serde_json::from_value(json_value)?
} else if json_value.is_object() {
vec![serde_json::from_value(json_value)?]
} else {
vec![]
};
for item in items {
if item.item_type == "file" {
if item.download_url.is_some() {
files.push(item);
} else {
log::warn!("Skipping file with no download_url: {}", item.path);
}
} else if item.item_type == "dir" {
queue.push_back(item.path);
}
}
}
Ok(files)
}
fn download_and_write_file(
client: &Client,
file_item: &ContentItem,
base_dir: &Path,
) -> Result<()> {
let download_url = file_item.download_url.as_ref().unwrap(); log::debug!("Downloading file from: {}", download_url);
let response = client.get(download_url).send()?.error_for_status()?;
let content = response.bytes()?;
let local_path = base_dir.join(&file_item.path);
if let Some(parent_dir) = local_path.parent() {
fs::create_dir_all(parent_dir).with_context(|| {
format!(
"Failed to create directory structure for '{}'",
local_path.display()
)
})?;
}
fs::write(&local_path, content).with_context(|| {
format!(
"Failed to write downloaded content to '{}'",
local_path.display()
)
})
}
fn parse_clone_url(clone_url: &str) -> Result<(String, String)> {
static RE: Lazy<Regex> =
Lazy::new(|| Regex::new(r"github\.com[/:]([^/]+)/([^/]+?)(?:\.git)?$").unwrap());
RE.captures(clone_url)
.and_then(|caps| Some((caps.get(1)?.as_str(), caps.get(2)?.as_str())))
.map(|(owner, repo)| (owner.to_string(), repo.to_string()))
.ok_or_else(|| anyhow!("Could not parse owner/repo from clone URL: {}", clone_url))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::Config;
use git2::{IndexTime, Signature};
use std::fs::File;
use std::io::Write;
use tempfile::{tempdir, TempDir};
fn setup_test_remote_repo() -> Result<(TempDir, Repository)> {
let remote_dir = tempdir()?;
let repo = Repository::init_bare(remote_dir.path())?;
Ok((remote_dir, repo))
}
fn add_commit_to_repo(
repo: &Repository,
filename: &str,
content: &str,
message: &str,
) -> Result<()> {
let mut index = repo.index()?;
let oid = repo.blob(content.as_bytes())?;
let entry = git2::IndexEntry {
ctime: IndexTime::new(0, 0),
mtime: IndexTime::new(0, 0),
dev: 0,
ino: 0,
mode: 0o100644,
uid: 0,
gid: 0,
file_size: content.len() as u32,
id: oid,
flags: 0,
flags_extended: 0,
path: filename.as_bytes().to_vec(),
};
index.add(&entry)?;
let tree_oid = index.write_tree()?;
let tree = repo.find_tree(tree_oid)?;
let signature = Signature::now("Test", "test@example.com")?;
let parent_commit = repo.head().ok().and_then(|h| h.peel_to_commit().ok());
let parents: Vec<&git2::Commit> = parent_commit.iter().collect();
repo.commit(
Some("HEAD"),
&signature,
&signature,
message,
&tree,
&parents,
)?;
Ok(())
}
#[test]
fn test_cache_miss_and_hit() -> Result<()> {
let (_remote_dir, remote_repo) = setup_test_remote_repo()?;
add_commit_to_repo(&remote_repo, "file.txt", "content v1", "Initial")?;
let remote_path_str = _remote_dir.path().to_str().unwrap();
#[cfg(windows)]
let remote_url = format!("file:///{}", remote_path_str.replace('\\', "/"));
#[cfg(not(windows))]
let remote_url = format!("file://{}", remote_path_str);
let cache_dir = tempdir()?;
let config = Config::new_for_test();
let cached_path = get_repo_with_base_cache(cache_dir.path(), &remote_url, &config)?;
assert!(cached_path.exists());
let content = fs::read_to_string(cached_path.join("file.txt"))?;
assert_eq!(content, "content v1");
let cached_path_2 = get_repo_with_base_cache(cache_dir.path(), &remote_url, &config)?;
assert_eq!(cached_path, cached_path_2); let content_2 = fs::read_to_string(cached_path_2.join("file.txt"))?;
assert_eq!(content_2, "content v1");
Ok(())
}
#[test]
fn test_cache_update() -> Result<()> {
let (_remote_dir, remote_repo) = setup_test_remote_repo()?;
add_commit_to_repo(&remote_repo, "file.txt", "content v1", "Initial")?;
let remote_path_str = _remote_dir.path().to_str().unwrap();
#[cfg(windows)]
let remote_url = format!("file:///{}", remote_path_str.replace('\\', "/"));
#[cfg(not(windows))]
let remote_url = format!("file://{}", remote_path_str);
let cache_dir = tempdir()?;
let config = Config::new_for_test();
let cached_path = get_repo_with_base_cache(cache_dir.path(), &remote_url, &config)?;
assert_eq!(
fs::read_to_string(cached_path.join("file.txt"))?,
"content v1"
);
add_commit_to_repo(&remote_repo, "file.txt", "content v2", "Update")?;
let updated_path = get_repo_with_base_cache(cache_dir.path(), &remote_url, &config)?;
assert_eq!(cached_path, updated_path);
assert_eq!(
fs::read_to_string(updated_path.join("file.txt"))?,
"content v2"
);
Ok(())
}
#[test]
fn test_corrupted_cache_recovery() -> Result<()> {
let (_remote_dir, remote_repo) = setup_test_remote_repo()?;
add_commit_to_repo(&remote_repo, "file.txt", "content", "Initial")?;
let remote_path_str = _remote_dir.path().to_str().unwrap();
#[cfg(windows)]
let remote_url = format!("file:///{}", remote_path_str.replace('\\', "/"));
#[cfg(not(windows))]
let remote_url = format!("file://{}", remote_path_str);
let cache_dir = tempdir()?;
let config = Config::new_for_test();
let expected_cache_path = get_repo_cache_path(cache_dir.path(), &remote_url);
fs::create_dir_all(expected_cache_path.parent().unwrap())?;
File::create(&expected_cache_path)?.write_all(b"corruption")?;
let cached_path = get_repo_with_base_cache(cache_dir.path(), &remote_url, &config)?;
assert!(cached_path.is_dir()); assert_eq!(fs::read_to_string(cached_path.join("file.txt"))?, "content");
Ok(())
}
#[test]
fn test_parse_github_folder_url_valid() {
let url = "https://github.com/BurntSushi/ripgrep/tree/master/crates/ignore";
let expected = Some(ParsedGitUrl {
clone_url: "https://github.com/BurntSushi/ripgrep.git".to_string(),
branch: "master".to_string(),
subdirectory: "crates/ignore".to_string(),
});
assert_eq!(parse_github_folder_url(url), expected);
}
#[test]
fn test_parse_github_sloppy_url_no_tree_assumes_default_branch() {
let url = "https://github.com/BurntSushi/ripgrep/master/crates/ignore";
let expected = Some(ParsedGitUrl {
clone_url: "https://github.com/BurntSushi/ripgrep.git".to_string(),
branch: "HEAD".to_string(), subdirectory: "master/crates/ignore".to_string(), });
assert_eq!(parse_github_folder_url(url), expected);
}
#[test]
fn test_parse_github_sloppy_url_no_branch() {
let url = "https://github.com/BurntSushi/ripgrep/crates/ignore";
let expected = Some(ParsedGitUrl {
clone_url: "https://github.com/BurntSushi/ripgrep.git".to_string(),
branch: "HEAD".to_string(),
subdirectory: "crates/ignore".to_string(),
});
assert_eq!(parse_github_folder_url(url), expected);
}
#[test]
fn test_parse_github_sloppy_url_with_git_suffix() {
let url = "https://github.com/BurntSushi/ripgrep.git/master/crates/ignore";
let expected = Some(ParsedGitUrl {
clone_url: "https://github.com/BurntSushi/ripgrep.git".to_string(),
branch: "HEAD".to_string(),
subdirectory: "master/crates/ignore".to_string(),
});
assert_eq!(parse_github_folder_url(url), expected);
}
#[test]
fn test_parse_github_url_rejects_root() {
assert_eq!(
parse_github_folder_url("https://github.com/rust-lang/rust"),
None
);
assert_eq!(
parse_github_folder_url("https://github.com/rust-lang/rust.git"),
None
);
}
#[test]
fn test_parse_github_url_rejects_reserved_paths() {
assert_eq!(
parse_github_folder_url("https://github.com/user/repo/blob/master/file.txt"),
None
);
assert_eq!(
parse_github_folder_url("https://github.com/user/repo/issues/1"),
None
);
assert_eq!(
parse_github_folder_url("https://github.com/user/repo/pull/2"),
None
);
assert_eq!(
parse_github_folder_url("https://gitlab.com/user/repo/tree/master"),
None
);
}
}