Skip to main content

git_lfs_transfer/
transfer.rs

1//! Top-level transfer orchestrator: batch + concurrent per-object transfer.
2
3use std::sync::Arc;
4
5use git_lfs_api::{
6    BatchRequest, BatchResponse, Client as ApiClient, ObjectResult, ObjectSpec, Operation, Ref,
7};
8use git_lfs_store::Store;
9use tokio::sync::Semaphore;
10use tokio::sync::mpsc::UnboundedSender;
11use tokio::task::JoinSet;
12
13use crate::basic;
14use crate::config::TransferConfig;
15use crate::error::{Report, TransferError};
16use crate::event::Event;
17
18/// Direction of a single transfer batch — used internally to share the
19/// fan-out machinery between [`Transfer::download`] and [`Transfer::upload`].
20#[derive(Debug, Clone, Copy)]
21enum Dir {
22    Download,
23    Upload,
24}
25
26impl From<Dir> for Operation {
27    fn from(d: Dir) -> Self {
28        match d {
29            Dir::Download => Operation::Download,
30            Dir::Upload => Operation::Upload,
31        }
32    }
33}
34
35/// Concurrent transfer queue. One [`Transfer`] is bound to one LFS endpoint
36/// (the `api` client) and one local store; create more if you need more.
37#[derive(Clone)]
38pub struct Transfer {
39    api: ApiClient,
40    store: Arc<Store>,
41    http: reqwest::Client,
42    config: TransferConfig,
43}
44
45impl Transfer {
46    /// Build a transfer queue. The `reqwest::Client` used for the action-URL
47    /// transfers is created fresh; if you need to share a connection pool,
48    /// use [`with_http_client`](Self::with_http_client).
49    pub fn new(api: ApiClient, store: Store, config: TransferConfig) -> Self {
50        Self::with_http_client(api, store, config, reqwest::Client::new())
51    }
52
53    pub fn with_http_client(
54        api: ApiClient,
55        store: Store,
56        config: TransferConfig,
57        http: reqwest::Client,
58    ) -> Self {
59        Self {
60            api,
61            store: Arc::new(store),
62            http,
63            config,
64        }
65    }
66
67    /// Download the given objects into the local store. Each object is
68    /// hash-verified by the store before being committed; corrupt downloads
69    /// are surfaced in [`Report::failed`].
70    pub async fn download(
71        &self,
72        objects: Vec<ObjectSpec>,
73        r#ref: Option<Ref>,
74        events: Option<UnboundedSender<Event>>,
75    ) -> Result<Report, TransferError> {
76        self.run(Dir::Download, objects, r#ref, events).await
77    }
78
79    /// Upload the given objects from the local store. Objects the server
80    /// already has are reported in [`Report::succeeded`] without any byte
81    /// transfer.
82    pub async fn upload(
83        &self,
84        objects: Vec<ObjectSpec>,
85        r#ref: Option<Ref>,
86        events: Option<UnboundedSender<Event>>,
87    ) -> Result<Report, TransferError> {
88        self.run(Dir::Upload, objects, r#ref, events).await
89    }
90
91    async fn run(
92        &self,
93        dir: Dir,
94        objects: Vec<ObjectSpec>,
95        r#ref: Option<Ref>,
96        events: Option<UnboundedSender<Event>>,
97    ) -> Result<Report, TransferError> {
98        if objects.is_empty() {
99            return Ok(Report::default());
100        }
101
102        // Index the request's sizes by oid so we can fill them back in
103        // for servers that omit `size` from the response (the upstream
104        // test fixture, plus at least one production server, drop it).
105        let req_sizes: std::collections::HashMap<String, u64> =
106            objects.iter().map(|o| (o.oid.clone(), o.size)).collect();
107
108        let mut req = BatchRequest::new(dir.into(), objects);
109        if let Some(r) = r#ref {
110            req = req.with_ref(r);
111        }
112        let resp: BatchResponse = self.api.batch(&req).await?;
113
114        let limit = Arc::new(Semaphore::new(self.config.concurrency.max(1)));
115        let mut join: JoinSet<(String, Result<(), TransferError>)> = JoinSet::new();
116
117        for mut obj in resp.objects {
118            if obj.size == 0 {
119                if let Some(s) = req_sizes.get(&obj.oid) {
120                    obj.size = *s;
121                }
122            }
123            let permit_src = limit.clone();
124            let http = self.http.clone();
125            let store = self.store.clone();
126            let config = self.config.clone();
127            let events = events.clone();
128            join.spawn(async move {
129                let _permit = permit_src.acquire_owned().await.expect("semaphore live");
130                let oid = obj.oid.clone();
131                let result = process_object(dir, &http, store, &config, obj, events.as_ref()).await;
132                (oid, result)
133            });
134        }
135
136        let mut report = Report::default();
137        while let Some(joined) = join.join_next().await {
138            let (oid, result) = joined.map_err(|e| TransferError::Io(std::io::Error::other(e.to_string())))?;
139            match result {
140                Ok(()) => {
141                    if let Some(s) = &events {
142                        let _ = s.send(Event::Completed { oid: oid.clone() });
143                    }
144                    report.succeeded.push(oid);
145                }
146                Err(err) => {
147                    if let Some(s) = &events {
148                        let _ = s.send(Event::Failed {
149                            oid: oid.clone(),
150                            error: err.to_string(),
151                        });
152                    }
153                    report.failed.push((oid, err));
154                }
155            }
156        }
157        Ok(report)
158    }
159}
160
161/// Handle one [`ObjectResult`]: emit Started, run with retry, return final
162/// outcome. Completed/Failed events are emitted by the caller so we can
163/// move the error into the Report without cloning.
164async fn process_object(
165    dir: Dir,
166    http: &reqwest::Client,
167    store: Arc<Store>,
168    config: &TransferConfig,
169    obj: ObjectResult,
170    events: Option<&UnboundedSender<Event>>,
171) -> Result<(), TransferError> {
172    if let Some(err) = obj.error {
173        return Err(TransferError::ServerObject(err));
174    }
175
176    if let Some(s) = events {
177        let _ = s.send(Event::Started {
178            oid: obj.oid.clone(),
179            size: obj.size,
180        });
181    }
182
183    match (dir, &obj.actions) {
184        (Dir::Download, Some(actions)) => {
185            let action = actions
186                .download
187                .as_ref()
188                .ok_or(TransferError::NoDownloadAction)?;
189            with_retry(config, || async {
190                basic::download(http, store.clone(), &obj.oid, action, events).await.map(|_| ())
191            })
192            .await
193        }
194        (Dir::Download, None) => Err(TransferError::NoDownloadAction),
195        (Dir::Upload, Some(actions)) => {
196            with_retry(config, || async {
197                basic::upload(http, store.clone(), &obj.oid, obj.size, actions, events).await
198            })
199            .await
200        }
201        (Dir::Upload, None) => {
202            // Server already has it — no actions means no-op, treated as success.
203            Ok(())
204        }
205    }
206}
207
208/// Run `op` with exponential-backoff retry. Stops on non-retryable errors
209/// or when `max_attempts` is reached.
210async fn with_retry<F, Fut>(config: &TransferConfig, mut op: F) -> Result<(), TransferError>
211where
212    F: FnMut() -> Fut,
213    Fut: std::future::Future<Output = Result<(), TransferError>>,
214{
215    let mut backoff = config.initial_backoff;
216    let mut last_err: Option<TransferError> = None;
217    for attempt in 0..config.max_attempts {
218        match op().await {
219            Ok(()) => return Ok(()),
220            Err(e) => {
221                let retry = e.is_retryable() && attempt + 1 < config.max_attempts;
222                last_err = Some(e);
223                if !retry {
224                    break;
225                }
226                tokio::time::sleep(backoff).await;
227                backoff = (backoff * 2).min(config.backoff_max);
228            }
229        }
230    }
231    Err(last_err.expect("loop ran at least once"))
232}