1use std::collections::HashMap;
41use std::io::Write;
42use std::process::{Command, Stdio};
43use std::sync::Mutex;
44use std::time::{Duration, SystemTime};
45
46use serde::Deserialize;
47
48use crate::trace::trace_enabled;
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
54pub enum SshOperation {
55 Upload,
57 Download,
59}
60
61impl SshOperation {
62 fn as_str(self) -> &'static str {
63 match self {
64 Self::Upload => "upload",
65 Self::Download => "download",
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct SshAuth {
73 pub href: String,
76 pub header: HashMap<String, String>,
78 pub expires_at: Option<SystemTime>,
81}
82
83#[derive(Debug, thiserror::Error)]
85pub enum SshAuthError {
86 #[error("io error invoking ssh: {0}")]
88 Io(#[from] std::io::Error),
89 #[error("ssh git-lfs-authenticate failed: {0}")]
91 Failed(String),
92 #[error("ssh git-lfs-authenticate returned malformed JSON: {0}")]
94 Json(String),
95}
96
97#[derive(Debug)]
104pub struct SshAuthClient {
105 program: String,
106 cache: Mutex<HashMap<CacheKey, SshAuth>>,
107}
108
109#[derive(Debug, Clone, PartialEq, Eq, Hash)]
110struct CacheKey {
111 user_and_host: String,
112 port: String,
113 path: String,
114 operation: SshOperation,
115}
116
117#[derive(Debug, Default, Deserialize)]
118struct WireResponse {
119 #[serde(default)]
120 href: String,
121 #[serde(default)]
127 header: Option<HashMap<String, String>>,
128 #[serde(default)]
129 expires_at: Option<String>,
130 #[serde(default)]
131 expires_in: Option<i64>,
132}
133
134impl SshAuthClient {
135 pub fn new(program: impl Into<String>) -> Self {
141 Self {
142 program: program.into(),
143 cache: Mutex::new(HashMap::new()),
144 }
145 }
146
147 pub fn resolve(
151 &self,
152 user_and_host: &str,
153 port: Option<&str>,
154 path: &str,
155 operation: SshOperation,
156 ) -> Result<SshAuth, SshAuthError> {
157 let key = CacheKey {
158 user_and_host: user_and_host.to_owned(),
159 port: port.unwrap_or("").to_owned(),
160 path: path.to_owned(),
161 operation,
162 };
163
164 let cached = self.cache.lock().unwrap().get(&key).cloned();
167 if let Some(c) = cached {
168 if !is_expired_within(c.expires_at, Duration::from_secs(5)) {
169 trace(format_args!(
170 "ssh cache: {user_and_host} git-lfs-authenticate {path} {}",
171 operation.as_str()
172 ));
173 return Ok(c);
174 }
175 trace(format_args!(
176 "ssh cache expired: {user_and_host} git-lfs-authenticate {path} {}",
177 operation.as_str()
178 ));
179 }
180
181 let resolved = self.spawn(user_and_host, port, path, operation)?;
182 self.cache.lock().unwrap().insert(key, resolved.clone());
183 Ok(resolved)
184 }
185
186 fn spawn(
187 &self,
188 user_and_host: &str,
189 port: Option<&str>,
190 path: &str,
191 operation: SshOperation,
192 ) -> Result<SshAuth, SshAuthError> {
193 let mut parts = self.program.split_whitespace();
201 let prog = parts
202 .next()
203 .ok_or_else(|| SshAuthError::Failed("ssh program is empty".into()))?;
204 let mut argv: Vec<String> = parts.map(str::to_owned).collect();
205 if let Some(p) = port {
206 argv.push("-p".to_owned());
207 argv.push(p.to_owned());
208 }
209 argv.push(user_and_host.to_owned());
210 argv.push(format!(
211 "git-lfs-authenticate {path} {}",
212 operation.as_str()
213 ));
214
215 if trace_enabled() {
218 let mut e = std::io::stderr().lock();
219 let _ = write!(e, "exec: {prog}");
220 for a in &argv {
221 let _ = write!(e, " {a}");
222 }
223 let _ = writeln!(e);
224 }
225
226 let now = SystemTime::now();
227 let out = Command::new(prog)
228 .args(&argv)
229 .stdin(Stdio::null())
230 .stdout(Stdio::piped())
231 .stderr(Stdio::piped())
232 .output()?;
233 if !out.status.success() {
234 let stderr = String::from_utf8_lossy(&out.stderr).trim().to_owned();
235 return Err(SshAuthError::Failed(if stderr.is_empty() {
236 format!("ssh {prog:?} exited {}", out.status)
237 } else {
238 stderr
239 }));
240 }
241
242 let wire: WireResponse =
243 serde_json::from_slice(&out.stdout).map_err(|e| SshAuthError::Json(e.to_string()))?;
244
245 Ok(SshAuth {
246 href: wire.href,
247 header: wire.header.unwrap_or_default(),
248 expires_at: compute_expires_at(now, wire.expires_at.as_deref(), wire.expires_in),
249 })
250 }
251}
252
253fn compute_expires_at(
257 now: SystemTime,
258 expires_at: Option<&str>,
259 expires_in: Option<i64>,
260) -> Option<SystemTime> {
261 let mut earliest: Option<SystemTime> = None;
262 if let Some(s) = expires_at
263 && !s.is_empty()
264 && let Some(t) = parse_rfc3339(s)
265 {
266 earliest = Some(t);
267 }
268 if let Some(secs) = expires_in {
269 let t = if secs >= 0 {
270 now.checked_add(Duration::from_secs(secs as u64))
271 } else {
272 now.checked_sub(Duration::from_secs(secs.unsigned_abs()))
273 };
274 if let Some(t) = t {
275 earliest = Some(match earliest {
276 Some(e) => e.min(t),
277 None => t,
278 });
279 }
280 }
281 earliest
282}
283
284fn is_expired_within(expires_at: Option<SystemTime>, buffer: Duration) -> bool {
285 let Some(t) = expires_at else { return false };
286 let now = SystemTime::now();
287 match t.duration_since(now) {
288 Ok(remaining) => remaining < buffer,
289 Err(_) => true,
290 }
291}
292
293fn parse_rfc3339(s: &str) -> Option<SystemTime> {
303 let bytes = s.as_bytes();
304 if bytes.len() < 20 {
305 return None;
306 }
307 if bytes[4] != b'-'
308 || bytes[7] != b'-'
309 || bytes[10] != b'T'
310 || bytes[13] != b':'
311 || bytes[16] != b':'
312 {
313 return None;
314 }
315 let year: i32 = s.get(0..4)?.parse().ok()?;
316 let month: u32 = s.get(5..7)?.parse().ok()?;
317 let day: u32 = s.get(8..10)?.parse().ok()?;
318 let hour: u32 = s.get(11..13)?.parse().ok()?;
319 let min: u32 = s.get(14..16)?.parse().ok()?;
320 let sec: u32 = s.get(17..19)?.parse().ok()?;
321
322 let mut idx = 19;
323 if bytes.get(idx) == Some(&b'.') {
324 idx += 1;
325 while bytes.get(idx).is_some_and(|b| b.is_ascii_digit()) {
326 idx += 1;
327 }
328 }
329 let tz_secs: i64 = match bytes.get(idx) {
330 Some(b'Z') | Some(b'z') => 0,
331 Some(b'+') | Some(b'-') => {
332 let sign = if bytes[idx] == b'+' { 1 } else { -1 };
333 let h: i64 = s.get(idx + 1..idx + 3)?.parse().ok()?;
334 let m: i64 = s.get(idx + 4..idx + 6)?.parse().ok()?;
335 sign * (h * 3600 + m * 60)
336 }
337 _ => return None,
338 };
339
340 let days = days_from_civil(year, month, day);
341 let secs_of_day = (hour as i64) * 3600 + (min as i64) * 60 + (sec as i64);
342 let unix = days * 86400 + secs_of_day - tz_secs;
343 if unix < 0 {
344 return None;
346 }
347 Some(SystemTime::UNIX_EPOCH + Duration::from_secs(unix as u64))
348}
349
350fn days_from_civil(year: i32, month: u32, day: u32) -> i64 {
353 let y = (if month <= 2 { year - 1 } else { year }) as i64;
354 let era = (if y >= 0 { y } else { y - 399 }) / 400;
355 let yoe = y - era * 400;
356 let m = month as i64;
357 let doy = (153 * (if m > 2 { m - 3 } else { m + 9 }) + 2) / 5 + day as i64 - 1;
358 let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
359 era * 146097 + doe - 719468
360}
361
362fn trace(args: std::fmt::Arguments) {
363 if !trace_enabled() {
364 return;
365 }
366 let mut e = std::io::stderr().lock();
367 let _ = writeln!(e, "{args}");
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373
374 #[test]
375 fn parse_rfc3339_z() {
376 let t = parse_rfc3339("2026-05-04T12:34:56Z").unwrap();
377 let unix = t.duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs();
378 assert_eq!(unix, 1777898096);
380 }
381
382 #[test]
383 fn parse_rfc3339_with_fraction() {
384 let a = parse_rfc3339("2026-05-04T12:34:56.789Z").unwrap();
385 let b = parse_rfc3339("2026-05-04T12:34:56Z").unwrap();
386 assert_eq!(a, b);
387 }
388
389 #[test]
390 fn parse_rfc3339_offset() {
391 let plus = parse_rfc3339("2026-05-04T14:34:56+02:00").unwrap();
392 let utc = parse_rfc3339("2026-05-04T12:34:56Z").unwrap();
393 assert_eq!(plus, utc);
394 }
395
396 #[test]
397 fn parse_rfc3339_zero_value_is_unset() {
398 assert_eq!(parse_rfc3339("0001-01-01T00:00:00Z"), None);
403 }
404
405 #[test]
406 fn parse_rfc3339_rejects_garbage() {
407 assert!(parse_rfc3339("").is_none());
408 assert!(parse_rfc3339("not a timestamp").is_none());
409 assert!(parse_rfc3339("2026-13-99T00:00:00Z").is_some()); }
411
412 #[test]
413 fn compute_expires_at_picks_earliest() {
414 let now = SystemTime::UNIX_EPOCH + Duration::from_secs(1_000_000);
415 let in_60 = Some(60);
417 let at_30 = Some("1970-01-12T13:46:40Z"); let _ = at_30;
421 let at = SystemTime::UNIX_EPOCH + Duration::from_secs(1_000_030);
422 let at_str = format_unix_for_test(at);
423 let combined = compute_expires_at(now, Some(&at_str), in_60).unwrap();
424 assert_eq!(combined, at);
425 }
426
427 #[test]
428 fn compute_expires_at_handles_negative_in() {
429 let now = SystemTime::UNIX_EPOCH + Duration::from_secs(100);
431 let result = compute_expires_at(now, None, Some(-5)).unwrap();
432 assert_eq!(result, SystemTime::UNIX_EPOCH + Duration::from_secs(95));
433 }
434
435 #[test]
436 fn compute_expires_at_returns_none_when_unset() {
437 let now = SystemTime::UNIX_EPOCH;
438 assert!(compute_expires_at(now, None, None).is_none());
439 assert!(compute_expires_at(now, Some(""), None).is_none());
440 }
441
442 #[test]
443 fn is_expired_within_buffer() {
444 let now = SystemTime::now();
445 assert!(!is_expired_within(
447 Some(now + Duration::from_secs(10)),
448 Duration::from_secs(5),
449 ));
450 assert!(is_expired_within(
452 Some(now + Duration::from_secs(2)),
453 Duration::from_secs(5),
454 ));
455 assert!(is_expired_within(
457 Some(now - Duration::from_secs(1)),
458 Duration::from_secs(5),
459 ));
460 assert!(!is_expired_within(None, Duration::from_secs(5)));
462 }
463
464 #[test]
465 fn fill_invokes_ssh_and_parses_response() {
466 let tmp = tempfile::TempDir::new().unwrap();
469 let prog = tmp.path().join("fakessh");
470 std::fs::write(
471 &prog,
472 "#!/bin/sh\n\
473 cat <<'EOF'\n\
474 {\"href\":\"https://lfs.example/repo.git/info/lfs\",\
475 \"header\":{\"Authorization\":\"Bearer abc\"},\
476 \"expires_in\":3600}\n\
477 EOF\n",
478 )
479 .unwrap();
480 #[cfg(unix)]
481 {
482 use std::os::unix::fs::PermissionsExt;
483 let mut perms = std::fs::metadata(&prog).unwrap().permissions();
484 perms.set_mode(0o755);
485 std::fs::set_permissions(&prog, perms).unwrap();
486 }
487
488 let client = SshAuthClient::new(prog.to_string_lossy().into_owned());
489 let auth = client
490 .resolve("git@host", None, "/repo", SshOperation::Upload)
491 .unwrap();
492 assert_eq!(auth.href, "https://lfs.example/repo.git/info/lfs");
493 assert_eq!(
494 auth.header.get("Authorization").map(String::as_str),
495 Some("Bearer abc")
496 );
497 assert!(auth.expires_at.is_some());
498 }
499
500 #[test]
501 fn cache_returns_same_response_within_ttl() {
502 let tmp = tempfile::TempDir::new().unwrap();
506 let counter = tmp.path().join("count");
507 let prog = tmp.path().join("fakessh");
508 std::fs::write(
509 &prog,
510 format!(
511 "#!/bin/sh\n\
512 echo invoked >> {counter}\n\
513 cat <<'EOF'\n\
514 {{\"href\":\"https://lfs.example/repo.git/info/lfs\",\"expires_in\":3600}}\n\
515 EOF\n",
516 counter = counter.display(),
517 ),
518 )
519 .unwrap();
520 #[cfg(unix)]
521 {
522 use std::os::unix::fs::PermissionsExt;
523 let mut perms = std::fs::metadata(&prog).unwrap().permissions();
524 perms.set_mode(0o755);
525 std::fs::set_permissions(&prog, perms).unwrap();
526 }
527
528 let client = SshAuthClient::new(prog.to_string_lossy().into_owned());
529 let _ = client
530 .resolve("git@host", None, "/repo", SshOperation::Upload)
531 .unwrap();
532 let _ = client
533 .resolve("git@host", None, "/repo", SshOperation::Upload)
534 .unwrap();
535 let lines = std::fs::read_to_string(&counter).unwrap();
536 assert_eq!(lines.lines().count(), 1, "expected exactly one ssh spawn");
537 }
538
539 #[test]
540 fn cache_re_resolves_when_expired() {
541 let tmp = tempfile::TempDir::new().unwrap();
544 let counter = tmp.path().join("count");
545 let prog = tmp.path().join("fakessh");
546 std::fs::write(
547 &prog,
548 format!(
549 "#!/bin/sh\n\
550 echo invoked >> {counter}\n\
551 cat <<'EOF'\n\
552 {{\"href\":\"https://lfs.example/repo.git/info/lfs\",\"expires_in\":-5}}\n\
553 EOF\n",
554 counter = counter.display(),
555 ),
556 )
557 .unwrap();
558 #[cfg(unix)]
559 {
560 use std::os::unix::fs::PermissionsExt;
561 let mut perms = std::fs::metadata(&prog).unwrap().permissions();
562 perms.set_mode(0o755);
563 std::fs::set_permissions(&prog, perms).unwrap();
564 }
565
566 let client = SshAuthClient::new(prog.to_string_lossy().into_owned());
567 let _ = client
568 .resolve("git@host", None, "/repo", SshOperation::Upload)
569 .unwrap();
570 let _ = client
571 .resolve("git@host", None, "/repo", SshOperation::Upload)
572 .unwrap();
573 let lines = std::fs::read_to_string(&counter).unwrap();
574 assert_eq!(
575 lines.lines().count(),
576 2,
577 "expected ssh to re-spawn after expiry"
578 );
579 }
580
581 #[test]
582 fn ssh_failure_surfaces_stderr() {
583 let tmp = tempfile::TempDir::new().unwrap();
584 let prog = tmp.path().join("fakessh");
585 std::fs::write(&prog, "#!/bin/sh\necho 'permission denied' >&2\nexit 255\n").unwrap();
586 #[cfg(unix)]
587 {
588 use std::os::unix::fs::PermissionsExt;
589 let mut perms = std::fs::metadata(&prog).unwrap().permissions();
590 perms.set_mode(0o755);
591 std::fs::set_permissions(&prog, perms).unwrap();
592 }
593
594 let client = SshAuthClient::new(prog.to_string_lossy().into_owned());
595 let err = client
596 .resolve("git@host", None, "/repo", SshOperation::Download)
597 .unwrap_err();
598 match err {
599 SshAuthError::Failed(msg) => assert!(msg.contains("permission denied"), "got {msg}"),
600 other => panic!("unexpected: {other}"),
601 }
602 }
603
604 fn format_unix_for_test(t: SystemTime) -> String {
608 let secs = t.duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() as i64;
609 let days = secs.div_euclid(86400);
610 let sod = secs.rem_euclid(86400);
611 let (y, m, d) = civil_from_days(days);
612 let h = sod / 3600;
613 let mi = (sod % 3600) / 60;
614 let se = sod % 60;
615 format!("{y:04}-{m:02}-{d:02}T{h:02}:{mi:02}:{se:02}Z")
616 }
617
618 fn civil_from_days(z: i64) -> (i32, u32, u32) {
620 let z = z + 719468;
621 let era = if z >= 0 { z } else { z - 146096 } / 146097;
622 let doe = (z - era * 146097) as u64; let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
624 let y = yoe as i64 + era * 400;
625 let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
626 let mp = (5 * doy + 2) / 153;
627 let d = doy - (153 * mp + 2) / 5 + 1;
628 let m = if mp < 10 { mp + 3 } else { mp - 9 };
629 let y = if m <= 2 { y + 1 } else { y };
630 (y as i32, m as u32, d as u32)
631 }
632}