use std::io::Write;
use std::sync::{Arc, Mutex};
use git_lfs_creds::{Credentials, Helper, Query};
use reqwest::header::{ACCEPT, CONTENT_TYPE, HeaderMap, HeaderName, HeaderValue};
use reqwest::{Method, RequestBuilder, Response};
use serde::Serialize;
use serde::de::DeserializeOwned;
use url::Url;
use crate::auth::Auth;
use crate::error::ApiError;
use crate::ssh::{SharedSshResolver, SshAuth, SshOperation};
pub(crate) const LFS_MEDIA_TYPE: &str = "application/vnd.git-lfs+json";
#[derive(Clone)]
pub struct Client {
pub(crate) endpoint: Url,
pub(crate) http: reqwest::Client,
pub(crate) auth: Arc<Mutex<Auth>>,
pub(crate) credentials: Option<Arc<dyn Helper>>,
pub(crate) filled: Arc<Mutex<Option<(Query, Credentials)>>>,
pub(crate) use_http_path: bool,
pub(crate) cred_url: Option<Url>,
pub(crate) ssh_resolver: Option<SharedSshResolver>,
pub(crate) extra_headers: Vec<(String, String)>,
}
impl std::fmt::Debug for Client {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Client")
.field("endpoint", &self.endpoint)
.field("auth", &self.auth)
.field("has_credential_helper", &self.credentials.is_some())
.finish()
}
}
impl Client {
pub fn new(endpoint: Url, auth: Auth) -> Self {
Self::with_http_client(endpoint, auth, reqwest::Client::new())
}
pub fn with_http_client(endpoint: Url, auth: Auth, http: reqwest::Client) -> Self {
Self {
endpoint,
http,
auth: Arc::new(Mutex::new(auth)),
credentials: None,
filled: Arc::new(Mutex::new(None)),
use_http_path: false,
cred_url: None,
ssh_resolver: None,
extra_headers: Vec::new(),
}
}
#[must_use]
pub fn with_extra_headers_for_verbose(mut self, headers: Vec<(String, String)>) -> Self {
self.extra_headers = headers;
self
}
#[must_use]
pub fn with_ssh_resolver(mut self, resolver: SharedSshResolver) -> Self {
self.ssh_resolver = Some(resolver);
self
}
#[must_use]
pub fn with_cred_url(mut self, url: Url) -> Self {
self.cred_url = Some(url);
self
}
#[must_use]
pub fn with_credential_helper(mut self, helper: Arc<dyn Helper>) -> Self {
self.credentials = Some(helper);
self
}
#[must_use]
pub fn with_use_http_path(mut self, on: bool) -> Self {
self.use_http_path = on;
self
}
pub fn endpoint(&self) -> &Url {
&self.endpoint
}
pub fn used_basic_auth(&self) -> bool {
matches!(*self.auth.lock().unwrap(), Auth::Basic { .. })
}
pub(crate) fn join(base: &Url, path: &str) -> Result<Url, ApiError> {
let mut base = base.clone();
if !base.path().ends_with('/') {
let p = format!("{}/", base.path());
base.set_path(&p);
}
Ok(base.join(path)?)
}
pub(crate) fn resolve_ssh(&self, operation: SshOperation) -> Result<(Url, SshAuth), ApiError> {
let Some(resolver) = self.ssh_resolver.as_ref() else {
return Ok((self.endpoint.clone(), SshAuth::default()));
};
let auth = resolver.resolve(operation)?;
let base = if auth.href.is_empty() {
self.endpoint.clone()
} else {
let mut u = Url::parse(&auth.href)
.map_err(|e| ApiError::Decode(format!("ssh href {:?}: {e}", auth.href)))?;
let path = u.path().to_owned();
let cleaned = collapse_slashes(&path);
if cleaned != path {
u.set_path(&cleaned);
}
u
};
Ok((base, auth))
}
pub(crate) fn request_with_headers(
&self,
method: Method,
url: Url,
ssh: &SshAuth,
) -> RequestBuilder {
let auth = self.auth.lock().unwrap().clone();
let mut headers = HeaderMap::new();
headers.insert(ACCEPT, HeaderValue::from_static(LFS_MEDIA_TYPE));
let req = self.http.request(method, url).headers(headers);
let mut req = auth.apply(req);
for (k, v) in &ssh.headers {
if let (Ok(name), Ok(value)) = (
HeaderName::try_from(k.as_str()),
HeaderValue::try_from(v.as_str()),
) {
req = req.header(name, value);
}
}
req
}
fn cred_query(&self) -> Query {
let url = self.cred_url.as_ref().unwrap_or(&self.endpoint);
let q = Query::from_url(url);
if self.use_http_path {
q
} else {
q.without_path()
}
}
fn cred_url_string(&self) -> String {
self.cred_url.as_ref().unwrap_or(&self.endpoint).to_string()
}
pub(crate) async fn post_json<B, R>(
&self,
path: &str,
body: &B,
op: SshOperation,
) -> Result<R, ApiError>
where
B: Serialize + ?Sized,
R: DeserializeOwned,
{
let (base, ssh) = self.resolve_ssh(op)?;
let url = Self::join(&base, path)?;
let body_bytes = serde_json::to_vec(body)
.map_err(|e| ApiError::Decode(format!("serializing request body: {e}")))?;
if std::env::var_os("GIT_CURL_VERBOSE").is_some_and(|v| !v.is_empty() && v != "0") {
let mut err = std::io::stderr().lock();
let _ = writeln!(err, "> POST {url}");
let _ = writeln!(err, "> Content-Type: {LFS_MEDIA_TYPE}");
for (name, value) in &self.extra_headers {
let _ = writeln!(err, "> {name}: {value}");
}
let _ = writeln!(err);
let _ = err.write_all(&body_bytes);
let _ = writeln!(err);
}
self.send_with_auth_retry(|| {
self.request_with_headers(Method::POST, url.clone(), &ssh)
.header(CONTENT_TYPE, LFS_MEDIA_TYPE)
.body(body_bytes.clone())
})
.await
}
pub(crate) async fn get_json<Q, R>(
&self,
path: &str,
query: &Q,
op: SshOperation,
) -> Result<R, ApiError>
where
Q: Serialize + ?Sized,
R: DeserializeOwned,
{
let (base, ssh) = self.resolve_ssh(op)?;
let url = Self::join(&base, path)?;
let qs = serde_urlencoded::to_string(query)
.map_err(|e| ApiError::Decode(format!("serializing query: {e}")))?;
self.send_with_auth_retry(|| {
let mut u = url.clone();
if !qs.is_empty() {
u.set_query(Some(&qs));
}
self.request_with_headers(Method::GET, u, &ssh)
})
.await
}
pub(crate) async fn send_with_auth_retry_response<F>(
&self,
build: F,
) -> Result<Response, ApiError>
where
F: Fn() -> RequestBuilder,
{
let filled_already = self.filled.lock().unwrap().is_some();
if filled_already && let Some(helper) = self.credentials.clone() {
let query = self.cred_query();
if let Ok(Some(c)) = tokio::task::spawn_blocking(move || helper.fill(&query))
.await
.unwrap_or(Ok(None))
{
*self.auth.lock().unwrap() = Auth::Basic {
username: c.username.clone(),
password: c.password.clone(),
};
*self.filled.lock().unwrap() = Some((self.cred_query(), c));
}
}
let resp = build().send().await?;
if resp.status().is_success() {
self.approve_filled().await;
return Ok(resp);
}
if resp.status().as_u16() != 401 {
return Ok(resp);
}
let Some(helper) = self.credentials.clone() else {
return Ok(resp);
};
let query = self.cred_query();
self.reject_filled().await;
let cred_url_str = self.cred_url_string();
let creds = match fill_for_endpoint(helper.clone(), query.clone(), &cred_url_str).await? {
Some(c) => c,
None => {
return Err(ApiError::CredentialsNotFound {
url: cred_url_str,
detail: None,
});
}
};
{
let mut auth = self.auth.lock().unwrap();
*auth = Auth::Basic {
username: creds.username.clone(),
password: creds.password.clone(),
};
}
{
let mut filled = self.filled.lock().unwrap();
*filled = Some((query.clone(), creds.clone()));
}
let resp2 = build().send().await?;
if resp2.status().is_success() {
approve_blocking(helper, query, creds).await?;
} else if matches!(resp2.status().as_u16(), 401 | 403) {
reject_blocking(helper, query, creds).await?;
*self.filled.lock().unwrap() = None;
*self.auth.lock().unwrap() = Auth::None;
}
Ok(resp2)
}
async fn send_with_auth_retry<F, R>(&self, build: F) -> Result<R, ApiError>
where
F: Fn() -> RequestBuilder,
R: DeserializeOwned,
{
let resp = self.send_with_auth_retry_response(build).await?;
decode::<R>(resp).await
}
async fn approve_filled(&self) {
let snapshot = self.filled.lock().unwrap().clone();
if let (Some(helper), Some((q, c))) = (self.credentials.clone(), snapshot) {
let _ = approve_blocking(helper, q, c).await;
}
}
async fn reject_filled(&self) {
let snapshot = self.filled.lock().unwrap().take();
if let (Some(helper), Some((q, c))) = (self.credentials.clone(), snapshot) {
let _ = reject_blocking(helper, q, c).await;
*self.auth.lock().unwrap() = Auth::None;
}
}
}
fn collapse_slashes(path: &str) -> String {
let mut out = String::with_capacity(path.len());
let mut last_was_slash = false;
for c in path.chars() {
if c == '/' {
if !last_was_slash {
out.push('/');
}
last_was_slash = true;
} else {
out.push(c);
last_was_slash = false;
}
}
out
}
pub(crate) async fn decode<R: DeserializeOwned>(resp: Response) -> Result<R, ApiError> {
let status = resp.status();
if status.is_success() {
let bytes = resp.bytes().await?;
return serde_json::from_slice(&bytes).map_err(|e| ApiError::Decode(e.to_string()));
}
let lfs_authenticate = resp
.headers()
.get("LFS-Authenticate")
.and_then(|v| v.to_str().ok())
.map(str::to_owned);
let retry_after = resp
.headers()
.get(reqwest::header::RETRY_AFTER)
.and_then(|v| v.to_str().ok())
.and_then(crate::error::parse_retry_after);
let request_url = resp.url().to_string();
let bytes = resp.bytes().await.unwrap_or_default();
Err(ApiError::Status {
status: status.as_u16(),
url: Some(request_url),
lfs_authenticate,
body: serde_json::from_slice(&bytes).ok(),
retry_after,
})
}
async fn fill_for_endpoint(
helper: Arc<dyn Helper>,
query: Query,
endpoint: &str,
) -> Result<Option<Credentials>, ApiError> {
let endpoint_str = endpoint.to_owned();
tokio::task::spawn_blocking(move || helper.fill(&query))
.await
.map_err(|e| ApiError::Decode(format!("credential helper join: {e}")))?
.map_err(|e| ApiError::CredentialsNotFound {
url: endpoint_str,
detail: Some(e.to_string()),
})
}
async fn approve_blocking(
helper: Arc<dyn Helper>,
query: Query,
creds: Credentials,
) -> Result<(), ApiError> {
tokio::task::spawn_blocking(move || helper.approve(&query, &creds))
.await
.map_err(|e| ApiError::Decode(format!("credential helper join: {e}")))?
.map_err(|e| ApiError::Decode(format!("credential helper approve: {e}")))
}
async fn reject_blocking(
helper: Arc<dyn Helper>,
query: Query,
creds: Credentials,
) -> Result<(), ApiError> {
tokio::task::spawn_blocking(move || helper.reject(&query, &creds))
.await
.map_err(|e| ApiError::Decode(format!("credential helper join: {e}")))?
.map_err(|e| ApiError::Decode(format!("credential helper reject: {e}")))
}