use std::fs::File;
use std::io::{BufWriter, Write};
use std::path::PathBuf;
use std::sync::Arc;
use anyhow::Result;
use clap::{Args, Parser, Subcommand};
use http::header::{self, HeaderMap, HeaderValue};
use walkdir::WalkDir;
use xet_client::cas_client::RemoteClient;
use xet_client::cas_client::auth::TokenRefresher;
use xet_client::cas_types::{FileRange, QueryReconstructionResponse};
use xet_client::hub_client::{BearerCredentialHelper, HubClient, Operation, RepoInfo};
use xet_core_structures::merklehash::MerkleHash;
use xet_core_structures::xorb_object::CompressionScheme;
use xet_data::processing::data_client::default_config;
use xet_data::processing::migration_tool::hub_client_token_refresher::HubClientTokenRefresher;
use xet_data::processing::migration_tool::migrate::migrate_files_impl;
use xet_runtime::config::XetConfig;
use xet_runtime::core::XetRuntime;
const DEFAULT_HF_ENDPOINT: &str = "https://huggingface.co";
const USER_AGENT: &str = concat!("xtool", "/", env!("CARGO_PKG_VERSION"));
#[derive(Parser)]
struct XCommand {
#[clap(flatten)]
overrides: CliOverrides,
#[clap(subcommand)]
command: Command,
}
#[derive(Args)]
struct CliOverrides {
#[clap(long)]
endpoint: Option<String>, #[clap(long)]
token: Option<String>, #[clap(long)]
repo_type: String,
#[clap(long)]
repo_id: String,
}
impl XCommand {
async fn run(self) -> Result<()> {
let endpoint = self
.overrides
.endpoint
.unwrap_or_else(|| std::env::var("HF_ENDPOINT").unwrap_or(DEFAULT_HF_ENDPOINT.to_owned()));
let token = self
.overrides
.token
.unwrap_or_else(|| std::env::var("HF_TOKEN").unwrap_or_default());
let mut headers = HeaderMap::new();
headers.insert(header::USER_AGENT, HeaderValue::from_static(USER_AGENT));
let cred_helper = BearerCredentialHelper::new(token, "");
let hub_client = HubClient::new(
&endpoint,
RepoInfo::try_from(&self.overrides.repo_type, &self.overrides.repo_id)?,
Some("main".to_owned()),
"",
Some(cred_helper),
Some(headers),
)?;
self.command.run(hub_client).await
}
}
#[derive(Subcommand)]
enum Command {
Dedup(DedupArg),
Query(QueryArg),
}
#[derive(Args)]
struct DedupArg {
files: Vec<String>,
#[clap(short, long)]
recursive: bool,
#[clap(short, long)]
sequential: bool,
#[clap(short, long)]
output: Option<PathBuf>,
#[clap(short, long)]
compression: Option<u8>,
#[clap(short, long)]
migrate: bool,
}
#[derive(Args)]
struct QueryArg {
hash: String,
bytes_range: Option<FileRange>,
}
impl Command {
async fn run(self, hub_client: HubClient) -> Result<()> {
match self {
Command::Dedup(arg) => {
let file_paths = walk_files(arg.files, arg.recursive);
eprintln!("Dedupping {} files...", file_paths.len());
let (all_file_info, clean_ret, total_bytes_trans) =
migrate_files_impl(file_paths, None, arg.sequential, hub_client, None, !arg.migrate).await?;
if !arg.migrate {
let mut writer: Box<dyn Write> = if let Some(path) = arg.output {
Box::new(BufWriter::new(File::options().create(true).write(true).truncate(true).open(path)?))
} else {
Box::new(std::io::stdout())
};
serde_json::to_writer(&mut writer, &all_file_info)?;
writer.flush()?;
}
eprintln!("\n\nClean results:");
for (xf, new_bytes) in clean_ret {
println!(
"{}: {} bytes -> {} bytes",
xf.hash(),
xf.file_size().map_or("?".to_string(), |s| s.to_string()),
new_bytes
);
}
eprintln!("Transmitted {total_bytes_trans} bytes in total.");
Ok(())
},
Command::Query(arg) => {
let file_hash = MerkleHash::from_hex(&arg.hash)?;
let ret = query_reconstruction(file_hash, arg.bytes_range, hub_client).await?;
eprintln!("{ret:?}");
Ok(())
},
}
}
}
fn walk_files(files: Vec<String>, recursive: bool) -> Vec<String> {
if recursive {
files
.iter()
.flat_map(|dir| {
WalkDir::new(dir)
.follow_links(false)
.max_depth(usize::MAX)
.into_iter()
.filter_entry(|e| !is_git_special_files(e.file_name().to_str().unwrap_or_default()))
.flatten()
.filter(|e| {
e.file_type().is_file() && !is_git_special_files(e.file_name().to_str().unwrap_or_default())
})
.filter_map(|e| e.path().to_str().map(|s| s.to_owned()))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
} else {
files
}
}
fn is_git_special_files(path: &str) -> bool {
matches!(path, ".git" | ".gitignore" | ".gitattributes")
}
async fn query_reconstruction(
file_hash: MerkleHash,
bytes_range: Option<FileRange>,
hub_client: HubClient,
) -> Result<Option<QueryReconstructionResponse>> {
let operation = Operation::Download;
let jwt_info = hub_client.get_cas_jwt(operation).await?;
let token_refresher = Arc::new(HubClientTokenRefresher {
operation,
client: Arc::new(hub_client),
}) as Arc<dyn TokenRefresher>;
let mut headers = http::HeaderMap::new();
headers.insert(http::header::USER_AGENT, http::HeaderValue::from_static(USER_AGENT));
let config = default_config(
jwt_info.cas_url.clone(),
Some((jwt_info.access_token, jwt_info.exp)),
Some(token_refresher),
Some(Arc::new(headers)),
)?;
let remote_client =
RemoteClient::new(&jwt_info.cas_url, &config.session.auth, "", true, config.session.custom_headers.clone());
remote_client
.get_reconstruction_v1(&file_hash, bytes_range)
.await
.map_err(Into::into)
}
fn main() -> Result<()> {
let cli = XCommand::parse();
let mut config = XetConfig::new();
if let Command::Dedup(ref arg) = cli.command
&& let Some(c) = arg.compression
{
let scheme = CompressionScheme::try_from(c).map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Invalid compression value {c}; expected one of: 0 (none), 1 (lz4), 2 (bg4-lz4), 99 (auto)"),
)
})?;
config
.xorb
.compression_policy
.try_set(<&str>::from(scheme))
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e.to_string()))?;
}
let threadpool = XetRuntime::new_with_config(config)?;
threadpool.bridge_sync(async move { cli.run().await })??;
Ok(())
}