use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
use tracing::{debug, info};
use xet_client::cas_client::{Client, URLProvider};
use xet_client::cas_types::{FileRange, HttpRange};
use xet_core_structures::merklehash::MerkleHash;
use xet_runtime::utils::UniqueId;
use super::super::FileReconstructionError;
use super::super::error::Result;
use super::file_term::retrieve_file_term_block;
pub struct TermBlockRetrievalURLs {
pub file_hash: MerkleHash,
pub byte_range: FileRange,
#[allow(clippy::type_complexity)]
pub(crate) xorb_block_retrieval_urls: RwLock<(UniqueId, Vec<(String, Vec<HttpRange>)>)>,
}
impl TermBlockRetrievalURLs {
pub fn new(
file_hash: MerkleHash,
byte_range: FileRange,
acquisition_id: UniqueId,
retrieval_urls: Vec<(String, Vec<HttpRange>)>,
) -> Self {
Self {
file_hash,
byte_range,
xorb_block_retrieval_urls: RwLock::new((acquisition_id, retrieval_urls)),
}
}
pub async fn get_retrieval_url(&self, xorb_block_index: usize) -> (UniqueId, String, Vec<HttpRange>) {
let xbru = self.xorb_block_retrieval_urls.read().await;
let (url, url_ranges) = &xbru.1[xorb_block_index];
(xbru.0, url.clone(), url_ranges.clone())
}
pub async fn refresh_retrieval_urls(&self, client: Arc<dyn Client>, acquisition_id: UniqueId) -> Result<()> {
if self.xorb_block_retrieval_urls.read().await.0 != acquisition_id {
debug!(
file_hash = %self.file_hash,
byte_range = ?(self.byte_range.start, self.byte_range.end),
"URL refresh skipped - already refreshed by another request"
);
return Ok(());
}
let mut retrieval_urls = self.xorb_block_retrieval_urls.write().await;
if retrieval_urls.0 != acquisition_id {
debug!(
file_hash = %self.file_hash,
byte_range = ?(self.byte_range.start, self.byte_range.end),
"URL refresh skipped - already refreshed while waiting for lock"
);
return Ok(());
}
info!(
file_hash = %self.file_hash,
byte_range = ?(self.byte_range.start, self.byte_range.end),
url_count = retrieval_urls.1.len(),
"Refreshing expired retrieval URLs"
);
let Some((returned_range, _transfer_bytes, file_terms)) =
retrieve_file_term_block(client, self.file_hash, self.byte_range).await?
else {
return Err(FileReconstructionError::CorruptedReconstruction(
"On URL refresh, the returned reconstruction was None.".to_owned(),
));
};
if returned_range != self.byte_range {
return Err(FileReconstructionError::CorruptedReconstruction(
"On URL refresh, the returned reconstruction range differs from expected.".to_owned(),
));
}
let Some(first_term) = file_terms.first() else {
return Err(FileReconstructionError::CorruptedReconstruction(
"On URL refresh, the returned reconstruction had no terms.".to_owned(),
));
};
{
let mut new_retrieval_urls = first_term.url_info.xorb_block_retrieval_urls.write().await;
retrieval_urls.0 = new_retrieval_urls.0;
retrieval_urls.1 = std::mem::take(&mut new_retrieval_urls.1);
}
info!(
file_hash = %self.file_hash,
byte_range = ?(self.byte_range.start, self.byte_range.end),
"Retrieval URLs refreshed successfully"
);
Ok(())
}
}
pub struct XorbURLProvider {
pub client: Arc<dyn Client>,
pub url_info: Arc<TermBlockRetrievalURLs>,
pub xorb_block_index: usize,
pub last_acquisition_id: Mutex<UniqueId>,
}
#[async_trait::async_trait]
impl URLProvider for XorbURLProvider {
async fn retrieve_url(&self) -> std::result::Result<(String, Vec<HttpRange>), xet_client::ClientError> {
let (unique_id, url, http_ranges) = self.url_info.get_retrieval_url(self.xorb_block_index).await;
*self.last_acquisition_id.lock().await = unique_id;
Ok((url, http_ranges))
}
async fn refresh_url(&self) -> std::result::Result<(), xet_client::ClientError> {
self.url_info
.refresh_retrieval_urls(self.client.clone(), *self.last_acquisition_id.lock().await)
.await
.map_err(|e| xet_client::ClientError::Other(e.to_string()))
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use tokio::sync::Mutex;
use xet_client::cas_client::{ClientTestingUtils, LocalClient, URLProvider};
use xet_client::cas_types::{FileRange, HttpRange};
use xet_core_structures::merklehash::MerkleHash;
use xet_runtime::utils::UniqueId;
use super::{TermBlockRetrievalURLs, XorbURLProvider};
fn sample_urls(n: usize) -> Vec<(String, Vec<HttpRange>)> {
(0..n)
.map(|i| (format!("https://example.com/xorb_{i}"), vec![HttpRange::new(0, 100)]))
.collect()
}
#[tokio::test]
async fn test_new_and_get_retrieval_url() {
let id = UniqueId::new();
let urls = sample_urls(3);
let block = TermBlockRetrievalURLs::new(MerkleHash::default(), FileRange::new(0, 100), id, urls.clone());
for (i, expected) in urls.iter().enumerate() {
let (ret_id, url, ranges) = block.get_retrieval_url(i).await;
assert!(ret_id == id, "acquisition ID mismatch for block {i}");
assert_eq!(url, expected.0);
assert_eq!(ranges, expected.1);
}
}
#[tokio::test]
async fn test_refresh_skipped_when_already_refreshed() {
let (client, file_contents) = {
let c = LocalClient::temporary().await.unwrap();
let fc = c.upload_random_file(&[(1, (0, 3))], 64).await.unwrap();
(c, fc)
};
let file_range = FileRange::new(0, file_contents.data.len() as u64);
let dyn_client: Arc<dyn xet_client::cas_client::Client> = client.clone();
let (_, _, file_terms) =
super::retrieve_file_term_block(dyn_client.clone(), file_contents.file_hash, file_range)
.await
.unwrap()
.unwrap();
let url_info = file_terms[0].url_info.clone();
let (original_id, _, _) = url_info.get_retrieval_url(0).await;
let stale_id = UniqueId::new();
url_info.refresh_retrieval_urls(dyn_client.clone(), stale_id).await.unwrap();
let (id_after, _, _) = url_info.get_retrieval_url(0).await;
assert!(id_after == original_id, "refresh with stale ID should not change acquisition ID");
url_info.refresh_retrieval_urls(dyn_client.clone(), original_id).await.unwrap();
let (refreshed_id, _, _) = url_info.get_retrieval_url(0).await;
assert!(refreshed_id != original_id, "refresh with correct ID should change acquisition ID");
}
#[tokio::test]
async fn test_xorb_url_provider_retrieve_and_refresh() {
let (client, file_contents) = {
let c = LocalClient::temporary().await.unwrap();
let fc = c.upload_random_file(&[(1, (0, 3))], 64).await.unwrap();
(c, fc)
};
let file_range = FileRange::new(0, file_contents.data.len() as u64);
let dyn_client: Arc<dyn xet_client::cas_client::Client> = client.clone();
let (_, _, file_terms) =
super::retrieve_file_term_block(dyn_client.clone(), file_contents.file_hash, file_range)
.await
.unwrap()
.unwrap();
let url_info = file_terms[0].url_info.clone();
let provider = XorbURLProvider {
client: dyn_client.clone(),
url_info,
xorb_block_index: 0,
last_acquisition_id: Mutex::new(UniqueId::null()),
};
let (url, ranges) = provider.retrieve_url().await.unwrap();
assert!(!url.is_empty());
assert!(!ranges.is_empty());
provider.refresh_url().await.unwrap();
let (url2, ranges2) = provider.retrieve_url().await.unwrap();
assert!(!url2.is_empty());
assert!(!ranges2.is_empty());
}
}