1use std::sync::{Arc, Mutex};
2
3use git_lfs_creds::{Credentials, Helper, Query};
4use reqwest::header::{ACCEPT, CONTENT_TYPE, HeaderMap, HeaderValue};
5use reqwest::{Method, RequestBuilder, Response};
6use serde::Serialize;
7use serde::de::DeserializeOwned;
8use url::Url;
9
10use crate::auth::Auth;
11use crate::error::ApiError;
12
13pub(crate) const LFS_MEDIA_TYPE: &str = "application/vnd.git-lfs+json";
18
19#[derive(Clone)]
37pub struct Client {
38 pub(crate) endpoint: Url,
39 pub(crate) http: reqwest::Client,
40 pub(crate) auth: Arc<Mutex<Auth>>,
41 pub(crate) credentials: Option<Arc<dyn Helper>>,
42 pub(crate) filled: Arc<Mutex<Option<(Query, Credentials)>>>,
45}
46
47impl std::fmt::Debug for Client {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 f.debug_struct("Client")
50 .field("endpoint", &self.endpoint)
51 .field("auth", &self.auth)
52 .field("has_credential_helper", &self.credentials.is_some())
53 .finish()
54 }
55}
56
57impl Client {
58 pub fn new(endpoint: Url, auth: Auth) -> Self {
64 Self::with_http_client(endpoint, auth, reqwest::Client::new())
65 }
66
67 pub fn with_http_client(endpoint: Url, auth: Auth, http: reqwest::Client) -> Self {
70 Self {
71 endpoint,
72 http,
73 auth: Arc::new(Mutex::new(auth)),
74 credentials: None,
75 filled: Arc::new(Mutex::new(None)),
76 }
77 }
78
79 #[must_use]
83 pub fn with_credential_helper(mut self, helper: Arc<dyn Helper>) -> Self {
84 self.credentials = Some(helper);
85 self
86 }
87
88 pub(crate) fn url(&self, path: &str) -> Result<Url, ApiError> {
94 let mut base = self.endpoint.clone();
95 if !base.path().ends_with('/') {
96 let p = format!("{}/", base.path());
97 base.set_path(&p);
98 }
99 Ok(base.join(path)?)
100 }
101
102 pub(crate) fn request(&self, method: Method, url: Url) -> RequestBuilder {
104 let auth = self.auth.lock().unwrap().clone();
105 let mut headers = HeaderMap::new();
106 headers.insert(ACCEPT, HeaderValue::from_static(LFS_MEDIA_TYPE));
107 let req = self.http.request(method, url).headers(headers);
108 auth.apply(req)
109 }
110
111 fn cred_query(&self) -> Query {
115 Query::from_url(&self.endpoint).without_path()
116 }
117
118 pub(crate) async fn post_json<B, R>(&self, path: &str, body: &B) -> Result<R, ApiError>
121 where
122 B: Serialize + ?Sized,
123 R: DeserializeOwned,
124 {
125 let url = self.url(path)?;
126 let body_bytes = serde_json::to_vec(body)
127 .map_err(|e| ApiError::Decode(format!("serializing request body: {e}")))?;
128 self.send_with_auth_retry(|| {
129 self.request(Method::POST, url.clone())
130 .header(CONTENT_TYPE, LFS_MEDIA_TYPE)
131 .body(body_bytes.clone())
132 })
133 .await
134 }
135
136 pub(crate) async fn get_json<Q, R>(&self, path: &str, query: &Q) -> Result<R, ApiError>
139 where
140 Q: Serialize + ?Sized,
141 R: DeserializeOwned,
142 {
143 let url = self.url(path)?;
144 let qs = serde_urlencoded::to_string(query)
148 .map_err(|e| ApiError::Decode(format!("serializing query: {e}")))?;
149 self.send_with_auth_retry(|| {
150 let mut u = url.clone();
151 if !qs.is_empty() {
152 u.set_query(Some(&qs));
153 }
154 self.request(Method::GET, u)
155 })
156 .await
157 }
158
159 pub(crate) async fn send_with_auth_retry_response<F>(
176 &self,
177 build: F,
178 ) -> Result<Response, ApiError>
179 where
180 F: Fn() -> RequestBuilder,
181 {
182 let resp = build().send().await?;
183 if resp.status().is_success() {
184 self.approve_filled().await;
185 return Ok(resp);
186 }
187 if resp.status().as_u16() != 401 {
188 return Ok(resp);
189 }
190 let Some(helper) = self.credentials.clone() else {
192 return Ok(resp);
193 };
194 let query = self.cred_query();
195 self.reject_filled().await;
196 let creds = match fill_blocking(helper.clone(), query.clone()).await? {
197 Some(c) => c,
198 None => return Ok(resp),
199 };
200 {
201 let mut auth = self.auth.lock().unwrap();
202 *auth = Auth::Basic {
203 username: creds.username.clone(),
204 password: creds.password.clone(),
205 };
206 }
207 {
208 let mut filled = self.filled.lock().unwrap();
209 *filled = Some((query.clone(), creds.clone()));
210 }
211 let resp2 = build().send().await?;
212 if resp2.status().is_success() {
213 approve_blocking(helper, query, creds).await?;
214 } else if resp2.status().as_u16() == 401 {
215 reject_blocking(helper, query, creds).await?;
216 *self.filled.lock().unwrap() = None;
217 *self.auth.lock().unwrap() = Auth::None;
218 }
219 Ok(resp2)
220 }
221
222 async fn send_with_auth_retry<F, R>(&self, build: F) -> Result<R, ApiError>
225 where
226 F: Fn() -> RequestBuilder,
227 R: DeserializeOwned,
228 {
229 let resp = self.send_with_auth_retry_response(build).await?;
230 decode::<R>(resp).await
231 }
232
233 async fn approve_filled(&self) {
234 let snapshot = self.filled.lock().unwrap().clone();
235 if let (Some(helper), Some((q, c))) = (self.credentials.clone(), snapshot) {
236 let _ = approve_blocking(helper, q, c).await;
239 }
240 }
241
242 async fn reject_filled(&self) {
243 let snapshot = self.filled.lock().unwrap().take();
244 if let (Some(helper), Some((q, c))) = (self.credentials.clone(), snapshot) {
245 let _ = reject_blocking(helper, q, c).await;
246 *self.auth.lock().unwrap() = Auth::None;
247 }
248 }
249}
250
251pub(crate) async fn decode<R: DeserializeOwned>(resp: Response) -> Result<R, ApiError> {
253 let status = resp.status();
254 if status.is_success() {
255 let bytes = resp.bytes().await?;
256 return serde_json::from_slice(&bytes).map_err(|e| ApiError::Decode(e.to_string()));
257 }
258
259 let lfs_authenticate = resp
260 .headers()
261 .get("LFS-Authenticate")
262 .and_then(|v| v.to_str().ok())
263 .map(str::to_owned);
264 let bytes = resp.bytes().await.unwrap_or_default();
265
266 Err(ApiError::Status {
267 status: status.as_u16(),
268 lfs_authenticate,
269 body: serde_json::from_slice(&bytes).ok(),
270 })
271}
272
273async fn fill_blocking(
276 helper: Arc<dyn Helper>,
277 query: Query,
278) -> Result<Option<Credentials>, ApiError> {
279 tokio::task::spawn_blocking(move || helper.fill(&query))
280 .await
281 .map_err(|e| ApiError::Decode(format!("credential helper join: {e}")))?
282 .map_err(|e| ApiError::Decode(format!("credential helper: {e}")))
283}
284
285async fn approve_blocking(
286 helper: Arc<dyn Helper>,
287 query: Query,
288 creds: Credentials,
289) -> Result<(), ApiError> {
290 tokio::task::spawn_blocking(move || helper.approve(&query, &creds))
291 .await
292 .map_err(|e| ApiError::Decode(format!("credential helper join: {e}")))?
293 .map_err(|e| ApiError::Decode(format!("credential helper approve: {e}")))
294}
295
296async fn reject_blocking(
297 helper: Arc<dyn Helper>,
298 query: Query,
299 creds: Credentials,
300) -> Result<(), ApiError> {
301 tokio::task::spawn_blocking(move || helper.reject(&query, &creds))
302 .await
303 .map_err(|e| ApiError::Decode(format!("credential helper join: {e}")))?
304 .map_err(|e| ApiError::Decode(format!("credential helper reject: {e}")))
305}