use std::path::{Path, PathBuf};
use std::sync::Arc;
use http::HeaderMap;
use tracing::info;
use xet_client::cas_client::auth::AuthConfig;
use xet_runtime::core::{xet_cache_root, xet_config};
use crate::error::Result;
#[derive(Debug)]
pub struct SessionContext {
pub endpoint: String,
pub auth: Option<AuthConfig>,
pub custom_headers: Option<Arc<HeaderMap>>,
pub repo_paths: Vec<String>,
pub session_id: Option<String>,
}
impl SessionContext {
pub fn is_local(&self) -> bool {
self.endpoint.starts_with(&xet_config().data.local_cas_scheme)
}
pub fn local_path(&self) -> Option<PathBuf> {
let path = self.endpoint.strip_prefix(&xet_config().data.local_cas_scheme)?;
Some(PathBuf::from(path))
}
pub fn is_memory(&self) -> bool {
self.endpoint == "memory://"
}
pub fn for_local_path(base_dir: impl AsRef<Path>) -> Self {
let path = base_dir.as_ref().to_path_buf();
let endpoint = format!("{}{}", xet_config().data.local_cas_scheme, path.display());
Self {
endpoint,
auth: None,
custom_headers: None,
repo_paths: vec!["".into()],
session_id: None,
}
}
pub fn for_memory() -> Self {
Self {
endpoint: "memory://".into(),
auth: None,
custom_headers: None,
repo_paths: vec!["".into()],
session_id: None,
}
}
}
#[derive(Debug)]
pub struct TranslatorConfig {
pub session: SessionContext,
pub shard_cache_directory: PathBuf,
pub shard_session_directory: PathBuf,
pub force_disable_progress_aggregation: bool,
}
impl TranslatorConfig {
fn create_base_xet_dir(base_dir: impl AsRef<Path>) -> Result<PathBuf> {
let base_path = base_dir.as_ref().join("xet");
std::fs::create_dir_all(&base_path)?;
Ok(base_path)
}
pub fn new(session: SessionContext) -> Result<Self> {
let config = xet_config();
let (shard_cache_directory, shard_session_directory) = if let Some(local_path) = session.local_path() {
let base_path = local_path.join("xet");
std::fs::create_dir_all(&base_path)?;
(base_path.join(&config.shard.cache_subdir), base_path.join(&config.session.dir_name))
} else if session.is_memory() {
let cache_path = xet_cache_root().join("memory");
std::fs::create_dir_all(&cache_path)?;
(cache_path.join(&config.shard.cache_subdir), cache_path.join(&config.session.dir_name))
} else {
let cache_path = compute_cache_path(&session.endpoint);
std::fs::create_dir_all(&cache_path)?;
let staging_directory = cache_path.join(&config.data.staging_subdir);
std::fs::create_dir_all(&staging_directory)?;
(cache_path.join(&config.shard.cache_subdir), staging_directory.join(&config.session.dir_name))
};
info!(
endpoint = %session.endpoint,
session_id = ?session.session_id,
shard_cache = %shard_cache_directory.display(),
shard_session = %shard_session_directory.display(),
"TranslatorConfig initialized"
);
Ok(Self {
session,
shard_cache_directory,
shard_session_directory,
force_disable_progress_aggregation: false,
})
}
pub fn local_config(base_dir: impl AsRef<Path>) -> Result<Self> {
Self::new(SessionContext::for_local_path(base_dir))
}
pub fn memory_config(base_dir: impl AsRef<Path>) -> Result<Self> {
let session = SessionContext::for_memory();
let config = xet_config();
let base_path = Self::create_base_xet_dir(base_dir)?;
Ok(Self {
session,
shard_cache_directory: base_path.join(&config.shard.cache_subdir),
shard_session_directory: base_path.join(&config.session.dir_name),
force_disable_progress_aggregation: false,
})
}
pub fn test_server_config(endpoint: impl AsRef<str>, base_dir: impl AsRef<Path>) -> Result<Self> {
let session = SessionContext {
endpoint: endpoint.as_ref().to_string(),
auth: None,
custom_headers: None,
repo_paths: vec!["".into()],
session_id: None,
};
let config = xet_config();
let base_path = Self::create_base_xet_dir(base_dir)?;
Ok(Self {
session,
shard_cache_directory: base_path.join(&config.shard.cache_subdir),
shard_session_directory: base_path.join(&config.session.dir_name),
force_disable_progress_aggregation: false,
})
}
pub fn disable_progress_aggregation(mut self) -> Self {
self.force_disable_progress_aggregation = true;
self
}
}
fn compute_cache_path(endpoint: &str) -> PathBuf {
let cache_root = xet_cache_root();
let endpoint_prefix = endpoint
.chars()
.take(16)
.map(|c| if c.is_alphanumeric() { c } else { '_' })
.collect::<String>();
let endpoint_hash = xet_core_structures::merklehash::compute_data_hash(endpoint.as_bytes()).base64();
let endpoint_tag = format!("{endpoint_prefix}-{}", &endpoint_hash[..16]);
cache_root.join(endpoint_tag)
}
#[cfg(test)]
mod tests {
use tempfile::tempdir;
use super::{SessionContext, TranslatorConfig};
#[test]
fn test_session_context_mode_detection() {
let temp_dir = tempdir().unwrap();
let local_session = SessionContext::for_local_path(temp_dir.path());
assert!(local_session.is_local());
assert!(!local_session.is_memory());
assert_eq!(local_session.local_path().unwrap(), temp_dir.path().to_path_buf());
let memory_session = SessionContext::for_memory();
assert!(memory_session.is_memory());
assert!(!memory_session.is_local());
assert!(memory_session.local_path().is_none());
let remote_session = SessionContext {
endpoint: "http://localhost:8080".into(),
auth: None,
custom_headers: None,
repo_paths: Vec::new(),
session_id: None,
};
assert!(!remote_session.is_local());
assert!(!remote_session.is_memory());
assert!(remote_session.local_path().is_none());
}
#[test]
fn test_memory_and_server_configs_use_base_xet_layout() {
let temp_dir = tempdir().unwrap();
let memory_config = TranslatorConfig::memory_config(temp_dir.path()).unwrap();
assert!(memory_config.shard_cache_directory.starts_with(temp_dir.path().join("xet")));
assert!(memory_config.shard_session_directory.starts_with(temp_dir.path().join("xet")));
let server_config = TranslatorConfig::test_server_config("http://localhost:8080", temp_dir.path()).unwrap();
assert!(server_config.shard_cache_directory.starts_with(temp_dir.path().join("xet")));
assert!(server_config.shard_session_directory.starts_with(temp_dir.path().join("xet")));
}
}