1use crate::common;
13use crate::error::Error;
14use crate::progress::Tracker;
15use crate::terminal;
16use crate::test_config::TestConfig;
17use owo_colors::OwoColorize;
18use std::sync::Arc;
19use std::sync::Mutex;
20use std::sync::atomic::{AtomicU64, Ordering};
21use std::time::Instant;
22
23pub const SAMPLE_INTERVAL_MS: u64 = 50;
25
26pub struct LoopState {
30 pub total_bytes: Arc<AtomicU64>,
31 pub peak_bps: Arc<AtomicU64>,
32 pub speed_samples: Arc<Mutex<Vec<f64>>>,
33 pub start: Instant,
34 pub last_sample_ms: Arc<AtomicU64>,
35 pub estimated_total: u64,
36 pub progress: Arc<Tracker>,
37}
38
39#[derive(Debug)]
41pub struct BandwidthResult {
42 pub avg_bps: f64,
43 pub peak_bps: f64,
44 pub total_bytes: u64,
45 pub duration_secs: f64,
46 pub speed_samples: Vec<f64>,
47}
48
49impl LoopState {
50 #[must_use]
52 pub fn new(estimated_total: u64, progress: Arc<Tracker>) -> Self {
53 Self {
54 total_bytes: Arc::new(AtomicU64::new(0)),
55 peak_bps: Arc::new(AtomicU64::new(0)),
56 speed_samples: Arc::new(Mutex::new(Vec::new())),
57 start: Instant::now(),
58 last_sample_ms: Arc::new(AtomicU64::new(0)),
59 estimated_total,
60 progress,
61 }
62 }
63
64 pub fn record_bytes(&self, len: u64, sample_interval_ms: u64) {
71 self.total_bytes.fetch_add(len, Ordering::Release);
73
74 let elapsed_ms = u64::try_from(self.start.elapsed().as_millis()).unwrap_or(u64::MAX);
75 let last_ms = self.last_sample_ms.load(Ordering::Relaxed);
76 let should_sample =
77 last_ms == 0 || elapsed_ms.saturating_sub(last_ms) >= sample_interval_ms;
78
79 if should_sample {
80 self.last_sample_ms.store(elapsed_ms, Ordering::Relaxed);
81 self.sample_now();
82 }
83 }
84
85 fn sample_now(&self) {
87 let total = self.total_bytes.load(Ordering::Acquire);
88 let elapsed = self.start.elapsed().as_secs_f64();
89 let speed = common::calculate_bandwidth(total, elapsed);
90
91 let current_peak = self.peak_bps.load(Ordering::Relaxed) as f64;
93 if speed > current_peak {
94 let peak_u64 = speed.clamp(0.0, u64::MAX as f64) as u64;
95 self.peak_bps.store(peak_u64, Ordering::Release);
97 }
98
99 if let Ok(mut samples) = self.speed_samples.lock() {
100 samples.push(speed);
101 }
102
103 let pct = (total as f64 / self.estimated_total as f64).min(1.0);
106 self.progress.update(speed / 1_000_000.0, pct, total);
107 }
108
109 #[must_use]
111 pub fn finish(&self) -> BandwidthResult {
112 let total = self.total_bytes.load(Ordering::Acquire);
114 let peak = self.peak_bps.load(Ordering::Acquire) as f64;
116 let duration = self.start.elapsed().as_secs_f64();
117 let samples = self
119 .speed_samples
120 .lock()
121 .map(|g| g.to_vec())
122 .unwrap_or_default();
123 let avg = common::calculate_bandwidth(total, duration);
124
125 BandwidthResult {
126 avg_bps: avg,
127 peak_bps: peak,
128 total_bytes: total,
129 duration_secs: duration,
130 speed_samples: samples,
131 }
132 }
133}
134
135#[must_use = "the BandwidthResult should be used to report test outcomes"]
159pub async fn run_concurrent_streams(
160 estimated_total: u64,
161 stream_count: usize,
162 progress: Arc<Tracker>,
163 label: &str,
164 mut spawn_fn: impl FnMut(usize, Arc<LoopState>, u64) -> tokio::task::JoinHandle<Result<(), Error>>,
165) -> Result<BandwidthResult, Error> {
166 let config = TestConfig::default();
167 let sample_interval_ms = config.sample_interval_ms;
168 let state = Arc::new(LoopState::new(estimated_total, progress));
169
170 let mut handles = Vec::with_capacity(stream_count);
171 for i in 0..stream_count {
172 handles.push(spawn_fn(i, Arc::clone(&state), sample_interval_ms));
173 }
174
175 let mut any_succeeded = false;
177 let mut first_error: Option<Error> = None;
178 for (i, handle) in handles.into_iter().enumerate() {
179 match handle.await {
180 Ok(Ok(())) => any_succeeded = true,
181 Ok(Err(err)) => {
182 let msg = format!("Warning: {label} stream {i} failed: {err}");
183 if terminal::no_color() {
184 eprintln!("\n{msg}");
185 } else {
186 eprintln!("\n{}", msg.yellow().bold());
187 }
188 if first_error.is_none() {
189 first_error = Some(err);
190 }
191 }
192 Err(e) => {
193 let msg = format!("Warning: {label} stream {i} failed: {e}");
194 if terminal::no_color() {
195 eprintln!("\n{msg}");
196 } else {
197 eprintln!("\n{}", msg.yellow().bold());
198 }
199 if first_error.is_none() {
200 first_error = Some(Error::context(format!("{label} stream {i} panicked: {e}")));
201 }
202 }
203 }
204 }
205
206 if !any_succeeded {
207 return Err(
208 first_error.unwrap_or_else(|| Error::context(format!("all {label} streams failed")))
209 );
210 }
211
212 let result = state.finish();
213 if result.total_bytes == 0 {
214 return Err(first_error.unwrap_or_else(|| match label {
215 "download" => {
216 Error::DownloadFailure("test completed without transferring data".to_string())
217 }
218 "upload" => {
219 Error::UploadFailure("test completed without transferring data".to_string())
220 }
221 _ => Error::context(format!("{label} test completed without transferring data")),
222 }));
223 }
224
225 Ok(result)
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use std::sync::atomic::Ordering;
232 use std::thread;
233 use std::time::Duration;
234
235 fn make_tracker() -> Arc<Tracker> {
238 Arc::new(Tracker::new("test"))
239 }
240
241 #[test]
242 fn test_loop_state_new_fields() {
243 let tracker = make_tracker();
244 let state = LoopState::new(100_000_000, tracker);
245 assert_eq!(state.total_bytes.load(Ordering::SeqCst), 0);
246 assert_eq!(state.peak_bps.load(Ordering::SeqCst), 0);
247 assert_eq!(state.estimated_total, 100_000_000);
248 assert!(state.speed_samples.lock().unwrap().is_empty());
249 }
250
251 #[test]
252 fn test_loop_state_concurrent_atomic_updates() {
253 let tracker = make_tracker();
254 let state = Arc::new(LoopState::new(100_000_000, tracker));
255
256 let handles: Vec<_> = (0..4)
257 .map(|_| {
258 let s = Arc::clone(&state);
259 thread::spawn(move || {
260 for _ in 0..1000 {
261 s.record_bytes(100, SAMPLE_INTERVAL_MS);
262 }
263 })
264 })
265 .collect();
266
267 for h in handles {
268 h.join().unwrap();
269 }
270
271 assert_eq!(state.total_bytes.load(Ordering::SeqCst), 400_000);
273 }
274
275 #[test]
276 fn test_record_bytes_zero_value() {
277 let tracker = make_tracker();
278 let state = LoopState::new(100_000_000, tracker);
279 state.record_bytes(0, SAMPLE_INTERVAL_MS);
280 assert_eq!(state.total_bytes.load(Ordering::SeqCst), 0);
281 }
282
283 #[test]
284 fn test_record_bytes_accumulates() {
285 let tracker = make_tracker();
286 let state = LoopState::new(100_000_000, tracker);
287 state.record_bytes(1000, SAMPLE_INTERVAL_MS);
288 state.record_bytes(2000, SAMPLE_INTERVAL_MS);
289 state.record_bytes(3000, SAMPLE_INTERVAL_MS);
290 assert_eq!(state.total_bytes.load(Ordering::SeqCst), 6000);
291 }
292
293 #[test]
294 fn test_record_bytes_large_values() {
295 let tracker = make_tracker();
296 let state = LoopState::new(u64::MAX, tracker);
297 state.record_bytes(1_000_000_000, SAMPLE_INTERVAL_MS);
298 assert_eq!(state.total_bytes.load(Ordering::SeqCst), 1_000_000_000);
299 }
300
301 #[test]
302 fn test_record_bytes_throttle_mechanism() {
303 let tracker = make_tracker();
304 let state = LoopState::new(100_000_000, tracker);
305
306 let interval_ms = 50u64;
309
310 state.record_bytes(1000, interval_ms);
312 assert_eq!(state.speed_samples.lock().unwrap().len(), 1);
313
314 state.record_bytes(1000, interval_ms);
316
317 thread::sleep(Duration::from_millis(100));
319 state.record_bytes(1000, interval_ms);
320
321 let samples = state.speed_samples.lock().unwrap();
324 assert!(
325 samples.len() >= 2,
326 "Expected at least 2 samples, got {}",
327 samples.len()
328 );
329 }
330
331 #[test]
332 fn test_record_bytes_short_interval_samples_more() {
333 let tracker = make_tracker();
334 let state = LoopState::new(100_000_000, tracker);
335
336 for _ in 0..3 {
338 state.record_bytes(1_000_000, 5); thread::sleep(Duration::from_millis(10));
340 }
341
342 let samples = state.speed_samples.lock().unwrap();
343 assert!(
345 samples.len() >= 2,
346 "Expected >= 2 samples with short interval, got {}",
347 samples.len()
348 );
349 }
350
351 #[test]
352 fn test_record_bytes_updates_peak() {
353 let tracker = make_tracker();
354 let state = LoopState::new(100_000_000, tracker);
355
356 state.record_bytes(10_000_000, SAMPLE_INTERVAL_MS);
357 thread::sleep(Duration::from_millis(60));
358 state.record_bytes(10_000_000, SAMPLE_INTERVAL_MS);
359
360 let peak = state.peak_bps.load(Ordering::SeqCst);
361 assert!(peak > 0);
362 }
363
364 #[test]
365 fn test_finish_empty_state() {
366 let tracker = make_tracker();
367 let state = LoopState::new(100_000_000, tracker);
368 thread::sleep(Duration::from_millis(10));
369 let result = state.finish();
370
371 assert_eq!(result.total_bytes, 0);
372 assert_eq!(result.avg_bps, 0.0);
373 assert_eq!(result.peak_bps, 0.0);
374 assert!(result.duration_secs > 0.0);
375 assert!(result.speed_samples.is_empty());
376 }
377
378 #[test]
379 fn test_finish_with_transfer() {
380 let tracker = make_tracker();
381 let state = LoopState::new(100_000_000, tracker);
382
383 state.record_bytes(20_000_000, SAMPLE_INTERVAL_MS);
384 thread::sleep(Duration::from_millis(100));
385
386 let result = state.finish();
387 assert_eq!(result.total_bytes, 20_000_000);
388 assert!(result.avg_bps > 0.0);
389 }
390
391 #[test]
392 fn test_finish_peak_gte_avg() {
393 let tracker = make_tracker();
394 let state = LoopState::new(100_000_000, tracker);
395
396 for _ in 0..5 {
397 state.record_bytes(5_000_000, SAMPLE_INTERVAL_MS);
398 thread::sleep(Duration::from_millis(60));
399 }
400
401 let result = state.finish();
402 assert!(result.peak_bps >= result.avg_bps);
403 }
404
405 #[test]
406 fn test_finish_various_estimated_totals() {
407 for estimated in [1u64, 1000, 1_000_000, u64::MAX / 2] {
408 let tracker = make_tracker();
409 let state = LoopState::new(estimated, tracker);
410 state.record_bytes(100, SAMPLE_INTERVAL_MS);
411 thread::sleep(Duration::from_millis(10));
412 let result = state.finish();
413 assert_eq!(result.total_bytes, 100);
414 }
415 }
416
417 #[test]
418 fn test_finish_returns_speed_samples() {
419 let tracker = make_tracker();
420 let state = LoopState::new(10_000_000, tracker);
421
422 for _ in 0..3 {
423 state.record_bytes(1_000_000, 10);
424 thread::sleep(Duration::from_millis(20));
425 }
426
427 let result = state.finish();
428 assert!(!result.speed_samples.is_empty());
429 for sample in &result.speed_samples {
430 assert!(*sample >= 0.0);
431 }
432 }
433
434 #[test]
435 fn test_sample_interval_constant() {
436 assert_eq!(SAMPLE_INTERVAL_MS, 50);
437 }
438
439 #[test]
440 fn test_bandwidth_result_struct() {
441 let tracker = make_tracker();
442 let state = LoopState::new(100_000_000, tracker);
443 state.record_bytes(50_000_000, SAMPLE_INTERVAL_MS);
444 thread::sleep(Duration::from_millis(100));
445
446 let result = state.finish();
447
448 assert!(result.avg_bps >= 0.0);
450 assert!(result.peak_bps >= 0.0);
451 assert!(result.total_bytes > 0);
452 assert!(result.duration_secs > 0.0);
453 }
454
455 #[tokio::test]
458 async fn test_run_concurrent_streams_zero_streams() {
459 let tracker = make_tracker();
460 let result = run_concurrent_streams(100_000_000, 0, tracker, "test", |_, _, _| {
461 tokio::spawn(async { Ok(()) })
462 })
463 .await;
464 assert!(result.is_err());
465 }
466
467 #[tokio::test]
468 async fn test_run_concurrent_streams_single_stream_success() {
469 let tracker = make_tracker();
470 let result =
471 run_concurrent_streams(100_000_000, 1, tracker, "download", |_, state, interval| {
472 let s = Arc::clone(&state);
473 tokio::spawn(async move {
474 s.record_bytes(10_000_000, interval);
475 Ok(())
476 })
477 })
478 .await;
479
480 assert!(result.is_ok());
481 assert_eq!(result.unwrap().total_bytes, 10_000_000);
482 }
483
484 #[tokio::test]
485 async fn test_run_concurrent_streams_four_streams() {
486 let tracker = make_tracker();
487 let result =
488 run_concurrent_streams(100_000_000, 4, tracker, "upload", |_, state, interval| {
489 let s = Arc::clone(&state);
490 tokio::spawn(async move {
491 s.record_bytes(1_000_000, interval);
492 Ok(())
493 })
494 })
495 .await;
496
497 assert!(result.is_ok());
498 assert_eq!(result.unwrap().total_bytes, 4_000_000);
499 }
500
501 #[tokio::test]
502 async fn test_run_concurrent_streams_all_fail() {
503 let tracker = make_tracker();
504 let result = run_concurrent_streams(100_000_000, 3, tracker, "download", |_, _, _| {
505 tokio::spawn(async { Err(Error::DownloadFailure("failed".into())) })
506 })
507 .await;
508
509 assert!(result.is_err());
510 }
511
512 #[tokio::test]
513 async fn test_run_concurrent_streams_partial_failure() {
514 let tracker = make_tracker();
515 let result =
516 run_concurrent_streams(100_000_000, 4, tracker, "upload", |i, state, interval| {
517 let s = Arc::clone(&state);
518 tokio::spawn(async move {
519 if i < 2 {
520 s.record_bytes(1_000_000, interval);
521 Ok(())
522 } else {
523 Err(Error::UploadFailure("failed".into()))
524 }
525 })
526 })
527 .await;
528
529 assert!(result.is_ok());
530 assert_eq!(result.unwrap().total_bytes, 2_000_000);
531 }
532
533 #[tokio::test]
534 async fn test_run_concurrent_streams_stream_panic() {
535 let tracker = make_tracker();
536 let result =
537 run_concurrent_streams(100_000_000, 2, tracker, "download", |i, state, interval| {
538 let s = Arc::clone(&state);
539 tokio::spawn(async move {
540 if i == 0 {
541 s.record_bytes(1_000_000, interval);
542 Ok(())
543 } else {
544 panic!("stream panicked");
545 }
546 })
547 })
548 .await;
549
550 assert!(result.is_ok());
551 assert_eq!(result.unwrap().total_bytes, 1_000_000);
552 }
553
554 #[tokio::test]
555 async fn test_run_concurrent_streams_zero_bytes_returns_error() {
556 let tracker = make_tracker();
557 let result = run_concurrent_streams(100_000_000, 2, tracker, "download", |_, _, _| {
558 tokio::spawn(async { Ok(()) })
559 })
560 .await;
561
562 assert!(result.is_err());
563 }
564
565 #[tokio::test]
566 async fn test_run_concurrent_streams_label_different_errors() {
567 for label in ["download", "upload", "custom"] {
568 let tracker = make_tracker();
569 let result = run_concurrent_streams(100_000_000, 0, tracker, label, |_, _, _| {
570 tokio::spawn(async { Ok(()) })
571 })
572 .await;
573
574 assert!(result.is_err());
575 let err_str = format!("{:?}", result.unwrap_err());
576 assert!(err_str.contains(label));
577 }
578 }
579
580 #[tokio::test]
581 async fn test_run_concurrent_streams_estimated_total_param() {
582 for estimated in [1_000u64, 10_000_000, 1_000_000_000] {
583 let tracker = make_tracker();
584 let result =
585 run_concurrent_streams(estimated, 1, tracker, "test", |_, state, interval| {
586 let s = Arc::clone(&state);
587 tokio::spawn(async move {
588 s.record_bytes(1000, interval);
589 Ok(())
590 })
591 })
592 .await;
593 assert!(result.is_ok());
594 }
595 }
596}