use std::sync::atomic::AtomicU64;
use std::sync::{Arc, Mutex};
use reqwest::Client;
use tokio::sync::Semaphore;
use crate::config::XetConfig;
use crate::utils::adjustable_semaphore::AdjustableSemaphore;
#[derive(Debug)]
pub struct XetCommon {
global_reqwest_client: Mutex<Option<(String, Client)>>,
pub file_ingestion_semaphore: Arc<Semaphore>,
pub file_download_semaphore: Arc<Semaphore>,
pub reconstruction_download_buffer: Arc<AdjustableSemaphore>,
pub active_downloads: Arc<AtomicU64>,
}
impl XetCommon {
pub fn new(config: &XetConfig) -> Self {
Self {
global_reqwest_client: Mutex::new(None),
file_ingestion_semaphore: Arc::new(Semaphore::new(config.data.max_concurrent_file_ingestion)),
file_download_semaphore: Arc::new(Semaphore::new(config.data.max_concurrent_file_downloads)),
reconstruction_download_buffer: {
let base = config.reconstruction.download_buffer_size.as_u64();
let limit = config.reconstruction.download_buffer_limit.as_u64();
AdjustableSemaphore::new(base, (base, limit))
},
active_downloads: Arc::new(AtomicU64::new(0)),
}
}
pub fn get_or_create_reqwest_client<F>(&self, tag: String, create_client_fn: F) -> crate::error::Result<Client>
where
F: FnOnce() -> std::result::Result<Client, reqwest::Error>,
{
let mut guard = self.global_reqwest_client.lock()?;
match guard.as_ref() {
Some((cached_tag, cached_client)) if cached_tag == &tag => {
Ok(cached_client.clone())
},
_ => {
let new_client = create_client_fn()?;
*guard = Some((tag, new_client.clone()));
Ok(new_client)
},
}
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use super::*;
#[test]
fn test_get_or_create_reqwest_client_caches_by_tag() {
let common = XetCommon::new(&XetConfig::new());
let call_count = AtomicUsize::new(0);
let _client1 = common
.get_or_create_reqwest_client("test-tag".to_string(), || {
call_count.fetch_add(1, Ordering::SeqCst);
reqwest::Client::builder().build()
})
.unwrap();
let _client2 = common
.get_or_create_reqwest_client("test-tag".to_string(), || {
call_count.fetch_add(1, Ordering::SeqCst);
reqwest::Client::builder().build()
})
.unwrap();
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[test]
fn test_get_or_create_reqwest_client_creates_new_for_different_tag() {
let common = XetCommon::new(&XetConfig::new());
let call_count = AtomicUsize::new(0);
let _client1 = common
.get_or_create_reqwest_client("tag1".to_string(), || {
call_count.fetch_add(1, Ordering::SeqCst);
reqwest::Client::builder().user_agent("client1").build()
})
.unwrap();
let _client2 = common
.get_or_create_reqwest_client("tag2".to_string(), || {
call_count.fetch_add(1, Ordering::SeqCst);
reqwest::Client::builder().user_agent("client2").build()
})
.unwrap();
assert_eq!(call_count.load(Ordering::SeqCst), 2);
}
#[test]
fn test_initializes_with_empty_client_cache() {
let common = XetCommon::new(&XetConfig::new());
let guard = common.global_reqwest_client.lock().unwrap();
assert!(guard.is_none());
}
#[test]
fn test_replaces_client_when_tag_changes() {
let common = XetCommon::new(&XetConfig::new());
let _client1 = common
.get_or_create_reqwest_client("tcp".to_string(), || {
reqwest::Client::builder().user_agent("tcp-client").build()
})
.unwrap();
{
let guard = common.global_reqwest_client.lock().unwrap();
let (tag, _) = guard.as_ref().unwrap();
assert_eq!(tag, "tcp");
}
let _client2 = common
.get_or_create_reqwest_client("/tmp/socket.sock".to_string(), || {
reqwest::Client::builder().user_agent("uds-client").build()
})
.unwrap();
{
let guard = common.global_reqwest_client.lock().unwrap();
let (tag, _) = guard.as_ref().unwrap();
assert_eq!(tag, "/tmp/socket.sock");
}
}
#[test]
fn test_semaphores_initialized_from_config() {
let config = XetConfig::new();
let common = XetCommon::new(&config);
assert_eq!(common.file_ingestion_semaphore.available_permits(), config.data.max_concurrent_file_ingestion);
assert_eq!(common.file_download_semaphore.available_permits(), config.data.max_concurrent_file_downloads);
assert!(
common.reconstruction_download_buffer.total_permits()
>= config.reconstruction.download_buffer_size.as_u64()
);
assert_eq!(common.active_downloads.load(Ordering::Relaxed), 0);
}
}