use crate::format::model_card::ModelCard;
use std::path::PathBuf;
use std::sync::Arc;
fn base64_encode(data: &[u8]) -> String {
const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut result = String::with_capacity((data.len() + 2) / 3 * 4);
for chunk in data.chunks(3) {
let mut buf = [0u8; 3];
buf[..chunk.len()].copy_from_slice(chunk);
let n = (u32::from(buf[0]) << 16) | (u32::from(buf[1]) << 8) | u32::from(buf[2]);
result.push(ALPHABET[(n >> 18) as usize & 0x3F] as char);
result.push(ALPHABET[(n >> 12) as usize & 0x3F] as char);
if chunk.len() > 1 {
result.push(ALPHABET[(n >> 6) as usize & 0x3F] as char);
} else {
result.push('=');
}
if chunk.len() > 2 {
result.push(ALPHABET[n as usize & 0x3F] as char);
} else {
result.push('=');
}
}
result
}
#[derive(Debug)]
pub enum HfHubError {
MissingToken,
NetworkError(String),
RepoNotFound(String),
FileNotFound(String),
InvalidRepoId(String),
IoError(std::io::Error),
ModelCardError(String),
}
impl std::fmt::Display for HfHubError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::MissingToken => write!(f, "HF_TOKEN environment variable not set"),
Self::NetworkError(e) => write!(f, "Network error: {e}"),
Self::RepoNotFound(repo) => write!(f, "Repository not found: {repo}"),
Self::FileNotFound(file) => write!(f, "File not found: {file}"),
Self::InvalidRepoId(id) => write!(f, "Invalid repo ID (expected 'org/name'): {id}"),
Self::IoError(e) => write!(f, "IO error: {e}"),
Self::ModelCardError(e) => write!(f, "Model card error: {e}"),
}
}
}
impl std::error::Error for HfHubError {}
impl From<std::io::Error> for HfHubError {
fn from(e: std::io::Error) -> Self {
Self::IoError(e)
}
}
pub type Result<T> = std::result::Result<T, HfHubError>;
#[derive(Debug, Clone)]
pub struct UploadProgress {
pub bytes_sent: u64,
pub total_bytes: u64,
pub current_file: String,
pub files_completed: usize,
pub total_files: usize,
}
impl UploadProgress {
#[must_use]
pub fn percentage(&self) -> f64 {
if self.total_bytes == 0 {
100.0
} else {
(self.bytes_sent as f64 / self.total_bytes as f64) * 100.0
}
}
}
pub type ProgressCallback = Arc<dyn Fn(UploadProgress) + Send + Sync>;
#[derive(Debug, Clone)]
pub struct UploadResult {
pub repo_url: String,
pub commit_sha: String,
pub files_uploaded: Vec<String>,
pub bytes_transferred: u64,
}
#[derive(Clone)]
pub struct PushOptions {
pub commit_message: Option<String>,
pub model_card: Option<ModelCard>,
pub create_repo: bool,
pub private: bool,
pub filename: String,
pub progress_callback: Option<ProgressCallback>,
pub max_retries: usize,
pub initial_backoff_ms: u64,
}
impl std::fmt::Debug for PushOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PushOptions")
.field("commit_message", &self.commit_message)
.field("model_card", &self.model_card)
.field("create_repo", &self.create_repo)
.field("private", &self.private)
.field("filename", &self.filename)
.field("progress_callback", &self.progress_callback.is_some())
.field("max_retries", &self.max_retries)
.field("initial_backoff_ms", &self.initial_backoff_ms)
.finish()
}
}
impl Default for PushOptions {
fn default() -> Self {
Self::new()
}
}
impl PushOptions {
#[must_use]
pub fn new() -> Self {
Self {
commit_message: None,
model_card: None,
create_repo: true,
private: false,
filename: "model.apr".to_string(),
progress_callback: None,
max_retries: 3,
initial_backoff_ms: 1000,
}
}
#[must_use]
pub fn with_commit_message(mut self, msg: impl Into<String>) -> Self {
self.commit_message = Some(msg.into());
self
}
#[must_use]
pub fn with_model_card(mut self, card: ModelCard) -> Self {
self.model_card = Some(card);
self
}
#[must_use]
pub fn with_create_repo(mut self, create: bool) -> Self {
self.create_repo = create;
self
}
#[must_use]
pub fn with_private(mut self, private: bool) -> Self {
self.private = private;
self
}
#[must_use]
pub fn with_filename(mut self, filename: impl Into<String>) -> Self {
self.filename = filename.into();
self
}
#[must_use]
pub fn with_progress_callback(mut self, callback: ProgressCallback) -> Self {
self.progress_callback = Some(callback);
self
}
#[must_use]
pub fn with_max_retries(mut self, retries: usize) -> Self {
self.max_retries = retries;
self
}
}
#[derive(Debug)]
pub struct HfHubClient {
pub(crate) token: Option<String>,
pub(crate) cache_dir: PathBuf,
pub(crate) api_base: String,
}
mod client_default;
mod client_modules;
#[cfg(test)]
mod tests;