use crate::config::Config;
use anyhow::{Context, Result};
use directories::ProjectDirs;
use git2::{build::RepoBuilder, Cred, FetchOptions, RemoteCallbacks, Repository, ResetType};
use hex;
use sha2::{Digest, Sha256};
use std::env;
use std::{
fs,
path::{Path, PathBuf},
};
fn get_base_cache_dir() -> Result<PathBuf> {
if let Ok(cache_override) = env::var("DIRCAT_TEST_CACHE_DIR") {
return Ok(PathBuf::from(cache_override));
}
let proj_dirs = ProjectDirs::from("com", "romelium", "dircat")
.context("Could not determine project cache directory")?;
let cache_dir = proj_dirs.cache_dir().join("repos");
Ok(cache_dir)
}
fn get_repo_cache_path(base_cache_dir: &Path, url: &str) -> PathBuf {
let mut hasher = Sha256::new();
hasher.update(url.as_bytes());
let hash = hasher.finalize();
let hex_hash = hex::encode(hash);
base_cache_dir.join(hex_hash)
}
fn create_remote_callbacks() -> RemoteCallbacks<'static> {
let mut callbacks = RemoteCallbacks::new();
callbacks.credentials(|_url, username_from_url, _allowed_types| {
let username = username_from_url.unwrap_or("git");
log::debug!("Attempting SSH authentication for user: {}", username);
if let Ok(cred) = Cred::ssh_key_from_agent(username) {
log::debug!("Authenticated via SSH agent");
return Ok(cred);
}
if let Ok(cred) = Cred::ssh_key(
username,
None,
std::env::var("HOME")
.or_else(|_| std::env::var("USERPROFILE"))
.map(std::path::PathBuf::from)
.ok()
.as_deref()
.unwrap_or_else(|| std::path::Path::new(""))
.join(".ssh")
.join("id_rsa")
.as_path(),
None,
) {
log::debug!("Authenticated via default SSH key path");
return Ok(cred);
}
log::warn!("SSH authentication failed: No agent or default keys found.");
Err(git2::Error::from_str(
"Authentication failed: could not connect with SSH agent or default keys",
))
});
let progress_bar = {
let pb = indicatif::ProgressBar::new(0);
pb.set_style(
indicatif::ProgressStyle::default_bar()
.template("{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({percent}%) {msg}")
.unwrap()
.progress_chars("#>-"),
);
pb
};
let pb_clone = progress_bar.clone();
callbacks.transfer_progress(move |stats| {
if stats.received_objects() == stats.total_objects() {
pb_clone.set_length(stats.total_deltas() as u64);
pb_clone.set_position(stats.indexed_deltas() as u64);
pb_clone.set_message("Resolving deltas...");
} else if stats.total_objects() > 0 {
pb_clone.set_length(stats.total_objects() as u64);
pb_clone.set_position(stats.received_objects() as u64);
pb_clone.set_message("Receiving objects...");
}
true
});
callbacks
}
fn create_fetch_options(config: &Config) -> FetchOptions<'static> {
let mut fetch_options = FetchOptions::new();
fetch_options.remote_callbacks(create_remote_callbacks());
fetch_options.prune(git2::FetchPrune::On);
if let Some(depth) = config.git_depth {
fetch_options.depth(depth as i32);
log::debug!("Set shallow clone depth to: {}", depth);
}
fetch_options
}
fn find_remote_commit<'a>(repo: &'a Repository, config: &Config) -> Result<git2::Commit<'a>> {
let remote_branch_ref_name = if let Some(branch_name) = &config.git_branch {
log::debug!("Using user-specified branch: {}", branch_name);
format!("refs/remotes/origin/{}", branch_name)
} else {
log::debug!("Resolving remote's default branch via origin/HEAD");
let remote_head = repo.find_reference("refs/remotes/origin/HEAD").context(
"Could not find remote's HEAD. The repository might not have a default branch set, or it may be empty. Please specify a branch with --git-branch.",
)?;
remote_head
.symbolic_target()
.context("Remote HEAD is not a symbolic reference; cannot determine default branch.")?
.to_string()
};
log::debug!("Targeting remote reference: {}", remote_branch_ref_name);
let fetch_head = repo.find_reference(&remote_branch_ref_name)
.with_context(|| format!("Could not find remote branch reference '{}' after fetch. Does this branch exist on the remote?", remote_branch_ref_name))?;
let fetch_commit = repo.find_commit(
fetch_head
.target()
.context("Remote branch reference has no target commit")?,
)?;
Ok(fetch_commit)
}
fn update_repo(repo: &Repository, config: &Config) -> Result<()> {
log::info!("Updating cached repository...");
let mut remote = repo.find_remote("origin")?;
let mut fetch_options = create_fetch_options(config);
remote
.fetch(&[] as &[&str], Some(&mut fetch_options), None)
.context("Failed to fetch from remote 'origin'")?;
let target_commit = find_remote_commit(repo, config)?;
repo.set_head_detached(target_commit.id())
.context("Failed to detach HEAD in cached repository")?;
repo.reset(
target_commit.as_object(),
ResetType::Hard,
None, )
.context("Failed to perform hard reset on cached repository")?;
log::info!("Cached repository updated successfully.");
Ok(())
}
pub(crate) fn get_repo_with_base_cache(
base_cache_dir: &Path,
url: &str,
config: &Config,
) -> Result<PathBuf> {
let repo_path = get_repo_cache_path(base_cache_dir, url);
if repo_path.exists() {
log::info!(
"Found cached repository for '{}' at '{}'. Checking for updates...",
url,
repo_path.display()
);
match Repository::open(&repo_path) {
Ok(repo) => {
update_repo(&repo, config)?;
return Ok(repo_path);
}
Err(e) => {
log::warn!(
"Cached repository at '{}' is corrupted or invalid: {}. Re-cloning...",
repo_path.display(),
e
);
if repo_path.is_dir() {
fs::remove_dir_all(&repo_path).with_context(|| {
format!(
"Failed to remove corrupted cache directory at '{}'",
repo_path.display()
)
})?;
} else if repo_path.is_file() {
fs::remove_file(&repo_path).with_context(|| {
format!(
"Failed to remove corrupted cache file at '{}'",
repo_path.display()
)
})?;
}
}
}
}
log::info!(
"Cloning git repository from '{}' into cache at '{}'...",
url,
repo_path.display()
);
fs::create_dir_all(repo_path.parent().unwrap()).context("Failed to create cache directory")?;
let fetch_options = create_fetch_options(config);
let mut repo_builder = RepoBuilder::new();
repo_builder.fetch_options(fetch_options);
if let Some(branch) = &config.git_branch {
repo_builder.branch(branch);
}
repo_builder.clone(url, &repo_path)?;
log::info!("Successfully cloned repository into cache.");
Ok(repo_path)
}
pub fn get_repo(url: &str, config: &Config) -> Result<PathBuf> {
let base_cache_dir = get_base_cache_dir()?;
get_repo_with_base_cache(&base_cache_dir, url, config)
}
pub fn is_git_url(path_str: &str) -> bool {
path_str.starts_with("https://")
|| path_str.starts_with("http://")
|| path_str.starts_with("git@")
|| path_str.starts_with("file://")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::Config;
use git2::{IndexTime, Signature};
use std::fs::File;
use std::io::Write;
use tempfile::{tempdir, TempDir};
fn setup_test_remote_repo() -> Result<(TempDir, Repository)> {
let remote_dir = tempdir()?;
let repo = Repository::init_bare(remote_dir.path())?;
Ok((remote_dir, repo))
}
fn add_commit_to_repo(
repo: &Repository,
filename: &str,
content: &str,
message: &str,
) -> Result<()> {
let mut index = repo.index()?;
let oid = repo.blob(content.as_bytes())?;
let entry = git2::IndexEntry {
ctime: IndexTime::new(0, 0),
mtime: IndexTime::new(0, 0),
dev: 0,
ino: 0,
mode: 0o100644,
uid: 0,
gid: 0,
file_size: content.len() as u32,
id: oid,
flags: 0,
flags_extended: 0,
path: filename.as_bytes().to_vec(),
};
index.add(&entry)?;
let tree_oid = index.write_tree()?;
let tree = repo.find_tree(tree_oid)?;
let signature = Signature::now("Test", "test@example.com")?;
let parent_commit = repo.head().ok().and_then(|h| h.peel_to_commit().ok());
let parents: Vec<&git2::Commit> = parent_commit.iter().collect();
repo.commit(
Some("HEAD"),
&signature,
&signature,
message,
&tree,
&parents,
)?;
Ok(())
}
#[test]
fn test_cache_miss_and_hit() -> Result<()> {
let (_remote_dir, remote_repo) = setup_test_remote_repo()?;
add_commit_to_repo(&remote_repo, "file.txt", "content v1", "Initial")?;
let remote_path_str = _remote_dir.path().to_str().unwrap();
#[cfg(windows)]
let remote_url = format!("file:///{}", remote_path_str.replace('\\', "/"));
#[cfg(not(windows))]
let remote_url = format!("file://{}", remote_path_str);
let cache_dir = tempdir()?;
let config = Config::new_for_test();
let cached_path = get_repo_with_base_cache(cache_dir.path(), &remote_url, &config)?;
assert!(cached_path.exists());
let content = fs::read_to_string(cached_path.join("file.txt"))?;
assert_eq!(content, "content v1");
let cached_path_2 = get_repo_with_base_cache(cache_dir.path(), &remote_url, &config)?;
assert_eq!(cached_path, cached_path_2); let content_2 = fs::read_to_string(cached_path_2.join("file.txt"))?;
assert_eq!(content_2, "content v1");
Ok(())
}
#[test]
fn test_cache_update() -> Result<()> {
let (_remote_dir, remote_repo) = setup_test_remote_repo()?;
add_commit_to_repo(&remote_repo, "file.txt", "content v1", "Initial")?;
let remote_path_str = _remote_dir.path().to_str().unwrap();
#[cfg(windows)]
let remote_url = format!("file:///{}", remote_path_str.replace('\\', "/"));
#[cfg(not(windows))]
let remote_url = format!("file://{}", remote_path_str);
let cache_dir = tempdir()?;
let config = Config::new_for_test();
let cached_path = get_repo_with_base_cache(cache_dir.path(), &remote_url, &config)?;
assert_eq!(
fs::read_to_string(cached_path.join("file.txt"))?,
"content v1"
);
add_commit_to_repo(&remote_repo, "file.txt", "content v2", "Update")?;
let updated_path = get_repo_with_base_cache(cache_dir.path(), &remote_url, &config)?;
assert_eq!(cached_path, updated_path);
assert_eq!(
fs::read_to_string(updated_path.join("file.txt"))?,
"content v2"
);
Ok(())
}
#[test]
fn test_corrupted_cache_recovery() -> Result<()> {
let (_remote_dir, remote_repo) = setup_test_remote_repo()?;
add_commit_to_repo(&remote_repo, "file.txt", "content", "Initial")?;
let remote_path_str = _remote_dir.path().to_str().unwrap();
#[cfg(windows)]
let remote_url = format!("file:///{}", remote_path_str.replace('\\', "/"));
#[cfg(not(windows))]
let remote_url = format!("file://{}", remote_path_str);
let cache_dir = tempdir()?;
let config = Config::new_for_test();
let expected_cache_path = get_repo_cache_path(cache_dir.path(), &remote_url);
fs::create_dir_all(expected_cache_path.parent().unwrap())?;
File::create(&expected_cache_path)?.write_all(b"corruption")?;
let cached_path = get_repo_with_base_cache(cache_dir.path(), &remote_url, &config)?;
assert!(cached_path.is_dir()); assert_eq!(fs::read_to_string(cached_path.join("file.txt"))?, "content");
Ok(())
}
}