Skip to main content

synth_claw/hub/
mod.rs

1use crate::config::HubConfig;
2use crate::{Error, Result};
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5use std::path::Path;
6use std::time::Duration;
7
8const HF_API_URL: &str = "https://huggingface.co/api";
9
10pub struct HubClient {
11    client: Client,
12    token: String,
13}
14
15#[derive(Serialize)]
16struct CreateRepoRequest {
17    #[serde(rename = "type")]
18    repo_type: String,
19    name: String,
20    private: bool,
21}
22
23#[derive(Deserialize)]
24struct RepoInfo {
25    pub name: String,
26}
27
28fn resolve_token(token: Option<String>) -> Result<String> {
29    token
30        .or_else(|| std::env::var("HF_TOKEN").ok())
31        .or_else(|| std::env::var("HUGGING_FACE_HUB_TOKEN").ok())
32        .or_else(read_cached_token)
33        .ok_or_else(|| Error::Config(
34            "HF token not found. Set via config, HF_TOKEN env var, or run `huggingface-cli login`".into()
35        ))
36}
37
38fn read_cached_token() -> Option<String> {
39    let home = std::env::var("HOME").or_else(|_| std::env::var("USERPROFILE")).ok()?;
40    let token_path = std::path::PathBuf::from(home).join(".cache/huggingface/token");
41    std::fs::read_to_string(token_path).ok().map(|s| s.trim().to_string())
42}
43
44impl HubClient {
45    pub fn new(token: Option<String>) -> Result<Self> {
46        Ok(Self {
47            client: Client::new(),
48            token: resolve_token(token)?,
49        })
50    }
51
52    pub fn from_config(config: &HubConfig) -> Result<Self> {
53        Self::new(config.token.clone())
54    }
55
56    pub async fn create_dataset_repo(&self, repo_id: &str, private: bool) -> Result<String> {
57        let repo_name = repo_id.split('/').last().unwrap_or(repo_id);
58
59        let req = CreateRepoRequest {
60            repo_type: "dataset".to_string(),
61            name: repo_name.to_string(),
62            private,
63        };
64
65        let resp = self
66            .client
67            .post(&format!("{}/repos/create", HF_API_URL))
68            .header("Authorization", format!("Bearer {}", self.token))
69            .json(&req)
70            .send()
71            .await
72            .map_err(|e| Error::Http(e))?;
73
74        if resp.status() == 409 {
75            if repo_id.contains('/') {
76                return Ok(repo_id.to_string());
77            }
78            return Err(Error::Config(format!(
79                "Dataset '{}' already exists. Use 'username/{}' in hub.repo",
80                repo_name, repo_name
81            )));
82        }
83
84        if !resp.status().is_success() {
85            let status = resp.status();
86            let body = resp.text().await.unwrap_or_default();
87            return Err(Error::Provider(format!("Failed to create repo: {} - {}", status, body)));
88        }
89
90        let info: RepoInfo = resp.json().await.map_err(|e| Error::Provider(e.to_string()))?;
91        // HF needs time to propagate newly created repos
92        tokio::time::sleep(Duration::from_secs(2)).await;
93        Ok(info.name)
94    }
95
96    pub async fn upload_file(
97        &self,
98        repo_id: &str,
99        path_in_repo: &str,
100        content: &[u8],
101        commit_message: Option<&str>,
102    ) -> Result<String> {
103        let commit_msg = commit_message.unwrap_or("Upload file");
104
105        use base64::Engine;
106        let content_b64 = base64::engine::general_purpose::STANDARD.encode(content);
107
108        // HF commit API uses NDJSON (application/x-ndjson)
109        let header_line = serde_json::json!({
110            "key": "header",
111            "value": { "summary": commit_msg }
112        });
113        let file_line = serde_json::json!({
114            "key": "file",
115            "value": {
116                "path": path_in_repo,
117                "content": content_b64,
118                "encoding": "base64"
119            }
120        });
121        let body = format!("{}\n{}", header_line, file_line);
122
123        let url = format!("{}/datasets/{}/commit/main", HF_API_URL, repo_id);
124
125        let mut resp = self
126            .client
127            .post(&url)
128            .header("Authorization", format!("Bearer {}", self.token))
129            .header("Content-Type", "application/x-ndjson")
130            .body(body.clone())
131            .send()
132            .await
133            .map_err(Error::Http)?;
134
135        // Retry on 404 — newly created repos need time to propagate
136        let mut attempts = 0;
137        while resp.status() == 404 && attempts < 5 {
138            attempts += 1;
139            tokio::time::sleep(Duration::from_secs(1)).await;
140            resp = self
141                .client
142                .post(&url)
143                .header("Authorization", format!("Bearer {}", self.token))
144                .header("Content-Type", "application/x-ndjson")
145                .body(body.clone())
146                .send()
147                .await
148                .map_err(Error::Http)?;
149        }
150
151        if !resp.status().is_success() {
152            let status = resp.status();
153            let body = resp.text().await.unwrap_or_default();
154            return Err(Error::Provider(format!("Upload failed: {} - {}", status, body)));
155        }
156
157        Ok(format!(
158            "https://huggingface.co/datasets/{}/blob/main/{}",
159            repo_id, path_in_repo
160        ))
161    }
162
163    pub async fn upload_file_from_path(
164        &self,
165        repo_id: &str,
166        local_path: &Path,
167        path_in_repo: &str,
168        commit_message: Option<&str>,
169    ) -> Result<String> {
170        let content = std::fs::read(local_path)
171            .map_err(|e| Error::Io(e))?;
172        self.upload_file(repo_id, path_in_repo, &content, commit_message).await
173    }
174}
175
176pub struct DatasetUploader {
177    client: HubClient,
178    repo_id: String,
179}
180
181impl DatasetUploader {
182    pub async fn new(repo_name: &str, private: bool, token: Option<String>) -> Result<Self> {
183        let client = HubClient::new(token)?;
184        let repo_id = client.create_dataset_repo(repo_name, private).await?;
185        Ok(Self { client, repo_id })
186    }
187
188    pub async fn from_config(config: &HubConfig) -> Result<Self> {
189        let repo = config.repo.as_ref()
190            .ok_or_else(|| Error::Config("hub.repo is required".into()))?;
191        Self::new(repo, config.private, config.token.clone()).await
192    }
193
194    pub fn repo_id(&self) -> &str {
195        &self.repo_id
196    }
197
198    pub fn repo_url(&self) -> String {
199        format!("https://huggingface.co/datasets/{}", self.repo_id)
200    }
201
202    pub async fn upload(&self, path_in_repo: &str, content: &[u8], commit_message: Option<&str>) -> Result<String> {
203        self.client.upload_file(&self.repo_id, path_in_repo, content, commit_message).await
204    }
205
206    pub async fn upload_file(&self, local_path: &Path, path_in_repo: &str, commit_message: Option<&str>) -> Result<String> {
207        self.client.upload_file_from_path(&self.repo_id, local_path, path_in_repo, commit_message).await
208    }
209
210    pub async fn upload_jsonl(&self, data: &[serde_json::Value], filename: &str) -> Result<String> {
211        let content: String = data
212            .iter()
213            .map(|v| serde_json::to_string(v).unwrap())
214            .collect::<Vec<_>>()
215            .join("\n");
216        
217        self.upload(filename, content.as_bytes(), Some(&format!("Upload {}", filename))).await
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn test_hub_client_with_token() {
227        let result = HubClient::new(Some("test-token".to_string()));
228        assert!(result.is_ok());
229    }
230
231    #[test]
232    fn test_resolve_token_priority() {
233        std::env::set_var("HF_TOKEN", "env-token");
234        let token = resolve_token(Some("direct-token".to_string())).unwrap();
235        assert_eq!(token, "direct-token");
236        std::env::remove_var("HF_TOKEN");
237    }
238}