Skip to main content

xet_data/processing/
configurations.rs

1use std::path::{Path, PathBuf};
2use std::sync::Arc;
3
4use http::HeaderMap;
5use tracing::info;
6use xet_client::cas_client::auth::AuthConfig;
7use xet_runtime::core::{xet_cache_root, xet_config};
8
9use crate::error::Result;
10
11/// Session-specific configuration that varies per upload/download session.
12/// These are runtime values that cannot be configured via environment variables.
13#[derive(Debug)]
14pub struct SessionContext {
15    /// The endpoint URL. Use the `local://` prefix (configurable via `HF_XET_DATA_LOCAL_CAS_SCHEME`)
16    /// to specify a local filesystem path, or `memory://` for in-memory storage.
17    pub endpoint: String,
18    pub auth: Option<AuthConfig>,
19    pub custom_headers: Option<Arc<HeaderMap>>,
20    pub repo_paths: Vec<String>,
21    pub session_id: Option<String>,
22}
23
24impl SessionContext {
25    /// Returns true if this endpoint points to a local filesystem path.
26    pub fn is_local(&self) -> bool {
27        self.endpoint.starts_with(&xet_config().data.local_cas_scheme)
28    }
29
30    /// Returns the local filesystem path if this is a local endpoint.
31    pub fn local_path(&self) -> Option<PathBuf> {
32        let path = self.endpoint.strip_prefix(&xet_config().data.local_cas_scheme)?;
33        Some(PathBuf::from(path))
34    }
35
36    /// Returns true if this endpoint uses in-memory storage.
37    pub fn is_memory(&self) -> bool {
38        self.endpoint == "memory://"
39    }
40
41    /// Creates a SessionContext for local filesystem-based operations.
42    pub fn for_local_path(base_dir: impl AsRef<Path>) -> Self {
43        let path = base_dir.as_ref().to_path_buf();
44        let endpoint = format!("{}{}", xet_config().data.local_cas_scheme, path.display());
45        Self {
46            endpoint,
47            auth: None,
48            custom_headers: None,
49            repo_paths: vec!["".into()],
50            session_id: None,
51        }
52    }
53
54    /// Creates a SessionContext for in-memory storage.
55    pub fn for_memory() -> Self {
56        Self {
57            endpoint: "memory://".into(),
58            auth: None,
59            custom_headers: None,
60            repo_paths: vec!["".into()],
61            session_id: None,
62        }
63    }
64}
65
66/// Main configuration for file upload/download operations.
67/// Combines session-specific values with runtime-computed paths derived from the endpoint.
68#[derive(Debug)]
69pub struct TranslatorConfig {
70    pub session: SessionContext,
71
72    /// Directory for caching shard files.
73    pub shard_cache_directory: PathBuf,
74
75    /// Directory for session-specific shard files.
76    pub shard_session_directory: PathBuf,
77
78    /// Per-session override: when true, progress aggregation is disabled
79    /// regardless of the global `HF_XET_DATA_AGGREGATE_PROGRESS` config value.
80    pub force_disable_progress_aggregation: bool,
81}
82
83impl TranslatorConfig {
84    fn create_base_xet_dir(base_dir: impl AsRef<Path>) -> Result<PathBuf> {
85        let base_path = base_dir.as_ref().join("xet");
86        std::fs::create_dir_all(&base_path)?;
87        Ok(base_path)
88    }
89
90    /// Creates a new TranslatorConfig from a SessionContext, computing all derived paths.
91    pub fn new(session: SessionContext) -> Result<Self> {
92        let config = xet_config();
93
94        let (shard_cache_directory, shard_session_directory) = if let Some(local_path) = session.local_path() {
95            let base_path = local_path.join("xet");
96            std::fs::create_dir_all(&base_path)?;
97
98            (base_path.join(&config.shard.cache_subdir), base_path.join(&config.session.dir_name))
99        } else if session.is_memory() {
100            let cache_path = xet_cache_root().join("memory");
101            std::fs::create_dir_all(&cache_path)?;
102
103            (cache_path.join(&config.shard.cache_subdir), cache_path.join(&config.session.dir_name))
104        } else {
105            let cache_path = compute_cache_path(&session.endpoint);
106            std::fs::create_dir_all(&cache_path)?;
107
108            let staging_directory = cache_path.join(&config.data.staging_subdir);
109            std::fs::create_dir_all(&staging_directory)?;
110
111            (cache_path.join(&config.shard.cache_subdir), staging_directory.join(&config.session.dir_name))
112        };
113
114        info!(
115            endpoint = %session.endpoint,
116            session_id = ?session.session_id,
117            shard_cache = %shard_cache_directory.display(),
118            shard_session = %shard_session_directory.display(),
119            "TranslatorConfig initialized"
120        );
121
122        Ok(Self {
123            session,
124            shard_cache_directory,
125            shard_session_directory,
126            force_disable_progress_aggregation: false,
127        })
128    }
129
130    /// Creates a TranslatorConfig for local filesystem-based storage.
131    pub fn local_config(base_dir: impl AsRef<Path>) -> Result<Self> {
132        Self::new(SessionContext::for_local_path(base_dir))
133    }
134
135    /// Creates a TranslatorConfig that uses in-memory storage for XORBs.
136    /// Shard data still uses file-based storage in the provided base directory.
137    pub fn memory_config(base_dir: impl AsRef<Path>) -> Result<Self> {
138        let session = SessionContext::for_memory();
139        let config = xet_config();
140        let base_path = Self::create_base_xet_dir(base_dir)?;
141
142        Ok(Self {
143            session,
144            shard_cache_directory: base_path.join(&config.shard.cache_subdir),
145            shard_session_directory: base_path.join(&config.session.dir_name),
146            force_disable_progress_aggregation: false,
147        })
148    }
149
150    /// Creates a TranslatorConfig that connects to a CAS server at the given endpoint.
151    /// Shard cache and session directories are created under the provided base directory.
152    /// Useful for tests that use LocalTestServer.
153    pub fn test_server_config(endpoint: impl AsRef<str>, base_dir: impl AsRef<Path>) -> Result<Self> {
154        let session = SessionContext {
155            endpoint: endpoint.as_ref().to_string(),
156            auth: None,
157            custom_headers: None,
158            repo_paths: vec!["".into()],
159            session_id: None,
160        };
161        let config = xet_config();
162        let base_path = Self::create_base_xet_dir(base_dir)?;
163
164        Ok(Self {
165            session,
166            shard_cache_directory: base_path.join(&config.shard.cache_subdir),
167            shard_session_directory: base_path.join(&config.session.dir_name),
168            force_disable_progress_aggregation: false,
169        })
170    }
171
172    pub fn disable_progress_aggregation(mut self) -> Self {
173        self.force_disable_progress_aggregation = true;
174        self
175    }
176}
177
178/// Computes a cache-safe path from an endpoint URL.
179fn compute_cache_path(endpoint: &str) -> PathBuf {
180    let cache_root = xet_cache_root();
181
182    let endpoint_prefix = endpoint
183        .chars()
184        .take(16)
185        .map(|c| if c.is_alphanumeric() { c } else { '_' })
186        .collect::<String>();
187
188    let endpoint_hash = xet_core_structures::merklehash::compute_data_hash(endpoint.as_bytes()).base64();
189    let endpoint_tag = format!("{endpoint_prefix}-{}", &endpoint_hash[..16]);
190
191    cache_root.join(endpoint_tag)
192}
193
194#[cfg(test)]
195mod tests {
196    use tempfile::tempdir;
197
198    use super::{SessionContext, TranslatorConfig};
199
200    #[test]
201    fn test_session_context_mode_detection() {
202        let temp_dir = tempdir().unwrap();
203        let local_session = SessionContext::for_local_path(temp_dir.path());
204        assert!(local_session.is_local());
205        assert!(!local_session.is_memory());
206        assert_eq!(local_session.local_path().unwrap(), temp_dir.path().to_path_buf());
207
208        let memory_session = SessionContext::for_memory();
209        assert!(memory_session.is_memory());
210        assert!(!memory_session.is_local());
211        assert!(memory_session.local_path().is_none());
212
213        let remote_session = SessionContext {
214            endpoint: "http://localhost:8080".into(),
215            auth: None,
216            custom_headers: None,
217            repo_paths: Vec::new(),
218            session_id: None,
219        };
220        assert!(!remote_session.is_local());
221        assert!(!remote_session.is_memory());
222        assert!(remote_session.local_path().is_none());
223    }
224
225    #[test]
226    fn test_memory_and_server_configs_use_base_xet_layout() {
227        let temp_dir = tempdir().unwrap();
228
229        let memory_config = TranslatorConfig::memory_config(temp_dir.path()).unwrap();
230        assert!(memory_config.shard_cache_directory.starts_with(temp_dir.path().join("xet")));
231        assert!(memory_config.shard_session_directory.starts_with(temp_dir.path().join("xet")));
232
233        let server_config = TranslatorConfig::test_server_config("http://localhost:8080", temp_dir.path()).unwrap();
234        assert!(server_config.shard_cache_directory.starts_with(temp_dir.path().join("xet")));
235        assert!(server_config.shard_session_directory.starts_with(temp_dir.path().join("xet")));
236    }
237}