1use std::collections::HashMap;
12use std::fmt::Write as _;
13use std::io::Read;
14use std::time::Duration;
15
16use serde::{Deserialize, Serialize};
17use thiserror::Error;
18
19use crate::creds::CredentialProvider;
20use crate::store::Store;
21
22const MEDIA_TYPE: &str = "application/vnd.git-lfs+json";
23const HTTP_TIMEOUT: Duration = Duration::from_mins(2);
24
25pub struct ObjectSpec {
26 pub oid: [u8; 32],
27 pub size: u64,
28}
29
30impl ObjectSpec {
31 fn oid_hex(&self) -> String {
32 oid_hex(&self.oid)
33 }
34}
35
36#[derive(Default)]
37pub struct DownloadReport {
38 pub succeeded: Vec<[u8; 32]>,
39 pub failed: Vec<(String, String)>, }
41
42#[derive(Default)]
43pub struct UploadReport {
44 pub succeeded: Vec<[u8; 32]>,
45 pub failed: Vec<(String, String)>,
46}
47
48#[derive(Debug, Error)]
49pub enum BatchError {
50 #[error("invalid remote url {0}: {1}")]
51 InvalidUrl(String, String),
52 #[error("http error: {0}")]
53 Http(String),
54 #[error("malformed batch response: {0}")]
55 MalformedResponse(String),
56 #[error("unsupported transfer adapter: {0} (maw-lfs only supports 'basic')")]
57 UnsupportedTransfer(String),
58 #[error("authentication failed for {0}")]
59 AuthFailed(String),
60 #[error("server error {status}: {body}")]
61 Server { status: u16, body: String },
62 #[error("no credentials for {0}")]
63 NoCreds(String),
64 #[error("store error: {0}")]
65 Store(#[from] crate::store::StoreError),
66}
67
68pub struct BatchClient {
69 endpoint: String, host: String,
71 http: reqwest::blocking::Client,
72 creds: CredentialProvider,
73}
74
75impl BatchClient {
76 pub fn new(remote_url: &str, creds: CredentialProvider) -> Result<Self, BatchError> {
83 let base = derive_lfs_base(remote_url)?;
84 let endpoint = format!("{base}/objects/batch");
85 let host = extract_host(&endpoint)?;
86 let http = reqwest::blocking::Client::builder()
87 .timeout(HTTP_TIMEOUT)
88 .user_agent(concat!("maw-lfs/", env!("CARGO_PKG_VERSION")))
89 .build()
90 .map_err(|e| BatchError::Http(e.to_string()))?;
91 Ok(Self {
92 endpoint,
93 host,
94 http,
95 creds,
96 })
97 }
98
99 pub fn download(
105 &mut self,
106 objects: &[ObjectSpec],
107 store: &Store,
108 ) -> Result<DownloadReport, BatchError> {
109 if objects.is_empty() {
110 return Ok(DownloadReport::default());
111 }
112 let resp = self.batch("download", objects)?;
113 let mut report = DownloadReport::default();
114 for obj in resp.objects {
115 let Ok(oid_bytes) = hex_to_oid(&obj.oid) else {
116 report
117 .failed
118 .push((obj.oid.clone(), "bad oid hex".to_owned()));
119 continue;
120 };
121 if let Some(err) = obj.error {
122 report.failed.push((obj.oid, err.message));
123 continue;
124 }
125 let Some(actions) = obj.actions else {
126 report
129 .failed
130 .push((obj.oid, "server returned no download action".to_owned()));
131 continue;
132 };
133 let Some(dl) = actions.download else {
134 report
135 .failed
136 .push((obj.oid, "no download action".to_owned()));
137 continue;
138 };
139 match self.fetch_and_store(&dl, &oid_bytes, obj.size, store) {
140 Ok(()) => report.succeeded.push(oid_bytes),
141 Err(e) => report.failed.push((obj.oid, e.to_string())),
142 }
143 }
144 Ok(report)
145 }
146
147 pub fn upload(
153 &mut self,
154 objects: &[ObjectSpec],
155 store: &Store,
156 ) -> Result<UploadReport, BatchError> {
157 if objects.is_empty() {
158 return Ok(UploadReport::default());
159 }
160 let resp = self.batch("upload", objects)?;
161 let mut report = UploadReport::default();
162 for obj in resp.objects {
163 let Ok(oid_bytes) = hex_to_oid(&obj.oid) else {
164 report
165 .failed
166 .push((obj.oid.clone(), "bad oid hex".to_owned()));
167 continue;
168 };
169 if let Some(err) = obj.error {
170 report.failed.push((obj.oid, err.message));
171 continue;
172 }
173 let Some(actions) = obj.actions else {
174 report.succeeded.push(oid_bytes);
176 continue;
177 };
178 let Some(up) = actions.upload else {
179 report.succeeded.push(oid_bytes);
181 continue;
182 };
183 match self.put_and_verify(&up, actions.verify.as_ref(), &oid_bytes, obj.size, store) {
184 Ok(()) => report.succeeded.push(oid_bytes),
185 Err(e) => report.failed.push((obj.oid, e.to_string())),
186 }
187 }
188 Ok(report)
189 }
190
191 fn batch(
192 &mut self,
193 operation: &str,
194 objects: &[ObjectSpec],
195 ) -> Result<BatchResponse, BatchError> {
196 let body = BatchRequest {
197 operation: operation.to_owned(),
198 transfers: vec!["basic".to_owned()],
199 hash_algo: "sha256".to_owned(),
200 objects: objects
201 .iter()
202 .map(|o| BatchObjectReq {
203 oid: o.oid_hex(),
204 size: o.size,
205 })
206 .collect(),
207 };
208
209 for attempt in 0..2 {
211 let creds = self
212 .creds
213 .get(&self.host)
214 .map_err(|_| BatchError::NoCreds(self.host.clone()))?;
215 let resp = self
216 .http
217 .post(&self.endpoint)
218 .header("Accept", MEDIA_TYPE)
219 .header("Content-Type", MEDIA_TYPE)
220 .basic_auth(&creds.username, Some(&creds.password))
221 .json(&body)
222 .send()
223 .map_err(|e| BatchError::Http(e.to_string()))?;
224 let status = resp.status();
225 if status.as_u16() == 401 || status.as_u16() == 403 {
226 if attempt == 0 {
227 self.creds.reject(&self.host);
228 continue;
229 }
230 return Err(BatchError::AuthFailed(self.host.clone()));
231 }
232 if !status.is_success() {
233 let body = resp.text().unwrap_or_default();
234 return Err(BatchError::Server {
235 status: status.as_u16(),
236 body,
237 });
238 }
239 let parsed: BatchResponse = resp
240 .json()
241 .map_err(|e| BatchError::MalformedResponse(e.to_string()))?;
242 if parsed.transfer.as_deref().unwrap_or("basic") != "basic" {
243 return Err(BatchError::UnsupportedTransfer(
244 parsed.transfer.unwrap_or_default(),
245 ));
246 }
247 return Ok(parsed);
248 }
249 unreachable!("loop always returns")
250 }
251
252 fn fetch_and_store(
253 &self,
254 action: &ActionLink,
255 oid: &[u8; 32],
256 size: u64,
257 store: &Store,
258 ) -> Result<(), BatchError> {
259 let mut req = self.http.get(&action.href);
260 for (k, v) in action.header.iter().flatten() {
261 req = req.header(k, v);
262 }
263 let resp = req.send().map_err(|e| BatchError::Http(e.to_string()))?;
264 if !resp.status().is_success() {
265 return Err(BatchError::Server {
266 status: resp.status().as_u16(),
267 body: format!("GET {}", action.href),
268 });
269 }
270 let reader = resp;
271 store.insert_from_stream(oid, size, reader)?;
272 Ok(())
273 }
274
275 fn put_and_verify(
276 &self,
277 upload: &ActionLink,
278 verify: Option<&ActionLink>,
279 oid: &[u8; 32],
280 size: u64,
281 store: &Store,
282 ) -> Result<(), BatchError> {
283 let reader = store
284 .open_object(oid)?
285 .ok_or_else(|| BatchError::Http("object missing from local store".to_string()))?;
286 let mut req = self.http.put(&upload.href);
287 for (k, v) in upload.header.iter().flatten() {
288 req = req.header(k, v);
289 }
290 let body = reqwest::blocking::Body::sized(ReaderBody(reader), size);
291 let resp = req
292 .body(body)
293 .send()
294 .map_err(|e| BatchError::Http(e.to_string()))?;
295 if !resp.status().is_success() {
296 return Err(BatchError::Server {
297 status: resp.status().as_u16(),
298 body: format!("PUT {}", upload.href),
299 });
300 }
301 if let Some(v) = verify {
302 let mut vreq = self
303 .http
304 .post(&v.href)
305 .header("Accept", MEDIA_TYPE)
306 .header("Content-Type", MEDIA_TYPE);
307 for (k, val) in v.header.iter().flatten() {
308 vreq = vreq.header(k, val);
309 }
310 let oid_hex = oid_hex(oid);
311 let vresp = vreq
312 .json(&VerifyBody { oid: oid_hex, size })
313 .send()
314 .map_err(|e| BatchError::Http(e.to_string()))?;
315 if !vresp.status().is_success() {
316 return Err(BatchError::Server {
317 status: vresp.status().as_u16(),
318 body: format!("verify {}", v.href),
319 });
320 }
321 }
322 Ok(())
323 }
324}
325
326struct ReaderBody(Box<dyn Read + Send>);
327
328impl Read for ReaderBody {
329 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
330 self.0.read(buf)
331 }
332}
333
334fn derive_lfs_base(remote_url: &str) -> Result<String, BatchError> {
337 let trimmed = remote_url.trim_end_matches('/');
341 if !(trimmed.starts_with("https://") || trimmed.starts_with("http://")) {
343 return Err(BatchError::InvalidUrl(
344 remote_url.to_owned(),
345 "only http(s):// remotes supported".to_owned(),
346 ));
347 }
348 Ok(format!("{trimmed}/info/lfs"))
349}
350
351fn extract_host(url: &str) -> Result<String, BatchError> {
352 let without_scheme = url
354 .split_once("://")
355 .map(|(_, r)| r)
356 .ok_or_else(|| BatchError::InvalidUrl(url.to_owned(), "no scheme".to_owned()))?;
357 let host = without_scheme.split('/').next().unwrap_or("");
358 let host = host.split(':').next().unwrap_or(host);
360 Ok(host.to_owned())
361}
362
363fn hex_to_oid(hex: &str) -> Result<[u8; 32], ()> {
364 if hex.len() != 64 {
365 return Err(());
366 }
367 let mut out = [0u8; 32];
368 for i in 0..32 {
369 out[i] = u8::from_str_radix(&hex[i * 2..i * 2 + 2], 16).map_err(|_| ())?;
370 }
371 Ok(out)
372}
373
374fn oid_hex(oid: &[u8; 32]) -> String {
375 let mut out = String::with_capacity(64);
376 for b in oid {
377 write!(&mut out, "{b:02x}").expect("writing to a String cannot fail");
378 }
379 out
380}
381
382#[derive(Serialize)]
385struct BatchRequest {
386 operation: String,
387 transfers: Vec<String>,
388 #[serde(rename = "hash_algo")]
389 hash_algo: String,
390 objects: Vec<BatchObjectReq>,
391}
392
393#[derive(Serialize)]
394struct BatchObjectReq {
395 oid: String,
396 size: u64,
397}
398
399#[derive(Deserialize)]
400struct BatchResponse {
401 #[serde(default)]
402 transfer: Option<String>,
403 objects: Vec<BatchObjectResp>,
404}
405
406#[derive(Deserialize)]
407struct BatchObjectResp {
408 oid: String,
409 size: u64,
410 #[serde(default)]
411 actions: Option<Actions>,
412 #[serde(default)]
413 error: Option<ObjectError>,
414}
415
416#[derive(Deserialize)]
417struct Actions {
418 #[serde(default)]
419 download: Option<ActionLink>,
420 #[serde(default)]
421 upload: Option<ActionLink>,
422 #[serde(default)]
423 verify: Option<ActionLink>,
424}
425
426#[derive(Deserialize)]
427struct ActionLink {
428 href: String,
429 #[serde(default)]
430 header: Option<HashMap<String, String>>,
431 #[serde(default)]
432 #[allow(dead_code)]
433 expires_at: Option<String>,
434}
435
436#[derive(Deserialize)]
437struct ObjectError {
438 #[allow(dead_code)]
439 code: i64,
440 message: String,
441}
442
443#[derive(Serialize)]
444struct VerifyBody {
445 oid: String,
446 size: u64,
447}
448
449#[cfg(test)]
450mod tests {
451 use super::*;
452
453 #[test]
454 fn derive_lfs_base_https() {
455 assert_eq!(
456 derive_lfs_base("https://github.com/bob/repo.git").expect("operation should succeed"),
457 "https://github.com/bob/repo.git/info/lfs"
458 );
459 }
460
461 #[test]
462 fn derive_lfs_base_trailing_slash() {
463 assert_eq!(
464 derive_lfs_base("https://example.com/repo/").expect("operation should succeed"),
465 "https://example.com/repo/info/lfs"
466 );
467 }
468
469 #[test]
470 fn derive_lfs_base_rejects_ssh() {
471 assert!(derive_lfs_base("git@github.com:bob/repo.git").is_err());
472 assert!(derive_lfs_base("ssh://github.com/bob/repo.git").is_err());
473 }
474
475 #[test]
476 fn extract_host_parses_port() {
477 assert_eq!(
478 extract_host("https://git.example.com:8443/x").expect("operation should succeed"),
479 "git.example.com"
480 );
481 assert_eq!(
482 extract_host("https://github.com/x/y.git").expect("operation should succeed"),
483 "github.com"
484 );
485 }
486
487 #[test]
488 fn hex_to_oid_round_trip() {
489 let hex = "4d7a214614ab2935c943f9e0ff69d22eadbb8f32b1258daaa5e2ca24d17e2393";
490 let oid = hex_to_oid(hex).expect("operation should succeed");
491 let back = oid_hex(&oid);
492 assert_eq!(back, hex);
493 }
494
495 #[test]
496 fn hex_to_oid_rejects_bad_length() {
497 assert!(hex_to_oid("deadbeef").is_err());
498 }
499
500 #[test]
501 fn batch_request_body_shape() {
502 let body = BatchRequest {
503 operation: "download".to_owned(),
504 transfers: vec!["basic".to_owned()],
505 hash_algo: "sha256".to_owned(),
506 objects: vec![BatchObjectReq {
507 oid: "abc".to_owned(),
508 size: 12,
509 }],
510 };
511 let json = serde_json::to_value(&body).expect("operation should succeed");
512 assert_eq!(json["operation"], "download");
513 assert_eq!(json["transfers"][0], "basic");
514 assert_eq!(json["hash_algo"], "sha256");
515 assert_eq!(json["objects"][0]["oid"], "abc");
516 assert_eq!(json["objects"][0]["size"], 12);
517 }
518
519 #[test]
520 fn batch_response_parses() {
521 let body = r#"{
522 "transfer": "basic",
523 "objects": [
524 {
525 "oid": "deadbeef",
526 "size": 10,
527 "actions": {
528 "download": {
529 "href": "https://cdn.example/file",
530 "header": {"Authorization": "Bearer xyz"}
531 }
532 }
533 },
534 {
535 "oid": "cafebabe",
536 "size": 0,
537 "error": { "code": 404, "message": "not found" }
538 }
539 ]
540 }"#;
541 let parsed: BatchResponse = serde_json::from_str(body).expect("operation should succeed");
542 assert_eq!(parsed.transfer.as_deref(), Some("basic"));
543 assert_eq!(parsed.objects.len(), 2);
544 assert_eq!(parsed.objects[0].oid, "deadbeef");
545 assert!(parsed.objects[0].actions.is_some());
546 assert!(parsed.objects[0].error.is_none());
547 assert!(parsed.objects[1].error.is_some());
548 assert_eq!(
549 parsed.objects[1]
550 .error
551 .as_ref()
552 .expect("operation should succeed")
553 .message,
554 "not found"
555 );
556 }
557
558 #[test]
559 fn client_construction() {
560 let creds = CredentialProvider::empty();
561 let client = BatchClient::new("https://github.com/bob/repo.git", creds)
562 .expect("operation should succeed");
563 assert!(client.endpoint.ends_with("/info/lfs/objects/batch"));
564 assert_eq!(client.host, "github.com");
565 }
566}