diffusion_rs/
util.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
use std::{
    path::PathBuf,
    sync::{OnceLock, RwLock},
};

use hf_hub::api::sync::{ApiBuilder, ApiError};

static TOKEN: OnceLock<RwLock<String>> = OnceLock::new();

/// Set the huggingface hub token to access "protected" models. See <https://huggingface.co/settings/tokens>
pub fn set_hf_token(token: &str) {
    let guard = TOKEN.get_or_init(|| RwLock::new(Default::default()));
    let mut data = guard.write().unwrap();
    *data = token.to_owned();
}

/// Download file from huggingface hub
pub fn download_file_hf_hub(repo: &str, file: &str) -> Result<PathBuf, ApiError> {
    let token = TOKEN.get().map(|token| token.read().unwrap().to_owned());
    let repo = ApiBuilder::new()
        .with_token(token)
        .build()?
        .model(repo.to_string());
    repo.get(file)
}