use std::collections::HashMap;
use std::sync::{Arc, Mutex, Weak};
use tracing::info;
use uuid::Uuid;
use xet_data::progress_tracking::UniqueID;
use xet_runtime::RuntimeError;
use xet_runtime::config::XetConfig;
use xet_runtime::core::XetRuntime;
#[cfg(feature = "fd-track")]
use xet_runtime::fd_diagnostics::{report_fd_count, track_fd_scope};
use super::download_stream_group::{
XetDownloadStreamGroup, XetDownloadStreamGroupBuilder, XetDownloadStreamGroupInner,
};
use super::errors::SessionError;
use super::file_download_group::XetFileDownloadGroupBuilder;
use super::task_runtime::{TaskRuntime, XetTaskState};
use super::upload_commit::XetUploadCommitBuilder;
#[doc(hidden)]
pub struct XetSessionInner {
pub(super) runtime: Arc<XetRuntime>,
pub(super) config: XetConfig,
pub(super) task_runtime: Arc<TaskRuntime>,
pub(super) active_download_stream_groups: Mutex<HashMap<UniqueID, Weak<XetDownloadStreamGroupInner>>>,
pub(super) id: Uuid,
}
pub struct XetSessionBuilder {
config: XetConfig,
tokio_handle: Option<tokio::runtime::Handle>,
}
impl Default for XetSessionBuilder {
fn default() -> Self {
Self::new()
}
}
impl XetSessionBuilder {
pub fn new() -> Self {
Self {
config: XetConfig::new(),
tokio_handle: None,
}
}
pub fn new_with_config(config: XetConfig) -> Self {
Self {
config,
tokio_handle: None,
}
}
pub fn with_tokio_handle(self, handle: tokio::runtime::Handle) -> Self {
let accept = XetRuntime::handle_meets_requirements(&handle);
if !accept {
info!("supplied tokio handle rejected (missing drivers or wrong flavor); falling back to Owned mode");
}
Self {
tokio_handle: accept.then_some(handle),
..self
}
}
pub fn build(self) -> Result<XetSession, SessionError> {
#[cfg(feature = "fd-track")]
let _fd_scope = track_fd_scope("XetSessionBuilder::build");
let handle = self.tokio_handle.or_else(|| {
tokio::runtime::Handle::try_current()
.ok()
.filter(XetRuntime::handle_meets_requirements)
});
let runtime = match handle {
Some(h) => {
info!("XetSession using External runtime (wrapping caller's tokio handle)");
let result = XetRuntime::from_external_with_config(h, self.config.clone());
match result {
Ok(runtime) => runtime,
Err(RuntimeError::ExternalAlreadyAttached(_)) => {
info!(
"An existing XetSession already wraps caller's tokio handle, switching to creating Owned runtime"
);
XetRuntime::new_with_config(self.config.clone())?
},
Err(e) => Err(e)?,
}
},
None => {
info!("XetSession creating Owned runtime (new thread pool)");
XetRuntime::new_with_config(self.config.clone())?
},
};
let session = XetSession::new(self.config, runtime);
info!("Session created, session_id={}", session.inner.id);
#[cfg(feature = "fd-track")]
report_fd_count("XetSessionBuilder::build complete");
Ok(session)
}
}
#[derive(Clone)]
pub struct XetSession {
pub(super) inner: Arc<XetSessionInner>,
}
impl XetSession {
fn new(config: XetConfig, runtime: Arc<XetRuntime>) -> Self {
let task_runtime = TaskRuntime::new_root(runtime.clone());
Self {
inner: Arc::new(XetSessionInner {
runtime,
config,
task_runtime,
active_download_stream_groups: Mutex::new(HashMap::new()),
id: Uuid::now_v7(),
}),
}
}
pub fn new_upload_commit(&self) -> Result<XetUploadCommitBuilder, SessionError> {
self.inner.task_runtime.check_state("new_upload_commit")?;
#[cfg(feature = "fd-track")]
report_fd_count("XetSession::new_upload_commit");
Ok(XetUploadCommitBuilder::new(self.clone()))
}
pub fn new_file_download_group(&self) -> Result<XetFileDownloadGroupBuilder, SessionError> {
self.inner.task_runtime.check_state("new_file_download_group")?;
#[cfg(feature = "fd-track")]
report_fd_count("XetSession::new_file_download_group");
Ok(XetFileDownloadGroupBuilder::new(self.clone()))
}
pub fn new_download_stream_group(&self) -> Result<XetDownloadStreamGroupBuilder, SessionError> {
self.inner.task_runtime.check_state("new_download_stream_group")?;
#[cfg(feature = "fd-track")]
report_fd_count("XetSession::new_download_stream_group");
Ok(XetDownloadStreamGroupBuilder::new(self.clone()))
}
pub fn status(&self) -> Result<XetTaskState, SessionError> {
self.inner.task_runtime.status()
}
pub fn abort(&self) -> Result<(), SessionError> {
#[cfg(feature = "fd-track")]
let _fd_scope = track_fd_scope(format!("XetSession::abort({})", self.inner.id));
info!("Session abort, session_id={}", self.inner.id);
self.inner.task_runtime.cancel_subtree()?;
let active_download_stream_groups = std::mem::take(&mut *self.inner.active_download_stream_groups.lock()?);
for (_id, weak_group) in active_download_stream_groups {
if let Some(inner) = weak_group.upgrade() {
inner.abort();
}
}
#[cfg(feature = "fd-track")]
report_fd_count("XetSession::abort complete");
Ok(())
}
pub fn sigint_abort(&self) -> Result<(), SessionError> {
#[cfg(feature = "fd-track")]
let _fd_scope = track_fd_scope(format!("XetSession::sigint_abort({})", self.inner.id));
info!("Session SIGINT abort, session_id={}", self.inner.id);
self.inner.runtime.perform_sigint_shutdown();
let active_download_stream_groups = std::mem::take(&mut *self.inner.active_download_stream_groups.lock()?);
for (_id, weak_group) in active_download_stream_groups {
if let Some(inner) = weak_group.upgrade() {
inner.abort();
}
}
#[cfg(feature = "fd-track")]
report_fd_count("XetSession::sigint_abort complete");
Ok(())
}
#[cfg(test)]
pub(super) fn check_alive(&self) -> Result<(), SessionError> {
if self.inner.runtime.in_sigint_shutdown() {
return Err(SessionError::KeyboardInterrupt);
}
self.inner.task_runtime.check_state("session")
}
pub(super) fn register_download_stream_group(&self, group: &XetDownloadStreamGroup) -> Result<(), SessionError> {
self.inner
.active_download_stream_groups
.lock()?
.insert(group.id(), Arc::downgrade(&group.inner));
Ok(())
}
pub(super) fn id(&self) -> &Uuid {
&self.inner.id
}
}
#[cfg(test)]
mod tests {
use tempfile::tempdir;
use xet_data::processing::{Sha256Policy, XetFileInfo};
use xet_runtime::core::{RuntimeMode, XetRuntime};
use super::*;
#[test]
fn test_session_clone_shares_state() {
let s1 = XetSessionBuilder::new().build().unwrap();
let s2 = s1.clone();
assert_eq!(s1.inner.id, s2.inner.id);
}
#[test]
fn test_two_sessions_have_distinct_ids() {
let s1 = XetSessionBuilder::new().build().unwrap();
let s2 = XetSessionBuilder::new().build().unwrap();
assert_ne!(s1.inner.id, s2.inner.id);
}
#[test]
fn test_session_id_is_uuid_v7() {
let s = XetSessionBuilder::new().build().unwrap();
let parsed: uuid::Uuid = s.inner.id.to_string().parse().expect("session id must parse as Uuid");
assert_eq!(parsed.get_version(), Some(uuid::Version::SortRand));
}
#[test]
fn test_check_alive_after_abort() {
let session = XetSessionBuilder::new().build().unwrap();
session.abort().unwrap();
let err = session.check_alive().unwrap_err();
assert!(matches!(err, SessionError::UserCancelled(_)));
}
#[test]
fn test_check_alive_after_sigint_abort() {
let session = XetSessionBuilder::new().build().unwrap();
session.sigint_abort().unwrap();
let err = session.check_alive().unwrap_err();
assert!(matches!(err, SessionError::KeyboardInterrupt));
}
#[test]
fn test_new_upload_commit_after_abort_returns_aborted() {
let session = XetSessionBuilder::new().build().unwrap();
session.abort().unwrap();
let err = session.new_upload_commit().err().unwrap();
assert!(matches!(err, SessionError::UserCancelled(_)));
}
#[test]
fn test_new_file_download_group_after_abort_returns_aborted() {
let session = XetSessionBuilder::new().build().unwrap();
session.abort().unwrap();
let err = session.new_file_download_group().err().unwrap();
assert!(matches!(err, SessionError::UserCancelled(_)));
}
#[tokio::test(flavor = "multi_thread")]
async fn test_async_new_after_abort_returns_aborted() {
let session = XetSessionBuilder::new().build().unwrap();
session.abort().unwrap();
let commit_err = session.new_upload_commit().err().unwrap();
let group_err = session.new_file_download_group().err().unwrap();
assert!(matches!(commit_err, SessionError::UserCancelled(_)));
assert!(matches!(group_err, SessionError::UserCancelled(_)));
}
#[test]
fn test_handle_multi_thread_all_features_returns_true() {
let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
assert!(XetRuntime::handle_meets_requirements(rt.handle()));
}
#[test]
#[cfg(not(target_family = "wasm"))]
fn test_handle_current_thread_returns_false() {
let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap();
assert!(!XetRuntime::handle_meets_requirements(rt.handle()));
}
#[test]
fn test_handle_without_any_driver_returns_false() {
let rt = tokio::runtime::Builder::new_multi_thread().build().unwrap();
assert!(!XetRuntime::handle_meets_requirements(rt.handle()));
}
#[test]
fn test_handle_without_io_driver_returns_false() {
let rt = tokio::runtime::Builder::new_multi_thread().enable_time().build().unwrap();
assert!(!XetRuntime::handle_meets_requirements(rt.handle()));
}
#[test]
fn test_handle_without_time_driver_returns_false() {
let rt = tokio::runtime::Builder::new_multi_thread().enable_io().build().unwrap();
assert!(!XetRuntime::handle_meets_requirements(rt.handle()));
}
#[tokio::test(flavor = "multi_thread")]
async fn test_new_upload_commit_blocking_errors_in_external_mode() {
let session = XetSessionBuilder::new().build().unwrap();
assert_eq!(session.inner.runtime.mode(), RuntimeMode::External);
let err = session.new_upload_commit().unwrap().build_blocking().err().unwrap();
assert!(matches!(err, SessionError::WrongRuntimeMode(_)));
}
#[tokio::test(flavor = "multi_thread")]
async fn test_new_file_download_group_blocking_errors_in_external_mode() {
let session = XetSessionBuilder::new().build().unwrap();
assert_eq!(session.inner.runtime.mode(), RuntimeMode::External);
let err = session.new_file_download_group().unwrap().build_blocking().err().unwrap();
assert!(matches!(err, SessionError::WrongRuntimeMode(_)));
}
#[test]
fn test_new_upload_commit_blocking_panics_in_async_context() {
let session = XetSessionBuilder::new().build().unwrap();
assert_eq!(session.inner.runtime.mode(), RuntimeMode::Owned);
let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
rt.block_on(async { session.new_upload_commit().unwrap().build_blocking() })
}));
assert!(result.is_err(), "build_blocking() must panic when called from async");
}
#[test]
fn test_new_file_download_group_blocking_panics_in_async_context() {
let session = XetSessionBuilder::new().build().unwrap();
assert_eq!(session.inner.runtime.mode(), RuntimeMode::Owned);
let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
rt.block_on(async { session.new_file_download_group().unwrap().build_blocking() })
}));
assert!(result.is_err(), "build_blocking() must panic when called from async");
}
#[test]
fn test_new_download_stream_group_after_abort_returns_aborted() {
let session = XetSessionBuilder::new().build().unwrap();
session.abort().unwrap();
let err = session.new_download_stream_group().err().unwrap();
assert!(matches!(err, SessionError::UserCancelled(_)));
}
#[test]
fn test_abort_clears_active_download_stream_groups() {
let session = XetSessionBuilder::new().build().unwrap();
let _g1 = session.new_download_stream_group().unwrap().build_blocking().unwrap();
session.abort().unwrap();
assert_eq!(session.inner.active_download_stream_groups.lock().unwrap().len(), 0);
}
async fn upload_bytes(
session: &XetSession,
endpoint: &str,
data: &[u8],
name: &str,
) -> Result<XetFileInfo, Box<dyn std::error::Error>> {
let commit = session.new_upload_commit()?.with_endpoint(endpoint).build().await?;
let _handle = commit
.upload_bytes(data.to_vec(), Sha256Policy::Compute, Some(name.into()))
.await?;
let results = commit.commit().await?;
let meta = results.uploads.into_values().next().expect("one uploaded file");
Ok(meta.xet_info)
}
fn upload_bytes_blocking(
session: &XetSession,
endpoint: &str,
data: &[u8],
name: &str,
) -> Result<XetFileInfo, Box<dyn std::error::Error>> {
let commit = session.new_upload_commit()?.with_endpoint(endpoint).build_blocking()?;
let _handle = commit.upload_bytes_blocking(data.to_vec(), Sha256Policy::Compute, Some(name.into()))?;
let results = commit.commit_blocking()?;
let meta = results.uploads.into_values().next().expect("one uploaded file");
Ok(meta.xet_info)
}
#[tokio::test(flavor = "multi_thread")]
async fn test_download_stream_round_trip() {
let temp = tempdir().unwrap();
let session = XetSessionBuilder::new().build().unwrap();
let endpoint = format!("local://{}", temp.path().join("cas").display());
let original = b"Hello, streaming download!";
let file_info = upload_bytes(&session, &endpoint, original, "stream.bin").await.unwrap();
let mut stream = session
.new_download_stream_group()
.unwrap()
.with_endpoint(&endpoint)
.build()
.await
.unwrap()
.download_stream(file_info, None)
.await
.unwrap();
let mut collected = Vec::new();
while let Some(chunk) = stream.next().await.unwrap() {
collected.extend_from_slice(&chunk);
}
assert_eq!(collected, original);
}
#[test]
fn test_download_stream_blocking_round_trip() {
let temp = tempdir().unwrap();
let session = XetSessionBuilder::new().build().unwrap();
let endpoint = format!("local://{}", temp.path().join("cas").display());
let original = b"Hello, blocking streaming download!";
let file_info = upload_bytes_blocking(&session, &endpoint, original, "stream.bin").unwrap();
let mut stream = session
.new_download_stream_group()
.unwrap()
.with_endpoint(&endpoint)
.build_blocking()
.unwrap()
.download_stream_blocking(file_info, None)
.unwrap();
let mut collected = Vec::new();
while let Some(chunk) = stream.blocking_next().unwrap() {
collected.extend_from_slice(&chunk);
}
assert_eq!(collected, original);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_download_stream_progress_reports_completion() {
let temp = tempdir().unwrap();
let session = XetSessionBuilder::new().build().unwrap();
let endpoint = format!("local://{}", temp.path().join("cas").display());
let original = b"progress tracking test data for streaming";
let file_info = upload_bytes(&session, &endpoint, original, "progress.bin").await.unwrap();
let mut stream = session
.new_download_stream_group()
.unwrap()
.with_endpoint(&endpoint)
.build()
.await
.unwrap()
.download_stream(file_info, None)
.await
.unwrap();
let initial = stream.progress();
assert_eq!(initial.total_bytes, original.len() as u64);
assert_eq!(initial.bytes_completed, 0);
let mut collected = Vec::new();
while let Some(chunk) = stream.next().await.unwrap() {
collected.extend_from_slice(&chunk);
}
assert_eq!(collected, original);
let final_progress = stream.progress();
assert_eq!(final_progress.total_bytes, original.len() as u64);
assert_eq!(final_progress.bytes_completed, original.len() as u64);
}
#[test]
fn test_download_stream_blocking_progress_reports_completion() {
let temp = tempdir().unwrap();
let session = XetSessionBuilder::new().build().unwrap();
let endpoint = format!("local://{}", temp.path().join("cas").display());
let original = b"blocking progress tracking test data";
let file_info = upload_bytes_blocking(&session, &endpoint, original, "progress.bin").unwrap();
let mut stream = session
.new_download_stream_group()
.unwrap()
.with_endpoint(&endpoint)
.build_blocking()
.unwrap()
.download_stream_blocking(file_info, None)
.unwrap();
let mut collected = Vec::new();
while let Some(chunk) = stream.blocking_next().unwrap() {
collected.extend_from_slice(&chunk);
}
assert_eq!(collected, original);
let final_progress = stream.progress();
assert_eq!(final_progress.total_bytes, original.len() as u64);
assert_eq!(final_progress.bytes_completed, original.len() as u64);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_download_stream_multiple_sequential() {
let temp = tempdir().unwrap();
let session = XetSessionBuilder::new().build().unwrap();
let endpoint = format!("local://{}", temp.path().join("cas").display());
let data_a = b"first stream payload";
let data_b = b"second stream payload";
let info_a = upload_bytes(&session, &endpoint, data_a, "a.bin").await.unwrap();
let info_b = upload_bytes(&session, &endpoint, data_b, "b.bin").await.unwrap();
let group = session
.new_download_stream_group()
.unwrap()
.with_endpoint(&endpoint)
.build()
.await
.unwrap();
let mut stream_a = group.download_stream(info_a, None).await.unwrap();
let mut collected_a = Vec::new();
while let Some(chunk) = stream_a.next().await.unwrap() {
collected_a.extend_from_slice(&chunk);
}
assert_eq!(collected_a, data_a);
let mut stream_b = group.download_stream(info_b, None).await.unwrap();
let mut collected_b = Vec::new();
while let Some(chunk) = stream_b.next().await.unwrap() {
collected_b.extend_from_slice(&chunk);
}
assert_eq!(collected_b, data_b);
}
#[test]
fn test_build_with_same_handle_falls_back_to_owned() {
let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
let handle = tokio_rt.handle().clone();
let first = XetSessionBuilder::new().with_tokio_handle(handle.clone()).build().unwrap();
assert_eq!(first.inner.runtime.mode(), RuntimeMode::External, "first build must use External runtime");
let second = XetSessionBuilder::new().with_tokio_handle(handle).build();
assert!(second.is_ok(), "second build with the same tokio handle must still succeed");
assert_eq!(
second.unwrap().inner.runtime.mode(),
RuntimeMode::Owned,
"second build must fall back to Owned runtime when External handle is already in use"
);
}
#[test]
fn test_build_with_same_handle_succeeds_after_first_is_dropped() {
let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
let handle = tokio_rt.handle().clone();
let first = XetSessionBuilder::new().with_tokio_handle(handle.clone()).build().unwrap();
drop(first);
let second = XetSessionBuilder::new().with_tokio_handle(handle).build();
assert!(second.is_ok(), "build must succeed after the previous session holding the same handle is dropped");
}
}