use std::sync::Arc;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::time::Duration;
use more_asserts::debug_assert_le;
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
#[cfg(not(target_family = "wasm"))]
use tokio::time::Instant;
use tracing::info;
#[cfg(target_family = "wasm")]
use web_time::Instant;
use xet_core_structures::ExpWeightedMovingAvg;
use xet_runtime::core::xet_config;
use xet_runtime::utils::adjustable_semaphore::{AdjustableSemaphore, AdjustableSemaphorePermit};
use super::super::progress_tracked_streams::ProgressCallback;
use super::rtt_prediction::RTTPredictor;
use crate::error::Result;
const MIN_PARTIAL_REPORT_INTERVAL_MS: u64 = 200;
const PARTIAL_REPORT_WEIGHT_RATIO: f64 = 0.2;
const REFERENCE_SIZE_QUANTILE_Z: f64 = 1.645; const MIN_SIZE_OBSERVATIONS_FOR_REFERENCE: u64 = 3;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CCSuccessModelState {
pub success_ratio: f64,
pub success_ratio_thresholds: (f64, f64),
pub recommended_adjustment: i8,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CCLatencyModelState {
pub predicted_max_rtt: f64,
pub prediction_max_rtt_standard_error: f64,
pub predicted_bandwidth: f64,
}
struct ConcurrencyControllerState {
rtt_predictor: RTTPredictor,
success_ratio_tracking: ExpWeightedMovingAvg,
last_adjustment_time: Instant,
last_logging_time: Instant,
bytes_sent_so_far: u64,
completed_transmissions_count: u64,
size_log_tracker: ExpWeightedMovingAvg,
size_log_sq_tracker: ExpWeightedMovingAvg,
size_observation_count: u64,
}
impl ConcurrencyControllerState {
fn new() -> Self {
let config = xet_config();
let rtt_half_life_count = config.client.ac_latency_rtt_half_life;
let success_half_life_count = config.client.ac_success_tracking_half_life;
Self {
rtt_predictor: RTTPredictor::new(rtt_half_life_count),
success_ratio_tracking: ExpWeightedMovingAvg::new_count_decay(success_half_life_count),
last_adjustment_time: Instant::now(),
last_logging_time: Instant::now(),
bytes_sent_so_far: 0,
completed_transmissions_count: 0,
size_log_tracker: ExpWeightedMovingAvg::new_count_decay(rtt_half_life_count),
size_log_sq_tracker: ExpWeightedMovingAvg::new_count_decay(rtt_half_life_count),
size_observation_count: 0,
}
}
fn success_ratio_thresholds(&self) -> (f64, f64) {
let config = xet_config();
let increase_threshold = config.client.ac_healthy_success_ratio_threshold;
let decrease_threshold = config.client.ac_unhealthy_success_ratio_threshold;
(increase_threshold, decrease_threshold)
}
fn update_success(&mut self, is_success: bool, weight: f64) {
let value = if is_success { 1.0 } else { 0.0 };
self.success_ratio_tracking.update_with_weight(value, weight);
}
#[inline]
fn success_model_state(&self) -> CCSuccessModelState {
let success_ratio_thresholds = self.success_ratio_thresholds();
let success_ratio = self.success_ratio_tracking.value().clamp(0.0, 1.0);
let recommended_adjustment = if success_ratio > success_ratio_thresholds.0 {
1 } else if success_ratio < success_ratio_thresholds.1 {
-1 } else {
0 };
CCSuccessModelState {
success_ratio,
success_ratio_thresholds,
recommended_adjustment,
}
}
#[inline]
fn latency_model_state(&self, current_concurrency: f64) -> CCLatencyModelState {
let config = xet_config();
let (predicted_max_rtt, prediction_max_rtt_standard_error) = self
.rtt_predictor
.predict(*config.client.ac_max_reference_transmission_size, current_concurrency);
let predicted_bandwidth = self.rtt_predictor.predicted_bandwidth();
CCLatencyModelState {
predicted_max_rtt: predicted_max_rtt.unwrap_or(0.),
prediction_max_rtt_standard_error: prediction_max_rtt_standard_error.unwrap_or(0.),
predicted_bandwidth: predicted_bandwidth.unwrap_or(0.),
}
}
fn estimated_reference_transmission_size(&self) -> Option<u64> {
if self.size_observation_count < MIN_SIZE_OBSERVATIONS_FOR_REFERENCE {
return None;
}
let mu = self.size_log_tracker.value();
let mu_sq = self.size_log_sq_tracker.value();
let variance = (mu_sq - mu * mu).max(0.0);
let sigma = variance.sqrt();
let quantile_95 = (mu + REFERENCE_SIZE_QUANTILE_Z * sigma).exp();
let config = xet_config();
let min_size = *config.client.ac_min_reference_transmission_size;
let max_size = *config.client.ac_max_reference_transmission_size;
Some((quantile_95 as u64).clamp(min_size, max_size))
}
fn update_size_tracking(&mut self, n_bytes: u64) {
if n_bytes == 0 {
return;
}
let log_size = (n_bytes as f64).ln();
self.size_log_tracker.update(log_size);
self.size_log_sq_tracker.update(log_size * log_size);
self.size_observation_count += 1;
}
}
pub struct AdaptiveConcurrencyController {
state: Mutex<ConcurrencyControllerState>,
concurrency_semaphore: Arc<AdjustableSemaphore>,
min_concurrency_increase_delay: Duration,
min_concurrency_decrease_delay: Duration,
logging_tag: &'static str,
adjustment_disabled: bool,
min_bytes_required_for_adjustment: u64,
min_completed_transmissions_required_for_adjustment: u64,
}
impl AdaptiveConcurrencyController {
pub fn new(
logging_tag: &'static str,
concurrency: usize,
concurrency_bounds: (usize, usize),
min_bytes_required_for_adjustment: u64,
min_completed_transmissions_required_for_adjustment: u64,
) -> Arc<Self> {
let min_concurrency = concurrency_bounds.0.max(1);
let max_concurrency = concurrency_bounds.1.max(min_concurrency);
let current_concurrency = concurrency.clamp(min_concurrency, max_concurrency);
info!(
"Initializing Adaptive Concurrency Controller for {logging_tag} with starting concurrency = {current_concurrency}; min = {min_concurrency}, max = {max_concurrency}, min_bytes_for_adjustment = {min_bytes_required_for_adjustment}, min_completed_transmissions_for_adjustment = {min_completed_transmissions_required_for_adjustment}"
);
let config = xet_config();
Arc::new(Self {
state: Mutex::new(ConcurrencyControllerState::new()),
concurrency_semaphore: AdjustableSemaphore::new(
current_concurrency as u64,
(min_concurrency as u64, max_concurrency as u64),
),
min_concurrency_increase_delay: Duration::from_millis(config.client.ac_min_adjustment_window_ms),
min_concurrency_decrease_delay: Duration::from_millis(config.client.ac_min_adjustment_window_ms),
adjustment_disabled: false,
logging_tag,
min_bytes_required_for_adjustment,
min_completed_transmissions_required_for_adjustment,
})
}
pub fn new_fixed(logging_tag: &'static str, concurrency: usize) -> Arc<Self> {
info!("Fixing maximum concurrency for {logging_tag} at {concurrency}; adaptive concurrency disabled.");
Arc::new(Self {
state: Mutex::new(ConcurrencyControllerState::new()),
concurrency_semaphore: AdjustableSemaphore::new(
concurrency as u64,
(concurrency as u64, concurrency as u64),
),
adjustment_disabled: true,
min_concurrency_increase_delay: Default::default(),
min_concurrency_decrease_delay: Default::default(),
logging_tag,
min_bytes_required_for_adjustment: Default::default(),
min_completed_transmissions_required_for_adjustment: Default::default(),
})
}
pub fn new_upload(logging_tag: &'static str) -> Arc<Self> {
let config = xet_config();
Self::new(
logging_tag,
config.client.ac_initial_upload_concurrency,
(config.client.ac_min_upload_concurrency, config.client.ac_max_upload_concurrency),
config.client.ac_min_bytes_required_for_adjustment.into(),
config.client.ac_num_transmissions_required_for_adjustment,
)
}
pub fn new_download(logging_tag: &'static str) -> Arc<Self> {
let config = xet_config();
Self::new(
logging_tag,
config.client.ac_initial_download_concurrency,
(config.client.ac_min_download_concurrency, config.client.ac_max_download_concurrency),
config.client.ac_min_bytes_required_for_adjustment.into(),
config.client.ac_num_transmissions_required_for_adjustment,
)
}
pub async fn acquire_connection_permit(self: &Arc<Self>) -> Result<ConnectionPermit> {
let _permit = self.concurrency_semaphore.acquire().await?;
let info = Arc::new(ConnectionPermitInfo {
controller: Arc::clone(self),
transfer_start_time: Mutex::new(Instant::now()),
starting_concurrency: self.concurrency_semaphore.active_permits() as usize,
rtt_model_at_start: Some(self.state.lock().await.rtt_predictor.clone()),
report_portion: AtomicU32::new(0),
last_partial_report_ms: AtomicU64::new(0),
max_bytes_reported: AtomicU64::new(0),
});
Ok(ConnectionPermit { _permit, info })
}
pub fn total_permits(&self) -> usize {
self.concurrency_semaphore.total_permits() as usize
}
pub fn available_permits(&self) -> usize {
self.concurrency_semaphore.available_permits() as usize
}
pub fn active_permits(&self) -> usize {
self.concurrency_semaphore.active_permits() as usize
}
pub async fn success_model_state(&self) -> CCSuccessModelState {
self.state.lock().await.success_model_state()
}
pub async fn latency_model_state(&self) -> CCLatencyModelState {
self.state
.lock()
.await
.latency_model_state(self.concurrency_semaphore.active_permits() as f64)
}
async fn report_and_update(
&self,
permit_info: &ConnectionPermitInfo,
n_bytes_if_known: Option<u64>,
transmission_successful: bool,
partial_update: bool,
weight: f64,
) {
if self.adjustment_disabled {
return;
}
let transfer_start_time = *permit_info.transfer_start_time.lock().await;
let elapsed_time = transfer_start_time.elapsed();
let t_actual = elapsed_time.as_secs_f64().max(1e-4);
let config = xet_config();
let completed_in_time = elapsed_time < config.client.ac_max_healthy_rtt;
let mut state_lg = self.state.lock().await;
if let Some(n_bytes) = n_bytes_if_known {
let previous = permit_info.max_bytes_reported.fetch_max(n_bytes, Ordering::AcqRel);
state_lg.bytes_sent_so_far += n_bytes.saturating_sub(previous);
}
if !partial_update {
state_lg.completed_transmissions_count += 1;
}
let cur_concurrency = self.concurrency_semaphore.active_permits() as f64;
let avg_concurrency = (cur_concurrency + permit_info.starting_concurrency as f64) / 2.;
let track_as_success = transmission_successful && completed_in_time && {
if let Some(n_bytes) = n_bytes_if_known {
let quantile = permit_info
.rtt_model_at_start
.as_ref()
.map(|lm| lm.rtt_quantile(t_actual, n_bytes, avg_concurrency))
.unwrap_or(0.5);
quantile < config.client.ac_rtt_success_max_quantile
} else {
false
}
};
if track_as_success {
state_lg.update_success(true, weight);
} else {
state_lg.update_success(false, weight);
}
let model_state = state_lg.success_model_state();
if transmission_successful && let Some(n_bytes) = n_bytes_if_known {
state_lg.rtt_predictor.update(n_bytes, elapsed_time, avg_concurrency, weight);
}
if !partial_update && let Some(n_bytes) = n_bytes_if_known {
state_lg.update_size_tracking(n_bytes);
}
let reference_size = state_lg
.estimated_reference_transmission_size()
.unwrap_or(*config.client.ac_max_reference_transmission_size);
let target_rtt_secs = config.client.ac_target_rtt.as_secs_f64();
if model_state.recommended_adjustment == 1
&& state_lg.bytes_sent_so_far >= self.min_bytes_required_for_adjustment
&& state_lg.completed_transmissions_count >= self.min_completed_transmissions_required_for_adjustment
&& state_lg.last_adjustment_time.elapsed() > self.min_concurrency_increase_delay
{
let old_concurrency = self.concurrency_semaphore.total_permits();
let new_concurrency = 1. + old_concurrency as f64;
let predicted_rtt = state_lg
.rtt_predictor
.predicted_rtt(reference_size, new_concurrency)
.unwrap_or(f64::INFINITY);
if predicted_rtt < target_rtt_secs {
self.concurrency_semaphore.increment_total_permits(1);
let new_concurrency_actual = self.concurrency_semaphore.total_permits();
state_lg.last_adjustment_time = Instant::now();
info!(
"Concurrency control for {}: Increased concurrency from {} to {}; reason: success ratio {:.3} is above threshold {:.3} and predicted RTT for {}MB at new concurrency is {:.2}s < target {:.1}s",
self.logging_tag,
old_concurrency,
new_concurrency_actual,
model_state.success_ratio,
model_state.success_ratio_thresholds.0,
reference_size / (1024 * 1024),
predicted_rtt,
target_rtt_secs
);
}
}
if state_lg.bytes_sent_so_far >= self.min_bytes_required_for_adjustment
&& state_lg.completed_transmissions_count >= self.min_completed_transmissions_required_for_adjustment
&& (!transmission_successful
|| !completed_in_time
|| (!partial_update && model_state.recommended_adjustment == -1))
{
if state_lg.last_adjustment_time.elapsed() > self.min_concurrency_decrease_delay {
let old_concurrency = self.concurrency_semaphore.total_permits();
let _ = self.concurrency_semaphore.decrement_total_permits(1);
let new_concurrency = self.concurrency_semaphore.total_permits();
state_lg.last_adjustment_time = Instant::now();
let reason = if !transmission_successful {
"transfer failed"
} else {
"success ratio below threshold (connection struggling)"
};
info!(
"Concurrency control for {}: Decreased concurrency from {} to {}; reason: {} (success_ratio = {:.3}, threshold = {:.3})",
self.logging_tag,
old_concurrency,
new_concurrency,
reason,
model_state.success_ratio,
model_state.success_ratio_thresholds.1
);
}
}
if state_lg.last_logging_time.elapsed() > Duration::from_millis(config.client.ac_logging_interval_ms) {
state_lg.last_logging_time = Instant::now();
let latency_state = state_lg.latency_model_state(self.concurrency_semaphore.active_permits() as f64);
let ref_size_mb = reference_size as f64 / (1024.0 * 1024.0);
info!(
"Concurrency control for {}: Current concurrency = {}; predicted bandwidth = {:.0}; success_ratio = {:.3}; reference_size = {:.1}MB; observed bytes sent so far = {}; completed transmissions = {}",
self.logging_tag,
self.concurrency_semaphore.total_permits(),
latency_state.predicted_bandwidth,
model_state.success_ratio,
ref_size_mb,
state_lg.bytes_sent_so_far,
state_lg.completed_transmissions_count
);
}
}
}
pub struct ConnectionPermitInfo {
controller: Arc<AdaptiveConcurrencyController>,
transfer_start_time: Mutex<Instant>,
starting_concurrency: usize,
rtt_model_at_start: Option<RTTPredictor>,
report_portion: AtomicU32,
last_partial_report_ms: AtomicU64,
max_bytes_reported: AtomicU64,
}
pub struct ConnectionPermit {
_permit: AdjustableSemaphorePermit,
info: Arc<ConnectionPermitInfo>,
}
impl ConnectionPermit {
pub(crate) async fn transfer_starting(&self) {
*self.info.transfer_start_time.lock().await = Instant::now();
}
pub fn get_partial_completion_reporting_function(&self) -> ProgressCallback {
let info = Arc::clone(&self.info);
Arc::new(move |_delta: u64, completed: u64, total: u64| {
let info = Arc::clone(&info);
static REFERENCE_INSTANT: std::sync::OnceLock<Instant> = std::sync::OnceLock::new();
let now_ms = REFERENCE_INSTANT.get_or_init(Instant::now).elapsed().as_millis() as u64;
if info
.last_partial_report_ms
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |last_ms| {
if now_ms.saturating_sub(last_ms) >= MIN_PARTIAL_REPORT_INTERVAL_MS {
Some(now_ms)
} else {
None
}
})
.is_err()
{
return;
}
let portion_completed = if total > 0 {
(completed as f64 / total as f64).min(1.0)
} else {
0.0
};
tokio::spawn(async move {
let report_portion = PARTIAL_REPORT_WEIGHT_RATIO * portion_completed.clamp(0.0, 1.0);
let portion_scaled = (report_portion * u32::MAX as f64) as u32;
let previous_portion_scaled = info.report_portion.fetch_max(portion_scaled, Ordering::AcqRel);
if previous_portion_scaled >= portion_scaled {
return;
}
let weight = (portion_scaled - previous_portion_scaled) as f64 / u32::MAX as f64;
info.controller
.clone()
.report_and_update(&info, Some(completed), true, true, weight)
.await;
});
})
}
pub(crate) async fn report_completion(self, n_bytes: u64, success: bool) {
let reported_portion_scaled = self.info.report_portion.fetch_max(u32::MAX, Ordering::AcqRel);
let reported_portion = reported_portion_scaled as f64 / u32::MAX as f64;
debug_assert_le!(reported_portion, PARTIAL_REPORT_WEIGHT_RATIO);
let remaining_weight = (1.0 - reported_portion).clamp(0.0, 1.0);
self.info
.controller
.clone()
.report_and_update(&self.info, Some(n_bytes), success, false, remaining_weight)
.await;
}
pub(crate) async fn report_retryable_failure(&self) {
self.info
.controller
.clone()
.report_and_update(&self.info, None, false, false, 1.0)
.await;
}
}
#[cfg(test)]
mod test_constants {
pub const TR_HALF_LIFE_COUNT: f64 = 10.0;
pub const INCR_SPACING_MS: u64 = 200;
pub const DECR_SPACING_MS: u64 = 100;
pub const TARGET_TIME_MS_L: u64 = 20;
pub const LARGE_N_BYTES: u64 = 10000;
}
#[cfg(test)]
impl ConcurrencyControllerState {
#[cfg(test)]
fn new_testing() -> Self {
use self::test_constants::TR_HALF_LIFE_COUNT;
Self {
rtt_predictor: RTTPredictor::new(TR_HALF_LIFE_COUNT),
success_ratio_tracking: ExpWeightedMovingAvg::new_count_decay(TR_HALF_LIFE_COUNT),
last_adjustment_time: Instant::now(),
last_logging_time: Instant::now(),
bytes_sent_so_far: 0,
completed_transmissions_count: 0,
size_log_tracker: ExpWeightedMovingAvg::new_count_decay(TR_HALF_LIFE_COUNT),
size_log_sq_tracker: ExpWeightedMovingAvg::new_count_decay(TR_HALF_LIFE_COUNT),
size_observation_count: 0,
}
}
}
#[cfg(test)]
impl AdaptiveConcurrencyController {
pub fn new_testing(concurrency: usize, concurrency_bounds: (usize, usize)) -> Arc<Self> {
Arc::new(Self {
state: Mutex::new(ConcurrencyControllerState::new_testing()),
concurrency_semaphore: AdjustableSemaphore::new(
concurrency as u64,
(concurrency_bounds.0 as u64, concurrency_bounds.1 as u64),
),
min_concurrency_increase_delay: Duration::from_millis(test_constants::INCR_SPACING_MS),
min_concurrency_decrease_delay: Duration::from_millis(test_constants::DECR_SPACING_MS),
adjustment_disabled: false,
logging_tag: "testing",
min_bytes_required_for_adjustment: 0,
min_completed_transmissions_required_for_adjustment: 0,
})
}
}
#[cfg(test)]
mod tests {
use tokio::time::{self, Duration, advance};
use super::test_constants::*;
use super::*;
pub const TEST_TRANSFER_SIZE: u64 = 10 * 1024 * 1024;
#[tokio::test]
async fn test_permit_increase_to_max_on_repeated_success() {
time::pause();
let controller = AdaptiveConcurrencyController::new_testing(1, (1, 4));
for _ in 0..20 {
let permit = controller.acquire_connection_permit().await.unwrap();
let duration_ms = 2000;
advance(Duration::from_millis(duration_ms)).await;
permit.report_completion(TEST_TRANSFER_SIZE, true).await;
advance(Duration::from_millis(INCR_SPACING_MS + 1)).await;
}
assert!(controller.total_permits() >= 1);
}
#[tokio::test]
async fn test_permit_increase_to_max_slowly() {
time::pause();
let controller = AdaptiveConcurrencyController::new_testing(1, (1, 50));
advance(Duration::from_millis(INCR_SPACING_MS + 1)).await;
for i in 0..10 {
let permit = controller.acquire_connection_permit().await.unwrap();
let duration_ms = 2000;
advance(Duration::from_millis(duration_ms)).await;
permit.report_completion(TEST_TRANSFER_SIZE, true).await;
if i < 5 {
advance(Duration::from_millis(INCR_SPACING_MS + 1)).await;
}
}
assert!(controller.total_permits() >= 1);
}
#[tokio::test]
async fn test_permit_increase_on_slow_but_good_enough() {
time::pause();
let controller = AdaptiveConcurrencyController::new_testing(5, (5, 10));
for _ in 0..5 {
let permit = controller.acquire_connection_permit().await.unwrap();
advance(Duration::from_millis(TARGET_TIME_MS_L - 1)).await;
permit.report_completion(LARGE_N_BYTES, true).await;
advance(Duration::from_millis(INCR_SPACING_MS)).await;
}
}
#[tokio::test]
async fn test_permit_decrease_on_explicit_failure() {
time::pause();
let controller = AdaptiveConcurrencyController::new_testing(10, (5, 10));
for i in 1..=5 {
let permit = controller.acquire_connection_permit().await.unwrap();
advance(Duration::from_millis(DECR_SPACING_MS + 1)).await;
permit.report_completion(LARGE_N_BYTES, false).await;
assert_eq!(controller.available_permits(), 10 - i);
}
assert_eq!(controller.available_permits(), 5);
}
#[tokio::test]
async fn test_retryable_failures_count_against_success() {
time::pause();
let controller = AdaptiveConcurrencyController::new_testing(4, (1, 4));
let permit = controller.acquire_connection_permit().await.unwrap();
advance(Duration::from_millis(DECR_SPACING_MS + 1)).await;
permit.report_retryable_failure().await;
assert_eq!(controller.total_permits(), 3);
assert_eq!(controller.available_permits(), 2);
permit.report_retryable_failure().await;
assert_eq!(controller.total_permits(), 3);
assert_eq!(controller.available_permits(), 2);
let permit_1 = controller.acquire_connection_permit().await.unwrap();
let _permit_2 = controller.acquire_connection_permit().await.unwrap();
assert_eq!(controller.total_permits(), 3);
assert_eq!(controller.available_permits(), 0);
advance(Duration::from_millis(DECR_SPACING_MS + 1)).await;
permit_1.report_retryable_failure().await;
assert_eq!(controller.total_permits(), 2);
assert_eq!(controller.available_permits(), 0);
permit.report_completion(0, true).await;
assert_eq!(controller.total_permits(), 2);
assert_eq!(controller.available_permits(), 0);
permit_1.report_completion(0, true).await;
assert_eq!(controller.total_permits(), 2);
assert_eq!(controller.available_permits(), 1);
}
#[tokio::test]
async fn test_partial_completion_weighting() {
time::pause();
let controller = AdaptiveConcurrencyController::new_testing(1, (1, 4));
let permit = controller.acquire_connection_permit().await.unwrap();
let report = permit.get_partial_completion_reporting_function();
report(200, 200, 1000); advance(Duration::from_millis(10)).await;
report(300, 500, 1000); advance(Duration::from_millis(10)).await;
report(300, 800, 1000);
advance(Duration::from_millis(10)).await;
permit.report_completion(1000, true).await;
let latency_state = controller.latency_model_state().await;
assert!(latency_state.predicted_bandwidth >= 0.0);
}
#[tokio::test]
async fn test_partial_completion_max_weight_cap() {
time::pause();
let controller = AdaptiveConcurrencyController::new_testing(1, (1, 4));
let permit = controller.acquire_connection_permit().await.unwrap();
let report = permit.get_partial_completion_reporting_function();
for i in 1..=20 {
let completed = i * 50;
let total = 1000u64;
let delta = 50;
report(delta, completed, total);
advance(Duration::from_millis(1)).await;
}
advance(Duration::from_millis(10)).await;
permit.report_completion(1000, true).await;
let latency_state = controller.latency_model_state().await;
assert!(latency_state.predicted_bandwidth >= 0.0);
}
#[test]
fn test_reference_size_returns_none_with_insufficient_data() {
let state = ConcurrencyControllerState::new_testing();
assert!(state.estimated_reference_transmission_size().is_none());
}
#[test]
fn test_reference_size_with_uniform_sizes() {
let mut state = ConcurrencyControllerState::new_testing();
let size: u64 = 10 * 1024 * 1024; for _ in 0..10 {
state.update_size_tracking(size);
}
let ref_size = state.estimated_reference_transmission_size().unwrap();
let config = xet_config();
debug_assert!(ref_size >= *config.client.ac_min_reference_transmission_size);
debug_assert_le!(ref_size, *config.client.ac_max_reference_transmission_size);
assert!((5 * 1024 * 1024..=12 * 1024 * 1024).contains(&ref_size));
}
#[test]
fn test_reference_size_bounded_by_minimum() {
let mut state = ConcurrencyControllerState::new_testing();
let size: u64 = 1024; for _ in 0..10 {
state.update_size_tracking(size);
}
let config = xet_config();
let ref_size = state.estimated_reference_transmission_size().unwrap();
assert_eq!(ref_size, *config.client.ac_min_reference_transmission_size);
}
#[test]
fn test_reference_size_bounded_by_config_maximum() {
let mut state = ConcurrencyControllerState::new_testing();
let size: u64 = 200 * 1024 * 1024; for _ in 0..10 {
state.update_size_tracking(size);
}
let ref_size = state.estimated_reference_transmission_size().unwrap();
let config = xet_config();
assert!(ref_size <= *config.client.ac_max_reference_transmission_size);
}
#[test]
fn test_reference_size_skips_zero_byte_transfers() {
let mut state = ConcurrencyControllerState::new_testing();
for _ in 0..10 {
state.update_size_tracking(0);
}
assert!(state.estimated_reference_transmission_size().is_none());
assert_eq!(state.size_observation_count, 0);
}
#[test]
fn test_reference_size_with_mixed_sizes() {
let config = xet_config();
let mut small_only_state = ConcurrencyControllerState::new_testing();
for _ in 0..10 {
small_only_state.update_size_tracking(512 * 1024); }
let small_only_ref_size = small_only_state.estimated_reference_transmission_size().unwrap();
let mut state = ConcurrencyControllerState::new_testing();
for _ in 0..5 {
state.update_size_tracking(512 * 1024); }
for _ in 0..5 {
state.update_size_tracking(32 * 1024 * 1024); }
let ref_size = state.estimated_reference_transmission_size().unwrap();
debug_assert!(ref_size >= *config.client.ac_min_reference_transmission_size);
debug_assert_le!(ref_size, *config.client.ac_max_reference_transmission_size);
assert!(ref_size > small_only_ref_size);
}
#[tokio::test]
async fn test_failed_transfers_still_update_size_tracking() {
time::pause();
let controller = AdaptiveConcurrencyController::new_testing(1, (1, 4));
for _ in 0..MIN_SIZE_OBSERVATIONS_FOR_REFERENCE {
let permit = controller.acquire_connection_permit().await.unwrap();
advance(Duration::from_millis(10)).await;
permit.report_completion(8 * 1024 * 1024, false).await;
advance(Duration::from_millis(DECR_SPACING_MS + 1)).await;
}
let state = controller.state.lock().await;
assert_eq!(state.size_observation_count, MIN_SIZE_OBSERVATIONS_FOR_REFERENCE);
assert!(state.estimated_reference_transmission_size().is_some());
}
async fn train_controller(
controller: &Arc<AdaptiveConcurrencyController>,
sizes_bytes: &[u64],
bandwidth_bps: f64,
num_iterations: usize,
) -> usize {
for i in 0..num_iterations {
let size = sizes_bytes[i % sizes_bytes.len()];
let permit = controller.acquire_connection_permit().await.unwrap();
let duration_ms = ((size as f64 / bandwidth_bps) * 1000.0) as u64 + 10;
advance(Duration::from_millis(duration_ms)).await;
permit.report_completion(size, true).await;
advance(Duration::from_millis(INCR_SPACING_MS + 1)).await;
}
controller.total_permits()
}
#[tokio::test]
async fn test_small_transfers_allow_higher_concurrency_than_large() {
time::pause();
advance(Duration::from_millis(INCR_SPACING_MS + 1)).await;
let small_sizes: Vec<u64> = vec![256 * 1024, 512 * 1024, 1024 * 1024, 2 * 1024 * 1024];
let large_sizes: Vec<u64> = vec![10 * 1024 * 1024, 20 * 1024 * 1024, 40 * 1024 * 1024, 64 * 1024 * 1024];
let bandwidth = 5.0 * 1024.0 * 1024.0;
let controller_small = AdaptiveConcurrencyController::new_testing(1, (1, 50));
let small_concurrency = train_controller(&controller_small, &small_sizes, bandwidth, 40).await;
let controller_large = AdaptiveConcurrencyController::new_testing(1, (1, 50));
let large_concurrency = train_controller(&controller_large, &large_sizes, bandwidth, 40).await;
assert!(
small_concurrency >= large_concurrency,
"Small-transfer concurrency ({small_concurrency}) should be >= large-transfer concurrency ({large_concurrency})"
);
}
}