1use crate::config::HubConfig;
2use crate::{Error, Result};
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5use std::path::Path;
6use std::time::Duration;
7
8const HF_API_URL: &str = "https://huggingface.co/api";
9
10pub struct HubClient {
11 client: Client,
12 token: String,
13}
14
15#[derive(Serialize)]
16struct CreateRepoRequest {
17 #[serde(rename = "type")]
18 repo_type: String,
19 name: String,
20 private: bool,
21}
22
23#[derive(Deserialize)]
24struct RepoInfo {
25 pub name: String,
26}
27
28fn resolve_token(token: Option<String>) -> Result<String> {
29 token
30 .or_else(|| std::env::var("HF_TOKEN").ok())
31 .or_else(|| std::env::var("HUGGING_FACE_HUB_TOKEN").ok())
32 .or_else(read_cached_token)
33 .ok_or_else(|| Error::Config(
34 "HF token not found. Set via config, HF_TOKEN env var, or run `huggingface-cli login`".into()
35 ))
36}
37
38fn read_cached_token() -> Option<String> {
39 let home = std::env::var("HOME").or_else(|_| std::env::var("USERPROFILE")).ok()?;
40 let token_path = std::path::PathBuf::from(home).join(".cache/huggingface/token");
41 std::fs::read_to_string(token_path).ok().map(|s| s.trim().to_string())
42}
43
44impl HubClient {
45 pub fn new(token: Option<String>) -> Result<Self> {
46 Ok(Self {
47 client: Client::new(),
48 token: resolve_token(token)?,
49 })
50 }
51
52 pub fn from_config(config: &HubConfig) -> Result<Self> {
53 Self::new(config.token.clone())
54 }
55
56 pub async fn create_dataset_repo(&self, repo_id: &str, private: bool) -> Result<String> {
57 let repo_name = repo_id.split('/').last().unwrap_or(repo_id);
58
59 let req = CreateRepoRequest {
60 repo_type: "dataset".to_string(),
61 name: repo_name.to_string(),
62 private,
63 };
64
65 let resp = self
66 .client
67 .post(&format!("{}/repos/create", HF_API_URL))
68 .header("Authorization", format!("Bearer {}", self.token))
69 .json(&req)
70 .send()
71 .await
72 .map_err(|e| Error::Http(e))?;
73
74 if resp.status() == 409 {
75 if repo_id.contains('/') {
76 return Ok(repo_id.to_string());
77 }
78 return Err(Error::Config(format!(
79 "Dataset '{}' already exists. Use 'username/{}' in hub.repo",
80 repo_name, repo_name
81 )));
82 }
83
84 if !resp.status().is_success() {
85 let status = resp.status();
86 let body = resp.text().await.unwrap_or_default();
87 return Err(Error::Provider(format!("Failed to create repo: {} - {}", status, body)));
88 }
89
90 let info: RepoInfo = resp.json().await.map_err(|e| Error::Provider(e.to_string()))?;
91 tokio::time::sleep(Duration::from_secs(2)).await;
93 Ok(info.name)
94 }
95
96 pub async fn upload_file(
97 &self,
98 repo_id: &str,
99 path_in_repo: &str,
100 content: &[u8],
101 commit_message: Option<&str>,
102 ) -> Result<String> {
103 let commit_msg = commit_message.unwrap_or("Upload file");
104
105 use base64::Engine;
106 let content_b64 = base64::engine::general_purpose::STANDARD.encode(content);
107
108 let header_line = serde_json::json!({
110 "key": "header",
111 "value": { "summary": commit_msg }
112 });
113 let file_line = serde_json::json!({
114 "key": "file",
115 "value": {
116 "path": path_in_repo,
117 "content": content_b64,
118 "encoding": "base64"
119 }
120 });
121 let body = format!("{}\n{}", header_line, file_line);
122
123 let url = format!("{}/datasets/{}/commit/main", HF_API_URL, repo_id);
124
125 let mut resp = self
126 .client
127 .post(&url)
128 .header("Authorization", format!("Bearer {}", self.token))
129 .header("Content-Type", "application/x-ndjson")
130 .body(body.clone())
131 .send()
132 .await
133 .map_err(Error::Http)?;
134
135 let mut attempts = 0;
137 while resp.status() == 404 && attempts < 5 {
138 attempts += 1;
139 tokio::time::sleep(Duration::from_secs(1)).await;
140 resp = self
141 .client
142 .post(&url)
143 .header("Authorization", format!("Bearer {}", self.token))
144 .header("Content-Type", "application/x-ndjson")
145 .body(body.clone())
146 .send()
147 .await
148 .map_err(Error::Http)?;
149 }
150
151 if !resp.status().is_success() {
152 let status = resp.status();
153 let body = resp.text().await.unwrap_or_default();
154 return Err(Error::Provider(format!("Upload failed: {} - {}", status, body)));
155 }
156
157 Ok(format!(
158 "https://huggingface.co/datasets/{}/blob/main/{}",
159 repo_id, path_in_repo
160 ))
161 }
162
163 pub async fn upload_file_from_path(
164 &self,
165 repo_id: &str,
166 local_path: &Path,
167 path_in_repo: &str,
168 commit_message: Option<&str>,
169 ) -> Result<String> {
170 let content = std::fs::read(local_path)
171 .map_err(|e| Error::Io(e))?;
172 self.upload_file(repo_id, path_in_repo, &content, commit_message).await
173 }
174}
175
176pub struct DatasetUploader {
177 client: HubClient,
178 repo_id: String,
179}
180
181impl DatasetUploader {
182 pub async fn new(repo_name: &str, private: bool, token: Option<String>) -> Result<Self> {
183 let client = HubClient::new(token)?;
184 let repo_id = client.create_dataset_repo(repo_name, private).await?;
185 Ok(Self { client, repo_id })
186 }
187
188 pub async fn from_config(config: &HubConfig) -> Result<Self> {
189 let repo = config.repo.as_ref()
190 .ok_or_else(|| Error::Config("hub.repo is required".into()))?;
191 Self::new(repo, config.private, config.token.clone()).await
192 }
193
194 pub fn repo_id(&self) -> &str {
195 &self.repo_id
196 }
197
198 pub fn repo_url(&self) -> String {
199 format!("https://huggingface.co/datasets/{}", self.repo_id)
200 }
201
202 pub async fn upload(&self, path_in_repo: &str, content: &[u8], commit_message: Option<&str>) -> Result<String> {
203 self.client.upload_file(&self.repo_id, path_in_repo, content, commit_message).await
204 }
205
206 pub async fn upload_file(&self, local_path: &Path, path_in_repo: &str, commit_message: Option<&str>) -> Result<String> {
207 self.client.upload_file_from_path(&self.repo_id, local_path, path_in_repo, commit_message).await
208 }
209
210 pub async fn upload_jsonl(&self, data: &[serde_json::Value], filename: &str) -> Result<String> {
211 let content: String = data
212 .iter()
213 .map(|v| serde_json::to_string(v).unwrap())
214 .collect::<Vec<_>>()
215 .join("\n");
216
217 self.upload(filename, content.as_bytes(), Some(&format!("Upload {}", filename))).await
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn test_hub_client_with_token() {
227 let result = HubClient::new(Some("test-token".to_string()));
228 assert!(result.is_ok());
229 }
230
231 #[test]
232 fn test_resolve_token_priority() {
233 std::env::set_var("HF_TOKEN", "env-token");
234 let token = resolve_token(Some("direct-token".to_string())).unwrap();
235 assert_eq!(token, "direct-token");
236 std::env::remove_var("HF_TOKEN");
237 }
238}