use std::sync::Arc;
use git_lfs_api::{
BatchRequest, BatchResponse, Client as ApiClient, ObjectResult, ObjectSpec, Operation, Ref,
};
use git_lfs_store::Store;
use tokio::sync::Semaphore;
use tokio::sync::mpsc::UnboundedSender;
use tokio::task::JoinSet;
use crate::basic;
use crate::config::TransferConfig;
use crate::error::{Report, TransferError};
use crate::event::Event;
#[derive(Debug, Clone, Copy)]
enum Dir {
Download,
Upload,
}
impl From<Dir> for Operation {
fn from(d: Dir) -> Self {
match d {
Dir::Download => Operation::Download,
Dir::Upload => Operation::Upload,
}
}
}
#[derive(Clone)]
pub struct Transfer {
api: ApiClient,
store: Arc<Store>,
http: reqwest::Client,
config: TransferConfig,
}
impl Transfer {
pub fn new(api: ApiClient, store: Store, config: TransferConfig) -> Self {
Self::with_http_client(api, store, config, reqwest::Client::new())
}
pub fn with_http_client(
api: ApiClient,
store: Store,
config: TransferConfig,
http: reqwest::Client,
) -> Self {
Self {
api,
store: Arc::new(store),
http,
config,
}
}
pub async fn download(
&self,
objects: Vec<ObjectSpec>,
r#ref: Option<Ref>,
events: Option<UnboundedSender<Event>>,
) -> Result<Report, TransferError> {
self.run(Dir::Download, objects, r#ref, events).await
}
pub async fn upload(
&self,
objects: Vec<ObjectSpec>,
r#ref: Option<Ref>,
events: Option<UnboundedSender<Event>>,
) -> Result<Report, TransferError> {
self.run(Dir::Upload, objects, r#ref, events).await
}
async fn run(
&self,
dir: Dir,
objects: Vec<ObjectSpec>,
r#ref: Option<Ref>,
events: Option<UnboundedSender<Event>>,
) -> Result<Report, TransferError> {
if objects.is_empty() {
return Ok(Report::default());
}
let req_sizes: std::collections::HashMap<String, u64> =
objects.iter().map(|o| (o.oid.clone(), o.size)).collect();
let mut req = BatchRequest::new(dir.into(), objects);
if let Some(r) = r#ref {
req = req.with_ref(r);
}
if std::env::var_os("GIT_TRACE").is_some_and(|v| !v.is_empty() && v != "0") {
eprintln!("tq: sending batch of size {}", req.objects.len());
}
let resp: BatchResponse = self.api.batch(&req).await?;
let limit = Arc::new(Semaphore::new(self.config.concurrency.max(1)));
let mut join: JoinSet<(String, Result<(), TransferError>)> = JoinSet::new();
for mut obj in resp.objects {
if obj.size == 0
&& let Some(s) = req_sizes.get(&obj.oid)
{
obj.size = *s;
}
if let Some(rewriter) = &self.config.url_rewriter
&& let Some(actions) = obj.actions.as_mut()
{
for action in [
actions.download.as_mut(),
actions.upload.as_mut(),
actions.verify.as_mut(),
]
.into_iter()
.flatten()
{
action.href = rewriter(&action.href);
}
}
let permit_src = limit.clone();
let http = self.http.clone();
let store = self.store.clone();
let config = self.config.clone();
let events = events.clone();
join.spawn(async move {
let _permit = permit_src.acquire_owned().await.expect("semaphore live");
let oid = obj.oid.clone();
let result = process_object(dir, &http, store, &config, obj, events.as_ref()).await;
(oid, result)
});
}
let mut report = Report::default();
while let Some(joined) = join.join_next().await {
let (oid, result) =
joined.map_err(|e| TransferError::Io(std::io::Error::other(e.to_string())))?;
match result {
Ok(()) => {
if let Some(s) = &events {
let _ = s.send(Event::Completed { oid: oid.clone() });
}
report.succeeded.push(oid);
}
Err(err) => {
if let Some(s) = &events {
let _ = s.send(Event::Failed {
oid: oid.clone(),
error: err.to_string(),
});
}
report.failed.push((oid, err));
}
}
}
Ok(report)
}
}
async fn process_object(
dir: Dir,
http: &reqwest::Client,
store: Arc<Store>,
config: &TransferConfig,
obj: ObjectResult,
events: Option<&UnboundedSender<Event>>,
) -> Result<(), TransferError> {
if let Some(err) = obj.error {
return Err(TransferError::ServerObject(err));
}
if let Some(s) = events {
let _ = s.send(Event::Started {
oid: obj.oid.clone(),
size: obj.size,
});
}
match (dir, &obj.actions) {
(Dir::Download, Some(actions)) => {
let action = actions
.download
.as_ref()
.ok_or(TransferError::NoDownloadAction)?;
with_retry(config, || async {
basic::download(http, store.clone(), &obj.oid, action, events)
.await
.map(|_| ())
})
.await
}
(Dir::Download, None) => Err(TransferError::NoDownloadAction),
(Dir::Upload, Some(actions)) => {
with_retry(config, || async {
basic::upload(http, store.clone(), &obj.oid, obj.size, actions, events).await
})
.await
}
(Dir::Upload, None) => {
Ok(())
}
}
}
async fn with_retry<F, Fut>(config: &TransferConfig, mut op: F) -> Result<(), TransferError>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<(), TransferError>>,
{
let mut backoff = config.initial_backoff;
let mut last_err: Option<TransferError> = None;
for attempt in 0..config.max_attempts {
match op().await {
Ok(()) => return Ok(()),
Err(e) => {
let retry = e.is_retryable() && attempt + 1 < config.max_attempts;
last_err = Some(e);
if !retry {
break;
}
tokio::time::sleep(backoff).await;
backoff = (backoff * 2).min(config.backoff_max);
}
}
}
Err(last_err.expect("loop ran at least once"))
}