1use std::sync::Arc;
4
5use git_lfs_api::{
6 BatchRequest, BatchResponse, Client as ApiClient, ObjectResult, ObjectSpec, Operation, Ref,
7};
8use git_lfs_store::Store;
9use tokio::sync::Semaphore;
10use tokio::sync::mpsc::UnboundedSender;
11use tokio::task::JoinSet;
12
13use crate::basic;
14use crate::config::TransferConfig;
15use crate::error::{Report, TransferError};
16use crate::event::Event;
17
18#[derive(Debug, Clone, Copy)]
21enum Dir {
22 Download,
23 Upload,
24}
25
26impl From<Dir> for Operation {
27 fn from(d: Dir) -> Self {
28 match d {
29 Dir::Download => Operation::Download,
30 Dir::Upload => Operation::Upload,
31 }
32 }
33}
34
35#[derive(Clone)]
38pub struct Transfer {
39 api: ApiClient,
40 store: Arc<Store>,
41 http: reqwest::Client,
42 config: TransferConfig,
43}
44
45impl Transfer {
46 pub fn new(api: ApiClient, store: Store, config: TransferConfig) -> Self {
50 Self::with_http_client(api, store, config, reqwest::Client::new())
51 }
52
53 pub fn with_http_client(
54 api: ApiClient,
55 store: Store,
56 config: TransferConfig,
57 http: reqwest::Client,
58 ) -> Self {
59 Self {
60 api,
61 store: Arc::new(store),
62 http,
63 config,
64 }
65 }
66
67 pub async fn download(
71 &self,
72 objects: Vec<ObjectSpec>,
73 r#ref: Option<Ref>,
74 events: Option<UnboundedSender<Event>>,
75 ) -> Result<Report, TransferError> {
76 self.run(Dir::Download, objects, r#ref, events).await
77 }
78
79 pub async fn upload(
83 &self,
84 objects: Vec<ObjectSpec>,
85 r#ref: Option<Ref>,
86 events: Option<UnboundedSender<Event>>,
87 ) -> Result<Report, TransferError> {
88 self.run(Dir::Upload, objects, r#ref, events).await
89 }
90
91 async fn run(
92 &self,
93 dir: Dir,
94 objects: Vec<ObjectSpec>,
95 r#ref: Option<Ref>,
96 events: Option<UnboundedSender<Event>>,
97 ) -> Result<Report, TransferError> {
98 if objects.is_empty() {
99 return Ok(Report::default());
100 }
101
102 let req_sizes: std::collections::HashMap<String, u64> =
106 objects.iter().map(|o| (o.oid.clone(), o.size)).collect();
107
108 let mut req = BatchRequest::new(dir.into(), objects);
109 if let Some(r) = r#ref {
110 req = req.with_ref(r);
111 }
112 let resp: BatchResponse = self.api.batch(&req).await?;
113
114 let limit = Arc::new(Semaphore::new(self.config.concurrency.max(1)));
115 let mut join: JoinSet<(String, Result<(), TransferError>)> = JoinSet::new();
116
117 for mut obj in resp.objects {
118 if obj.size == 0 {
119 if let Some(s) = req_sizes.get(&obj.oid) {
120 obj.size = *s;
121 }
122 }
123 let permit_src = limit.clone();
124 let http = self.http.clone();
125 let store = self.store.clone();
126 let config = self.config.clone();
127 let events = events.clone();
128 join.spawn(async move {
129 let _permit = permit_src.acquire_owned().await.expect("semaphore live");
130 let oid = obj.oid.clone();
131 let result = process_object(dir, &http, store, &config, obj, events.as_ref()).await;
132 (oid, result)
133 });
134 }
135
136 let mut report = Report::default();
137 while let Some(joined) = join.join_next().await {
138 let (oid, result) = joined.map_err(|e| TransferError::Io(std::io::Error::other(e.to_string())))?;
139 match result {
140 Ok(()) => {
141 if let Some(s) = &events {
142 let _ = s.send(Event::Completed { oid: oid.clone() });
143 }
144 report.succeeded.push(oid);
145 }
146 Err(err) => {
147 if let Some(s) = &events {
148 let _ = s.send(Event::Failed {
149 oid: oid.clone(),
150 error: err.to_string(),
151 });
152 }
153 report.failed.push((oid, err));
154 }
155 }
156 }
157 Ok(report)
158 }
159}
160
161async fn process_object(
165 dir: Dir,
166 http: &reqwest::Client,
167 store: Arc<Store>,
168 config: &TransferConfig,
169 obj: ObjectResult,
170 events: Option<&UnboundedSender<Event>>,
171) -> Result<(), TransferError> {
172 if let Some(err) = obj.error {
173 return Err(TransferError::ServerObject(err));
174 }
175
176 if let Some(s) = events {
177 let _ = s.send(Event::Started {
178 oid: obj.oid.clone(),
179 size: obj.size,
180 });
181 }
182
183 match (dir, &obj.actions) {
184 (Dir::Download, Some(actions)) => {
185 let action = actions
186 .download
187 .as_ref()
188 .ok_or(TransferError::NoDownloadAction)?;
189 with_retry(config, || async {
190 basic::download(http, store.clone(), &obj.oid, action, events).await.map(|_| ())
191 })
192 .await
193 }
194 (Dir::Download, None) => Err(TransferError::NoDownloadAction),
195 (Dir::Upload, Some(actions)) => {
196 with_retry(config, || async {
197 basic::upload(http, store.clone(), &obj.oid, obj.size, actions, events).await
198 })
199 .await
200 }
201 (Dir::Upload, None) => {
202 Ok(())
204 }
205 }
206}
207
208async fn with_retry<F, Fut>(config: &TransferConfig, mut op: F) -> Result<(), TransferError>
211where
212 F: FnMut() -> Fut,
213 Fut: std::future::Future<Output = Result<(), TransferError>>,
214{
215 let mut backoff = config.initial_backoff;
216 let mut last_err: Option<TransferError> = None;
217 for attempt in 0..config.max_attempts {
218 match op().await {
219 Ok(()) => return Ok(()),
220 Err(e) => {
221 let retry = e.is_retryable() && attempt + 1 < config.max_attempts;
222 last_err = Some(e);
223 if !retry {
224 break;
225 }
226 tokio::time::sleep(backoff).await;
227 backoff = (backoff * 2).min(config.backoff_max);
228 }
229 }
230 }
231 Err(last_err.expect("loop ran at least once"))
232}