#[cfg(feature = "xet")]
use std::path::Path;
#[cfg(feature = "xet")]
use super::{HfHubError, Result};
pub const HF_XET_THRESHOLD_BYTES: u64 = 5 * 1024 * 1024 * 1024;
#[must_use]
pub const fn should_use_xet(file_size_bytes: u64) -> bool {
file_size_bytes > HF_XET_THRESHOLD_BYTES
}
#[must_use]
pub fn build_token_refresh_url(api_base: &str, repo_id: &str, revision: &str) -> String {
let base = api_base.trim_end_matches('/');
format!("{base}/api/models/{repo_id}/xet-write-token/{revision}")
}
#[cfg(feature = "xet")]
#[derive(Debug)]
pub struct XetUploader<'a> {
pub api_base: &'a str,
pub repo_id: &'a str,
pub revision: &'a str,
pub token: &'a str,
}
#[cfg(feature = "xet")]
impl<'a> XetUploader<'a> {
pub fn upload_file(&self, local_path: &Path, _commit_msg: &str) -> Result<()> {
use xet::xet_session::{header, HeaderMap, HeaderValue, Sha256Policy, XetSessionBuilder};
let token_refresh_url = build_token_refresh_url(self.api_base, self.repo_id, self.revision);
let mut headers = HeaderMap::new();
let auth_value = format!("Bearer {}", self.token);
headers.insert(
header::AUTHORIZATION,
HeaderValue::from_str(&auth_value)
.map_err(|e| HfHubError::XetUpload(format!("auth header build failed: {e}")))?,
);
let session = XetSessionBuilder::new()
.build()
.map_err(|e| HfHubError::XetUpload(format!("session build failed: {e}")))?;
let commit = session
.new_upload_commit()
.map_err(|e| HfHubError::XetUpload(format!("new_upload_commit failed: {e}")))?
.with_token_refresh_url(token_refresh_url, headers)
.build_blocking()
.map_err(|e| HfHubError::XetUpload(format!("commit build_blocking failed: {e}")))?;
commit
.upload_from_path_blocking(local_path.to_path_buf(), Sha256Policy::Compute)
.map_err(|e| HfHubError::XetUpload(format!("upload_from_path failed: {e}")))?;
commit
.commit_blocking()
.map_err(|e| HfHubError::PartialUpload {
cas_success: true,
commit_success: false,
detail: format!("commit_blocking failed: {e}"),
})?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dispatch_gate_partitions_exactly_at_5_gib() {
assert!(!should_use_xet(0));
assert!(!should_use_xet(HF_XET_THRESHOLD_BYTES - 1));
assert!(!should_use_xet(HF_XET_THRESHOLD_BYTES));
assert!(should_use_xet(HF_XET_THRESHOLD_BYTES + 1));
assert!(should_use_xet(8_035_635_524));
assert!(should_use_xet(15_231_938_404));
assert!(should_use_xet(8_037_129_408));
}
#[test]
fn token_refresh_url_matches_hf_protocol_shape() {
let url = build_token_refresh_url("https://huggingface.co", "paiml/my-model", "main");
assert_eq!(
url,
"https://huggingface.co/api/models/paiml/my-model/xet-write-token/main"
);
}
#[test]
fn token_refresh_url_strips_trailing_slash() {
let url = build_token_refresh_url("https://huggingface.co/", "org/repo", "main");
assert!(!url.contains("co//api"));
assert_eq!(
url,
"https://huggingface.co/api/models/org/repo/xet-write-token/main"
);
}
#[test]
fn token_refresh_url_supports_non_main_revision() {
let url = build_token_refresh_url("https://huggingface.co", "org/repo", "release-v1");
assert_eq!(
url,
"https://huggingface.co/api/models/org/repo/xet-write-token/release-v1"
);
}
}