xet_data/processing/
configurations.rs1use std::path::{Path, PathBuf};
2use std::sync::Arc;
3
4use http::HeaderMap;
5use tracing::info;
6use xet_client::cas_client::auth::AuthConfig;
7use xet_runtime::core::{xet_cache_root, xet_config};
8
9use crate::error::Result;
10
11#[derive(Debug)]
14pub struct SessionContext {
15 pub endpoint: String,
18 pub auth: Option<AuthConfig>,
19 pub custom_headers: Option<Arc<HeaderMap>>,
20 pub repo_paths: Vec<String>,
21 pub session_id: Option<String>,
22}
23
24impl SessionContext {
25 pub fn is_local(&self) -> bool {
27 self.endpoint.starts_with(&xet_config().data.local_cas_scheme)
28 }
29
30 pub fn local_path(&self) -> Option<PathBuf> {
32 let path = self.endpoint.strip_prefix(&xet_config().data.local_cas_scheme)?;
33 Some(PathBuf::from(path))
34 }
35
36 pub fn is_memory(&self) -> bool {
38 self.endpoint == "memory://"
39 }
40
41 pub fn for_local_path(base_dir: impl AsRef<Path>) -> Self {
43 let path = base_dir.as_ref().to_path_buf();
44 let endpoint = format!("{}{}", xet_config().data.local_cas_scheme, path.display());
45 Self {
46 endpoint,
47 auth: None,
48 custom_headers: None,
49 repo_paths: vec!["".into()],
50 session_id: None,
51 }
52 }
53
54 pub fn for_memory() -> Self {
56 Self {
57 endpoint: "memory://".into(),
58 auth: None,
59 custom_headers: None,
60 repo_paths: vec!["".into()],
61 session_id: None,
62 }
63 }
64}
65
66#[derive(Debug)]
69pub struct TranslatorConfig {
70 pub session: SessionContext,
71
72 pub shard_cache_directory: PathBuf,
74
75 pub shard_session_directory: PathBuf,
77
78 pub force_disable_progress_aggregation: bool,
81}
82
83impl TranslatorConfig {
84 fn create_base_xet_dir(base_dir: impl AsRef<Path>) -> Result<PathBuf> {
85 let base_path = base_dir.as_ref().join("xet");
86 std::fs::create_dir_all(&base_path)?;
87 Ok(base_path)
88 }
89
90 pub fn new(session: SessionContext) -> Result<Self> {
92 let config = xet_config();
93
94 let (shard_cache_directory, shard_session_directory) = if let Some(local_path) = session.local_path() {
95 let base_path = local_path.join("xet");
96 std::fs::create_dir_all(&base_path)?;
97
98 (base_path.join(&config.shard.cache_subdir), base_path.join(&config.session.dir_name))
99 } else if session.is_memory() {
100 let cache_path = xet_cache_root().join("memory");
101 std::fs::create_dir_all(&cache_path)?;
102
103 (cache_path.join(&config.shard.cache_subdir), cache_path.join(&config.session.dir_name))
104 } else {
105 let cache_path = compute_cache_path(&session.endpoint);
106 std::fs::create_dir_all(&cache_path)?;
107
108 let staging_directory = cache_path.join(&config.data.staging_subdir);
109 std::fs::create_dir_all(&staging_directory)?;
110
111 (cache_path.join(&config.shard.cache_subdir), staging_directory.join(&config.session.dir_name))
112 };
113
114 info!(
115 endpoint = %session.endpoint,
116 session_id = ?session.session_id,
117 shard_cache = %shard_cache_directory.display(),
118 shard_session = %shard_session_directory.display(),
119 "TranslatorConfig initialized"
120 );
121
122 Ok(Self {
123 session,
124 shard_cache_directory,
125 shard_session_directory,
126 force_disable_progress_aggregation: false,
127 })
128 }
129
130 pub fn local_config(base_dir: impl AsRef<Path>) -> Result<Self> {
132 Self::new(SessionContext::for_local_path(base_dir))
133 }
134
135 pub fn memory_config(base_dir: impl AsRef<Path>) -> Result<Self> {
138 let session = SessionContext::for_memory();
139 let config = xet_config();
140 let base_path = Self::create_base_xet_dir(base_dir)?;
141
142 Ok(Self {
143 session,
144 shard_cache_directory: base_path.join(&config.shard.cache_subdir),
145 shard_session_directory: base_path.join(&config.session.dir_name),
146 force_disable_progress_aggregation: false,
147 })
148 }
149
150 pub fn test_server_config(endpoint: impl AsRef<str>, base_dir: impl AsRef<Path>) -> Result<Self> {
154 let session = SessionContext {
155 endpoint: endpoint.as_ref().to_string(),
156 auth: None,
157 custom_headers: None,
158 repo_paths: vec!["".into()],
159 session_id: None,
160 };
161 let config = xet_config();
162 let base_path = Self::create_base_xet_dir(base_dir)?;
163
164 Ok(Self {
165 session,
166 shard_cache_directory: base_path.join(&config.shard.cache_subdir),
167 shard_session_directory: base_path.join(&config.session.dir_name),
168 force_disable_progress_aggregation: false,
169 })
170 }
171
172 pub fn disable_progress_aggregation(mut self) -> Self {
173 self.force_disable_progress_aggregation = true;
174 self
175 }
176}
177
178fn compute_cache_path(endpoint: &str) -> PathBuf {
180 let cache_root = xet_cache_root();
181
182 let endpoint_prefix = endpoint
183 .chars()
184 .take(16)
185 .map(|c| if c.is_alphanumeric() { c } else { '_' })
186 .collect::<String>();
187
188 let endpoint_hash = xet_core_structures::merklehash::compute_data_hash(endpoint.as_bytes()).base64();
189 let endpoint_tag = format!("{endpoint_prefix}-{}", &endpoint_hash[..16]);
190
191 cache_root.join(endpoint_tag)
192}
193
194#[cfg(test)]
195mod tests {
196 use tempfile::tempdir;
197
198 use super::{SessionContext, TranslatorConfig};
199
200 #[test]
201 fn test_session_context_mode_detection() {
202 let temp_dir = tempdir().unwrap();
203 let local_session = SessionContext::for_local_path(temp_dir.path());
204 assert!(local_session.is_local());
205 assert!(!local_session.is_memory());
206 assert_eq!(local_session.local_path().unwrap(), temp_dir.path().to_path_buf());
207
208 let memory_session = SessionContext::for_memory();
209 assert!(memory_session.is_memory());
210 assert!(!memory_session.is_local());
211 assert!(memory_session.local_path().is_none());
212
213 let remote_session = SessionContext {
214 endpoint: "http://localhost:8080".into(),
215 auth: None,
216 custom_headers: None,
217 repo_paths: Vec::new(),
218 session_id: None,
219 };
220 assert!(!remote_session.is_local());
221 assert!(!remote_session.is_memory());
222 assert!(remote_session.local_path().is_none());
223 }
224
225 #[test]
226 fn test_memory_and_server_configs_use_base_xet_layout() {
227 let temp_dir = tempdir().unwrap();
228
229 let memory_config = TranslatorConfig::memory_config(temp_dir.path()).unwrap();
230 assert!(memory_config.shard_cache_directory.starts_with(temp_dir.path().join("xet")));
231 assert!(memory_config.shard_session_directory.starts_with(temp_dir.path().join("xet")));
232
233 let server_config = TranslatorConfig::test_server_config("http://localhost:8080", temp_dir.path()).unwrap();
234 assert!(server_config.shard_cache_directory.starts_with(temp_dir.path().join("xet")));
235 assert!(server_config.shard_session_directory.starts_with(temp_dir.path().join("xet")));
236 }
237}