use std::sync::Arc;
use http::header;
use tracing::{Instrument, Span, info_span, instrument};
use xet_client::cas_client::auth::TokenRefresher;
use xet_client::hub_client::{BearerCredentialHelper, HubClient, Operation, RepoInfo};
use xet_core_structures::metadata_shard::file_structs::MDBFileInfo;
use xet_runtime::core::XetRuntime;
use xet_runtime::core::par_utils::run_constrained;
use super::super::data_client::{clean_file, default_config};
use super::super::{FileUploadSession, Sha256Policy, XetFileInfo};
use super::hub_client_token_refresher::HubClientTokenRefresher;
use crate::error::{DataError, Result};
const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
pub async fn migrate_with_external_runtime(
file_paths: Vec<String>,
sha256s: Option<Vec<String>>,
hub_endpoint: &str,
cas_endpoint: Option<String>,
hub_token: &str,
repo_type: &str,
repo_id: &str,
) -> Result<()> {
let cred_helper = BearerCredentialHelper::new(hub_token.to_owned(), "");
let mut headers = header::HeaderMap::new();
headers.insert(header::USER_AGENT, header::HeaderValue::from_static(USER_AGENT));
let hub_client = HubClient::new(
hub_endpoint,
RepoInfo::try_from(repo_type, repo_id)?,
Some("main".to_owned()),
"",
Some(cred_helper),
Some(headers),
)?;
migrate_files_impl(file_paths, sha256s, false, hub_client, cas_endpoint, false).await?;
Ok(())
}
pub type MigrationInfo = (Vec<MDBFileInfo>, Vec<(XetFileInfo, u64)>, u64);
#[instrument(skip_all, name = "migrate_files", fields(session_id = tracing::field::Empty, num_files = file_paths.len()))]
pub async fn migrate_files_impl(
file_paths: Vec<String>,
sha256s: Option<Vec<String>>,
sequential: bool,
hub_client: HubClient,
cas_endpoint: Option<String>,
dry_run: bool,
) -> Result<MigrationInfo> {
let operation = Operation::Upload;
let jwt_info = hub_client.get_cas_jwt(operation).await?;
let token_refresher = Arc::new(HubClientTokenRefresher {
operation,
client: Arc::new(hub_client),
}) as Arc<dyn TokenRefresher>;
let cas = cas_endpoint.unwrap_or(jwt_info.cas_url);
let mut headers = http::HeaderMap::new();
headers.insert(http::header::USER_AGENT, http::HeaderValue::from_static(USER_AGENT));
let config = default_config(
cas,
Some((jwt_info.access_token, jwt_info.exp)),
Some(token_refresher),
Some(Arc::new(headers)),
)?;
Span::current().record("session_id", &config.session.session_id);
let num_workers = if sequential {
1
} else {
XetRuntime::current().num_worker_threads()
};
let processor = if dry_run {
FileUploadSession::dry_run(config.into()).await?
} else {
FileUploadSession::new(config.into()).await?
};
let sha256_policies: Vec<Sha256Policy> = match sha256s {
Some(v) => {
if v.len() != file_paths.len() {
return Err(DataError::ParameterError(
"mismatched length of the file list and the sha256 list".to_string(),
));
}
v.iter().map(|s| Sha256Policy::from_hex(s)).collect()
},
None => vec![Sha256Policy::Compute; file_paths.len()],
};
let clean_futs = file_paths.into_iter().zip(sha256_policies).map(|(file_path, policy)| {
let proc = processor.clone();
async move {
let (pf, metrics) = clean_file(proc, file_path, policy).await?;
Ok::<(XetFileInfo, u64), DataError>((pf, metrics.new_bytes))
}
.instrument(info_span!("clean_file"))
});
let clean_ret = run_constrained(clean_futs, num_workers).await?;
if dry_run {
let (metrics, all_file_info) = processor.finalize_with_file_info().await?;
Ok((all_file_info, clean_ret, metrics.total_bytes_uploaded))
} else {
let metrics = processor.finalize().await?;
Ok((vec![], clean_ret, metrics.total_bytes_uploaded as u64))
}
}