#![allow(clippy::duration_suboptimal_units)]
use core::fmt;
use core::time::Duration;
use crate::hash::{FromHexError, Hash, to_hex};
use crate::refs::Ref;
pub use crate::refs::RefWriteCondition;
#[derive(Debug, thiserror::Error)]
pub enum TransportError {
#[error("pack not found on remote")]
PackNotFound,
#[error("access denied by remote")]
AccessDenied,
#[error("remote error: {0}")]
RemoteError(String),
#[error("ref CAS precondition failed")]
RefConflict,
#[error("invalid ref name: {0}")]
InvalidRef(String),
#[error("connection to remote failed")]
ConnectionFailed,
#[error("server error (status {status})")]
ServerError {
status: u16,
},
#[error("invalid response from remote")]
InvalidResponse,
#[error("protocol error")]
ProtocolError,
#[error("payload too large: {0} bytes")]
PayloadTooLarge(usize),
#[error("insecure scheme: plain http:// is allowed only for loopback hosts")]
InsecureScheme,
}
pub type TransportResult<T> = Result<T, TransportError>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct PackKey(pub [u8; 32]);
impl PackKey {
#[must_use]
pub const fn new(bytes: [u8; 32]) -> Self {
Self(bytes)
}
#[must_use]
pub const fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
#[must_use]
pub fn to_hex(&self) -> String {
to_hex(&self.0)
}
#[must_use]
pub const fn from_hash(h: Hash) -> Self {
Self(h)
}
#[must_use]
pub const fn into_hash(self) -> Hash {
self.0
}
}
impl From<Hash> for PackKey {
fn from(h: Hash) -> Self {
Self(h)
}
}
impl From<PackKey> for Hash {
fn from(k: PackKey) -> Hash {
k.0
}
}
impl fmt::Display for PackKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.to_hex())
}
}
pub fn pack_key_from_hex(s: &str) -> Result<PackKey, FromHexError> {
let h = crate::hash::from_hex(s)?;
Ok(PackKey(h))
}
#[must_use]
pub fn is_retryable(err: &TransportError) -> bool {
match err {
TransportError::ConnectionFailed => true,
TransportError::ServerError { status } => *status >= 500 || *status == 429,
_ => false,
}
}
pub const BACKOFF_MAX_ATTEMPTS: u32 = 5;
pub const BACKOFF_INITIAL: Duration = Duration::from_secs(1);
pub const BACKOFF_CAP: Duration = Duration::from_secs(300);
#[cfg(target_pointer_width = "64")]
pub const PACK_BODY_LIMIT: u64 = 4 * 1024 * 1024 * 1024;
#[cfg(not(target_pointer_width = "64"))]
pub const PACK_BODY_LIMIT: u64 = usize::MAX as u64;
#[allow(clippy::cast_possible_truncation)]
pub const PACK_BODY_LIMIT_USIZE: usize = PACK_BODY_LIMIT as usize;
const _: () = assert!(
(PACK_BODY_LIMIT_USIZE as u64) == PACK_BODY_LIMIT,
"PACK_BODY_LIMIT does not fit in usize on this target",
);
#[derive(Debug, Clone)]
pub struct BackoffIterator {
next_delay: Duration,
attempts_remaining: u32,
cap: Duration,
}
impl BackoffIterator {
#[must_use]
pub const fn new() -> Self {
Self {
next_delay: BACKOFF_INITIAL,
attempts_remaining: BACKOFF_MAX_ATTEMPTS,
cap: BACKOFF_CAP,
}
}
#[must_use]
pub const fn with(initial: Duration, cap: Duration, attempts: u32) -> Self {
Self {
next_delay: initial,
attempts_remaining: attempts,
cap,
}
}
}
impl Default for BackoffIterator {
fn default() -> Self {
Self::new()
}
}
impl Iterator for BackoffIterator {
type Item = Duration;
fn next(&mut self) -> Option<Self::Item> {
if self.attempts_remaining == 0 {
return None;
}
self.attempts_remaining -= 1;
let current = self.next_delay;
let doubled = current.saturating_mul(2);
self.next_delay = if doubled > self.cap {
self.cap
} else {
doubled
};
Some(current)
}
}
pub trait Transport: Send + Sync {
fn upload_pack(&self, bytes: &[u8], key: &PackKey) -> TransportResult<()>;
fn download_pack(&self, key: &PackKey) -> TransportResult<Vec<u8>>;
fn pack_exists(&self, key: &PackKey) -> TransportResult<bool>;
fn write_ref(&self, name: &str, hash: &Hash) -> TransportResult<()> {
self.update_ref(name, RefWriteCondition::Any, hash)
}
fn update_ref(
&self,
name: &str,
condition: RefWriteCondition,
hash: &Hash,
) -> TransportResult<()>;
fn read_ref(&self, name: &str) -> TransportResult<Option<Hash>>;
fn list_refs(&self, prefix: &str) -> TransportResult<Vec<Ref>>;
}
pub mod async_shim {
pub trait Executor: Send + Sync {
fn block_on<F, T>(&self, fut: F) -> T
where
F: core::future::Future<Output = T> + Send,
T: Send;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pack_key_hex_roundtrip() {
let bytes = [0x42u8; 32];
let pk = PackKey::new(bytes);
let hex = pk.to_hex();
assert_eq!(hex.len(), 64);
let pk2 = pack_key_from_hex(&hex).unwrap();
assert_eq!(pk, pk2);
}
#[test]
fn is_retryable_matches_spec() {
assert!(is_retryable(&TransportError::ConnectionFailed));
assert!(is_retryable(&TransportError::ServerError { status: 500 }));
assert!(is_retryable(&TransportError::ServerError { status: 503 }));
assert!(is_retryable(&TransportError::ServerError { status: 429 }));
assert!(!is_retryable(&TransportError::ServerError { status: 404 }));
assert!(!is_retryable(&TransportError::ServerError { status: 401 }));
assert!(!is_retryable(&TransportError::PackNotFound));
assert!(!is_retryable(&TransportError::AccessDenied));
assert!(!is_retryable(&TransportError::RefConflict));
}
#[test]
fn backoff_default_ladder_is_1_2_4_8_16() {
let delays: Vec<Duration> = BackoffIterator::new().collect();
assert_eq!(
delays,
vec![
Duration::from_secs(1),
Duration::from_secs(2),
Duration::from_secs(4),
Duration::from_secs(8),
Duration::from_secs(16),
]
);
}
#[test]
fn backoff_caps_at_max() {
let cap = Duration::from_secs(10);
let delays: Vec<Duration> = BackoffIterator::with(Duration::from_secs(8), cap, 5).collect();
assert_eq!(delays[0], Duration::from_secs(8));
for d in &delays[1..] {
assert!(*d <= cap);
}
}
}