1use std::sync::Arc;
4
5use git_lfs_api::{
6 BatchRequest, BatchResponse, Client as ApiClient, ObjectResult, ObjectSpec, Operation, Ref,
7};
8use git_lfs_store::Store;
9use tokio::sync::Semaphore;
10use tokio::sync::mpsc::UnboundedSender;
11use tokio::task::JoinSet;
12
13use crate::basic;
14use crate::config::TransferConfig;
15use crate::error::{Report, TransferError};
16use crate::event::Event;
17
18#[derive(Debug, Clone, Copy)]
21enum Dir {
22 Download,
23 Upload,
24}
25
26impl From<Dir> for Operation {
27 fn from(d: Dir) -> Self {
28 match d {
29 Dir::Download => Operation::Download,
30 Dir::Upload => Operation::Upload,
31 }
32 }
33}
34
35#[derive(Clone)]
38pub struct Transfer {
39 api: ApiClient,
40 store: Arc<Store>,
41 http: reqwest::Client,
42 config: TransferConfig,
43}
44
45impl Transfer {
46 pub fn new(api: ApiClient, store: Store, config: TransferConfig) -> Self {
50 Self::with_http_client(api, store, config, reqwest::Client::new())
51 }
52
53 pub fn with_http_client(
54 api: ApiClient,
55 store: Store,
56 config: TransferConfig,
57 http: reqwest::Client,
58 ) -> Self {
59 Self {
60 api,
61 store: Arc::new(store),
62 http,
63 config,
64 }
65 }
66
67 pub async fn download(
71 &self,
72 objects: Vec<ObjectSpec>,
73 r#ref: Option<Ref>,
74 events: Option<UnboundedSender<Event>>,
75 ) -> Result<Report, TransferError> {
76 self.run(Dir::Download, objects, r#ref, events).await
77 }
78
79 pub async fn upload(
83 &self,
84 objects: Vec<ObjectSpec>,
85 r#ref: Option<Ref>,
86 events: Option<UnboundedSender<Event>>,
87 ) -> Result<Report, TransferError> {
88 self.run(Dir::Upload, objects, r#ref, events).await
89 }
90
91 async fn run(
92 &self,
93 dir: Dir,
94 objects: Vec<ObjectSpec>,
95 r#ref: Option<Ref>,
96 events: Option<UnboundedSender<Event>>,
97 ) -> Result<Report, TransferError> {
98 if objects.is_empty() {
99 return Ok(Report::default());
100 }
101
102 let req_sizes: std::collections::HashMap<String, u64> =
106 objects.iter().map(|o| (o.oid.clone(), o.size)).collect();
107
108 let mut req = BatchRequest::new(dir.into(), objects);
109 if let Some(r) = r#ref {
110 req = req.with_ref(r);
111 }
112 if std::env::var_os("GIT_TRACE").is_some_and(|v| !v.is_empty() && v != "0") {
117 eprintln!("tq: sending batch of size {}", req.objects.len());
118 }
119 let resp: BatchResponse = self.api.batch(&req).await?;
120
121 let limit = Arc::new(Semaphore::new(self.config.concurrency.max(1)));
122 let mut join: JoinSet<(String, Result<(), TransferError>)> = JoinSet::new();
123
124 for mut obj in resp.objects {
125 if obj.size == 0
126 && let Some(s) = req_sizes.get(&obj.oid)
127 {
128 obj.size = *s;
129 }
130 if let Some(rewriter) = &self.config.url_rewriter
131 && let Some(actions) = obj.actions.as_mut()
132 {
133 for action in [
134 actions.download.as_mut(),
135 actions.upload.as_mut(),
136 actions.verify.as_mut(),
137 ]
138 .into_iter()
139 .flatten()
140 {
141 action.href = rewriter(&action.href);
142 }
143 }
144 let permit_src = limit.clone();
145 let http = self.http.clone();
146 let store = self.store.clone();
147 let config = self.config.clone();
148 let events = events.clone();
149 join.spawn(async move {
150 let _permit = permit_src.acquire_owned().await.expect("semaphore live");
151 let oid = obj.oid.clone();
152 let result = process_object(dir, &http, store, &config, obj, events.as_ref()).await;
153 (oid, result)
154 });
155 }
156
157 let mut report = Report::default();
158 while let Some(joined) = join.join_next().await {
159 let (oid, result) =
160 joined.map_err(|e| TransferError::Io(std::io::Error::other(e.to_string())))?;
161 match result {
162 Ok(()) => {
163 if let Some(s) = &events {
164 let _ = s.send(Event::Completed { oid: oid.clone() });
165 }
166 report.succeeded.push(oid);
167 }
168 Err(err) => {
169 if let Some(s) = &events {
170 let _ = s.send(Event::Failed {
171 oid: oid.clone(),
172 error: err.to_string(),
173 });
174 }
175 report.failed.push((oid, err));
176 }
177 }
178 }
179 Ok(report)
180 }
181}
182
183async fn process_object(
187 dir: Dir,
188 http: &reqwest::Client,
189 store: Arc<Store>,
190 config: &TransferConfig,
191 obj: ObjectResult,
192 events: Option<&UnboundedSender<Event>>,
193) -> Result<(), TransferError> {
194 if let Some(err) = obj.error {
195 return Err(TransferError::ServerObject(err));
196 }
197
198 if let Some(s) = events {
199 let _ = s.send(Event::Started {
200 oid: obj.oid.clone(),
201 size: obj.size,
202 });
203 }
204
205 match (dir, &obj.actions) {
206 (Dir::Download, Some(actions)) => {
207 let action = actions
208 .download
209 .as_ref()
210 .ok_or(TransferError::NoDownloadAction)?;
211 with_retry(config, || async {
212 basic::download(http, store.clone(), &obj.oid, action, events)
213 .await
214 .map(|_| ())
215 })
216 .await
217 }
218 (Dir::Download, None) => Err(TransferError::NoDownloadAction),
219 (Dir::Upload, Some(actions)) => {
220 with_retry(config, || async {
221 basic::upload(http, store.clone(), &obj.oid, obj.size, actions, events).await
222 })
223 .await
224 }
225 (Dir::Upload, None) => {
226 Ok(())
228 }
229 }
230}
231
232async fn with_retry<F, Fut>(config: &TransferConfig, mut op: F) -> Result<(), TransferError>
235where
236 F: FnMut() -> Fut,
237 Fut: std::future::Future<Output = Result<(), TransferError>>,
238{
239 let mut backoff = config.initial_backoff;
240 let mut last_err: Option<TransferError> = None;
241 for attempt in 0..config.max_attempts {
242 match op().await {
243 Ok(()) => return Ok(()),
244 Err(e) => {
245 let retry = e.is_retryable() && attempt + 1 < config.max_attempts;
246 last_err = Some(e);
247 if !retry {
248 break;
249 }
250 tokio::time::sleep(backoff).await;
251 backoff = (backoff * 2).min(config.backoff_max);
252 }
253 }
254 }
255 Err(last_err.expect("loop ran at least once"))
256}