xet-client 1.5.2

Client library for communicating with Hugging Face Xet storage servers. Use through the hf-xet crate.
Documentation
use core::fmt;
use std::cmp::min;
use std::collections::{HashMap, HashSet};
use std::marker::PhantomData;
use std::str::FromStr;

use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr};
use thiserror::Error;
use xet_core_structures::merklehash::MerkleHash;

mod key;
pub use key::*;

/// Indicates a "session id" that clients can use to group together related requests
/// (e.g. all requests made to CAS to support a user-triggered upload (xorbs + shards)).
pub const SESSION_ID_HEADER: &str = "X-Xet-Session-Id";
/// Request id generated by CAS for a request.
pub const REQUEST_ID_HEADER: &str = "X-Request-Id";

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct UploadXorbResponse {
    pub was_inserted: bool,
}

/// These types are defined to help differentiate the Range<,> type aliases,
/// so that they don't silently cast to each other without range adjustments.
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Default, Hash, Copy)]
pub struct _C;
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Default, Hash, Copy)]
pub struct _F;
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Default, Hash, Copy)]
pub struct _H;

/// Start and exclusive-end range for chunk content
pub type ChunkRange = Range<u32, _C>;
/// Start and exclusive-end range for file content
pub type FileRange = Range<u64, _F>;
/// Start and inclusive-end range for HTTP range content
pub type HttpRange = Range<u64, _H>;

impl FileRange {
    pub fn full() -> Self {
        Self::new(0, u64::MAX)
    }

    // consumes self and split the range into a segment of size `segment_size`
    // and a remainder.
    pub fn take_segment(self, segment_size: u64) -> (Self, Option<Self>) {
        let segment = FileRange {
            start: self.start,
            end: min(self.end, self.start + segment_size),
            _marker: PhantomData,
        };

        let remainder = if segment.end == self.end {
            None
        } else {
            Some(FileRange {
                start: segment.end,
                end: self.end,
                _marker: PhantomData,
            })
        };

        (segment, remainder)
    }

    pub fn length(&self) -> u64 {
        self.end - self.start
    }
}

impl From<HttpRange> for FileRange {
    fn from(value: HttpRange) -> Self {
        // right inclusive to right exclusive
        FileRange::new(value.start, value.end + 1)
    }
}

impl HttpRange {
    pub fn range_header(&self) -> String {
        format!("bytes={self}")
    }

    pub fn length(&self) -> u64 {
        self.end - self.start + 1
    }
}

impl From<FileRange> for HttpRange {
    fn from(value: FileRange) -> Self {
        // right exclusive to right inclusive
        HttpRange::new(value.start, value.end - 1)
    }
}

// note that the standard PartialOrd/Ord impls will first check `start` then `end`
#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, PartialOrd, Ord, Default, Hash)]
pub struct Range<Idx, Kind> {
    pub start: Idx,
    pub end: Idx,
    #[serde(skip)]
    pub _marker: PhantomData<Kind>,
}

impl<Idx, _C> fmt::Debug for Range<Idx, _C>
where
    Idx: fmt::Debug,
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("Range")
            .field("start", &self.start)
            .field("end", &self.end)
            .finish()
    }
}

impl<Idx, Kind> Range<Idx, Kind> {
    pub fn new(start: Idx, end: Idx) -> Self {
        Self {
            start,
            end,
            _marker: PhantomData,
        }
    }
}

impl<T: Copy, Kind: Copy> Copy for Range<T, Kind> {}

impl<Idx: fmt::Display, Kind> fmt::Display for Range<Idx, Kind> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{}-{}", self.start, self.end)
    }
}

#[derive(Error, Debug)]
pub enum RangeParseError<Idx: std::str::FromStr> {
    #[error("Invalid format, expect [start]-[end]")]
    InvalidFormat,
    #[error("Incorrect number: {0}")]
    ParseError(Idx::Err),
}

impl<Idx: FromStr, Kind> TryFrom<&str> for Range<Idx, Kind> {
    type Error = RangeParseError<Idx>;

    fn try_from(value: &str) -> Result<Self, Self::Error> {
        let parts: Vec<&str> = value.splitn(2, '-').collect();

        if parts.len() != 2 {
            return Err(RangeParseError::InvalidFormat);
        }

        let start = parts[0].parse::<Idx>().map_err(RangeParseError::ParseError)?;
        let end = parts[1].parse::<Idx>().map_err(RangeParseError::ParseError)?;

        Ok(Range {
            start,
            end,
            _marker: PhantomData,
        })
    }
}

impl<Idx: FromStr, Kind> FromStr for Range<Idx, Kind> {
    type Err = RangeParseError<Idx>;

    fn from_str(value: &str) -> Result<Self, Self::Err> {
        Self::try_from(value)
    }
}

/// Describes a portion of a reconstructed file, namely the xorb and
/// a range of chunks within that xorb that are needed.
///
/// unpacked_length is used for validation, the result data of this term
/// should have that field's value as its length
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct XorbReconstructionTerm {
    pub hash: HexMerkleHash,
    // the resulting data from deserializing the range in this term
    // should have a length equal to `unpacked_length`
    pub unpacked_length: u32,
    // chunk index start and end in a xorb
    pub range: ChunkRange,
}

/// To use a XorbReconstructionFetchInfo fetch info all that's needed
/// is an http get request on the url with the Range header directly
/// formed from the url_range values.
///
/// the `range` key describes the chunk range within the xorb that the
/// url is used to fetch
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, Hash)]
pub struct XorbReconstructionFetchInfo {
    // chunk index start and end in a xorb
    pub range: ChunkRange,
    pub url: String,
    // byte index start and end in a xorb, used exclusively for Range header
    pub url_range: HttpRange,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct QueryReconstructionResponse {
    // For range query [a, b) into a file content, the location
    // of "a" into the first range.
    pub offset_into_first_range: u64,
    // Series of terms describing a xorb hash and chunk range to be retreived
    // to reconstruct the file
    pub terms: Vec<XorbReconstructionTerm>,
    // information to fetch xorb ranges to reconstruct the file
    // each key is a hash that is present in the `terms` field reconstruction
    // terms, the values are information we will need to fetch ranges from
    // each xorb needed to reconstruct the file
    pub fetch_info: HashMap<HexMerkleHash, Vec<XorbReconstructionFetchInfo>>,
}

/// V2 reconstruction response - optimized for multi-range fetching.
/// May provide fewer signed URLs per xorb by combining multiple byte ranges
/// into a single URL where possible.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct QueryReconstructionResponseV2 {
    pub offset_into_first_range: u64,
    pub terms: Vec<XorbReconstructionTerm>,
    /// Map from xorb hash -> list of multi-range fetch entries.
    /// Typically 1 entry per xorb. Multiple entries when the URL length limit
    /// (~8 KiB, roughly ~500 ranges) forces a split.
    pub xorbs: HashMap<HexMerkleHash, Vec<XorbMultiRangeFetch>>,
}

/// A signed multi-range fetch: one URL covering a subset of ranges for a xorb.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct XorbMultiRangeFetch {
    /// Signed URL with all byte ranges encoded. Client must send exactly the
    /// signed range value as the Range header.
    pub url: String,
    /// Byte ranges covered by this URL, sorted by chunk start.
    pub ranges: Vec<XorbRangeDescriptor>,
}

/// A single byte range within a xorb, mapping chunk indices to physical bytes.
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct XorbRangeDescriptor {
    /// Chunk index range [start, end) within the xorb.
    pub chunks: ChunkRange,
    /// Physical byte range [start, end] (inclusive end) for the HTTP Range header.
    pub bytes: HttpRange,
}

impl From<QueryReconstructionResponse> for QueryReconstructionResponseV2 {
    fn from(v1: QueryReconstructionResponse) -> Self {
        let xorbs = v1
            .fetch_info
            .into_iter()
            .map(|(hash, fetch_infos)| {
                let fetch = fetch_infos
                    .into_iter()
                    .map(|info| XorbMultiRangeFetch {
                        url: info.url,
                        ranges: vec![XorbRangeDescriptor {
                            chunks: info.range,
                            bytes: info.url_range,
                        }],
                    })
                    .collect();
                (hash, fetch)
            })
            .collect();

        QueryReconstructionResponseV2 {
            offset_into_first_range: v1.offset_into_first_range,
            terms: v1.terms,
            xorbs,
        }
    }
}

// Request json body type representation for the POST /reconstructions endpoint
// to get the reconstruction for multiple files at a time.
// listing of non-duplicate (enforced by HashSet) keys (file ids) to get reconstructions for
pub type BatchQueryReconstructionRequest = HashSet<HexKey>;

// Response type for querying reconstruction for a batch of files
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct BatchQueryReconstructionResponse {
    // Map of FileID to series of terms describing a xorb hash and chunk range to be retreived
    // to reconstruct the file
    pub files: HashMap<HexMerkleHash, Vec<XorbReconstructionTerm>>,
    // information to fetch xorb ranges to reconstruct the file
    // each key is a hash that is present in the `terms` field reconstruction
    // terms, the values are information we will need to fetch ranges from
    // each xorb needed to reconstruct the file
    pub fetch_info: HashMap<HexMerkleHash, Vec<XorbReconstructionFetchInfo>>,
}

#[derive(Debug, Serialize_repr, Deserialize_repr, Clone, Copy)]
#[repr(u8)]
pub enum UploadShardResponseType {
    Exists = 0,
    SyncPerformed = 1,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct UploadShardResponse {
    pub result: UploadShardResponseType,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct QueryChunkResponse {
    pub shard: MerkleHash,
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_file_range_segment() {
        let file_range = FileRange::full();
        let segment_size = 824820;

        let (segment, remainder) = file_range.take_segment(segment_size);

        assert_eq!(segment, FileRange::new(0, segment_size));
        assert_eq!(remainder, Some(FileRange::new(segment_size, u64::MAX)));
    }

    #[test]
    fn test_file_range_segment_no_remainder() {
        let file_range = FileRange::new(50, 100);
        let segment_size = 40;

        let (s1, remainder) = file_range.take_segment(segment_size);

        assert_eq!(s1, FileRange::new(50, 90));
        assert_eq!(remainder, Some(FileRange::new(90, 100)));

        let (s2, remainder) = remainder.unwrap().take_segment(segment_size);

        assert_eq!(s2, FileRange::new(90, 100));
        assert_eq!(remainder, None);
    }

    #[test]
    fn test_http_range_type_casting() {
        assert_eq!(HttpRange::from(FileRange::new(0, 10)), HttpRange::new(0, 9));

        assert_eq!(FileRange::from(HttpRange::new(0, 10)), FileRange::new(0, 11));
    }
}