use std::collections::{HashMap, VecDeque};
use trustformers_core::Result;
use super::types::{
AllocationStrategy, CompressionStats, DeltaAlgorithm, DeltaCompressionConfig,
GradientCompressionAlgorithm, NetworkAdaptationConfig, NetworkQuantizationConfig,
PruningConfig,
};
pub struct BandwidthOptimizer {
config: NetworkAdaptationConfig,
compression_engine: NetworkCompressionEngine,
traffic_shaper: TrafficShaper,
usage_tracker: DataUsageTracker,
}
#[derive(Debug, Clone, Default)]
pub struct ErrorFeedbackBuffer {
pub error_count: u32,
pub last_error_timestamp: Option<std::time::Instant>,
pub error_rate: f32,
}
pub struct NetworkCompressionEngine {
gradient_compressor: GradientCompressor,
model_compressor: ModelCompressor,
differential_compressor: DifferentialCompressor,
compression_stats: CompressionStats,
}
pub struct GradientCompressor {
algorithm: GradientCompressionAlgorithm,
error_feedback: ErrorFeedbackBuffer,
compression_ratio: f32,
quality_threshold: f32,
}
pub struct ModelCompressor {
quantization_config: NetworkQuantizationConfig,
pruning_config: PruningConfig,
delta_compression: DeltaCompressionConfig,
}
pub struct DifferentialCompressor {
baseline_models: HashMap<String, Vec<u8>>,
delta_algorithm: DeltaAlgorithm,
compression_cache: CompressionCache,
}
pub struct CompressionCache {
cached_deltas: HashMap<String, Vec<u8>>,
cache_hit_rate: f32,
max_cache_size_mb: u32,
}
pub struct TrafficShaper {
rate_limiter: RateLimiter,
priority_queues: HashMap<String, VecDeque<Vec<u8>>>,
bandwidth_allocator: BandwidthAllocator,
}
pub struct RateLimiter {
current_rate_mbps: f32,
target_rate_mbps: f32,
burst_allowance_mb: f32,
window_size_ms: u64,
}
pub struct BandwidthAllocator {
total_bandwidth_mbps: f32,
allocated_bandwidth: HashMap<String, f32>,
allocation_strategy: AllocationStrategy,
}
pub struct DataUsageTracker {
daily_usage: HashMap<String, u64>,
monthly_usage: HashMap<String, u64>,
usage_history: VecDeque<(std::time::Instant, u64)>,
usage_predictor: UsagePredictor,
}
pub struct UsagePredictor {
prediction_models: HashMap<String, Vec<f32>>,
prediction_accuracy: f32,
prediction_window_hours: u32,
}
impl BandwidthOptimizer {
pub fn new(config: NetworkAdaptationConfig) -> Result<Self> {
Ok(Self {
config,
compression_engine: NetworkCompressionEngine::new(),
traffic_shaper: TrafficShaper::new(),
usage_tracker: DataUsageTracker::new(),
})
}
pub fn start(&mut self) -> Result<()> {
Ok(())
}
pub fn stop(&mut self) -> Result<()> {
Ok(())
}
pub fn optimize_transmission(&mut self, data: Vec<u8>, data_type: &str) -> Result<Vec<u8>> {
let compressed_data = match data_type {
"gradient" => self.compression_engine.compress_gradient(data)?,
"model" => self.compression_engine.compress_model(data)?,
"differential" => self.compression_engine.compress_differential(data)?,
_ => data, };
self.traffic_shaper.shape_traffic(compressed_data, data_type)
}
pub fn get_compression_stats(&self) -> &CompressionStats {
&self.compression_engine.compression_stats
}
pub fn get_bandwidth_utilization(&self) -> f32 {
self.traffic_shaper.get_current_utilization()
}
pub fn get_usage_stats(&self) -> HashMap<String, u64> {
self.usage_tracker.get_current_usage()
}
pub fn predict_bandwidth_requirements(&self, hours: u32) -> Result<f32> {
self.usage_tracker.predict_usage(hours)
}
pub fn update_config(&mut self, config: NetworkAdaptationConfig) -> Result<()> {
self.config = config;
Ok(())
}
pub fn is_approaching_limit(&self) -> bool {
self.usage_tracker.is_approaching_limit(&self.config.data_usage_limits)
}
pub fn get_recommended_compression_level(&self) -> f32 {
let usage_ratio = self.usage_tracker.get_usage_ratio(&self.config.data_usage_limits);
let bandwidth_ratio = self.traffic_shaper.get_utilization_ratio();
(usage_ratio + bandwidth_ratio) / 2.0
}
}
impl NetworkCompressionEngine {
pub fn new() -> Self {
Self {
gradient_compressor: GradientCompressor::new(),
model_compressor: ModelCompressor::new(),
differential_compressor: DifferentialCompressor::new(),
compression_stats: CompressionStats::default(),
}
}
pub fn compress_gradient(&mut self, data: Vec<u8>) -> Result<Vec<u8>> {
let original_size = data.len();
let compressed = self.gradient_compressor.compress(data)?;
self.update_stats(original_size, compressed.len());
Ok(compressed)
}
pub fn compress_model(&mut self, data: Vec<u8>) -> Result<Vec<u8>> {
let original_size = data.len();
let compressed = self.model_compressor.compress(data)?;
self.update_stats(original_size, compressed.len());
Ok(compressed)
}
pub fn compress_differential(&mut self, data: Vec<u8>) -> Result<Vec<u8>> {
let original_size = data.len();
let compressed = self.differential_compressor.compress(data)?;
self.update_stats(original_size, compressed.len());
Ok(compressed)
}
fn update_stats(&mut self, original_size: usize, compressed_size: usize) {
self.compression_stats.original_size_bytes += original_size;
self.compression_stats.compressed_size_bytes += compressed_size;
if self.compression_stats.original_size_bytes > 0 {
self.compression_stats.compression_ratio = self.compression_stats.compressed_size_bytes
as f32
/ self.compression_stats.original_size_bytes as f32;
}
}
pub fn get_compression_efficiency(&self) -> f32 {
1.0 - self.compression_stats.compression_ratio
}
}
impl GradientCompressor {
pub fn new() -> Self {
Self {
algorithm: GradientCompressionAlgorithm::Adaptive,
error_feedback: ErrorFeedbackBuffer::default(),
compression_ratio: 0.7,
quality_threshold: 0.95,
}
}
pub fn compress(&mut self, data: Vec<u8>) -> Result<Vec<u8>> {
match self.algorithm {
GradientCompressionAlgorithm::TopK { .. } => self.compress_top_k(data),
GradientCompressionAlgorithm::RandomSparsification { .. } => {
self.compress_randomized(data)
},
GradientCompressionAlgorithm::Adaptive => self.compress_adaptive(data),
GradientCompressionAlgorithm::Quantized { .. } => self.compress_quantized(data),
GradientCompressionAlgorithm::None => Ok(data), GradientCompressionAlgorithm::ThresholdBased { .. } => self.compress_top_k(data),
}
}
fn compress_top_k(&self, data: Vec<u8>) -> Result<Vec<u8>> {
let keep_ratio = self.compression_ratio;
let keep_count = (data.len() as f32 * keep_ratio) as usize;
Ok(data.into_iter().take(keep_count).collect())
}
fn compress_randomized(&self, data: Vec<u8>) -> Result<Vec<u8>> {
let keep_ratio = self.compression_ratio;
let data_len = data.len();
let keep_count = (data_len as f32 * keep_ratio) as usize;
let step_size = data_len.checked_div(keep_count).unwrap_or(1);
Ok(data.into_iter().step_by(step_size).collect())
}
fn compress_adaptive(&mut self, data: Vec<u8>) -> Result<Vec<u8>> {
let feedback_adjustment = 1.0; let adjusted_ratio = (self.compression_ratio + feedback_adjustment).clamp(0.1, 0.9);
let keep_count = (data.len() as f32 * adjusted_ratio) as usize;
Ok(data.into_iter().take(keep_count).collect())
}
fn compress_quantized(&self, data: Vec<u8>) -> Result<Vec<u8>> {
Ok(data.into_iter().map(|b| (b / 4) * 4).collect())
}
pub fn update_parameters(&mut self, compression_ratio: f32, quality_threshold: f32) {
self.compression_ratio = compression_ratio.clamp(0.1, 0.9);
self.quality_threshold = quality_threshold.clamp(0.5, 1.0);
}
}
impl ModelCompressor {
pub fn new() -> Self {
Self {
quantization_config: NetworkQuantizationConfig::default(),
pruning_config: PruningConfig::default(),
delta_compression: DeltaCompressionConfig::default(),
}
}
pub fn compress(&self, data: Vec<u8>) -> Result<Vec<u8>> {
let mut compressed_data = data;
if self.pruning_config.enable_pruning {
compressed_data = self.apply_pruning(compressed_data)?;
}
compressed_data = self.apply_quantization(compressed_data)?;
if self.delta_compression.enable_delta {
compressed_data = self.apply_delta_compression(compressed_data)?;
}
Ok(compressed_data)
}
fn apply_pruning(&self, data: Vec<u8>) -> Result<Vec<u8>> {
let pruning_ratio = self.pruning_config.pruning_ratio;
let keep_count = (data.len() as f32 * (1.0 - pruning_ratio)) as usize;
if self.pruning_config.structured_pruning {
let block_size = data.len() / keep_count;
Ok(data.chunks(block_size).take(keep_count).flatten().copied().collect())
} else {
Ok(data.into_iter().step_by((1.0 / (1.0 - pruning_ratio)) as usize).collect())
}
}
fn apply_quantization(&self, data: Vec<u8>) -> Result<Vec<u8>> {
let bits = self.quantization_config.gradient_bits;
let quantization_factor = 256 / (1 << bits);
Ok(data
.into_iter()
.map(|b| (b / quantization_factor as u8) * quantization_factor as u8)
.collect())
}
fn apply_delta_compression(&self, data: Vec<u8>) -> Result<Vec<u8>> {
match self.delta_compression.delta_algorithm {
DeltaAlgorithm::SimpleDiff => {
let mut result = Vec::with_capacity(data.len());
if !data.is_empty() {
result.push(data[0]);
for i in 1..data.len() {
result.push(data[i].wrapping_sub(data[i - 1]));
}
}
Ok(result)
},
DeltaAlgorithm::OptimizedDiff => {
Ok(data)
},
DeltaAlgorithm::SemanticDiff => {
let mut result = Vec::with_capacity(data.len());
if !data.is_empty() {
result.push(data[0]);
for i in 1..data.len() {
result.push(data[i].wrapping_sub(data[i - 1]));
}
}
Ok(result)
},
DeltaAlgorithm::CompressedDiff => {
Ok(data)
},
}
}
}
impl DifferentialCompressor {
pub fn new() -> Self {
Self {
baseline_models: HashMap::new(),
delta_algorithm: DeltaAlgorithm::OptimizedDiff,
compression_cache: CompressionCache::new(),
}
}
pub fn compress(&mut self, data: Vec<u8>) -> Result<Vec<u8>> {
let data_hash = self.calculate_hash(&data);
if let Some(cached) = self.compression_cache.get(&data_hash) {
return Ok(cached);
}
let compressed = match self.delta_algorithm {
DeltaAlgorithm::SimpleDiff => self.simple_diff_compress(data)?,
DeltaAlgorithm::OptimizedDiff => self.optimized_diff_compress(data)?,
DeltaAlgorithm::SemanticDiff => self.simple_diff_compress(data)?, DeltaAlgorithm::CompressedDiff => self.optimized_diff_compress(data)?, };
self.compression_cache.insert(data_hash, compressed.clone());
Ok(compressed)
}
fn simple_diff_compress(&self, data: Vec<u8>) -> Result<Vec<u8>> {
if let Some((_, baseline)) = self.baseline_models.iter().next() {
let mut diff = Vec::new();
for (i, &byte) in data.iter().enumerate() {
if i < baseline.len() {
diff.push(byte.wrapping_sub(baseline[i]));
} else {
diff.push(byte);
}
}
Ok(diff)
} else {
Ok(data)
}
}
fn optimized_diff_compress(&self, data: Vec<u8>) -> Result<Vec<u8>> {
self.simple_diff_compress(data)
}
fn calculate_hash(&self, data: &[u8]) -> String {
format!(
"{:08x}",
(data.len() as u32) ^ data.iter().fold(0u32, |acc, &b| acc.wrapping_add(b as u32))
)
}
pub fn set_baseline(&mut self, model_id: String, baseline: Vec<u8>) {
self.baseline_models.insert(model_id, baseline);
}
}
impl Default for CompressionCache {
fn default() -> Self {
Self::new()
}
}
impl CompressionCache {
pub fn new() -> Self {
Self {
cached_deltas: HashMap::new(),
cache_hit_rate: 0.0,
max_cache_size_mb: 100,
}
}
pub fn get(&mut self, key: &str) -> Option<Vec<u8>> {
self.cached_deltas.get(key).cloned()
}
pub fn insert(&mut self, key: String, value: Vec<u8>) {
let current_size_mb = self.get_cache_size_mb();
if current_size_mb >= self.max_cache_size_mb {
self.evict_oldest();
}
self.cached_deltas.insert(key, value);
}
fn get_cache_size_mb(&self) -> u32 {
let total_bytes: usize = self.cached_deltas.values().map(|v| v.len()).sum();
(total_bytes / (1024 * 1024)) as u32
}
fn evict_oldest(&mut self) {
if let Some(key) = self.cached_deltas.keys().next().cloned() {
self.cached_deltas.remove(&key);
}
}
}
impl TrafficShaper {
pub fn new() -> Self {
Self {
rate_limiter: RateLimiter::new(),
priority_queues: HashMap::new(),
bandwidth_allocator: BandwidthAllocator::new(),
}
}
pub fn shape_traffic(&mut self, data: Vec<u8>, data_type: &str) -> Result<Vec<u8>> {
if !self.rate_limiter.check_rate_limit(data.len()) {
self.queue_data(data_type.to_string(), data.clone());
return Ok(Vec::new()); }
self.bandwidth_allocator.allocate_bandwidth(data_type, data.len() as f32);
Ok(data)
}
fn queue_data(&mut self, data_type: String, data: Vec<u8>) {
self.priority_queues.entry(data_type).or_default().push_back(data);
}
pub fn get_current_utilization(&self) -> f32 {
self.rate_limiter.current_rate_mbps / self.rate_limiter.target_rate_mbps
}
pub fn get_utilization_ratio(&self) -> f32 {
self.get_current_utilization().clamp(0.0, 1.0)
}
pub fn process_queue(&mut self) -> Result<Vec<Vec<u8>>> {
let mut processed = Vec::new();
for (_, queue) in self.priority_queues.iter_mut() {
while let Some(data) = queue.pop_front() {
if self.rate_limiter.check_rate_limit(data.len()) {
processed.push(data);
} else {
queue.push_front(data); break;
}
}
}
Ok(processed)
}
}
impl RateLimiter {
pub fn new() -> Self {
Self {
current_rate_mbps: 0.0,
target_rate_mbps: 10.0,
burst_allowance_mb: 5.0,
window_size_ms: 1000,
}
}
pub fn check_rate_limit(&mut self, data_size_bytes: usize) -> bool {
let data_size_mb = data_size_bytes as f32 / (1024.0 * 1024.0);
if self.current_rate_mbps + data_size_mb <= self.target_rate_mbps + self.burst_allowance_mb
{
self.current_rate_mbps += data_size_mb;
true
} else {
false
}
}
pub fn update_target_rate(&mut self, target_mbps: f32) {
self.target_rate_mbps = target_mbps.max(0.1); }
pub fn reset_counters(&mut self) {
self.current_rate_mbps = 0.0;
}
}
impl BandwidthAllocator {
pub fn new() -> Self {
Self {
total_bandwidth_mbps: 10.0,
allocated_bandwidth: HashMap::new(),
allocation_strategy: AllocationStrategy::PriorityBased,
}
}
pub fn allocate_bandwidth(&mut self, data_type: &str, data_size_mb: f32) {
let current = self.allocated_bandwidth.entry(data_type.to_string()).or_insert(0.0);
*current += data_size_mb;
}
pub fn get_available_bandwidth(&self) -> f32 {
let allocated: f32 = self.allocated_bandwidth.values().sum();
(self.total_bandwidth_mbps - allocated).max(0.0)
}
pub fn update_total_bandwidth(&mut self, bandwidth_mbps: f32) {
self.total_bandwidth_mbps = bandwidth_mbps.max(0.1);
}
pub fn reset_allocations(&mut self) {
self.allocated_bandwidth.clear();
}
}
impl DataUsageTracker {
pub fn new() -> Self {
Self {
daily_usage: HashMap::new(),
monthly_usage: HashMap::new(),
usage_history: VecDeque::new(),
usage_predictor: UsagePredictor::new(),
}
}
pub fn track_usage(&mut self, data_type: &str, bytes: u64) {
let daily_entry = self.daily_usage.entry(data_type.to_string()).or_insert(0);
*daily_entry += bytes;
let monthly_entry = self.monthly_usage.entry(data_type.to_string()).or_insert(0);
*monthly_entry += bytes;
self.usage_history.push_back((std::time::Instant::now(), bytes));
if self.usage_history.len() > 1000 {
self.usage_history.pop_front();
}
}
pub fn get_current_usage(&self) -> HashMap<String, u64> {
self.daily_usage.clone()
}
pub fn predict_usage(&self, hours: u32) -> Result<f32> {
self.usage_predictor.predict(hours, &self.usage_history)
}
pub fn is_approaching_limit(&self, limits: &super::types::DataUsageLimits) -> bool {
let total_daily: u64 = self.daily_usage.values().sum();
let total_monthly: u64 = self.monthly_usage.values().sum();
total_daily > (limits.cellular_daily_limit_mb.unwrap_or(0) * 1024 * 1024) as u64 * 80 / 100
|| total_monthly
> (limits.cellular_monthly_limit_mb.unwrap_or(0) * 1024 * 1024) as u64 * 80 / 100
}
pub fn get_usage_ratio(&self, limits: &super::types::DataUsageLimits) -> f32 {
let total_daily: u64 = self.daily_usage.values().sum();
let daily_ratio = total_daily as f32
/ ((limits.cellular_daily_limit_mb.unwrap_or(0) * 1024 * 1024) as f32);
daily_ratio.clamp(0.0, 1.0)
}
}
impl UsagePredictor {
pub fn new() -> Self {
Self {
prediction_models: HashMap::new(),
prediction_accuracy: 0.5,
prediction_window_hours: 24,
}
}
pub fn predict(
&self,
hours: u32,
history: &VecDeque<(std::time::Instant, u64)>,
) -> Result<f32> {
if history.is_empty() {
return Ok(0.0);
}
let recent_count = (history.len() / 4).max(1); let recent_usage: u64 =
history.iter().rev().take(recent_count).map(|(_, usage)| usage).sum();
let average_usage_per_hour = recent_usage as f32 / recent_count as f32;
Ok(average_usage_per_hour * hours as f32)
}
pub fn update_accuracy(&mut self, predicted: f32, actual: f32) {
let error = (predicted - actual).abs() / actual.max(1.0);
let accuracy = 1.0 - error;
self.prediction_accuracy = (self.prediction_accuracy * 0.9) + (accuracy * 0.1);
self.prediction_accuracy = self.prediction_accuracy.clamp(0.0, 1.0);
}
}
impl Default for NetworkCompressionEngine {
fn default() -> Self {
Self::new()
}
}
impl Default for GradientCompressor {
fn default() -> Self {
Self::new()
}
}
impl Default for ModelCompressor {
fn default() -> Self {
Self::new()
}
}
impl Default for DifferentialCompressor {
fn default() -> Self {
Self::new()
}
}
impl Default for TrafficShaper {
fn default() -> Self {
Self::new()
}
}
impl Default for RateLimiter {
fn default() -> Self {
Self::new()
}
}
impl Default for BandwidthAllocator {
fn default() -> Self {
Self::new()
}
}
impl Default for DataUsageTracker {
fn default() -> Self {
Self::new()
}
}
impl Default for UsagePredictor {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::super::types::NetworkAdaptationConfig;
use super::*;
#[test]
fn test_bandwidth_optimizer_new() {
let config = NetworkAdaptationConfig::default();
let opt = BandwidthOptimizer::new(config).expect("optimizer should be created");
let stats = opt.get_compression_stats();
assert_eq!(stats.original_size_bytes, 0);
assert_eq!(stats.compressed_size_bytes, 0);
}
#[test]
fn test_optimize_transmission_unknown_type_passthrough() {
let config = NetworkAdaptationConfig::default();
let mut opt = BandwidthOptimizer::new(config).expect("optimizer should be created");
let data = vec![1u8, 2, 3, 4, 5];
let result = opt.optimize_transmission(data.clone(), "unknown").expect("should succeed");
let _ = result;
}
#[test]
fn test_bandwidth_optimizer_start_stop() {
let config = NetworkAdaptationConfig::default();
let mut opt = BandwidthOptimizer::new(config).expect("optimizer should be created");
opt.start().expect("start should succeed");
opt.stop().expect("stop should succeed");
}
#[test]
fn test_bandwidth_utilization_initial() {
let config = NetworkAdaptationConfig::default();
let opt = BandwidthOptimizer::new(config).expect("optimizer should be created");
let utilization = opt.get_bandwidth_utilization();
assert_eq!(utilization, 0.0);
}
#[test]
fn test_recommended_compression_level_range() {
let config = NetworkAdaptationConfig::default();
let opt = BandwidthOptimizer::new(config).expect("optimizer should be created");
let level = opt.get_recommended_compression_level();
assert!(level >= 0.0);
}
#[test]
fn test_compression_engine_gradient_compress_nonempty() {
let mut engine = NetworkCompressionEngine::new();
let data: Vec<u8> = (0u8..=20).collect();
let result = engine.compress_gradient(data.clone()).expect("should succeed");
assert!(!result.is_empty());
}
#[test]
fn test_compression_engine_model_compress() {
let mut engine = NetworkCompressionEngine::new();
let data: Vec<u8> = (0u8..=30).collect();
let result = engine.compress_model(data.clone()).expect("should succeed");
assert!(!result.is_empty());
}
#[test]
fn test_compression_engine_differential_compress() {
let mut engine = NetworkCompressionEngine::new();
let data: Vec<u8> = (0u8..=15).collect();
let result = engine.compress_differential(data.clone()).expect("should succeed");
assert!(!result.is_empty());
}
#[test]
fn test_compression_engine_stats_update_after_compress() {
let mut engine = NetworkCompressionEngine::new();
let data: Vec<u8> = (0u8..=99).collect();
let original_len = data.len();
engine.compress_gradient(data).expect("should succeed");
assert_eq!(engine.compression_stats.original_size_bytes, original_len);
assert!(engine.compression_stats.compressed_size_bytes > 0);
}
#[test]
fn test_compression_efficiency_range() {
let mut engine = NetworkCompressionEngine::new();
let data: Vec<u8> = (0u8..=50).collect();
engine.compress_gradient(data).expect("should succeed");
let eff = engine.get_compression_efficiency();
assert!(eff >= 0.0);
}
#[test]
fn test_gradient_compressor_compress_empty() {
let mut c = GradientCompressor::new();
let result = c.compress(vec![]).expect("empty compress should succeed");
assert!(result.is_empty());
}
#[test]
fn test_gradient_compressor_update_parameters_clamps() {
let mut c = GradientCompressor::new();
c.update_parameters(1.5, 2.0); assert!((0.1..=0.9).contains(&c.compression_ratio));
assert!((0.5..=1.0).contains(&c.quality_threshold));
}
#[test]
fn test_rate_limiter_allows_small_data() {
let mut rl = RateLimiter::new();
let allowed = rl.check_rate_limit(100);
assert!(allowed);
}
#[test]
fn test_rate_limiter_blocks_massive_data() {
let mut rl = RateLimiter::new();
rl.check_rate_limit(10 * 1024 * 1024); let blocked = rl.check_rate_limit(10 * 1024 * 1024);
assert!(!blocked);
}
#[test]
fn test_rate_limiter_reset_clears_counter() {
let mut rl = RateLimiter::new();
rl.check_rate_limit(9 * 1024 * 1024);
rl.reset_counters();
assert_eq!(rl.current_rate_mbps, 0.0);
}
#[test]
fn test_rate_limiter_update_target_minimum() {
let mut rl = RateLimiter::new();
rl.update_target_rate(-5.0);
assert!(rl.target_rate_mbps >= 0.1);
}
#[test]
fn test_bandwidth_allocator_allocate_and_available() {
let mut alloc = BandwidthAllocator::new();
alloc.allocate_bandwidth("gradient", 2.0);
let avail = alloc.get_available_bandwidth();
assert!((avail - 8.0).abs() < 0.01);
}
#[test]
fn test_bandwidth_allocator_reset() {
let mut alloc = BandwidthAllocator::new();
alloc.allocate_bandwidth("model", 5.0);
alloc.reset_allocations();
assert!((alloc.get_available_bandwidth() - 10.0).abs() < 0.01);
}
#[test]
fn test_usage_tracker_tracks_bytes() {
let mut tracker = DataUsageTracker::new();
tracker.track_usage("gradient", 1024);
tracker.track_usage("gradient", 512);
let usage = tracker.get_current_usage();
assert_eq!(usage.get("gradient").copied().unwrap_or(0), 1536);
}
#[test]
fn test_usage_tracker_multiple_types() {
let mut tracker = DataUsageTracker::new();
tracker.track_usage("gradient", 100);
tracker.track_usage("model", 200);
let usage = tracker.get_current_usage();
assert_eq!(usage.get("gradient").copied().unwrap_or(0), 100);
assert_eq!(usage.get("model").copied().unwrap_or(0), 200);
}
#[test]
fn test_error_feedback_buffer_default() {
let buf = ErrorFeedbackBuffer::default();
assert_eq!(buf.error_count, 0);
assert!(buf.last_error_timestamp.is_none());
assert_eq!(buf.error_rate, 0.0);
}
}