1use std::time::Duration;
2
3use rust_decimal::{Decimal, MathematicalOps, dec};
4
5use crate::{FileSizeFormat, format_parts};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
11pub struct DownloadAcceleration {
12 bytes_per_second_sq: Decimal,
13}
14
15impl DownloadAcceleration {
16 #[inline]
32 pub fn from_raw(bytes_per_second_sq: i64) -> Self {
33 Self {
34 bytes_per_second_sq: Decimal::from(bytes_per_second_sq),
35 }
36 }
37
38 pub fn new(initial_speed: u64, final_speed: u64, duration: Duration) -> Self {
52 let seconds = Decimal::from(duration.as_secs())
53 + Decimal::from(duration.subsec_nanos()) / Decimal::from(1_000_000_000);
54 let speed_diff = Decimal::from(final_speed) - Decimal::from(initial_speed);
55 let bytes_per_second_sq = if seconds.is_zero() {
56 Decimal::ZERO
57 } else {
58 speed_diff / seconds
59 };
60
61 Self {
62 bytes_per_second_sq,
63 }
64 }
65
66 #[inline]
70 pub fn as_decimal(&self) -> Decimal {
71 self.bytes_per_second_sq
72 }
73
74 #[inline]
78 pub fn as_i64(&self) -> i64 {
79 self.bytes_per_second_sq.floor().try_into().unwrap_or(0)
80 }
81
82 fn format_parts(
83 &self,
84 base: Decimal,
85 units: &'static [&'static str],
86 ) -> (String, &'static str) {
87 let mut value = Decimal::from(self.bytes_per_second_sq);
88 let is_negative = value.is_sign_negative();
89
90 if is_negative {
91 value.set_sign_positive(true);
92 }
93
94 let (formatted_value, unit) = format_parts(value, base, units);
95 let formatted_value = if is_negative {
96 format!("-{}", formatted_value)
97 } else {
98 formatted_value
99 };
100
101 (formatted_value, unit)
102 }
103
104 #[inline]
108 fn linear_prediction(current_speed: Decimal, remaining_bytes: Decimal) -> Option<Decimal> {
109 (current_speed > Decimal::ZERO).then(|| remaining_bytes / current_speed)
110 }
111
112 pub fn predict_eta(&self, current_speed: Decimal, remaining_bytes: u64) -> Option<Decimal> {
143 if !current_speed.is_sign_positive() {
145 return None;
146 }
147
148 let remaining_bytes_decimal = Decimal::from(remaining_bytes);
149 let accel = self.bytes_per_second_sq;
150
151 const ACCEL_THRESHOLD: Decimal = dec!(0.1);
153
154 if accel.abs() < ACCEL_THRESHOLD {
156 return Self::linear_prediction(current_speed, remaining_bytes_decimal);
157 }
158
159 self.solve_quadratic_eta(current_speed, remaining_bytes_decimal, accel)
161 }
162
163 fn solve_quadratic_eta(
167 &self,
168 current_speed: Decimal,
169 remaining_bytes: Decimal,
170 accel: Decimal,
171 ) -> Option<Decimal> {
172 const HALF: Decimal = dec!(0.5);
173 const TWO: Decimal = dec!(2);
174
175 let a = HALF * accel;
177 let b = current_speed;
178 let c = -remaining_bytes;
179
180 let discriminant = self.calculate_discriminant(a, b, c)?;
182
183 if discriminant < Decimal::ZERO {
185 return Self::linear_prediction(current_speed, remaining_bytes);
186 }
187
188 let sqrt_discriminant = discriminant.sqrt()?;
190 let two_a = TWO * a;
191
192 let t1 = (-b + sqrt_discriminant) / two_a;
194 let t2 = (-b - sqrt_discriminant) / two_a;
195
196 match (t1.is_sign_positive(), t2.is_sign_positive()) {
198 (true, true) => Some(t1.min(t2)),
199 (true, false) => Some(t1),
200 (false, true) => Some(t2),
201 (false, false) => Self::linear_prediction(current_speed, remaining_bytes),
202 }
203 }
204
205 #[inline]
209 fn calculate_discriminant(&self, a: Decimal, b: Decimal, c: Decimal) -> Option<Decimal> {
210 let b_squared = b.checked_mul(b)?;
211 let four_ac = dec!(4).checked_mul(a)?.checked_mul(c)?;
212 b_squared.checked_sub(four_ac)
213 }
214}
215
216impl FileSizeFormat for DownloadAcceleration {
217 fn get_si_parts(&self) -> (String, &'static str) {
221 const UNITS: &[&str] = &[
222 "B/s²", "KB/s²", "MB/s²", "GB/s²", "TB/s²", "PB/s²", "EB/s²", "ZB/s²", "YB/s²",
223 ];
224 self.format_parts(Decimal::from(1000), UNITS)
225 }
226
227 fn get_iec_parts(&self) -> (String, &'static str) {
231 const UNITS: &[&str] = &[
232 "B/s²", "KiB/s²", "MiB/s²", "GiB/s²", "TiB/s²", "PiB/s²", "EiB/s²", "ZiB/s²", "YiB/s²",
233 ];
234 self.format_parts(Decimal::from(1024), UNITS)
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use rust_decimal::Decimal;
241
242 use super::DownloadAcceleration;
243 use crate::{FormattedValue, SizeStandard};
244 fn format_test_si(accel: i64) -> String {
248 FormattedValue::new(DownloadAcceleration::from_raw(accel), SizeStandard::SI).to_string()
249 }
250
251 fn format_test_iec(accel: i64) -> String {
255 FormattedValue::new(DownloadAcceleration::from_raw(accel), SizeStandard::IEC).to_string()
256 }
257
258 #[test]
259 fn test_predict_eta() {
260 let acc = DownloadAcceleration::from_raw(1000); let eta = acc
266 .predict_eta(Decimal::new(100, 0), 1000)
267 .expect("Should have a valid ETA");
268 assert!(eta < Decimal::new(10, 0)); assert!(eta > Decimal::ZERO);
270
271 let acc = DownloadAcceleration::from_raw(-50); let eta = acc
277 .predict_eta(Decimal::new(200, 0), 1000)
278 .expect("Should have a valid ETA");
279 assert!(eta >= Decimal::new(5, 0)); let acc = DownloadAcceleration::from_raw(0);
283 let eta = acc
284 .predict_eta(Decimal::new(100, 0), 1000)
285 .expect("Should have a valid ETA");
286 assert_eq!(eta, Decimal::new(10, 0)); let eta = acc.predict_eta(Decimal::ZERO, 1000);
290 assert!(eta.is_none());
291
292 let eta = acc
294 .predict_eta(Decimal::new(100, 0), 1000)
295 .expect("Should have a valid ETA");
296 assert_eq!(eta, Decimal::new(10, 0));
297 }
298
299 #[test]
300 fn test_predict_eta_edge_cases() {
301 let acc = DownloadAcceleration::from_raw(1_000_000);
303 let eta = acc
304 .predict_eta(Decimal::new(1000, 0), 1_000_000)
305 .expect("Should have a valid ETA");
306 assert!(eta > Decimal::ZERO);
307 assert!(eta < Decimal::new(1000, 0)); let acc = DownloadAcceleration::from_raw(100);
311 let eta = acc
312 .predict_eta(Decimal::new(1000, 0), 1)
313 .expect("Should have a valid ETA");
314 assert!(eta > Decimal::ZERO);
315 assert!(eta < Decimal::new(1, 0));
316
317 let acc = DownloadAcceleration::from_raw(-100);
319 let eta = acc
320 .predict_eta(Decimal::new(500, 0), 2000)
321 .expect("Should have a valid ETA");
322 assert!(eta > Decimal::ZERO);
323
324 let acc = DownloadAcceleration::from_raw(0);
326 let eta_linear = acc
327 .predict_eta(Decimal::new(100, 0), 1000)
328 .expect("Should have a valid ETA");
329
330 let acc_threshold = DownloadAcceleration::from_raw(0);
331 let eta_threshold = acc_threshold
332 .predict_eta(Decimal::new(100, 0), 1000)
333 .expect("Should have a valid ETA");
334
335 assert_eq!(eta_linear, eta_threshold);
336 }
337
338 #[test]
341 fn test_si_acceleration() {
342 assert_eq!(format_test_si(512), "512.0 B/s²");
344 assert_eq!(format_test_si(1000), "1.00 KB/s²");
345 assert_eq!(format_test_si(1024), "1.02 KB/s²");
346
347 assert_eq!(format_test_si(-1500), "-1.50 KB/s²");
349 }
350
351 #[test]
354 fn test_iec_acceleration() {
355 assert_eq!(format_test_iec(1024), "1.00 KiB/s²");
357 assert_eq!(format_test_iec(1500), "1.46 KiB/s²");
358
359 assert_eq!(format_test_iec(-2048), "-2.00 KiB/s²");
361 }
362}