diffusion_rs/
util.rs

1use std::{
2    path::PathBuf,
3    sync::{OnceLock, RwLock},
4};
5
6use hf_hub::api::sync::{ApiBuilder, ApiError};
7
8static TOKEN: OnceLock<RwLock<String>> = OnceLock::new();
9
10/// Set the huggingface hub token to access "protected" models. See <https://huggingface.co/settings/tokens>
11pub fn set_hf_token(token: &str) {
12    let guard = TOKEN.get_or_init(|| RwLock::new(Default::default()));
13    let mut data = guard.write().unwrap();
14    *data = token.to_owned();
15}
16
17/// Download file from huggingface hub
18pub fn download_file_hf_hub(repo: &str, file: &str) -> Result<PathBuf, ApiError> {
19    let token = TOKEN.get().map(|token| token.read().unwrap().to_owned());
20    let repo = ApiBuilder::new()
21        .with_token(token)
22        .build()?
23        .model(repo.to_string());
24    repo.get(file)
25}