1use std::io::Write;
2use std::sync::{Arc, Mutex};
3
4use git_lfs_creds::{Credentials, Helper, Query};
5use reqwest::header::{ACCEPT, CONTENT_TYPE, HeaderMap, HeaderValue};
6use reqwest::{Method, RequestBuilder, Response};
7use serde::Serialize;
8use serde::de::DeserializeOwned;
9use url::Url;
10
11use crate::auth::Auth;
12use crate::error::ApiError;
13
14pub(crate) const LFS_MEDIA_TYPE: &str = "application/vnd.git-lfs+json";
19
20#[derive(Clone)]
38pub struct Client {
39 pub(crate) endpoint: Url,
40 pub(crate) http: reqwest::Client,
41 pub(crate) auth: Arc<Mutex<Auth>>,
42 pub(crate) credentials: Option<Arc<dyn Helper>>,
43 pub(crate) filled: Arc<Mutex<Option<(Query, Credentials)>>>,
46}
47
48impl std::fmt::Debug for Client {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 f.debug_struct("Client")
51 .field("endpoint", &self.endpoint)
52 .field("auth", &self.auth)
53 .field("has_credential_helper", &self.credentials.is_some())
54 .finish()
55 }
56}
57
58impl Client {
59 pub fn new(endpoint: Url, auth: Auth) -> Self {
65 Self::with_http_client(endpoint, auth, reqwest::Client::new())
66 }
67
68 pub fn with_http_client(endpoint: Url, auth: Auth, http: reqwest::Client) -> Self {
71 Self {
72 endpoint,
73 http,
74 auth: Arc::new(Mutex::new(auth)),
75 credentials: None,
76 filled: Arc::new(Mutex::new(None)),
77 }
78 }
79
80 #[must_use]
84 pub fn with_credential_helper(mut self, helper: Arc<dyn Helper>) -> Self {
85 self.credentials = Some(helper);
86 self
87 }
88
89 pub fn endpoint(&self) -> &Url {
93 &self.endpoint
94 }
95
96 pub fn used_basic_auth(&self) -> bool {
101 matches!(*self.auth.lock().unwrap(), Auth::Basic { .. })
102 }
103
104 pub(crate) fn url(&self, path: &str) -> Result<Url, ApiError> {
110 let mut base = self.endpoint.clone();
111 if !base.path().ends_with('/') {
112 let p = format!("{}/", base.path());
113 base.set_path(&p);
114 }
115 Ok(base.join(path)?)
116 }
117
118 pub(crate) fn request(&self, method: Method, url: Url) -> RequestBuilder {
120 let auth = self.auth.lock().unwrap().clone();
121 let mut headers = HeaderMap::new();
122 headers.insert(ACCEPT, HeaderValue::from_static(LFS_MEDIA_TYPE));
123 let req = self.http.request(method, url).headers(headers);
124 auth.apply(req)
125 }
126
127 fn cred_query(&self) -> Query {
131 Query::from_url(&self.endpoint).without_path()
132 }
133
134 pub(crate) async fn post_json<B, R>(&self, path: &str, body: &B) -> Result<R, ApiError>
137 where
138 B: Serialize + ?Sized,
139 R: DeserializeOwned,
140 {
141 let url = self.url(path)?;
142 let body_bytes = serde_json::to_vec(body)
143 .map_err(|e| ApiError::Decode(format!("serializing request body: {e}")))?;
144 if std::env::var_os("GIT_CURL_VERBOSE").is_some_and(|v| !v.is_empty() && v != "0") {
150 let mut err = std::io::stderr().lock();
151 let _ = writeln!(err, "> POST {url}");
152 let _ = writeln!(err, "> Content-Type: {LFS_MEDIA_TYPE}");
153 let _ = writeln!(err);
154 let _ = err.write_all(&body_bytes);
155 let _ = writeln!(err);
156 }
157 self.send_with_auth_retry(|| {
158 self.request(Method::POST, url.clone())
159 .header(CONTENT_TYPE, LFS_MEDIA_TYPE)
160 .body(body_bytes.clone())
161 })
162 .await
163 }
164
165 pub(crate) async fn get_json<Q, R>(&self, path: &str, query: &Q) -> Result<R, ApiError>
168 where
169 Q: Serialize + ?Sized,
170 R: DeserializeOwned,
171 {
172 let url = self.url(path)?;
173 let qs = serde_urlencoded::to_string(query)
177 .map_err(|e| ApiError::Decode(format!("serializing query: {e}")))?;
178 self.send_with_auth_retry(|| {
179 let mut u = url.clone();
180 if !qs.is_empty() {
181 u.set_query(Some(&qs));
182 }
183 self.request(Method::GET, u)
184 })
185 .await
186 }
187
188 pub(crate) async fn send_with_auth_retry_response<F>(
205 &self,
206 build: F,
207 ) -> Result<Response, ApiError>
208 where
209 F: Fn() -> RequestBuilder,
210 {
211 let resp = build().send().await?;
212 if resp.status().is_success() {
213 self.approve_filled().await;
214 return Ok(resp);
215 }
216 if resp.status().as_u16() != 401 {
217 return Ok(resp);
218 }
219 let Some(helper) = self.credentials.clone() else {
221 return Ok(resp);
222 };
223 let query = self.cred_query();
224 self.reject_filled().await;
225 let creds = match fill_blocking(helper.clone(), query.clone()).await? {
226 Some(c) => c,
227 None => return Ok(resp),
228 };
229 {
230 let mut auth = self.auth.lock().unwrap();
231 *auth = Auth::Basic {
232 username: creds.username.clone(),
233 password: creds.password.clone(),
234 };
235 }
236 {
237 let mut filled = self.filled.lock().unwrap();
238 *filled = Some((query.clone(), creds.clone()));
239 }
240 let resp2 = build().send().await?;
241 if resp2.status().is_success() {
242 approve_blocking(helper, query, creds).await?;
243 } else if resp2.status().as_u16() == 401 {
244 reject_blocking(helper, query, creds).await?;
245 *self.filled.lock().unwrap() = None;
246 *self.auth.lock().unwrap() = Auth::None;
247 }
248 Ok(resp2)
249 }
250
251 async fn send_with_auth_retry<F, R>(&self, build: F) -> Result<R, ApiError>
254 where
255 F: Fn() -> RequestBuilder,
256 R: DeserializeOwned,
257 {
258 let resp = self.send_with_auth_retry_response(build).await?;
259 decode::<R>(resp).await
260 }
261
262 async fn approve_filled(&self) {
263 let snapshot = self.filled.lock().unwrap().clone();
264 if let (Some(helper), Some((q, c))) = (self.credentials.clone(), snapshot) {
265 let _ = approve_blocking(helper, q, c).await;
268 }
269 }
270
271 async fn reject_filled(&self) {
272 let snapshot = self.filled.lock().unwrap().take();
273 if let (Some(helper), Some((q, c))) = (self.credentials.clone(), snapshot) {
274 let _ = reject_blocking(helper, q, c).await;
275 *self.auth.lock().unwrap() = Auth::None;
276 }
277 }
278}
279
280pub(crate) async fn decode<R: DeserializeOwned>(resp: Response) -> Result<R, ApiError> {
282 let status = resp.status();
283 if status.is_success() {
284 let bytes = resp.bytes().await?;
285 return serde_json::from_slice(&bytes).map_err(|e| ApiError::Decode(e.to_string()));
286 }
287
288 let lfs_authenticate = resp
289 .headers()
290 .get("LFS-Authenticate")
291 .and_then(|v| v.to_str().ok())
292 .map(str::to_owned);
293 let bytes = resp.bytes().await.unwrap_or_default();
294
295 Err(ApiError::Status {
296 status: status.as_u16(),
297 lfs_authenticate,
298 body: serde_json::from_slice(&bytes).ok(),
299 })
300}
301
302async fn fill_blocking(
305 helper: Arc<dyn Helper>,
306 query: Query,
307) -> Result<Option<Credentials>, ApiError> {
308 tokio::task::spawn_blocking(move || helper.fill(&query))
309 .await
310 .map_err(|e| ApiError::Decode(format!("credential helper join: {e}")))?
311 .map_err(|e| ApiError::Decode(format!("credential helper: {e}")))
312}
313
314async fn approve_blocking(
315 helper: Arc<dyn Helper>,
316 query: Query,
317 creds: Credentials,
318) -> Result<(), ApiError> {
319 tokio::task::spawn_blocking(move || helper.approve(&query, &creds))
320 .await
321 .map_err(|e| ApiError::Decode(format!("credential helper join: {e}")))?
322 .map_err(|e| ApiError::Decode(format!("credential helper approve: {e}")))
323}
324
325async fn reject_blocking(
326 helper: Arc<dyn Helper>,
327 query: Query,
328 creds: Credentials,
329) -> Result<(), ApiError> {
330 tokio::task::spawn_blocking(move || helper.reject(&query, &creds))
331 .await
332 .map_err(|e| ApiError::Decode(format!("credential helper join: {e}")))?
333 .map_err(|e| ApiError::Decode(format!("credential helper reject: {e}")))
334}