1use std::sync::Arc;
4use std::time::{Duration, SystemTime};
5
6use git_lfs_api::{
7 BatchRequest, BatchResponse, Client as ApiClient, ObjectResult, ObjectSpec, Operation, Ref,
8};
9use git_lfs_store::Store;
10use tokio::sync::Semaphore;
11use tokio::sync::mpsc::UnboundedSender;
12use tokio::task::JoinSet;
13
14use crate::basic;
15use crate::config::TransferConfig;
16use crate::error::{Report, TransferError};
17use crate::event::Event;
18
19#[derive(Debug, Clone, Copy)]
22enum Dir {
23 Download,
24 Upload,
25}
26
27impl From<Dir> for Operation {
28 fn from(d: Dir) -> Self {
29 match d {
30 Dir::Download => Operation::Download,
31 Dir::Upload => Operation::Upload,
32 }
33 }
34}
35
36#[derive(Clone)]
39pub struct Transfer {
40 api: ApiClient,
41 store: Arc<Store>,
42 http: reqwest::Client,
43 config: TransferConfig,
44}
45
46impl Transfer {
47 pub fn new(api: ApiClient, store: Store, config: TransferConfig) -> Self {
51 Self::with_http_client(api, store, config, reqwest::Client::new())
52 }
53
54 pub fn with_http_client(
61 api: ApiClient,
62 store: Store,
63 config: TransferConfig,
64 http: reqwest::Client,
65 ) -> Self {
66 Self {
67 api,
68 store: Arc::new(store),
69 http,
70 config,
71 }
72 }
73
74 pub async fn download(
78 &self,
79 objects: Vec<ObjectSpec>,
80 r#ref: Option<Ref>,
81 events: Option<UnboundedSender<Event>>,
82 ) -> Result<Report, TransferError> {
83 self.run(Dir::Download, objects, r#ref, events).await
84 }
85
86 pub async fn upload(
90 &self,
91 objects: Vec<ObjectSpec>,
92 r#ref: Option<Ref>,
93 events: Option<UnboundedSender<Event>>,
94 ) -> Result<Report, TransferError> {
95 self.run(Dir::Upload, objects, r#ref, events).await
96 }
97
98 async fn run(
99 &self,
100 dir: Dir,
101 objects: Vec<ObjectSpec>,
102 r#ref: Option<Ref>,
103 events: Option<UnboundedSender<Event>>,
104 ) -> Result<Report, TransferError> {
105 if objects.is_empty() {
106 return Ok(Report::default());
107 }
108 let batch_size = self.config.batch_size.max(1);
113 if objects.len() > batch_size {
114 let mut report = Report::default();
115 for chunk in objects.chunks(batch_size) {
116 let chunk_report =
117 Box::pin(self.run(dir, chunk.to_vec(), r#ref.clone(), events.clone())).await?;
118 report.succeeded.extend(chunk_report.succeeded);
119 report.failed.extend(chunk_report.failed);
120 }
121 return Ok(report);
122 }
123
124 let req_sizes: std::collections::HashMap<String, u64> =
128 objects.iter().map(|o| (o.oid.clone(), o.size)).collect();
129
130 let mut objects = objects;
137 objects.sort_by_key(|o| std::cmp::Reverse(o.size));
138
139 let mut req = BatchRequest::new(dir.into(), objects);
140 if let Some(r) = r#ref {
141 req = req.with_ref(r);
142 }
143 let resp: BatchResponse = self.batch_with_retry(&req).await?;
144
145 if let Some(h) = resp.hash_algo.as_deref()
151 && !h.is_empty()
152 && !h.eq_ignore_ascii_case("sha256")
153 {
154 return Err(TransferError::UnsupportedHashAlgo(h.to_owned()));
155 }
156
157 let limit = Arc::new(Semaphore::new(self.config.concurrency.max(1)));
158 let mut join: JoinSet<(String, Result<(), TransferError>)> = JoinSet::new();
159
160 for mut obj in resp.objects {
161 if obj.size == 0
162 && let Some(s) = req_sizes.get(&obj.oid)
163 {
164 obj.size = *s;
165 }
166 if let Some(actions) = obj.actions.as_mut() {
167 if let Some(rewriter) = &self.config.url_rewriter
168 && let Some(d) = actions.download.as_mut()
169 {
170 d.href = rewriter(&d.href);
171 }
172 let up_rewriter = self
179 .config
180 .upload_url_rewriter
181 .as_ref()
182 .or(self.config.url_rewriter.as_ref());
183 if let Some(rewriter) = up_rewriter {
184 if let Some(u) = actions.upload.as_mut() {
185 u.href = rewriter(&u.href);
186 }
187 if let Some(v) = actions.verify.as_mut() {
188 v.href = rewriter(&v.href);
189 }
190 }
191 }
192 let permit_src = limit.clone();
193 let http = self.http.clone();
194 let store = self.store.clone();
195 let config = self.config.clone();
196 let events = events.clone();
197 join.spawn(async move {
198 let _permit = permit_src.acquire_owned().await.expect("semaphore live");
199 let oid = obj.oid.clone();
200 let result = process_object(dir, &http, store, &config, obj, events.as_ref()).await;
201 (oid, result)
202 });
203 }
204
205 let mut report = Report::default();
206 while let Some(joined) = join.join_next().await {
207 let (oid, result) =
208 joined.map_err(|e| TransferError::Io(std::io::Error::other(e.to_string())))?;
209 match result {
210 Ok(()) => {
211 if let Some(s) = &events {
212 let _ = s.send(Event::Completed { oid: oid.clone() });
213 }
214 report.succeeded.push(oid);
215 }
216 Err(err) => {
217 if let Some(s) = &events {
218 let _ = s.send(Event::Failed {
219 oid: oid.clone(),
220 error: err.to_string(),
221 });
222 }
223 report.failed.push((oid, err));
224 }
225 }
226 }
227 Ok(report)
228 }
229
230 async fn batch_with_retry(&self, req: &BatchRequest) -> Result<BatchResponse, TransferError> {
241 let mut backoff = self.config.initial_backoff;
242 let mut retry_count: u32 = 0;
243 let mut last_err: Option<git_lfs_api::ApiError> = None;
244 for attempt in 0..self.config.max_attempts {
245 if trace_enabled() {
246 eprintln!("tq: sending batch of size {}", req.objects.len());
247 }
248 match self.api.batch(req).await {
249 Ok(resp) => return Ok(resp),
250 Err(e) => {
251 let retry = e.is_retryable() && attempt + 1 < self.config.max_attempts;
252 if !retry {
253 return Err(TransferError::BatchResponse(Box::new(e)));
254 }
255 let server_delay = e.retry_after();
256 let delay = server_delay.unwrap_or(backoff);
257 retry_count += 1;
258 if trace_enabled() {
259 let secs = delay.as_secs_f64();
260 for obj in &req.objects {
261 eprintln!(
266 "tq: enqueue retry #{retry_count} after {secs:.2}s for {:?} (size: {}): {e}",
267 obj.oid, obj.size
268 );
269 }
270 }
271 last_err = Some(e);
272 tokio::time::sleep(delay).await;
273 if server_delay.is_none() {
274 backoff = (backoff * 2).min(self.config.backoff_max);
275 }
276 }
277 }
278 }
279 Err(TransferError::BatchResponse(Box::new(
280 last_err.expect("loop ran at least once"),
281 )))
282 }
283}
284
285fn trace_enabled() -> bool {
288 std::env::var_os("GIT_TRACE").is_some_and(|v| !v.is_empty() && v != "0")
289}
290
291async fn process_object(
295 dir: Dir,
296 http: &reqwest::Client,
297 store: Arc<Store>,
298 config: &TransferConfig,
299 obj: ObjectResult,
300 events: Option<&UnboundedSender<Event>>,
301) -> Result<(), TransferError> {
302 if let Some(err) = obj.error {
303 return Err(TransferError::ServerObject(err));
304 }
305
306 if let Some(s) = events {
307 let _ = s.send(Event::Started {
308 oid: obj.oid.clone(),
309 size: obj.size,
310 });
311 }
312
313 match (dir, &obj.actions) {
314 (Dir::Download, Some(actions)) => {
315 let action = actions
316 .download
317 .as_ref()
318 .ok_or(TransferError::NoDownloadAction)?;
319 check_not_expired("download", action)?;
320 with_retry(config, &obj.oid, obj.size, || async {
321 basic::download(http, store.clone(), &obj.oid, obj.size, action, events)
322 .await
323 .map(|_| ())
324 })
325 .await
326 }
327 (Dir::Download, None) => Err(TransferError::NoDownloadAction),
328 (Dir::Upload, Some(actions)) => {
329 if let Some(upload) = actions.upload.as_ref() {
330 check_not_expired("upload", upload)?;
331 }
332 if let Some(verify) = actions.verify.as_ref() {
333 check_not_expired("verify", verify)?;
334 }
335 with_retry(config, &obj.oid, obj.size, || async {
336 basic::upload(
337 http,
338 store.clone(),
339 &obj.oid,
340 obj.size,
341 actions,
342 config.detect_content_type,
343 events,
344 )
345 .await
346 })
347 .await
348 }
349 (Dir::Upload, None) => {
350 Ok(())
352 }
353 }
354}
355
356const ACTION_EXPIRATION_BUFFER: Duration = Duration::from_secs(5);
359
360fn check_not_expired(rel: &str, action: &git_lfs_api::Action) -> Result<(), TransferError> {
361 if action.is_expired_within(SystemTime::now(), ACTION_EXPIRATION_BUFFER) {
362 return Err(TransferError::ActionExpired {
363 rel: rel.to_owned(),
364 });
365 }
366 Ok(())
367}
368
369async fn with_retry<F, Fut>(
382 config: &TransferConfig,
383 oid: &str,
384 size: u64,
385 mut op: F,
386) -> Result<(), TransferError>
387where
388 F: FnMut() -> Fut,
389 Fut: std::future::Future<Output = Result<(), TransferError>>,
390{
391 let mut backoff = config.initial_backoff;
392 let mut retry_count: u32 = 0;
393 let mut last_err: Option<TransferError> = None;
394 for attempt in 0..config.max_attempts {
395 match op().await {
396 Ok(()) => return Ok(()),
397 Err(e) => {
398 let retry = e.is_retryable() && attempt + 1 < config.max_attempts;
399 if !retry {
400 last_err = Some(e);
401 break;
402 }
403 let delay = e.retry_after().unwrap_or(backoff);
404 retry_count += 1;
405 emit_retry_trace(oid, size, retry_count, delay, &e);
406 last_err = Some(e);
407 tokio::time::sleep(delay).await;
408 if last_err
412 .as_ref()
413 .and_then(TransferError::retry_after)
414 .is_none()
415 {
416 backoff = (backoff * 2).min(config.backoff_max);
417 }
418 }
419 }
420 }
421 Err(last_err.expect("loop ran at least once"))
422}
423
424fn emit_retry_trace(oid: &str, size: u64, count: u32, delay: Duration, err: &TransferError) {
436 if !trace_enabled() {
437 return;
438 }
439 let secs = delay.as_secs_f64();
440 if err.retry_after().is_some() {
441 eprintln!("tq: retrying object {oid} after {secs:.2}s");
442 } else {
443 eprintln!("tq: retrying object {oid}: {err}");
444 }
445 eprintln!("tq: enqueue retry #{count} after {secs:.2}s for {oid:?} (size: {size}): {err}");
448}