use std::sync::Arc;
use std::time::Duration;
use git_lfs_api::{
BatchRequest, BatchResponse, Client as ApiClient, ObjectResult, ObjectSpec, Operation, Ref,
};
use git_lfs_store::Store;
use tokio::sync::Semaphore;
use tokio::sync::mpsc::UnboundedSender;
use tokio::task::JoinSet;
use crate::basic;
use crate::config::TransferConfig;
use crate::error::{Report, TransferError};
use crate::event::Event;
#[derive(Debug, Clone, Copy)]
enum Dir {
Download,
Upload,
}
impl From<Dir> for Operation {
fn from(d: Dir) -> Self {
match d {
Dir::Download => Operation::Download,
Dir::Upload => Operation::Upload,
}
}
}
#[derive(Clone)]
pub struct Transfer {
api: ApiClient,
store: Arc<Store>,
http: reqwest::Client,
config: TransferConfig,
}
impl Transfer {
pub fn new(api: ApiClient, store: Store, config: TransferConfig) -> Self {
Self::with_http_client(api, store, config, reqwest::Client::new())
}
pub fn with_http_client(
api: ApiClient,
store: Store,
config: TransferConfig,
http: reqwest::Client,
) -> Self {
Self {
api,
store: Arc::new(store),
http,
config,
}
}
pub async fn download(
&self,
objects: Vec<ObjectSpec>,
r#ref: Option<Ref>,
events: Option<UnboundedSender<Event>>,
) -> Result<Report, TransferError> {
self.run(Dir::Download, objects, r#ref, events).await
}
pub async fn upload(
&self,
objects: Vec<ObjectSpec>,
r#ref: Option<Ref>,
events: Option<UnboundedSender<Event>>,
) -> Result<Report, TransferError> {
self.run(Dir::Upload, objects, r#ref, events).await
}
async fn run(
&self,
dir: Dir,
objects: Vec<ObjectSpec>,
r#ref: Option<Ref>,
events: Option<UnboundedSender<Event>>,
) -> Result<Report, TransferError> {
if objects.is_empty() {
return Ok(Report::default());
}
let batch_size = self.config.batch_size.max(1);
if objects.len() > batch_size {
let mut report = Report::default();
for chunk in objects.chunks(batch_size) {
let chunk_report =
Box::pin(self.run(dir, chunk.to_vec(), r#ref.clone(), events.clone())).await?;
report.succeeded.extend(chunk_report.succeeded);
report.failed.extend(chunk_report.failed);
}
return Ok(report);
}
let req_sizes: std::collections::HashMap<String, u64> =
objects.iter().map(|o| (o.oid.clone(), o.size)).collect();
let mut objects = objects;
objects.sort_by_key(|o| std::cmp::Reverse(o.size));
let mut req = BatchRequest::new(dir.into(), objects);
if let Some(r) = r#ref {
req = req.with_ref(r);
}
let resp: BatchResponse = self.batch_with_retry(&req).await?;
if let Some(h) = resp.hash_algo.as_deref()
&& !h.is_empty()
&& !h.eq_ignore_ascii_case("sha256")
{
return Err(TransferError::UnsupportedHashAlgo(h.to_owned()));
}
let limit = Arc::new(Semaphore::new(self.config.concurrency.max(1)));
let mut join: JoinSet<(String, Result<(), TransferError>)> = JoinSet::new();
for mut obj in resp.objects {
if obj.size == 0
&& let Some(s) = req_sizes.get(&obj.oid)
{
obj.size = *s;
}
if let Some(rewriter) = &self.config.url_rewriter
&& let Some(actions) = obj.actions.as_mut()
{
for action in [
actions.download.as_mut(),
actions.upload.as_mut(),
actions.verify.as_mut(),
]
.into_iter()
.flatten()
{
action.href = rewriter(&action.href);
}
}
let permit_src = limit.clone();
let http = self.http.clone();
let store = self.store.clone();
let config = self.config.clone();
let events = events.clone();
join.spawn(async move {
let _permit = permit_src.acquire_owned().await.expect("semaphore live");
let oid = obj.oid.clone();
let result = process_object(dir, &http, store, &config, obj, events.as_ref()).await;
(oid, result)
});
}
let mut report = Report::default();
while let Some(joined) = join.join_next().await {
let (oid, result) =
joined.map_err(|e| TransferError::Io(std::io::Error::other(e.to_string())))?;
match result {
Ok(()) => {
if let Some(s) = &events {
let _ = s.send(Event::Completed { oid: oid.clone() });
}
report.succeeded.push(oid);
}
Err(err) => {
if let Some(s) = &events {
let _ = s.send(Event::Failed {
oid: oid.clone(),
error: err.to_string(),
});
}
report.failed.push((oid, err));
}
}
}
Ok(report)
}
async fn batch_with_retry(&self, req: &BatchRequest) -> Result<BatchResponse, TransferError> {
let mut backoff = self.config.initial_backoff;
let mut retry_count: u32 = 0;
let mut last_err: Option<git_lfs_api::ApiError> = None;
for attempt in 0..self.config.max_attempts {
if trace_enabled() {
eprintln!("tq: sending batch of size {}", req.objects.len());
}
match self.api.batch(req).await {
Ok(resp) => return Ok(resp),
Err(e) => {
let retry = e.is_retryable() && attempt + 1 < self.config.max_attempts;
if !retry {
return Err(TransferError::BatchResponse(Box::new(e)));
}
let server_delay = e.retry_after();
let delay = server_delay.unwrap_or(backoff);
retry_count += 1;
if trace_enabled() {
let secs = delay.as_secs_f64();
for obj in &req.objects {
eprintln!(
"tq: enqueue retry #{retry_count} after {secs:.2}s for {:?} (size: {}): {e}",
obj.oid, obj.size
);
}
}
last_err = Some(e);
tokio::time::sleep(delay).await;
if server_delay.is_none() {
backoff = (backoff * 2).min(self.config.backoff_max);
}
}
}
}
Err(TransferError::BatchResponse(Box::new(
last_err.expect("loop ran at least once"),
)))
}
}
fn trace_enabled() -> bool {
std::env::var_os("GIT_TRACE").is_some_and(|v| !v.is_empty() && v != "0")
}
async fn process_object(
dir: Dir,
http: &reqwest::Client,
store: Arc<Store>,
config: &TransferConfig,
obj: ObjectResult,
events: Option<&UnboundedSender<Event>>,
) -> Result<(), TransferError> {
if let Some(err) = obj.error {
return Err(TransferError::ServerObject(err));
}
if let Some(s) = events {
let _ = s.send(Event::Started {
oid: obj.oid.clone(),
size: obj.size,
});
}
match (dir, &obj.actions) {
(Dir::Download, Some(actions)) => {
let action = actions
.download
.as_ref()
.ok_or(TransferError::NoDownloadAction)?;
with_retry(config, &obj.oid, obj.size, || async {
basic::download(http, store.clone(), &obj.oid, obj.size, action, events)
.await
.map(|_| ())
})
.await
}
(Dir::Download, None) => Err(TransferError::NoDownloadAction),
(Dir::Upload, Some(actions)) => {
with_retry(config, &obj.oid, obj.size, || async {
basic::upload(
http,
store.clone(),
&obj.oid,
obj.size,
actions,
config.detect_content_type,
events,
)
.await
})
.await
}
(Dir::Upload, None) => {
Ok(())
}
}
}
async fn with_retry<F, Fut>(
config: &TransferConfig,
oid: &str,
size: u64,
mut op: F,
) -> Result<(), TransferError>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<(), TransferError>>,
{
let mut backoff = config.initial_backoff;
let mut retry_count: u32 = 0;
let mut last_err: Option<TransferError> = None;
for attempt in 0..config.max_attempts {
match op().await {
Ok(()) => return Ok(()),
Err(e) => {
let retry = e.is_retryable() && attempt + 1 < config.max_attempts;
if !retry {
last_err = Some(e);
break;
}
let delay = e.retry_after().unwrap_or(backoff);
retry_count += 1;
emit_retry_trace(oid, size, retry_count, delay, &e);
last_err = Some(e);
tokio::time::sleep(delay).await;
if last_err
.as_ref()
.and_then(TransferError::retry_after)
.is_none()
{
backoff = (backoff * 2).min(config.backoff_max);
}
}
}
}
Err(last_err.expect("loop ran at least once"))
}
fn emit_retry_trace(oid: &str, size: u64, count: u32, delay: Duration, err: &TransferError) {
if !trace_enabled() {
return;
}
let secs = delay.as_secs_f64();
if err.retry_after().is_some() {
eprintln!("tq: retrying object {oid} after {secs:.2}s");
} else {
eprintln!("tq: retrying object {oid}: {err}");
}
eprintln!("tq: enqueue retry #{count} after {secs:.2}s for {oid:?} (size: {size}): {err}");
}