git_remote_object_store/protocol/
fetch.rs1use std::collections::HashSet;
20use std::num::NonZeroU32;
21use std::path::Path;
22use std::sync::{Arc, Mutex};
23
24use gix_hash::ObjectId;
25use tokio::sync::Semaphore;
26use tokio::task::{JoinError, JoinSet};
27use tracing::debug;
28
29use crate::git::{self, GitError, RefName, RefNameError, Sha, ShaError};
30use crate::keys;
31use crate::object_store::{GetOpts, ObjectStore, ObjectStoreError};
32
33pub(crate) const MAX_FETCH_CONCURRENCY: usize = 8;
35
36#[derive(Debug, thiserror::Error)]
38pub enum FetchError {
39 #[error("invalid fetch command {line:?}: expected `<sha> <ref>`")]
41 Parse {
42 line: String,
44 },
45
46 #[error("invalid SHA in fetch command: {0}")]
48 Sha(#[from] ShaError),
49
50 #[error("invalid ref in fetch command: {0}")]
52 Ref(#[from] RefNameError),
53
54 #[error("object-store error during fetch: {0}")]
56 Store(#[from] ObjectStoreError),
57
58 #[error("local I/O error during fetch: {0}")]
60 Io(#[from] std::io::Error),
61
62 #[error("git error during fetch: {0}")]
64 Git(#[from] GitError),
65
66 #[error("fetch task join failed: {0}")]
68 Join(#[from] JoinError),
69
70 #[error("packchain engine error during fetch: {0}")]
75 Packchain(#[from] crate::packchain::PackchainError),
76}
77
78#[derive(Clone, Default)]
84pub(crate) struct FetchedRefs {
85 inner: Arc<Mutex<HashSet<Sha>>>,
86}
87
88impl FetchedRefs {
89 pub(crate) fn new() -> Self {
90 Self::default()
91 }
92
93 pub(crate) fn contains(&self, sha: &Sha) -> bool {
94 self.inner
100 .lock()
101 .unwrap_or_else(std::sync::PoisonError::into_inner)
102 .contains(sha)
103 }
104
105 pub(crate) fn insert(&self, sha: Sha) {
106 self.inner
107 .lock()
108 .unwrap_or_else(std::sync::PoisonError::into_inner)
109 .insert(sha);
110 }
111
112 #[cfg(test)]
114 pub(crate) fn snapshot(&self) -> HashSet<Sha> {
115 self.inner
116 .lock()
117 .unwrap_or_else(std::sync::PoisonError::into_inner)
118 .clone()
119 }
120}
121
122#[derive(Clone, Default)]
129pub(crate) struct ShallowBoundaries {
130 inner: Arc<Mutex<HashSet<ObjectId>>>,
131}
132
133impl ShallowBoundaries {
134 pub(crate) fn new() -> Self {
135 Self::default()
136 }
137
138 pub(crate) fn extend(&self, ids: impl IntoIterator<Item = ObjectId>) {
139 let mut guard = self
140 .inner
141 .lock()
142 .unwrap_or_else(std::sync::PoisonError::into_inner);
143 guard.extend(ids);
144 }
145
146 pub(crate) fn drain(&self) -> Vec<ObjectId> {
147 let mut guard = self
148 .inner
149 .lock()
150 .unwrap_or_else(std::sync::PoisonError::into_inner);
151 guard.drain().collect()
152 }
153}
154
155pub(crate) async fn fetch_batch(
169 ctx: &super::BatchCtx,
170 cmds: Vec<String>,
171 fetched_refs: FetchedRefs,
172 depth: Option<NonZeroU32>,
173) -> Result<(), FetchError> {
174 if cmds.is_empty() {
175 return Ok(());
176 }
177 debug!(
178 count = cmds.len(),
179 depth = ?depth,
180 "fetching bundles in parallel"
181 );
182
183 let semaphore = Arc::new(Semaphore::new(MAX_FETCH_CONCURRENCY));
184 let mut tasks: JoinSet<Result<(), FetchError>> = JoinSet::new();
185 let prefix = ctx.prefix.clone();
187 let boundaries = ShallowBoundaries::new();
188
189 for cmd in cmds {
190 let store = Arc::clone(&ctx.store);
191 let semaphore = Arc::clone(&semaphore);
192 let prefix = prefix.clone();
193 let repo_dir = Arc::clone(&ctx.repo_dir);
194 let fetched_refs = fetched_refs.clone();
195 let boundaries = boundaries.clone();
196 tasks.spawn(async move {
197 let _permit = semaphore
198 .acquire_owned()
199 .await
200 .expect("fetch semaphore is owned by this batch and never closed");
201 let (sha, ref_name) = parse_fetch_args(&cmd)?;
202 fetch_one(FetchOneCtx {
203 store: store.as_ref(),
204 prefix: prefix.as_deref(),
205 repo_dir: repo_dir.as_path(),
206 sha,
207 ref_name: &ref_name,
208 fetched_refs: &fetched_refs,
209 depth,
210 boundaries: &boundaries,
211 })
212 .await
213 });
214 }
215
216 let mut first_err: Option<FetchError> = None;
222 while let Some(joined) = tasks.join_next().await {
223 let res: Result<(), FetchError> = joined.unwrap_or_else(|je| Err(je.into()));
227 if let Err(err) = res {
228 if first_err.is_none() {
229 first_err = Some(err);
230 } else {
231 debug!(error = %err, "additional bundle fetch task error (first error already captured)");
232 }
233 }
234 }
235
236 if first_err.is_none() && depth.is_some() {
247 let collected = boundaries.drain();
248 let repo_dir = ctx.repo_dir.as_path().to_path_buf();
249 tokio::task::spawn_blocking(move || git::write_shallow_file(&repo_dir, &collected))
250 .await
251 .map_err(FetchError::from)?
252 .map_err(FetchError::from)?;
253 }
254
255 first_err.map_or(Ok(()), Err)
256}
257
258struct FetchOneCtx<'a> {
263 store: &'a dyn ObjectStore,
264 prefix: Option<&'a str>,
265 repo_dir: &'a Path,
266 sha: Sha,
267 ref_name: &'a RefName,
268 fetched_refs: &'a FetchedRefs,
269 depth: Option<NonZeroU32>,
270 boundaries: &'a ShallowBoundaries,
271}
272
273async fn fetch_one(ctx: FetchOneCtx<'_>) -> Result<(), FetchError> {
274 let FetchOneCtx {
275 store,
276 prefix,
277 repo_dir,
278 sha,
279 ref_name,
280 fetched_refs,
281 depth,
282 boundaries,
283 } = ctx;
284
285 if fetched_refs.contains(&sha) {
286 debug!(%sha, ref_name = %ref_name, "skipping fetch: already fetched in this session");
287 } else {
288 let key = keys::bundle_key(prefix, ref_name, sha);
289 let temp_dir = tempfile::Builder::new()
290 .prefix("git_remote_object_store_fetch_")
291 .tempdir()?;
292 let bundle_path = temp_dir.path().join(format!("{sha}.bundle"));
293 debug!(%sha, ref_name = %ref_name, key = %key, "downloading bundle");
294 store
295 .get_to_file(&key, &bundle_path, GetOpts::default())
296 .await?;
297 git::unbundle_at(repo_dir, temp_dir.path(), sha).await?;
298 fetched_refs.insert(sha);
299 }
300
301 if let Some(depth) = depth {
307 let repo_dir = repo_dir.to_path_buf();
308 let ids = tokio::task::spawn_blocking(move || {
309 let repo = gix::open(&repo_dir).map_err(GitError::from)?;
310 git::shallow_boundaries(&repo, sha, depth)
311 })
312 .await
313 .map_err(FetchError::from)?
314 .map_err(FetchError::from)?;
315 boundaries.extend(ids);
316 }
317 Ok(())
318}
319
320pub(crate) fn parse_fetch_args(args: &str) -> Result<(Sha, RefName), FetchError> {
323 let parse_err = || FetchError::Parse {
324 line: args.to_owned(),
325 };
326 let (sha, ref_name) = args.split_once(' ').ok_or_else(parse_err)?;
327 if sha.is_empty() || ref_name.is_empty() || ref_name.contains(' ') {
328 return Err(parse_err());
329 }
330 Ok((Sha::from_hex(sha)?, RefName::new(ref_name)?))
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336
337 const SHA: &str = "0123456789abcdef0123456789abcdef01234567";
338
339 #[test]
346 fn parse_fetch_args_accepts_canonical_form() {
347 let (sha, ref_name) = parse_fetch_args(&format!("{SHA} refs/heads/main")).unwrap();
348 assert_eq!(sha.to_string(), SHA);
349 assert_eq!(ref_name.as_str(), "refs/heads/main");
350 }
351
352 #[test]
353 fn parse_fetch_args_rejects_missing_ref() {
354 assert!(matches!(
355 parse_fetch_args(SHA),
356 Err(FetchError::Parse { .. })
357 ));
358 }
359
360 #[test]
361 fn parse_fetch_args_rejects_empty_ref() {
362 assert!(matches!(
363 parse_fetch_args(&format!("{SHA} ")),
364 Err(FetchError::Parse { .. })
365 ));
366 }
367
368 #[test]
369 fn parse_fetch_args_rejects_invalid_sha() {
370 assert!(matches!(
371 parse_fetch_args("notahex refs/heads/main"),
372 Err(FetchError::Sha(_))
373 ));
374 }
375
376 #[test]
377 fn parse_fetch_args_rejects_invalid_ref() {
378 assert!(matches!(
379 parse_fetch_args(&format!("{SHA} refs/heads/.bad")),
380 Err(FetchError::Ref(_))
381 ));
382 }
383
384 #[test]
385 fn parse_fetch_args_rejects_extra_whitespace() {
386 assert!(matches!(
389 parse_fetch_args(&format!("{SHA} refs/heads/main extra")),
390 Err(FetchError::Parse { .. })
391 ));
392 }
393
394 #[test]
395 fn fetched_refs_dedupes_repeated_inserts() {
396 let refs = FetchedRefs::new();
401 let sha = Sha::from_hex(SHA).unwrap();
402 assert!(!refs.contains(&sha));
403 refs.insert(sha);
404 refs.insert(sha);
405 assert!(refs.contains(&sha));
406 assert_eq!(refs.snapshot().len(), 1);
407 }
408
409 #[tokio::test]
410 async fn fetch_batch_empty_cmds_short_circuits() {
411 use crate::object_store::mock::{Fault, MockStore};
412 use crate::protocol::BatchCtx;
413 let mock = Arc::new(MockStore::new());
427 mock.arm(Fault::AccessDeniedOnAnyList);
428 let repo_dir = tempfile::tempdir().expect("tempdir");
429 let ctx = BatchCtx {
430 store: Arc::clone(&mock) as Arc<dyn ObjectStore>,
431 prefix: Some("repo".into()),
432 repo_dir: Arc::new(repo_dir.path().to_path_buf()),
433 };
434 let result = fetch_batch(&ctx, Vec::new(), FetchedRefs::new(), None).await;
435 assert!(matches!(result, Ok(())));
436 assert_eq!(
439 mock.pending_faults(),
440 1,
441 "fetch_batch with empty cmds must make zero list calls",
442 );
443 }
444}