1use std::sync::Arc;
4use std::time::Duration;
5
6use git_lfs_api::{
7 BatchRequest, BatchResponse, Client as ApiClient, ObjectResult, ObjectSpec, Operation, Ref,
8};
9use git_lfs_store::Store;
10use tokio::sync::Semaphore;
11use tokio::sync::mpsc::UnboundedSender;
12use tokio::task::JoinSet;
13
14use crate::basic;
15use crate::config::TransferConfig;
16use crate::error::{Report, TransferError};
17use crate::event::Event;
18
19#[derive(Debug, Clone, Copy)]
22enum Dir {
23 Download,
24 Upload,
25}
26
27impl From<Dir> for Operation {
28 fn from(d: Dir) -> Self {
29 match d {
30 Dir::Download => Operation::Download,
31 Dir::Upload => Operation::Upload,
32 }
33 }
34}
35
36#[derive(Clone)]
39pub struct Transfer {
40 api: ApiClient,
41 store: Arc<Store>,
42 http: reqwest::Client,
43 config: TransferConfig,
44}
45
46impl Transfer {
47 pub fn new(api: ApiClient, store: Store, config: TransferConfig) -> Self {
51 Self::with_http_client(api, store, config, reqwest::Client::new())
52 }
53
54 pub fn with_http_client(
55 api: ApiClient,
56 store: Store,
57 config: TransferConfig,
58 http: reqwest::Client,
59 ) -> Self {
60 Self {
61 api,
62 store: Arc::new(store),
63 http,
64 config,
65 }
66 }
67
68 pub async fn download(
72 &self,
73 objects: Vec<ObjectSpec>,
74 r#ref: Option<Ref>,
75 events: Option<UnboundedSender<Event>>,
76 ) -> Result<Report, TransferError> {
77 self.run(Dir::Download, objects, r#ref, events).await
78 }
79
80 pub async fn upload(
84 &self,
85 objects: Vec<ObjectSpec>,
86 r#ref: Option<Ref>,
87 events: Option<UnboundedSender<Event>>,
88 ) -> Result<Report, TransferError> {
89 self.run(Dir::Upload, objects, r#ref, events).await
90 }
91
92 async fn run(
93 &self,
94 dir: Dir,
95 objects: Vec<ObjectSpec>,
96 r#ref: Option<Ref>,
97 events: Option<UnboundedSender<Event>>,
98 ) -> Result<Report, TransferError> {
99 if objects.is_empty() {
100 return Ok(Report::default());
101 }
102 let batch_size = self.config.batch_size.max(1);
107 if objects.len() > batch_size {
108 let mut report = Report::default();
109 for chunk in objects.chunks(batch_size) {
110 let chunk_report =
111 Box::pin(self.run(dir, chunk.to_vec(), r#ref.clone(), events.clone())).await?;
112 report.succeeded.extend(chunk_report.succeeded);
113 report.failed.extend(chunk_report.failed);
114 }
115 return Ok(report);
116 }
117
118 let req_sizes: std::collections::HashMap<String, u64> =
122 objects.iter().map(|o| (o.oid.clone(), o.size)).collect();
123
124 let mut objects = objects;
131 objects.sort_by_key(|o| std::cmp::Reverse(o.size));
132
133 let mut req = BatchRequest::new(dir.into(), objects);
134 if let Some(r) = r#ref {
135 req = req.with_ref(r);
136 }
137 let resp: BatchResponse = self.batch_with_retry(&req).await?;
138
139 if let Some(h) = resp.hash_algo.as_deref()
145 && !h.is_empty()
146 && !h.eq_ignore_ascii_case("sha256")
147 {
148 return Err(TransferError::UnsupportedHashAlgo(h.to_owned()));
149 }
150
151 let limit = Arc::new(Semaphore::new(self.config.concurrency.max(1)));
152 let mut join: JoinSet<(String, Result<(), TransferError>)> = JoinSet::new();
153
154 for mut obj in resp.objects {
155 if obj.size == 0
156 && let Some(s) = req_sizes.get(&obj.oid)
157 {
158 obj.size = *s;
159 }
160 if let Some(rewriter) = &self.config.url_rewriter
161 && let Some(actions) = obj.actions.as_mut()
162 {
163 for action in [
164 actions.download.as_mut(),
165 actions.upload.as_mut(),
166 actions.verify.as_mut(),
167 ]
168 .into_iter()
169 .flatten()
170 {
171 action.href = rewriter(&action.href);
172 }
173 }
174 let permit_src = limit.clone();
175 let http = self.http.clone();
176 let store = self.store.clone();
177 let config = self.config.clone();
178 let events = events.clone();
179 join.spawn(async move {
180 let _permit = permit_src.acquire_owned().await.expect("semaphore live");
181 let oid = obj.oid.clone();
182 let result = process_object(dir, &http, store, &config, obj, events.as_ref()).await;
183 (oid, result)
184 });
185 }
186
187 let mut report = Report::default();
188 while let Some(joined) = join.join_next().await {
189 let (oid, result) =
190 joined.map_err(|e| TransferError::Io(std::io::Error::other(e.to_string())))?;
191 match result {
192 Ok(()) => {
193 if let Some(s) = &events {
194 let _ = s.send(Event::Completed { oid: oid.clone() });
195 }
196 report.succeeded.push(oid);
197 }
198 Err(err) => {
199 if let Some(s) = &events {
200 let _ = s.send(Event::Failed {
201 oid: oid.clone(),
202 error: err.to_string(),
203 });
204 }
205 report.failed.push((oid, err));
206 }
207 }
208 }
209 Ok(report)
210 }
211
212 async fn batch_with_retry(&self, req: &BatchRequest) -> Result<BatchResponse, TransferError> {
223 let mut backoff = self.config.initial_backoff;
224 let mut retry_count: u32 = 0;
225 let mut last_err: Option<git_lfs_api::ApiError> = None;
226 for attempt in 0..self.config.max_attempts {
227 if trace_enabled() {
228 eprintln!("tq: sending batch of size {}", req.objects.len());
229 }
230 match self.api.batch(req).await {
231 Ok(resp) => return Ok(resp),
232 Err(e) => {
233 let retry = e.is_retryable() && attempt + 1 < self.config.max_attempts;
234 if !retry {
235 return Err(TransferError::BatchResponse(Box::new(e)));
236 }
237 let server_delay = e.retry_after();
238 let delay = server_delay.unwrap_or(backoff);
239 retry_count += 1;
240 if trace_enabled() {
241 let secs = delay.as_secs_f64();
242 for obj in &req.objects {
243 eprintln!(
248 "tq: enqueue retry #{retry_count} after {secs:.2}s for {:?} (size: {}): {e}",
249 obj.oid, obj.size
250 );
251 }
252 }
253 last_err = Some(e);
254 tokio::time::sleep(delay).await;
255 if server_delay.is_none() {
256 backoff = (backoff * 2).min(self.config.backoff_max);
257 }
258 }
259 }
260 }
261 Err(TransferError::BatchResponse(Box::new(
262 last_err.expect("loop ran at least once"),
263 )))
264 }
265}
266
267fn trace_enabled() -> bool {
270 std::env::var_os("GIT_TRACE").is_some_and(|v| !v.is_empty() && v != "0")
271}
272
273async fn process_object(
277 dir: Dir,
278 http: &reqwest::Client,
279 store: Arc<Store>,
280 config: &TransferConfig,
281 obj: ObjectResult,
282 events: Option<&UnboundedSender<Event>>,
283) -> Result<(), TransferError> {
284 if let Some(err) = obj.error {
285 return Err(TransferError::ServerObject(err));
286 }
287
288 if let Some(s) = events {
289 let _ = s.send(Event::Started {
290 oid: obj.oid.clone(),
291 size: obj.size,
292 });
293 }
294
295 match (dir, &obj.actions) {
296 (Dir::Download, Some(actions)) => {
297 let action = actions
298 .download
299 .as_ref()
300 .ok_or(TransferError::NoDownloadAction)?;
301 with_retry(config, &obj.oid, obj.size, || async {
302 basic::download(http, store.clone(), &obj.oid, obj.size, action, events)
303 .await
304 .map(|_| ())
305 })
306 .await
307 }
308 (Dir::Download, None) => Err(TransferError::NoDownloadAction),
309 (Dir::Upload, Some(actions)) => {
310 with_retry(config, &obj.oid, obj.size, || async {
311 basic::upload(
312 http,
313 store.clone(),
314 &obj.oid,
315 obj.size,
316 actions,
317 config.detect_content_type,
318 events,
319 )
320 .await
321 })
322 .await
323 }
324 (Dir::Upload, None) => {
325 Ok(())
327 }
328 }
329}
330
331async fn with_retry<F, Fut>(
344 config: &TransferConfig,
345 oid: &str,
346 size: u64,
347 mut op: F,
348) -> Result<(), TransferError>
349where
350 F: FnMut() -> Fut,
351 Fut: std::future::Future<Output = Result<(), TransferError>>,
352{
353 let mut backoff = config.initial_backoff;
354 let mut retry_count: u32 = 0;
355 let mut last_err: Option<TransferError> = None;
356 for attempt in 0..config.max_attempts {
357 match op().await {
358 Ok(()) => return Ok(()),
359 Err(e) => {
360 let retry = e.is_retryable() && attempt + 1 < config.max_attempts;
361 if !retry {
362 last_err = Some(e);
363 break;
364 }
365 let delay = e.retry_after().unwrap_or(backoff);
366 retry_count += 1;
367 emit_retry_trace(oid, size, retry_count, delay, &e);
368 last_err = Some(e);
369 tokio::time::sleep(delay).await;
370 if last_err
374 .as_ref()
375 .and_then(TransferError::retry_after)
376 .is_none()
377 {
378 backoff = (backoff * 2).min(config.backoff_max);
379 }
380 }
381 }
382 }
383 Err(last_err.expect("loop ran at least once"))
384}
385
386fn emit_retry_trace(oid: &str, size: u64, count: u32, delay: Duration, err: &TransferError) {
398 if !trace_enabled() {
399 return;
400 }
401 let secs = delay.as_secs_f64();
402 if err.retry_after().is_some() {
403 eprintln!("tq: retrying object {oid} after {secs:.2}s");
404 } else {
405 eprintln!("tq: retrying object {oid}: {err}");
406 }
407 eprintln!("tq: enqueue retry #{count} after {secs:.2}s for {oid:?} (size: {size}): {err}");
410}