Skip to main content

entrenar/hf_pipeline/publish/
publisher.rs

1//! HuggingFace Hub publisher
2//!
3//! Uploads models, files, and model cards to HuggingFace Hub repositories
4//! using the HF REST API.
5
6use std::path::Path;
7
8use super::config::PublishConfig;
9use super::model_card::ModelCard;
10use super::result::{PublishError, PublishResult};
11use crate::hf_pipeline::HfModelFetcher;
12
13const HF_API_BASE: &str = "https://huggingface.co/api";
14
15/// HuggingFace Hub publisher
16pub struct HfPublisher {
17    config: PublishConfig,
18    client: reqwest::blocking::Client,
19    token: String,
20}
21
22impl HfPublisher {
23    /// Create a new publisher with config
24    pub fn new(config: PublishConfig) -> Result<Self, PublishError> {
25        let token = config
26            .token
27            .clone()
28            .or_else(HfModelFetcher::resolve_token)
29            .ok_or(PublishError::AuthRequired)?;
30
31        if config.repo_id.is_empty() || !config.repo_id.contains('/') {
32            return Err(PublishError::InvalidRepoId { repo_id: config.repo_id.clone() });
33        }
34
35        let client =
36            reqwest::blocking::Client::builder().user_agent("entrenar/0.5").build().map_err(
37                |e| PublishError::Http { message: format!("Failed to create HTTP client: {e}") },
38            )?;
39
40        Ok(Self { config, client, token })
41    }
42
43    /// Create the HuggingFace repository
44    ///
45    /// POST <https://huggingface.co/api/repos/create>
46    pub fn create_repo(&self) -> Result<String, PublishError> {
47        let url = format!("{HF_API_BASE}/repos/create");
48
49        let mut body = serde_json::json!({
50            "name": self.repo_name(),
51            "type": self.config.repo_type.to_string(),
52            "private": self.config.private,
53        });
54
55        // Add organization if repo_id contains one
56        if let Some(org) = self.repo_org() {
57            body["organization"] = serde_json::Value::String(org);
58        }
59
60        let response =
61            self.client.post(&url).bearer_auth(&self.token).json(&body).send().map_err(|e| {
62                PublishError::Http { message: format!("Create repo request failed: {e}") }
63            })?;
64
65        if response.status().is_success() || response.status().as_u16() == 409 {
66            // 409 = already exists, which is fine
67            let repo_url = format!(
68                "https://huggingface.co/{}/{}",
69                self.config.repo_type.api_path(),
70                self.config.repo_id
71            );
72            Ok(repo_url)
73        } else {
74            let status = response.status();
75            let body = response.text().unwrap_or_default();
76            Err(PublishError::RepoCreationFailed {
77                repo_id: self.config.repo_id.clone(),
78                message: format!("HTTP {status}: {body}"),
79            })
80        }
81    }
82
83    /// Upload a local file to the repository
84    ///
85    /// PUT <https://huggingface.co/api/{type}s/{repo_id}/upload/{path}>
86    pub fn upload_file(&self, local_path: &Path, path_in_repo: &str) -> Result<(), PublishError> {
87        let content = std::fs::read(local_path).map_err(PublishError::Io)?;
88        self.upload_bytes(&content, path_in_repo)
89    }
90
91    /// Upload bytes directly to the repository
92    pub fn upload_bytes(&self, content: &[u8], path_in_repo: &str) -> Result<(), PublishError> {
93        let url = format!(
94            "{HF_API_BASE}/{}/{}/upload/main/{}",
95            self.config.repo_type.api_path(),
96            self.config.repo_id,
97            path_in_repo
98        );
99
100        let response = self
101            .client
102            .put(&url)
103            .bearer_auth(&self.token)
104            .header("Content-Type", "application/octet-stream")
105            .body(content.to_vec())
106            .send()
107            .map_err(|e| PublishError::UploadFailed {
108                path: path_in_repo.to_string(),
109                message: format!("Upload request failed: {e}"),
110            })?;
111
112        if response.status().is_success() {
113            Ok(())
114        } else {
115            let status = response.status();
116            let body = response.text().unwrap_or_default();
117            Err(PublishError::UploadFailed {
118                path: path_in_repo.to_string(),
119                message: format!("HTTP {status}: {body}"),
120            })
121        }
122    }
123
124    /// Full publish flow: create repo → upload files → upload model card
125    pub fn publish(
126        &self,
127        files: &[(&Path, &str)],
128        model_card: Option<&ModelCard>,
129    ) -> Result<PublishResult, PublishError> {
130        let repo_url = self.create_repo()?;
131
132        let mut files_uploaded = 0;
133
134        // Upload all files
135        for (local_path, remote_path) in files {
136            self.upload_file(local_path, remote_path)?;
137            files_uploaded += 1;
138        }
139
140        // Upload model card
141        let model_card_generated = if let Some(card) = model_card {
142            let markdown = card.to_markdown();
143            self.upload_bytes(markdown.as_bytes(), "README.md")?;
144            true
145        } else {
146            false
147        };
148
149        Ok(PublishResult {
150            repo_url,
151            repo_id: self.config.repo_id.clone(),
152            files_uploaded,
153            model_card_generated,
154        })
155    }
156
157    /// Extract the repository name (part after the last '/')
158    fn repo_name(&self) -> &str {
159        self.config.repo_id.rsplit('/').next().unwrap_or(&self.config.repo_id)
160    }
161
162    /// Extract the organization (part before '/')
163    fn repo_org(&self) -> Option<String> {
164        let parts: Vec<&str> = self.config.repo_id.splitn(2, '/').collect();
165        if parts.len() == 2 {
166            Some(parts[0].to_string())
167        } else {
168            None
169        }
170    }
171}
172
173impl std::fmt::Debug for HfPublisher {
174    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175        f.debug_struct("HfPublisher")
176            .field("repo_id", &self.config.repo_id)
177            .field("repo_type", &self.config.repo_type)
178            .field("private", &self.config.private)
179            .finish_non_exhaustive()
180    }
181}