use std::sync::Arc;
use http::header::HeaderMap;
use urlencoding::encode;
use super::types::{CasJWTInfo, RepoInfo};
use crate::cas_client::exports::ClientWithMiddleware;
use crate::cas_client::retry_wrapper::RetryWrapper;
use crate::common::auth::CredentialHelper;
use crate::common::http_client::{Api, build_http_client};
use crate::error::Result;
#[derive(Clone, Copy)]
pub enum Operation {
Upload,
Download,
}
impl Operation {
pub fn as_str(&self) -> &'static str {
match self {
Self::Upload => "upload",
Self::Download => "download",
}
}
pub fn token_type(&self) -> &'static str {
match self {
Self::Upload => "write",
Self::Download => "read",
}
}
}
pub struct HubClient {
endpoint: String,
repo_info: RepoInfo,
reference: Option<String>,
client: ClientWithMiddleware,
cred_helper: Option<Arc<dyn CredentialHelper>>,
}
impl HubClient {
pub fn new(
endpoint: &str,
repo_info: RepoInfo,
reference: Option<String>,
session_id: &str,
cred_helper: Option<Arc<dyn CredentialHelper>>,
custom_headers: Option<HeaderMap>,
) -> Result<Self> {
Ok(HubClient {
endpoint: endpoint.to_owned(),
repo_info,
reference,
client: build_http_client(session_id, None, custom_headers.map(|ch| ch.into()))?,
cred_helper,
})
}
pub async fn get_cas_jwt(&self, operation: Operation) -> Result<CasJWTInfo> {
let endpoint = self.endpoint.as_str();
let repo_type = self.repo_info.repo_type.as_str();
let repo_id = self.repo_info.full_name.as_str();
let token_type = operation.token_type();
let rev = encode(self.reference.as_deref().unwrap_or("main"));
let query = if matches!(operation, Operation::Upload) && self.reference.is_none() {
"?create_pr=1"
} else {
""
};
let url = format!("{endpoint}/api/{repo_type}s/{repo_id}/xet-{token_type}-token/{rev}{query}");
let client = self.client.clone();
let cred_helper = self.cred_helper.clone();
let info: CasJWTInfo = RetryWrapper::new("xet-token")
.run_and_extract_json(move || {
let url = url.clone();
let client = client.clone();
let cred_helper = cred_helper.clone();
async move {
let mut req = client.get(&url).with_extension(Api("xet-token"));
if let Some(cred) = cred_helper {
req = cred.fill_credential(req).await.map_err(reqwest_middleware::Error::middleware)?;
}
req.send().await
}
})
.await?;
Ok(info)
}
}
#[cfg(test)]
mod tests {
use http::header::{self, HeaderMap, HeaderValue};
use super::super::{BearerCredentialHelper, HFRepoType, Operation, RepoInfo};
use super::HubClient;
use crate::error::Result;
#[tokio::test]
#[ignore = "need valid write token"]
async fn test_get_jwt_token_with_hf_write_token() -> Result<()> {
let cred_helper = BearerCredentialHelper::new("[hf_write_token]".to_owned(), "");
let mut headers = HeaderMap::new();
headers.insert(header::USER_AGENT, HeaderValue::from_static("xtool"));
let hub_client = HubClient::new(
"https://huggingface.co",
RepoInfo {
repo_type: HFRepoType::Model,
full_name: "seanses/tm".into(),
},
Some("main".into()),
"",
Some(cred_helper),
Some(headers),
)?;
let read_info = hub_client.get_cas_jwt(Operation::Upload).await?;
assert!(read_info.access_token.len() > 0);
assert!(read_info.cas_url.len() > 0);
assert!(read_info.exp > 0);
Ok(())
}
#[tokio::test]
#[ignore = "need valid read token and pr created on hub"]
async fn test_get_jwt_token_with_hf_read_token_pr_branch() -> Result<()> {
let cred_helper = BearerCredentialHelper::new("[hf_read_token]".to_owned(), "");
let mut headers = HeaderMap::new();
headers.insert(header::USER_AGENT, HeaderValue::from_static("xtool"));
let hub_client = HubClient::new(
"https://huggingface.co",
RepoInfo {
repo_type: HFRepoType::Model,
full_name: "seanses/tm".into(),
},
Some("refs/pr/1".into()),
"",
Some(cred_helper),
Some(headers),
)?;
let read_info = hub_client.get_cas_jwt(Operation::Upload).await?;
assert!(read_info.access_token.len() > 0);
assert!(read_info.cas_url.len() > 0);
assert!(read_info.exp > 0);
Ok(())
}
#[tokio::test]
#[ignore = "need valid read token"]
async fn test_get_jwt_token_with_hf_read_token_create_pr() -> Result<()> {
let cred_helper = BearerCredentialHelper::new("[hf_read_token]".to_owned(), "");
let mut headers = HeaderMap::new();
headers.insert(header::USER_AGENT, HeaderValue::from_static("xtool"));
let hub_client = HubClient::new(
"https://huggingface.co",
RepoInfo {
repo_type: HFRepoType::Model,
full_name: "seanses/tm".into(),
},
None,
"",
Some(cred_helper),
Some(headers),
)?;
let read_info = hub_client.get_cas_jwt(Operation::Upload).await?;
assert!(read_info.access_token.len() > 0);
assert!(read_info.cas_url.len() > 0);
assert!(read_info.exp > 0);
Ok(())
}
}