use std::sync::{Arc, Mutex};
use git_lfs_creds::{Credentials, Helper, Query};
use reqwest::header::{ACCEPT, CONTENT_TYPE, HeaderMap, HeaderValue};
use reqwest::{Method, RequestBuilder, Response};
use serde::Serialize;
use serde::de::DeserializeOwned;
use url::Url;
use crate::auth::Auth;
use crate::error::ApiError;
pub(crate) const LFS_MEDIA_TYPE: &str = "application/vnd.git-lfs+json";
#[derive(Clone)]
pub struct Client {
pub(crate) endpoint: Url,
pub(crate) http: reqwest::Client,
pub(crate) auth: Arc<Mutex<Auth>>,
pub(crate) credentials: Option<Arc<dyn Helper>>,
pub(crate) filled: Arc<Mutex<Option<(Query, Credentials)>>>,
}
impl std::fmt::Debug for Client {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Client")
.field("endpoint", &self.endpoint)
.field("auth", &self.auth)
.field("has_credential_helper", &self.credentials.is_some())
.finish()
}
}
impl Client {
pub fn new(endpoint: Url, auth: Auth) -> Self {
Self::with_http_client(endpoint, auth, reqwest::Client::new())
}
pub fn with_http_client(endpoint: Url, auth: Auth, http: reqwest::Client) -> Self {
Self {
endpoint,
http,
auth: Arc::new(Mutex::new(auth)),
credentials: None,
filled: Arc::new(Mutex::new(None)),
}
}
#[must_use]
pub fn with_credential_helper(mut self, helper: Arc<dyn Helper>) -> Self {
self.credentials = Some(helper);
self
}
pub(crate) fn url(&self, path: &str) -> Result<Url, ApiError> {
let mut base = self.endpoint.clone();
if !base.path().ends_with('/') {
let p = format!("{}/", base.path());
base.set_path(&p);
}
Ok(base.join(path)?)
}
pub(crate) fn request(&self, method: Method, url: Url) -> RequestBuilder {
let auth = self.auth.lock().unwrap().clone();
let mut headers = HeaderMap::new();
headers.insert(ACCEPT, HeaderValue::from_static(LFS_MEDIA_TYPE));
let req = self.http.request(method, url).headers(headers);
auth.apply(req)
}
fn cred_query(&self) -> Query {
Query::from_url(&self.endpoint).without_path()
}
pub(crate) async fn post_json<B, R>(&self, path: &str, body: &B) -> Result<R, ApiError>
where
B: Serialize + ?Sized,
R: DeserializeOwned,
{
let url = self.url(path)?;
let body_bytes = serde_json::to_vec(body)
.map_err(|e| ApiError::Decode(format!("serializing request body: {e}")))?;
self.send_with_auth_retry(|| {
self.request(Method::POST, url.clone())
.header(CONTENT_TYPE, LFS_MEDIA_TYPE)
.body(body_bytes.clone())
})
.await
}
pub(crate) async fn get_json<Q, R>(&self, path: &str, query: &Q) -> Result<R, ApiError>
where
Q: Serialize + ?Sized,
R: DeserializeOwned,
{
let url = self.url(path)?;
let qs = serde_urlencoded::to_string(query)
.map_err(|e| ApiError::Decode(format!("serializing query: {e}")))?;
self.send_with_auth_retry(|| {
let mut u = url.clone();
if !qs.is_empty() {
u.set_query(Some(&qs));
}
self.request(Method::GET, u)
})
.await
}
pub(crate) async fn send_with_auth_retry_response<F>(
&self,
build: F,
) -> Result<Response, ApiError>
where
F: Fn() -> RequestBuilder,
{
let resp = build().send().await?;
if resp.status().is_success() {
self.approve_filled().await;
return Ok(resp);
}
if resp.status().as_u16() != 401 {
return Ok(resp);
}
let Some(helper) = self.credentials.clone() else {
return Ok(resp);
};
let query = self.cred_query();
self.reject_filled().await;
let creds = match fill_blocking(helper.clone(), query.clone()).await? {
Some(c) => c,
None => return Ok(resp),
};
{
let mut auth = self.auth.lock().unwrap();
*auth = Auth::Basic {
username: creds.username.clone(),
password: creds.password.clone(),
};
}
{
let mut filled = self.filled.lock().unwrap();
*filled = Some((query.clone(), creds.clone()));
}
let resp2 = build().send().await?;
if resp2.status().is_success() {
approve_blocking(helper, query, creds).await?;
} else if resp2.status().as_u16() == 401 {
reject_blocking(helper, query, creds).await?;
*self.filled.lock().unwrap() = None;
*self.auth.lock().unwrap() = Auth::None;
}
Ok(resp2)
}
async fn send_with_auth_retry<F, R>(&self, build: F) -> Result<R, ApiError>
where
F: Fn() -> RequestBuilder,
R: DeserializeOwned,
{
let resp = self.send_with_auth_retry_response(build).await?;
decode::<R>(resp).await
}
async fn approve_filled(&self) {
let snapshot = self.filled.lock().unwrap().clone();
if let (Some(helper), Some((q, c))) = (self.credentials.clone(), snapshot) {
let _ = approve_blocking(helper, q, c).await;
}
}
async fn reject_filled(&self) {
let snapshot = self.filled.lock().unwrap().take();
if let (Some(helper), Some((q, c))) = (self.credentials.clone(), snapshot) {
let _ = reject_blocking(helper, q, c).await;
*self.auth.lock().unwrap() = Auth::None;
}
}
}
pub(crate) async fn decode<R: DeserializeOwned>(resp: Response) -> Result<R, ApiError> {
let status = resp.status();
if status.is_success() {
let bytes = resp.bytes().await?;
return serde_json::from_slice(&bytes).map_err(|e| ApiError::Decode(e.to_string()));
}
let lfs_authenticate = resp
.headers()
.get("LFS-Authenticate")
.and_then(|v| v.to_str().ok())
.map(str::to_owned);
let bytes = resp.bytes().await.unwrap_or_default();
Err(ApiError::Status {
status: status.as_u16(),
lfs_authenticate,
body: serde_json::from_slice(&bytes).ok(),
})
}
async fn fill_blocking(
helper: Arc<dyn Helper>,
query: Query,
) -> Result<Option<Credentials>, ApiError> {
tokio::task::spawn_blocking(move || helper.fill(&query))
.await
.map_err(|e| ApiError::Decode(format!("credential helper join: {e}")))?
.map_err(|e| ApiError::Decode(format!("credential helper: {e}")))
}
async fn approve_blocking(
helper: Arc<dyn Helper>,
query: Query,
creds: Credentials,
) -> Result<(), ApiError> {
tokio::task::spawn_blocking(move || helper.approve(&query, &creds))
.await
.map_err(|e| ApiError::Decode(format!("credential helper join: {e}")))?
.map_err(|e| ApiError::Decode(format!("credential helper approve: {e}")))
}
async fn reject_blocking(
helper: Arc<dyn Helper>,
query: Query,
creds: Credentials,
) -> Result<(), ApiError> {
tokio::task::spawn_blocking(move || helper.reject(&query, &creds))
.await
.map_err(|e| ApiError::Decode(format!("credential helper join: {e}")))?
.map_err(|e| ApiError::Decode(format!("credential helper reject: {e}")))
}