use std::fs::{File, create_dir_all, read_dir};
use std::io::{Read, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use itertools::multizip;
use rand::prelude::*;
use tempfile::TempDir;
use xet_client::cas_client::{Client, LocalClient};
#[cfg(feature = "simulation")]
use xet_client::cas_client::{LocalTestServer, LocalTestServerBuilder};
use super::configurations::TranslatorConfig;
use super::data_client::clean_file;
use super::file_cleaner::Sha256Policy;
use super::{FileDownloadSession, FileUploadSession, XetFileInfo};
#[derive(Debug, Clone, Copy)]
pub enum HydrationMode {
DirectClient,
#[cfg(feature = "simulation")]
ServerV2,
#[cfg(feature = "simulation")]
ServerV1Fallback,
#[cfg(feature = "simulation")]
ServerMaxRanges2,
}
impl HydrationMode {
pub fn all() -> &'static [HydrationMode] {
&[
HydrationMode::DirectClient,
#[cfg(feature = "simulation")]
HydrationMode::ServerV2,
#[cfg(feature = "simulation")]
HydrationMode::ServerV1Fallback,
#[cfg(feature = "simulation")]
HydrationMode::ServerMaxRanges2,
]
}
pub fn uses_server(&self) -> bool {
match self {
HydrationMode::DirectClient => false,
#[cfg(feature = "simulation")]
_ => true,
}
}
}
impl std::fmt::Display for HydrationMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
HydrationMode::DirectClient => write!(f, "direct_client"),
#[cfg(feature = "simulation")]
HydrationMode::ServerV2 => write!(f, "server_v2"),
#[cfg(feature = "simulation")]
HydrationMode::ServerV1Fallback => write!(f, "server_v1_fallback"),
#[cfg(feature = "simulation")]
HydrationMode::ServerMaxRanges2 => write!(f, "server_max_ranges_2"),
}
}
}
pub fn create_random_file(path: impl AsRef<Path>, size: usize, seed: u64) -> usize {
let path = path.as_ref();
let dir = path.parent().unwrap();
create_dir_all(dir).unwrap();
let mut rng = StdRng::seed_from_u64(seed);
let mut file = File::create(path).unwrap();
let mut buffer = vec![0_u8; size];
rng.fill_bytes(&mut buffer);
file.write_all(&buffer).unwrap();
size
}
pub fn create_random_files(dir: impl AsRef<Path>, files: &[(impl AsRef<str>, usize)], seed: u64) -> usize {
let dir = dir.as_ref();
let mut total_bytes = 0;
let mut rng = SmallRng::seed_from_u64(seed);
for (file_name, size) in files {
total_bytes += create_random_file(dir.join(file_name.as_ref()), *size, rng.random());
}
total_bytes
}
pub fn create_random_multipart_file(path: impl AsRef<Path>, segments: &[(usize, u64)]) -> usize {
let path = path.as_ref();
let dir = path.parent().unwrap();
create_dir_all(dir).unwrap();
let mut file = File::create(path).unwrap();
let mut total_size = 0;
for &(size, seed) in segments {
let mut rng = StdRng::seed_from_u64(seed);
let mut buffer = vec![0_u8; size];
rng.fill_bytes(&mut buffer);
file.write_all(&buffer).unwrap();
total_size += size;
}
total_size
}
pub fn verify_directories_match(dir1: impl AsRef<Path>, dir2: impl AsRef<Path>) {
let dir1 = dir1.as_ref();
let dir2 = dir2.as_ref();
let mut files_in_dir1 = Vec::new();
for entry in read_dir(dir1).unwrap() {
let entry = entry.unwrap();
assert!(entry.file_type().unwrap().is_file());
files_in_dir1.push(entry.file_name());
}
let mut files_in_dir2 = Vec::new();
for entry in read_dir(dir2).unwrap() {
let entry = entry.unwrap();
assert!(entry.file_type().unwrap().is_file());
files_in_dir2.push(entry.file_name());
}
files_in_dir1.sort();
files_in_dir2.sort();
if files_in_dir1 != files_in_dir2 {
panic!(
"Directories differ: file sets are not the same.\n \
dir1: {files_in_dir1:?}\n dir2: {files_in_dir2:?}"
);
}
for file_name in &files_in_dir1 {
let path1 = dir1.join(file_name);
let path2 = dir2.join(file_name);
let mut buf1 = Vec::new();
let mut buf2 = Vec::new();
File::open(&path1).unwrap().read_to_end(&mut buf1).unwrap();
File::open(&path2).unwrap().read_to_end(&mut buf2).unwrap();
if buf1 != buf2 {
panic!(
"File contents differ for {file_name:?}\n \
dir1 path: {path1:?}\n dir2 path: {path2:?}"
);
}
}
}
pub struct HydrateDehydrateTest {
_temp_dir: TempDir,
pub cas_dir: PathBuf,
pub src_dir: PathBuf,
pub ptr_dir: PathBuf,
pub dest_dir: PathBuf,
use_test_server: bool,
#[cfg(feature = "simulation")]
test_server: Option<LocalTestServer>,
}
impl Default for HydrateDehydrateTest {
fn default() -> Self {
Self::new(false)
}
}
impl HydrateDehydrateTest {
pub fn new(use_test_server: bool) -> Self {
let _temp_dir = TempDir::new().unwrap();
let temp_path = _temp_dir.path();
let cas_dir = temp_path.join("cas");
let src_dir = temp_path.join("src");
let ptr_dir = temp_path.join("pointers");
let dest_dir = temp_path.join("dest");
std::fs::create_dir_all(&cas_dir).unwrap();
std::fs::create_dir_all(&src_dir).unwrap();
std::fs::create_dir_all(&ptr_dir).unwrap();
std::fs::create_dir_all(&dest_dir).unwrap();
Self {
cas_dir,
src_dir,
ptr_dir,
dest_dir,
_temp_dir,
use_test_server,
#[cfg(feature = "simulation")]
test_server: None,
}
}
pub fn for_mode(mode: HydrationMode) -> Self {
Self::new(mode.uses_server())
}
pub async fn apply_hydration_mode(&mut self, mode: HydrationMode) {
match mode {
HydrationMode::DirectClient => {},
#[cfg(feature = "simulation")]
HydrationMode::ServerV2 => {
self.ensure_server_created().await;
},
#[cfg(feature = "simulation")]
HydrationMode::ServerV1Fallback => {
self.ensure_server_created().await;
self.test_server.as_ref().unwrap().client().disable_v2_reconstruction(404);
},
#[cfg(feature = "simulation")]
HydrationMode::ServerMaxRanges2 => {
self.ensure_server_created().await;
self.test_server.as_ref().unwrap().client().set_max_ranges_per_fetch(2);
},
}
}
#[cfg(feature = "simulation")]
pub async fn ensure_server_created(&mut self) {
if self.use_test_server && self.test_server.is_none() {
let local_client = LocalClient::new(self.cas_dir.join("xet/xorbs")).await.unwrap();
self.test_server = Some(LocalTestServerBuilder::new().with_client(local_client).start().await);
}
}
#[cfg(feature = "simulation")]
pub fn test_server(&self) -> Option<&LocalTestServer> {
self.test_server.as_ref()
}
async fn get_or_create_client(&mut self) -> Arc<dyn Client> {
if self.use_test_server {
#[cfg(feature = "simulation")]
{
if self.test_server.is_none() {
let local_client = LocalClient::new(self.cas_dir.join("xet/xorbs")).await.unwrap();
self.test_server = Some(LocalTestServerBuilder::new().with_client(local_client).start().await);
}
self.test_server.as_ref().unwrap().remote_client().clone() as Arc<dyn Client>
}
#[cfg(not(feature = "simulation"))]
{
panic!("test server requires the 'simulation' feature");
}
} else {
LocalClient::new(self.cas_dir.join("xet/xorbs")).await.unwrap() as Arc<dyn Client>
}
}
pub async fn new_upload_session(&self) -> Arc<FileUploadSession> {
let config = Arc::new(TranslatorConfig::local_config(&self.cas_dir).unwrap());
FileUploadSession::new(config.clone()).await.unwrap()
}
pub async fn clean_all_files(&self, upload_session: &Arc<FileUploadSession>, sequential: bool) {
create_dir_all(&self.ptr_dir).unwrap();
if sequential {
for entry in read_dir(&self.src_dir).unwrap() {
let entry = entry.unwrap();
let out_file = self.ptr_dir.join(entry.file_name());
let upload_session = upload_session.clone();
if sequential {
let (pf, metrics) = clean_file(upload_session.clone(), entry.path(), Sha256Policy::Compute)
.await
.unwrap();
assert_eq!({ metrics.total_bytes }, entry.metadata().unwrap().len());
std::fs::write(out_file, pf.as_pointer_file().unwrap().as_bytes()).unwrap();
upload_session.checkpoint().await.unwrap();
}
}
} else {
let files: Vec<PathBuf> = read_dir(&self.src_dir)
.unwrap()
.map(|entry| self.src_dir.join(entry.unwrap().file_name()))
.collect();
let files_and_sha256 = multizip((files.iter(), std::iter::repeat_with(|| Sha256Policy::Compute)));
let clean_results = upload_session.upload_files(files_and_sha256).await.unwrap();
for (i, xf) in clean_results.into_iter().enumerate() {
std::fs::write(self.ptr_dir.join(files[i].file_name().unwrap()), serde_json::to_string(&xf).unwrap())
.unwrap();
}
}
}
pub async fn dehydrate(&mut self, sequential: bool) {
let upload_session = self.new_upload_session().await;
self.clean_all_files(&upload_session, sequential).await;
upload_session.finalize().await.unwrap();
}
pub async fn hydrate(&mut self) {
let client = self.get_or_create_client().await;
let session = FileDownloadSession::from_client(client, None);
for entry in read_dir(&self.ptr_dir).unwrap() {
let entry = entry.unwrap();
let out_filename = self.dest_dir.join(entry.file_name());
let xf: XetFileInfo = serde_json::from_reader(File::open(entry.path()).unwrap()).unwrap();
let (_id, _) = session.download_file(&xf, &out_filename).await.unwrap();
}
}
pub async fn hydrate_partitioned_writers(&mut self, partitions: usize) {
let client = self.get_or_create_client().await;
let session = FileDownloadSession::from_client(client, None);
for entry in read_dir(&self.ptr_dir).unwrap() {
let entry = entry.unwrap();
let out_filename = self.dest_dir.join(entry.file_name());
let xf: XetFileInfo = serde_json::from_reader(File::open(entry.path()).unwrap()).unwrap();
let file_size = xf.file_size().expect("file size required for partitioned hydration");
let out_file = File::create(&out_filename).unwrap();
out_file.set_len(file_size).unwrap();
if file_size == 0 {
continue;
}
let partition_count = partitions.max(1) as u64;
let mut tasks = Vec::new();
for idx in 0..partition_count {
let start = (idx * file_size) / partition_count;
let end = ((idx + 1) * file_size) / partition_count;
if start == end {
continue;
}
let session = session.clone();
let xf = xf.clone();
let out_filename = out_filename.clone();
tasks.push(tokio::spawn(async move {
let mut writer = std::fs::OpenOptions::new().write(true).open(out_filename).unwrap();
writer.seek(SeekFrom::Start(start)).unwrap();
session.download_to_writer(&xf, start..end, writer).await
}));
}
for task in tasks {
task.await.unwrap().unwrap();
}
}
}
pub async fn hydrate_stream(&mut self) {
let client = self.get_or_create_client().await;
let session = FileDownloadSession::from_client(client, None);
for entry in read_dir(&self.ptr_dir).unwrap() {
let entry = entry.unwrap();
let out_filename = self.dest_dir.join(entry.file_name());
let xf: XetFileInfo = serde_json::from_reader(File::open(entry.path()).unwrap()).unwrap();
let (_id, mut stream) = session.download_stream(&xf, None).await.unwrap();
let mut file = File::create(&out_filename).unwrap();
while let Some(chunk) = stream.next().await.unwrap() {
file.write_all(&chunk).unwrap();
}
}
}
pub fn verify_src_dest_match(&self) {
verify_directories_match(&self.src_dir, &self.dest_dir);
}
}
pub struct TestEnvironment {
_temp_dir: TempDir,
pub base_dir: PathBuf,
pub config: Arc<super::configurations::TranslatorConfig>,
#[cfg(feature = "simulation")]
_server: Option<LocalTestServer>,
}
impl TestEnvironment {
pub async fn new() -> Self {
let temp_dir = TempDir::new().unwrap();
let base_dir = temp_dir.path().to_path_buf();
#[cfg(feature = "simulation")]
let (config, server) = {
let server = LocalTestServerBuilder::new().start().await;
let config = Arc::new(
super::configurations::TranslatorConfig::test_server_config(server.http_endpoint(), &base_dir).unwrap(),
);
(config, Some(server))
};
#[cfg(not(feature = "simulation"))]
let config = Arc::new(super::configurations::TranslatorConfig::local_config(&base_dir).unwrap());
Self {
_temp_dir: temp_dir,
base_dir,
config,
#[cfg(feature = "simulation")]
_server: server,
}
}
}