1use std::sync::Arc;
4
5use git_lfs_api::{
6 BatchRequest, BatchResponse, Client as ApiClient, ObjectResult, ObjectSpec, Operation, Ref,
7};
8use git_lfs_store::Store;
9use tokio::sync::Semaphore;
10use tokio::sync::mpsc::UnboundedSender;
11use tokio::task::JoinSet;
12
13use crate::basic;
14use crate::config::TransferConfig;
15use crate::error::{Report, TransferError};
16use crate::event::Event;
17
18#[derive(Debug, Clone, Copy)]
21enum Dir {
22 Download,
23 Upload,
24}
25
26impl From<Dir> for Operation {
27 fn from(d: Dir) -> Self {
28 match d {
29 Dir::Download => Operation::Download,
30 Dir::Upload => Operation::Upload,
31 }
32 }
33}
34
35#[derive(Clone)]
38pub struct Transfer {
39 api: ApiClient,
40 store: Arc<Store>,
41 http: reqwest::Client,
42 config: TransferConfig,
43}
44
45impl Transfer {
46 pub fn new(api: ApiClient, store: Store, config: TransferConfig) -> Self {
50 Self::with_http_client(api, store, config, reqwest::Client::new())
51 }
52
53 pub fn with_http_client(
54 api: ApiClient,
55 store: Store,
56 config: TransferConfig,
57 http: reqwest::Client,
58 ) -> Self {
59 Self {
60 api,
61 store: Arc::new(store),
62 http,
63 config,
64 }
65 }
66
67 pub async fn download(
71 &self,
72 objects: Vec<ObjectSpec>,
73 r#ref: Option<Ref>,
74 events: Option<UnboundedSender<Event>>,
75 ) -> Result<Report, TransferError> {
76 self.run(Dir::Download, objects, r#ref, events).await
77 }
78
79 pub async fn upload(
83 &self,
84 objects: Vec<ObjectSpec>,
85 r#ref: Option<Ref>,
86 events: Option<UnboundedSender<Event>>,
87 ) -> Result<Report, TransferError> {
88 self.run(Dir::Upload, objects, r#ref, events).await
89 }
90
91 async fn run(
92 &self,
93 dir: Dir,
94 objects: Vec<ObjectSpec>,
95 r#ref: Option<Ref>,
96 events: Option<UnboundedSender<Event>>,
97 ) -> Result<Report, TransferError> {
98 if objects.is_empty() {
99 return Ok(Report::default());
100 }
101 let batch_size = self.config.batch_size.max(1);
106 if objects.len() > batch_size {
107 let mut report = Report::default();
108 for chunk in objects.chunks(batch_size) {
109 let chunk_report =
110 Box::pin(self.run(dir, chunk.to_vec(), r#ref.clone(), events.clone())).await?;
111 report.succeeded.extend(chunk_report.succeeded);
112 report.failed.extend(chunk_report.failed);
113 }
114 return Ok(report);
115 }
116
117 let req_sizes: std::collections::HashMap<String, u64> =
121 objects.iter().map(|o| (o.oid.clone(), o.size)).collect();
122
123 let mut objects = objects;
130 objects.sort_by_key(|o| std::cmp::Reverse(o.size));
131
132 let mut req = BatchRequest::new(dir.into(), objects);
133 if let Some(r) = r#ref {
134 req = req.with_ref(r);
135 }
136 if std::env::var_os("GIT_TRACE").is_some_and(|v| !v.is_empty() && v != "0") {
141 eprintln!("tq: sending batch of size {}", req.objects.len());
142 }
143 let resp: BatchResponse = self.api.batch(&req).await?;
144
145 if let Some(h) = resp.hash_algo.as_deref()
151 && !h.is_empty()
152 && !h.eq_ignore_ascii_case("sha256")
153 {
154 return Err(TransferError::UnsupportedHashAlgo(h.to_owned()));
155 }
156
157 let limit = Arc::new(Semaphore::new(self.config.concurrency.max(1)));
158 let mut join: JoinSet<(String, Result<(), TransferError>)> = JoinSet::new();
159
160 for mut obj in resp.objects {
161 if obj.size == 0
162 && let Some(s) = req_sizes.get(&obj.oid)
163 {
164 obj.size = *s;
165 }
166 if let Some(rewriter) = &self.config.url_rewriter
167 && let Some(actions) = obj.actions.as_mut()
168 {
169 for action in [
170 actions.download.as_mut(),
171 actions.upload.as_mut(),
172 actions.verify.as_mut(),
173 ]
174 .into_iter()
175 .flatten()
176 {
177 action.href = rewriter(&action.href);
178 }
179 }
180 let permit_src = limit.clone();
181 let http = self.http.clone();
182 let store = self.store.clone();
183 let config = self.config.clone();
184 let events = events.clone();
185 join.spawn(async move {
186 let _permit = permit_src.acquire_owned().await.expect("semaphore live");
187 let oid = obj.oid.clone();
188 let result = process_object(dir, &http, store, &config, obj, events.as_ref()).await;
189 (oid, result)
190 });
191 }
192
193 let mut report = Report::default();
194 while let Some(joined) = join.join_next().await {
195 let (oid, result) =
196 joined.map_err(|e| TransferError::Io(std::io::Error::other(e.to_string())))?;
197 match result {
198 Ok(()) => {
199 if let Some(s) = &events {
200 let _ = s.send(Event::Completed { oid: oid.clone() });
201 }
202 report.succeeded.push(oid);
203 }
204 Err(err) => {
205 if let Some(s) = &events {
206 let _ = s.send(Event::Failed {
207 oid: oid.clone(),
208 error: err.to_string(),
209 });
210 }
211 report.failed.push((oid, err));
212 }
213 }
214 }
215 Ok(report)
216 }
217}
218
219async fn process_object(
223 dir: Dir,
224 http: &reqwest::Client,
225 store: Arc<Store>,
226 config: &TransferConfig,
227 obj: ObjectResult,
228 events: Option<&UnboundedSender<Event>>,
229) -> Result<(), TransferError> {
230 if let Some(err) = obj.error {
231 return Err(TransferError::ServerObject(err));
232 }
233
234 if let Some(s) = events {
235 let _ = s.send(Event::Started {
236 oid: obj.oid.clone(),
237 size: obj.size,
238 });
239 }
240
241 match (dir, &obj.actions) {
242 (Dir::Download, Some(actions)) => {
243 let action = actions
244 .download
245 .as_ref()
246 .ok_or(TransferError::NoDownloadAction)?;
247 with_retry(config, || async {
248 basic::download(http, store.clone(), &obj.oid, action, events)
249 .await
250 .map(|_| ())
251 })
252 .await
253 }
254 (Dir::Download, None) => Err(TransferError::NoDownloadAction),
255 (Dir::Upload, Some(actions)) => {
256 with_retry(config, || async {
257 basic::upload(http, store.clone(), &obj.oid, obj.size, actions, events).await
258 })
259 .await
260 }
261 (Dir::Upload, None) => {
262 Ok(())
264 }
265 }
266}
267
268async fn with_retry<F, Fut>(config: &TransferConfig, mut op: F) -> Result<(), TransferError>
271where
272 F: FnMut() -> Fut,
273 Fut: std::future::Future<Output = Result<(), TransferError>>,
274{
275 let mut backoff = config.initial_backoff;
276 let mut last_err: Option<TransferError> = None;
277 for attempt in 0..config.max_attempts {
278 match op().await {
279 Ok(()) => return Ok(()),
280 Err(e) => {
281 let retry = e.is_retryable() && attempt + 1 < config.max_attempts;
282 last_err = Some(e);
283 if !retry {
284 break;
285 }
286 tokio::time::sleep(backoff).await;
287 backoff = (backoff * 2).min(config.backoff_max);
288 }
289 }
290 }
291 Err(last_err.expect("loop ran at least once"))
292}