1#![cfg_attr(not(feature = "std"), no_std)]
21
22#[macro_export]
24macro_rules! assert_eq_error_rate {
25 ($x:expr, $y:expr, $error:expr $(,)?) => {
26 assert!(
27 ($x) >= (($y) - ($error)) && ($x) <= (($y) + ($error)),
28 "{:?} != {:?} (with error rate {:?})",
29 $x,
30 $y,
31 $error,
32 );
33 };
34}
35
36pub mod biguint;
37pub mod helpers_128bit;
38pub mod traits;
39pub mod per_things;
40pub mod fixed_point;
41pub mod rational;
42
43pub use fixed_point::{FixedPointNumber, FixedPointOperand, FixedI64, FixedI128, FixedU128};
44pub use per_things::{PerThing, InnerOf, UpperOf, Percent, PerU16, Permill, Perbill, Perquintill};
45pub use rational::{Rational128, RationalInfinite};
46
47use tetcore_std::{prelude::*, cmp::Ordering, fmt::Debug, convert::TryInto};
48use traits::{BaseArithmetic, One, Zero, SaturatedConversion, Unsigned};
49
50pub trait ThresholdOrd<T> {
57 fn tcmp(&self, other: &T, epsilon: T) -> Ordering;
59}
60
61impl<T> ThresholdOrd<T> for T
62where
63 T: Ord + PartialOrd + Copy + Clone + traits::Zero + traits::Saturating,
64{
65 fn tcmp(&self, other: &T, threshold: T) -> Ordering {
66 if threshold.is_zero() {
68 return self.cmp(&other)
69 }
70
71 let upper_bound = other.saturating_add(threshold);
72 let lower_bound = other.saturating_sub(threshold);
73
74 if upper_bound <= lower_bound {
75 self.cmp(&other)
77 } else {
78 match (self.cmp(&lower_bound), self.cmp(&upper_bound)) {
80 (Ordering::Greater, Ordering::Greater) => Ordering::Greater,
81 (Ordering::Less, Ordering::Less) => Ordering::Less,
82 _ => Ordering::Equal,
83 }
84 }
85
86 }
87}
88
89pub trait Normalizable<T> {
94 fn normalize(&self, targeted_sum: T) -> Result<Vec<T>, &'static str>;
99}
100
101macro_rules! impl_normalize_for_numeric {
102 ($($numeric:ty),*) => {
103 $(
104 impl Normalizable<$numeric> for Vec<$numeric> {
105 fn normalize(&self, targeted_sum: $numeric) -> Result<Vec<$numeric>, &'static str> {
106 normalize(self.as_ref(), targeted_sum)
107 }
108 }
109 )*
110 };
111}
112
113impl_normalize_for_numeric!(u8, u16, u32, u64, u128);
114
115impl<P: PerThing> Normalizable<P> for Vec<P> {
116 fn normalize(&self, targeted_sum: P) -> Result<Vec<P>, &'static str> {
117 let uppers =
118 self.iter().map(|p| <UpperOf<P>>::from(p.clone().deconstruct())).collect::<Vec<_>>();
119
120 let normalized =
121 normalize(uppers.as_ref(), <UpperOf<P>>::from(targeted_sum.deconstruct()))?;
122
123 Ok(normalized
124 .into_iter()
125 .map(|i: UpperOf<P>| P::from_parts(i.saturated_into::<P::Inner>()))
126 .collect())
127 }
128}
129
130pub fn normalize<T>(input: &[T], targeted_sum: T) -> Result<Vec<T>, &'static str>
160 where T: Clone + Copy + Ord + BaseArithmetic + Unsigned + Debug,
161{
162 let mut sum = T::zero();
164 for t in input.iter() {
165 sum = sum.checked_add(t).ok_or("sum of input cannot fit in `T`")?;
166 }
167
168 let count = input.len();
170 let count_t: T = count.try_into().map_err(|_| "length of `inputs` cannot fit in `T`")?;
171
172 if count.is_zero() {
174 return Ok(Vec::<T>::new());
175 }
176
177 let diff = targeted_sum.max(sum) - targeted_sum.min(sum);
178 if diff.is_zero() {
179 return Ok(input.to_vec());
180 }
181
182 let needs_bump = targeted_sum > sum;
183 let per_round = diff / count_t;
184 let mut leftover = diff % count_t;
185
186 let mut output_with_idx = input.iter().cloned().enumerate().collect::<Vec<(usize, T)>>();
189 output_with_idx.sort_by_key(|x| x.1);
190
191 if needs_bump {
192 let mut min_index = 0;
196 let threshold = targeted_sum / count_t;
198
199 if !per_round.is_zero() {
200 for _ in 0..count {
201 output_with_idx[min_index].1 = output_with_idx[min_index].1
202 .checked_add(&per_round)
203 .expect("Proof provided in the module doc; qed.");
204 if output_with_idx[min_index].1 >= threshold {
205 min_index += 1;
206 min_index = min_index % count;
207 }
208 }
209 }
210
211 while !leftover.is_zero() {
213 output_with_idx[min_index].1 = output_with_idx[min_index].1
214 .checked_add(&T::one())
215 .expect("Proof provided in the module doc; qed.");
216 if output_with_idx[min_index].1 >= threshold {
217 min_index += 1;
218 min_index = min_index % count;
219 }
220 leftover -= One::one()
221 }
222 } else {
223 let mut max_index = count - 1;
227 let threshold = output_with_idx
229 .first()
230 .expect("length of input is greater than zero; it must have a first; qed")
231 .1;
232
233 if !per_round.is_zero() {
234 for _ in 0..count {
235 output_with_idx[max_index].1 = output_with_idx[max_index].1
236 .checked_sub(&per_round)
237 .unwrap_or_else(|| {
238 let remainder = per_round - output_with_idx[max_index].1;
239 leftover += remainder;
240 output_with_idx[max_index].1.saturating_sub(per_round)
241 });
242 if output_with_idx[max_index].1 <= threshold {
243 max_index = max_index.checked_sub(1).unwrap_or(count - 1);
244 }
245 }
246 }
247
248 while !leftover.is_zero() {
250 if let Some(next) = output_with_idx[max_index].1.checked_sub(&One::one()) {
251 output_with_idx[max_index].1 = next;
252 if output_with_idx[max_index].1 <= threshold {
253 max_index = max_index.checked_sub(1).unwrap_or(count - 1);
254 }
255 leftover -= One::one()
256 } else {
257 max_index = max_index.checked_sub(1).unwrap_or(count - 1);
258 }
259 }
260 }
261
262 debug_assert_eq!(
263 output_with_idx.iter().fold(T::zero(), |acc, (_, x)| acc + *x),
264 targeted_sum,
265 "sum({:?}) != {:?}",
266 output_with_idx,
267 targeted_sum,
268 );
269
270 output_with_idx.sort_by_key(|x| x.0);
272 Ok(output_with_idx.into_iter().map(|(_, t)| t).collect())
273}
274
275#[cfg(test)]
276mod normalize_tests {
277 use super::*;
278
279 #[test]
280 fn work_for_all_types() {
281 macro_rules! test_for {
282 ($type:ty) => {
283 assert_eq!(
284 normalize(vec![8 as $type, 9, 7, 10].as_ref(), 40).unwrap(),
285 vec![10, 10, 10, 10],
286 );
287 }
288 }
289 test_for!(u128);
291 test_for!(u64);
292 test_for!(u32);
293 test_for!(u16);
294 test_for!(u8);
295 }
296
297 #[test]
298 fn fails_on_if_input_sum_large() {
299 assert!(normalize(vec![1u8; 255].as_ref(), 10).is_ok());
300 assert_eq!(
301 normalize(vec![1u8; 256].as_ref(), 10),
302 Err("sum of input cannot fit in `T`"),
303 );
304 }
305
306 #[test]
307 fn does_not_fail_on_subtraction_overflow() {
308 assert_eq!(
309 normalize(vec![1u8, 100, 100].as_ref(), 10).unwrap(),
310 vec![1, 9, 0],
311 );
312 assert_eq!(
313 normalize(vec![1u8, 8, 9].as_ref(), 1).unwrap(),
314 vec![0, 1, 0],
315 );
316 }
317
318 #[test]
319 fn works_for_vec() {
320 assert_eq!(vec![8u32, 9, 7, 10].normalize(40).unwrap(), vec![10u32, 10, 10, 10]);
321 }
322
323 #[test]
324 fn works_for_per_thing() {
325 assert_eq!(
326 vec![
327 Perbill::from_percent(33),
328 Perbill::from_percent(33),
329 Perbill::from_percent(33)
330 ].normalize(Perbill::one()).unwrap(),
331 vec![
332 Perbill::from_parts(333333334),
333 Perbill::from_parts(333333333),
334 Perbill::from_parts(333333333),
335 ]
336 );
337
338 assert_eq!(
339 vec![
340 Perbill::from_percent(20),
341 Perbill::from_percent(15),
342 Perbill::from_percent(30)
343 ].normalize(Perbill::one()).unwrap(),
344 vec![
345 Perbill::from_parts(316666668),
346 Perbill::from_parts(383333332),
347 Perbill::from_parts(300000000),
348 ]
349 );
350 }
351
352 #[test]
353 fn can_work_for_peru16() {
354 assert_eq!(
358 vec![
359 PerU16::from_percent(40),
360 PerU16::from_percent(40),
361 PerU16::from_percent(40),
362 ].normalize(PerU16::one()).unwrap(),
363 vec![
364 PerU16::from_parts(21845), PerU16::from_parts(21845), PerU16::from_parts(21845), ]
368 );
369 }
370
371 #[test]
372 fn normalize_works_all_le() {
373 assert_eq!(
374 normalize(vec![8u32, 9, 7, 10].as_ref(), 40).unwrap(),
375 vec![10, 10, 10, 10],
376 );
377
378 assert_eq!(
379 normalize(vec![7u32, 7, 7, 7].as_ref(), 40).unwrap(),
380 vec![10, 10, 10, 10],
381 );
382
383 assert_eq!(
384 normalize(vec![7u32, 7, 7, 10].as_ref(), 40).unwrap(),
385 vec![11, 11, 8, 10],
386 );
387
388 assert_eq!(
389 normalize(vec![7u32, 8, 7, 10].as_ref(), 40).unwrap(),
390 vec![11, 8, 11, 10],
391 );
392
393 assert_eq!(
394 normalize(vec![7u32, 7, 8, 10].as_ref(), 40).unwrap(),
395 vec![11, 11, 8, 10],
396 );
397 }
398
399 #[test]
400 fn normalize_works_some_ge() {
401 assert_eq!(
402 normalize(vec![8u32, 11, 9, 10].as_ref(), 40).unwrap(),
403 vec![10, 11, 9, 10],
404 );
405 }
406
407 #[test]
408 fn always_inc_min() {
409 assert_eq!(
410 normalize(vec![10u32, 7, 10, 10].as_ref(), 40).unwrap(),
411 vec![10, 10, 10, 10],
412 );
413 assert_eq!(
414 normalize(vec![10u32, 10, 7, 10].as_ref(), 40).unwrap(),
415 vec![10, 10, 10, 10],
416 );
417 assert_eq!(
418 normalize(vec![10u32, 10, 10, 7].as_ref(), 40).unwrap(),
419 vec![10, 10, 10, 10],
420 );
421 }
422
423 #[test]
424 fn normalize_works_all_ge() {
425 assert_eq!(
426 normalize(vec![12u32, 11, 13, 10].as_ref(), 40).unwrap(),
427 vec![10, 10, 10, 10],
428 );
429
430 assert_eq!(
431 normalize(vec![13u32, 13, 13, 13].as_ref(), 40).unwrap(),
432 vec![10, 10, 10, 10],
433 );
434
435 assert_eq!(
436 normalize(vec![13u32, 13, 13, 10].as_ref(), 40).unwrap(),
437 vec![12, 9, 9, 10],
438 );
439
440 assert_eq!(
441 normalize(vec![13u32, 12, 13, 10].as_ref(), 40).unwrap(),
442 vec![9, 12, 9, 10],
443 );
444
445 assert_eq!(
446 normalize(vec![13u32, 13, 12, 10].as_ref(), 40).unwrap(),
447 vec![9, 9, 12, 10],
448 );
449 }
450}
451
452#[cfg(test)]
453mod threshold_compare_tests {
454 use super::*;
455 use crate::traits::Saturating;
456 use tetcore_std::cmp::Ordering;
457
458 #[test]
459 fn epsilon_ord_works() {
460 let b = 115u32;
461 let e = Perbill::from_percent(10).mul_ceil(b);
462
463 assert_eq!(103u32.tcmp(&b, e), Ordering::Equal);
465 assert_eq!(104u32.tcmp(&b, e), Ordering::Equal);
466 assert_eq!(115u32.tcmp(&b, e), Ordering::Equal);
467 assert_eq!(120u32.tcmp(&b, e), Ordering::Equal);
468 assert_eq!(126u32.tcmp(&b, e), Ordering::Equal);
469 assert_eq!(127u32.tcmp(&b, e), Ordering::Equal);
470
471 assert_eq!(128u32.tcmp(&b, e), Ordering::Greater);
472 assert_eq!(102u32.tcmp(&b, e), Ordering::Less);
473 }
474
475 #[test]
476 fn epsilon_ord_works_with_small_epc() {
477 let b = 115u32;
478 let e = Perbill::from_parts(100) * b;
480
481 assert_eq!(103u32.tcmp(&b, e), 103u32.cmp(&b));
483 assert_eq!(104u32.tcmp(&b, e), 104u32.cmp(&b));
484 assert_eq!(115u32.tcmp(&b, e), 115u32.cmp(&b));
485 assert_eq!(120u32.tcmp(&b, e), 120u32.cmp(&b));
486 assert_eq!(126u32.tcmp(&b, e), 126u32.cmp(&b));
487 assert_eq!(127u32.tcmp(&b, e), 127u32.cmp(&b));
488
489 assert_eq!(128u32.tcmp(&b, e), 128u32.cmp(&b));
490 assert_eq!(102u32.tcmp(&b, e), 102u32.cmp(&b));
491 }
492
493 #[test]
494 fn peru16_rational_does_not_overflow() {
495 let _ = PerU16::from_rational_approximation(17424870u32, 17424870);
498 }
499
500 #[test]
501 fn saturating_mul_works() {
502 assert_eq!(Saturating::saturating_mul(2, i32::min_value()), i32::min_value());
503 assert_eq!(Saturating::saturating_mul(2, i32::max_value()), i32::max_value());
504 }
505
506 #[test]
507 fn saturating_pow_works() {
508 assert_eq!(Saturating::saturating_pow(i32::min_value(), 0), 1);
509 assert_eq!(Saturating::saturating_pow(i32::max_value(), 0), 1);
510 assert_eq!(Saturating::saturating_pow(i32::min_value(), 3), i32::min_value());
511 assert_eq!(Saturating::saturating_pow(i32::min_value(), 2), i32::max_value());
512 assert_eq!(Saturating::saturating_pow(i32::max_value(), 2), i32::max_value());
513 }
514}