entrenar/hf_pipeline/publish/
publisher.rs1use std::path::Path;
7
8use super::config::PublishConfig;
9use super::model_card::ModelCard;
10use super::result::{PublishError, PublishResult};
11use crate::hf_pipeline::HfModelFetcher;
12
13const HF_API_BASE: &str = "https://huggingface.co/api";
14
15pub struct HfPublisher {
17 config: PublishConfig,
18 client: reqwest::blocking::Client,
19 token: String,
20}
21
22impl HfPublisher {
23 pub fn new(config: PublishConfig) -> Result<Self, PublishError> {
25 let token = config
26 .token
27 .clone()
28 .or_else(HfModelFetcher::resolve_token)
29 .ok_or(PublishError::AuthRequired)?;
30
31 if config.repo_id.is_empty() || !config.repo_id.contains('/') {
32 return Err(PublishError::InvalidRepoId { repo_id: config.repo_id.clone() });
33 }
34
35 let client =
36 reqwest::blocking::Client::builder().user_agent("entrenar/0.5").build().map_err(
37 |e| PublishError::Http { message: format!("Failed to create HTTP client: {e}") },
38 )?;
39
40 Ok(Self { config, client, token })
41 }
42
43 pub fn create_repo(&self) -> Result<String, PublishError> {
47 let url = format!("{HF_API_BASE}/repos/create");
48
49 let mut body = serde_json::json!({
50 "name": self.repo_name(),
51 "type": self.config.repo_type.to_string(),
52 "private": self.config.private,
53 });
54
55 if let Some(org) = self.repo_org() {
57 body["organization"] = serde_json::Value::String(org);
58 }
59
60 let response =
61 self.client.post(&url).bearer_auth(&self.token).json(&body).send().map_err(|e| {
62 PublishError::Http { message: format!("Create repo request failed: {e}") }
63 })?;
64
65 if response.status().is_success() || response.status().as_u16() == 409 {
66 let repo_url = format!(
68 "https://huggingface.co/{}/{}",
69 self.config.repo_type.api_path(),
70 self.config.repo_id
71 );
72 Ok(repo_url)
73 } else {
74 let status = response.status();
75 let body = response.text().unwrap_or_default();
76 Err(PublishError::RepoCreationFailed {
77 repo_id: self.config.repo_id.clone(),
78 message: format!("HTTP {status}: {body}"),
79 })
80 }
81 }
82
83 pub fn upload_file(&self, local_path: &Path, path_in_repo: &str) -> Result<(), PublishError> {
87 let content = std::fs::read(local_path).map_err(PublishError::Io)?;
88 self.upload_bytes(&content, path_in_repo)
89 }
90
91 pub fn upload_bytes(&self, content: &[u8], path_in_repo: &str) -> Result<(), PublishError> {
93 let url = format!(
94 "{HF_API_BASE}/{}/{}/upload/main/{}",
95 self.config.repo_type.api_path(),
96 self.config.repo_id,
97 path_in_repo
98 );
99
100 let response = self
101 .client
102 .put(&url)
103 .bearer_auth(&self.token)
104 .header("Content-Type", "application/octet-stream")
105 .body(content.to_vec())
106 .send()
107 .map_err(|e| PublishError::UploadFailed {
108 path: path_in_repo.to_string(),
109 message: format!("Upload request failed: {e}"),
110 })?;
111
112 if response.status().is_success() {
113 Ok(())
114 } else {
115 let status = response.status();
116 let body = response.text().unwrap_or_default();
117 Err(PublishError::UploadFailed {
118 path: path_in_repo.to_string(),
119 message: format!("HTTP {status}: {body}"),
120 })
121 }
122 }
123
124 pub fn publish(
126 &self,
127 files: &[(&Path, &str)],
128 model_card: Option<&ModelCard>,
129 ) -> Result<PublishResult, PublishError> {
130 let repo_url = self.create_repo()?;
131
132 let mut files_uploaded = 0;
133
134 for (local_path, remote_path) in files {
136 self.upload_file(local_path, remote_path)?;
137 files_uploaded += 1;
138 }
139
140 let model_card_generated = if let Some(card) = model_card {
142 let markdown = card.to_markdown();
143 self.upload_bytes(markdown.as_bytes(), "README.md")?;
144 true
145 } else {
146 false
147 };
148
149 Ok(PublishResult {
150 repo_url,
151 repo_id: self.config.repo_id.clone(),
152 files_uploaded,
153 model_card_generated,
154 })
155 }
156
157 fn repo_name(&self) -> &str {
159 self.config.repo_id.rsplit('/').next().unwrap_or(&self.config.repo_id)
160 }
161
162 fn repo_org(&self) -> Option<String> {
164 let parts: Vec<&str> = self.config.repo_id.splitn(2, '/').collect();
165 if parts.len() == 2 {
166 Some(parts[0].to_string())
167 } else {
168 None
169 }
170 }
171}
172
173impl std::fmt::Debug for HfPublisher {
174 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175 f.debug_struct("HfPublisher")
176 .field("repo_id", &self.config.repo_id)
177 .field("repo_type", &self.config.repo_type)
178 .field("private", &self.config.private)
179 .finish_non_exhaustive()
180 }
181}