use std::collections::HashSet;
use std::num::NonZeroU32;
use std::path::Path;
use std::sync::{Arc, Mutex};
use gix_hash::ObjectId;
use tokio::sync::Semaphore;
use tokio::task::{JoinError, JoinSet};
use tracing::debug;
use crate::git::{self, GitError, RefName, RefNameError, Sha, ShaError};
use crate::keys;
use crate::object_store::{GetOpts, ObjectStore, ObjectStoreError};
pub(crate) const MAX_FETCH_CONCURRENCY: usize = 8;
#[derive(Debug, thiserror::Error)]
pub enum FetchError {
#[error("invalid fetch command {line:?}: expected `<sha> <ref>`")]
Parse {
line: String,
},
#[error("invalid SHA in fetch command: {0}")]
Sha(#[from] ShaError),
#[error("invalid ref in fetch command: {0}")]
Ref(#[from] RefNameError),
#[error("object-store error during fetch: {0}")]
Store(#[from] ObjectStoreError),
#[error("local I/O error during fetch: {0}")]
Io(#[from] std::io::Error),
#[error("git error during fetch: {0}")]
Git(#[from] GitError),
#[error("fetch task join failed: {0}")]
Join(#[from] JoinError),
#[error("packchain engine error during fetch: {0}")]
Packchain(#[from] crate::packchain::PackchainError),
}
#[derive(Clone, Default)]
pub(crate) struct FetchedRefs {
inner: Arc<Mutex<HashSet<Sha>>>,
}
impl FetchedRefs {
pub(crate) fn new() -> Self {
Self::default()
}
pub(crate) fn contains(&self, sha: &Sha) -> bool {
self.inner
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.contains(sha)
}
pub(crate) fn insert(&self, sha: Sha) {
self.inner
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.insert(sha);
}
#[cfg(test)]
pub(crate) fn snapshot(&self) -> HashSet<Sha> {
self.inner
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.clone()
}
}
#[derive(Clone, Default)]
pub(crate) struct ShallowBoundaries {
inner: Arc<Mutex<HashSet<ObjectId>>>,
}
impl ShallowBoundaries {
pub(crate) fn new() -> Self {
Self::default()
}
pub(crate) fn extend(&self, ids: impl IntoIterator<Item = ObjectId>) {
let mut guard = self
.inner
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
guard.extend(ids);
}
pub(crate) fn drain(&self) -> Vec<ObjectId> {
let mut guard = self
.inner
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
guard.drain().collect()
}
}
pub(crate) async fn fetch_batch(
ctx: &super::BatchCtx,
cmds: Vec<String>,
fetched_refs: FetchedRefs,
depth: Option<NonZeroU32>,
) -> Result<(), FetchError> {
if cmds.is_empty() {
return Ok(());
}
debug!(
count = cmds.len(),
depth = ?depth,
"fetching bundles in parallel"
);
let semaphore = Arc::new(Semaphore::new(MAX_FETCH_CONCURRENCY));
let mut tasks: JoinSet<Result<(), FetchError>> = JoinSet::new();
let prefix = ctx.prefix.clone();
let boundaries = ShallowBoundaries::new();
for cmd in cmds {
let store = Arc::clone(&ctx.store);
let semaphore = Arc::clone(&semaphore);
let prefix = prefix.clone();
let repo_dir = Arc::clone(&ctx.repo_dir);
let fetched_refs = fetched_refs.clone();
let boundaries = boundaries.clone();
tasks.spawn(async move {
let _permit = semaphore
.acquire_owned()
.await
.expect("fetch semaphore is owned by this batch and never closed");
let (sha, ref_name) = parse_fetch_args(&cmd)?;
fetch_one(FetchOneCtx {
store: store.as_ref(),
prefix: prefix.as_deref(),
repo_dir: repo_dir.as_path(),
sha,
ref_name: &ref_name,
fetched_refs: &fetched_refs,
depth,
boundaries: &boundaries,
})
.await
});
}
let mut first_err: Option<FetchError> = None;
while let Some(joined) = tasks.join_next().await {
let res: Result<(), FetchError> = joined.unwrap_or_else(|je| Err(je.into()));
if let Err(err) = res {
if first_err.is_none() {
first_err = Some(err);
} else {
debug!(error = %err, "additional bundle fetch task error (first error already captured)");
}
}
}
if first_err.is_none() && depth.is_some() {
let collected = boundaries.drain();
let repo_dir = ctx.repo_dir.as_path().to_path_buf();
tokio::task::spawn_blocking(move || git::write_shallow_file(&repo_dir, &collected))
.await
.map_err(FetchError::from)?
.map_err(FetchError::from)?;
}
first_err.map_or(Ok(()), Err)
}
struct FetchOneCtx<'a> {
store: &'a dyn ObjectStore,
prefix: Option<&'a str>,
repo_dir: &'a Path,
sha: Sha,
ref_name: &'a RefName,
fetched_refs: &'a FetchedRefs,
depth: Option<NonZeroU32>,
boundaries: &'a ShallowBoundaries,
}
async fn fetch_one(ctx: FetchOneCtx<'_>) -> Result<(), FetchError> {
let FetchOneCtx {
store,
prefix,
repo_dir,
sha,
ref_name,
fetched_refs,
depth,
boundaries,
} = ctx;
if fetched_refs.contains(&sha) {
debug!(%sha, ref_name = %ref_name, "skipping fetch: already fetched in this session");
} else {
let key = keys::bundle_key(prefix, ref_name, sha);
let temp_dir = tempfile::Builder::new()
.prefix("git_remote_object_store_fetch_")
.tempdir()?;
let bundle_path = temp_dir.path().join(format!("{sha}.bundle"));
debug!(%sha, ref_name = %ref_name, key = %key, "downloading bundle");
store
.get_to_file(&key, &bundle_path, GetOpts::default())
.await?;
git::unbundle_at(repo_dir, temp_dir.path(), sha).await?;
fetched_refs.insert(sha);
}
if let Some(depth) = depth {
let repo_dir = repo_dir.to_path_buf();
let ids = tokio::task::spawn_blocking(move || {
let repo = gix::open(&repo_dir).map_err(GitError::from)?;
git::shallow_boundaries(&repo, sha, depth)
})
.await
.map_err(FetchError::from)?
.map_err(FetchError::from)?;
boundaries.extend(ids);
}
Ok(())
}
pub(crate) fn parse_fetch_args(args: &str) -> Result<(Sha, RefName), FetchError> {
let parse_err = || FetchError::Parse {
line: args.to_owned(),
};
let (sha, ref_name) = args.split_once(' ').ok_or_else(parse_err)?;
if sha.is_empty() || ref_name.is_empty() || ref_name.contains(' ') {
return Err(parse_err());
}
Ok((Sha::from_hex(sha)?, RefName::new(ref_name)?))
}
#[cfg(test)]
mod tests {
use super::*;
const SHA: &str = "0123456789abcdef0123456789abcdef01234567";
#[test]
fn parse_fetch_args_accepts_canonical_form() {
let (sha, ref_name) = parse_fetch_args(&format!("{SHA} refs/heads/main")).unwrap();
assert_eq!(sha.to_string(), SHA);
assert_eq!(ref_name.as_str(), "refs/heads/main");
}
#[test]
fn parse_fetch_args_rejects_missing_ref() {
assert!(matches!(
parse_fetch_args(SHA),
Err(FetchError::Parse { .. })
));
}
#[test]
fn parse_fetch_args_rejects_empty_ref() {
assert!(matches!(
parse_fetch_args(&format!("{SHA} ")),
Err(FetchError::Parse { .. })
));
}
#[test]
fn parse_fetch_args_rejects_invalid_sha() {
assert!(matches!(
parse_fetch_args("notahex refs/heads/main"),
Err(FetchError::Sha(_))
));
}
#[test]
fn parse_fetch_args_rejects_invalid_ref() {
assert!(matches!(
parse_fetch_args(&format!("{SHA} refs/heads/.bad")),
Err(FetchError::Ref(_))
));
}
#[test]
fn parse_fetch_args_rejects_extra_whitespace() {
assert!(matches!(
parse_fetch_args(&format!("{SHA} refs/heads/main extra")),
Err(FetchError::Parse { .. })
));
}
#[test]
fn fetched_refs_dedupes_repeated_inserts() {
let refs = FetchedRefs::new();
let sha = Sha::from_hex(SHA).unwrap();
assert!(!refs.contains(&sha));
refs.insert(sha);
refs.insert(sha);
assert!(refs.contains(&sha));
assert_eq!(refs.snapshot().len(), 1);
}
#[tokio::test]
async fn fetch_batch_empty_cmds_short_circuits() {
use crate::object_store::mock::{Fault, MockStore};
use crate::protocol::BatchCtx;
let mock = Arc::new(MockStore::new());
mock.arm(Fault::AccessDeniedOnAnyList);
let repo_dir = tempfile::tempdir().expect("tempdir");
let ctx = BatchCtx {
store: Arc::clone(&mock) as Arc<dyn ObjectStore>,
prefix: Some("repo".into()),
repo_dir: Arc::new(repo_dir.path().to_path_buf()),
};
let result = fetch_batch(&ctx, Vec::new(), FetchedRefs::new(), None).await;
assert!(matches!(result, Ok(())));
assert_eq!(
mock.pending_faults(),
1,
"fetch_batch with empty cmds must make zero list calls",
);
}
}