arithmetic/
lib.rs

1// This file is part of Tetcore.
2
3// Copyright (C) 2019-2021 Parity Technologies (UK) Ltd.
4// SPDX-License-Identifier: Apache-2.0
5
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10// 	http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! Minimal fixed point arithmetic primitives and types for runtime.
19
20#![cfg_attr(not(feature = "std"), no_std)]
21
22/// Copied from `tp-runtime` and documented there.
23#[macro_export]
24macro_rules! assert_eq_error_rate {
25	($x:expr, $y:expr, $error:expr $(,)?) => {
26		assert!(
27			($x) >= (($y) - ($error)) && ($x) <= (($y) + ($error)),
28			"{:?} != {:?} (with error rate {:?})",
29			$x,
30			$y,
31			$error,
32		);
33	};
34}
35
36pub mod biguint;
37pub mod helpers_128bit;
38pub mod traits;
39pub mod per_things;
40pub mod fixed_point;
41pub mod rational;
42
43pub use fixed_point::{FixedPointNumber, FixedPointOperand, FixedI64, FixedI128, FixedU128};
44pub use per_things::{PerThing, InnerOf, UpperOf, Percent, PerU16, Permill, Perbill, Perquintill};
45pub use rational::{Rational128, RationalInfinite};
46
47use tetcore_std::{prelude::*, cmp::Ordering, fmt::Debug, convert::TryInto};
48use traits::{BaseArithmetic, One, Zero, SaturatedConversion, Unsigned};
49
50/// Trait for comparing two numbers with an threshold.
51///
52/// Returns:
53/// - `Ordering::Greater` if `self` is greater than `other + threshold`.
54/// - `Ordering::Less` if `self` is less than `other - threshold`.
55/// - `Ordering::Equal` otherwise.
56pub trait ThresholdOrd<T> {
57	/// Compare if `self` is `threshold` greater or less than `other`.
58	fn tcmp(&self, other: &T, epsilon: T) -> Ordering;
59}
60
61impl<T> ThresholdOrd<T> for T
62where
63	T: Ord + PartialOrd + Copy + Clone + traits::Zero + traits::Saturating,
64{
65	fn tcmp(&self, other: &T, threshold: T) -> Ordering {
66		// early exit.
67		if threshold.is_zero() {
68			return self.cmp(&other)
69		}
70
71		let upper_bound = other.saturating_add(threshold);
72		let lower_bound = other.saturating_sub(threshold);
73
74		if upper_bound <= lower_bound {
75			// defensive only. Can never happen.
76			self.cmp(&other)
77		} else {
78			// upper_bound is guaranteed now to be bigger than lower.
79			match (self.cmp(&lower_bound), self.cmp(&upper_bound)) {
80				(Ordering::Greater, Ordering::Greater) => Ordering::Greater,
81				(Ordering::Less, Ordering::Less) => Ordering::Less,
82				_ => Ordering::Equal,
83			}
84		}
85
86	}
87}
88
89/// A collection-like object that is made of values of type `T` and can normalize its individual
90/// values around a centric point.
91///
92/// Note that the order of items in the collection may affect the result.
93pub trait Normalizable<T> {
94	/// Normalize self around `targeted_sum`.
95	///
96	/// Only returns `Ok` if the new sum of results is guaranteed to be equal to `targeted_sum`.
97	/// Else, returns an error explaining why it failed to do so.
98	fn normalize(&self, targeted_sum: T) -> Result<Vec<T>, &'static str>;
99}
100
101macro_rules! impl_normalize_for_numeric {
102	($($numeric:ty),*) => {
103		$(
104			impl Normalizable<$numeric> for Vec<$numeric> {
105				fn normalize(&self, targeted_sum: $numeric) -> Result<Vec<$numeric>, &'static str> {
106					normalize(self.as_ref(), targeted_sum)
107				}
108			}
109		)*
110	};
111}
112
113impl_normalize_for_numeric!(u8, u16, u32, u64, u128);
114
115impl<P: PerThing> Normalizable<P> for Vec<P> {
116	fn normalize(&self, targeted_sum: P) -> Result<Vec<P>, &'static str> {
117		let uppers =
118			self.iter().map(|p| <UpperOf<P>>::from(p.clone().deconstruct())).collect::<Vec<_>>();
119
120		let normalized =
121			normalize(uppers.as_ref(), <UpperOf<P>>::from(targeted_sum.deconstruct()))?;
122
123		Ok(normalized
124			.into_iter()
125			.map(|i: UpperOf<P>| P::from_parts(i.saturated_into::<P::Inner>()))
126			.collect())
127	}
128}
129
130/// Normalize `input` so that the sum of all elements reaches `targeted_sum`.
131///
132/// This implementation is currently in a balanced position between being performant and accurate.
133///
134/// 1. We prefer storing original indices, and sorting the `input` only once. This will save the
135///    cost of sorting per round at the cost of a little bit of memory.
136/// 2. The granularity of increment/decrements is determined by the number of elements in `input`
137///    and their sum difference with `targeted_sum`, namely `diff = diff(sum(input), target_sum)`.
138///    This value is then distributed into `per_round = diff / input.len()` and `leftover = diff %
139///    round`. First, per_round is applied to all elements of input, and then we move to leftover,
140///    in which case we add/subtract 1 by 1 until `leftover` is depleted.
141///
142/// When the sum is less than the target, the above approach always holds. In this case, then each
143/// individual element is also less than target. Thus, by adding `per_round` to each item, neither
144/// of them can overflow the numeric bound of `T`. In fact, neither of the can go beyond
145/// `target_sum`*.
146///
147/// If sum is more than target, there is small twist. The subtraction of `per_round`
148/// form each element might go below zero. In this case, we saturate and add the error to the
149/// `leftover` value. This ensures that the result will always stay accurate, yet it might cause the
150/// execution to become increasingly slow, since leftovers are applied one by one.
151///
152/// All in all, the complicated case above is rare to happen in most use cases within this repo ,
153/// hence we opt for it due to its simplicity.
154///
155/// This function will return an error is if length of `input` cannot fit in `T`, or if `sum(input)`
156/// cannot fit inside `T`.
157///
158/// * This proof is used in the implementation as well.
159pub fn normalize<T>(input: &[T], targeted_sum: T) -> Result<Vec<T>, &'static str>
160	where T: Clone + Copy + Ord + BaseArithmetic + Unsigned + Debug,
161{
162	// compute sum and return error if failed.
163	let mut sum = T::zero();
164	for t in input.iter() {
165		sum = sum.checked_add(t).ok_or("sum of input cannot fit in `T`")?;
166	}
167
168	// convert count and return error if failed.
169	let count = input.len();
170	let count_t: T = count.try_into().map_err(|_| "length of `inputs` cannot fit in `T`")?;
171
172	// Nothing to do here.
173	if count.is_zero() {
174		return Ok(Vec::<T>::new());
175	}
176
177	let diff = targeted_sum.max(sum) - targeted_sum.min(sum);
178	if diff.is_zero() {
179		return Ok(input.to_vec());
180	}
181
182	let needs_bump = targeted_sum > sum;
183	let per_round = diff / count_t;
184	let mut leftover = diff % count_t;
185
186	// sort output once based on diff. This will require more data transfer and saving original
187	// index, but we sort only twice instead: once now and once at the very end.
188	let mut output_with_idx = input.iter().cloned().enumerate().collect::<Vec<(usize, T)>>();
189	output_with_idx.sort_by_key(|x| x.1);
190
191	if needs_bump {
192		// must increase the values a bit. Bump from the min element. Index of minimum is now zero
193		// because we did a sort. If at any point the min goes greater or equal the `max_threshold`,
194		// we move to the next minimum.
195		let mut min_index = 0;
196		// at this threshold we move to next index.
197		let threshold = targeted_sum / count_t;
198
199		if !per_round.is_zero() {
200			for _ in 0..count {
201				output_with_idx[min_index].1 = output_with_idx[min_index].1
202					.checked_add(&per_round)
203					.expect("Proof provided in the module doc; qed.");
204				if output_with_idx[min_index].1 >= threshold {
205					min_index += 1;
206					min_index = min_index % count;
207				}
208			}
209		}
210
211		// continue with the previous min_index
212		while !leftover.is_zero() {
213			output_with_idx[min_index].1 = output_with_idx[min_index].1
214				.checked_add(&T::one())
215				.expect("Proof provided in the module doc; qed.");
216			if output_with_idx[min_index].1 >= threshold {
217				min_index += 1;
218				min_index = min_index % count;
219			}
220			leftover -= One::one()
221		}
222	} else {
223		// must decrease the stakes a bit. decrement from the max element. index of maximum is now
224		// last. if at any point the max goes less or equal the `min_threshold`, we move to the next
225		// maximum.
226		let mut max_index = count - 1;
227		// at this threshold we move to next index.
228		let threshold = output_with_idx
229			.first()
230			.expect("length of input is greater than zero; it must have a first; qed")
231			.1;
232
233		if !per_round.is_zero() {
234			for _ in 0..count {
235				output_with_idx[max_index].1 = output_with_idx[max_index].1
236					.checked_sub(&per_round)
237					.unwrap_or_else(|| {
238						let remainder = per_round - output_with_idx[max_index].1;
239						leftover += remainder;
240						output_with_idx[max_index].1.saturating_sub(per_round)
241					});
242				if output_with_idx[max_index].1 <= threshold {
243					max_index = max_index.checked_sub(1).unwrap_or(count - 1);
244				}
245			}
246		}
247
248		// continue with the previous max_index
249		while !leftover.is_zero() {
250			if let Some(next) = output_with_idx[max_index].1.checked_sub(&One::one()) {
251				output_with_idx[max_index].1 = next;
252				if output_with_idx[max_index].1 <= threshold {
253					max_index = max_index.checked_sub(1).unwrap_or(count - 1);
254				}
255				leftover -= One::one()
256			} else {
257				max_index = max_index.checked_sub(1).unwrap_or(count - 1);
258			}
259		}
260	}
261
262	debug_assert_eq!(
263		output_with_idx.iter().fold(T::zero(), |acc, (_, x)| acc + *x),
264		targeted_sum,
265		"sum({:?}) != {:?}",
266		output_with_idx,
267		targeted_sum,
268	);
269
270	// sort again based on the original index.
271	output_with_idx.sort_by_key(|x| x.0);
272	Ok(output_with_idx.into_iter().map(|(_, t)| t).collect())
273}
274
275#[cfg(test)]
276mod normalize_tests {
277	use super::*;
278
279	#[test]
280	fn work_for_all_types() {
281		macro_rules! test_for {
282			($type:ty) => {
283				assert_eq!(
284					normalize(vec![8 as $type, 9, 7, 10].as_ref(), 40).unwrap(),
285					vec![10, 10, 10, 10],
286				);
287			}
288		}
289		// it should work for all types as long as the length of vector can be converted to T.
290		test_for!(u128);
291		test_for!(u64);
292		test_for!(u32);
293		test_for!(u16);
294		test_for!(u8);
295	}
296
297	#[test]
298	fn fails_on_if_input_sum_large() {
299		assert!(normalize(vec![1u8; 255].as_ref(), 10).is_ok());
300		assert_eq!(
301			normalize(vec![1u8; 256].as_ref(), 10),
302			Err("sum of input cannot fit in `T`"),
303		);
304	}
305
306	#[test]
307	fn does_not_fail_on_subtraction_overflow() {
308		assert_eq!(
309			normalize(vec![1u8, 100, 100].as_ref(), 10).unwrap(),
310			vec![1, 9, 0],
311		);
312		assert_eq!(
313			normalize(vec![1u8, 8, 9].as_ref(), 1).unwrap(),
314			vec![0, 1, 0],
315		);
316	}
317
318	#[test]
319	fn works_for_vec() {
320		assert_eq!(vec![8u32, 9, 7, 10].normalize(40).unwrap(), vec![10u32, 10, 10, 10]);
321	}
322
323	#[test]
324	fn works_for_per_thing() {
325		assert_eq!(
326			vec![
327				Perbill::from_percent(33),
328				Perbill::from_percent(33),
329				Perbill::from_percent(33)
330			].normalize(Perbill::one()).unwrap(),
331			vec![
332				Perbill::from_parts(333333334),
333				Perbill::from_parts(333333333),
334				Perbill::from_parts(333333333),
335			]
336		);
337
338		assert_eq!(
339			vec![
340				Perbill::from_percent(20),
341				Perbill::from_percent(15),
342				Perbill::from_percent(30)
343			].normalize(Perbill::one()).unwrap(),
344			vec![
345				Perbill::from_parts(316666668),
346				Perbill::from_parts(383333332),
347				Perbill::from_parts(300000000),
348			]
349		);
350	}
351
352	#[test]
353	fn can_work_for_peru16() {
354		// Peru16 is a rather special case; since inner type is exactly the same as capacity, we
355		// could have a situation where the sum cannot be calculated in the inner type. Calculating
356		// using the upper type of the per_thing should assure this to be okay.
357		assert_eq!(
358			vec![
359				PerU16::from_percent(40),
360				PerU16::from_percent(40),
361				PerU16::from_percent(40),
362			].normalize(PerU16::one()).unwrap(),
363			vec![
364				PerU16::from_parts(21845), // 33%
365				PerU16::from_parts(21845), // 33%
366				PerU16::from_parts(21845), // 33%
367			]
368		);
369	}
370
371	#[test]
372	fn normalize_works_all_le() {
373		assert_eq!(
374			normalize(vec![8u32, 9, 7, 10].as_ref(), 40).unwrap(),
375			vec![10, 10, 10, 10],
376		);
377
378		assert_eq!(
379			normalize(vec![7u32, 7, 7, 7].as_ref(), 40).unwrap(),
380			vec![10, 10, 10, 10],
381		);
382
383		assert_eq!(
384			normalize(vec![7u32, 7, 7, 10].as_ref(), 40).unwrap(),
385			vec![11, 11, 8, 10],
386		);
387
388		assert_eq!(
389			normalize(vec![7u32, 8, 7, 10].as_ref(), 40).unwrap(),
390			vec![11, 8, 11, 10],
391		);
392
393		assert_eq!(
394			normalize(vec![7u32, 7, 8, 10].as_ref(), 40).unwrap(),
395			vec![11, 11, 8, 10],
396		);
397	}
398
399	#[test]
400	fn normalize_works_some_ge() {
401		assert_eq!(
402			normalize(vec![8u32, 11, 9, 10].as_ref(), 40).unwrap(),
403			vec![10, 11, 9, 10],
404		);
405	}
406
407	#[test]
408	fn always_inc_min() {
409		assert_eq!(
410			normalize(vec![10u32, 7, 10, 10].as_ref(), 40).unwrap(),
411			vec![10, 10, 10, 10],
412		);
413		assert_eq!(
414			normalize(vec![10u32, 10, 7, 10].as_ref(), 40).unwrap(),
415			vec![10, 10, 10, 10],
416		);
417		assert_eq!(
418			normalize(vec![10u32, 10, 10, 7].as_ref(), 40).unwrap(),
419			vec![10, 10, 10, 10],
420		);
421	}
422
423	#[test]
424	fn normalize_works_all_ge() {
425		assert_eq!(
426			normalize(vec![12u32, 11, 13, 10].as_ref(), 40).unwrap(),
427			vec![10, 10, 10, 10],
428		);
429
430		assert_eq!(
431			normalize(vec![13u32, 13, 13, 13].as_ref(), 40).unwrap(),
432			vec![10, 10, 10, 10],
433		);
434
435		assert_eq!(
436			normalize(vec![13u32, 13, 13, 10].as_ref(), 40).unwrap(),
437			vec![12, 9, 9, 10],
438		);
439
440		assert_eq!(
441			normalize(vec![13u32, 12, 13, 10].as_ref(), 40).unwrap(),
442			vec![9, 12, 9, 10],
443		);
444
445		assert_eq!(
446			normalize(vec![13u32, 13, 12, 10].as_ref(), 40).unwrap(),
447			vec![9, 9, 12, 10],
448		);
449	}
450}
451
452#[cfg(test)]
453mod threshold_compare_tests {
454	use super::*;
455	use crate::traits::Saturating;
456	use tetcore_std::cmp::Ordering;
457
458	#[test]
459	fn epsilon_ord_works() {
460		let b = 115u32;
461		let e = Perbill::from_percent(10).mul_ceil(b);
462
463		// [115 - 11,5 (103,5), 115 + 11,5 (126,5)] is all equal
464		assert_eq!(103u32.tcmp(&b, e), Ordering::Equal);
465		assert_eq!(104u32.tcmp(&b, e), Ordering::Equal);
466		assert_eq!(115u32.tcmp(&b, e), Ordering::Equal);
467		assert_eq!(120u32.tcmp(&b, e), Ordering::Equal);
468		assert_eq!(126u32.tcmp(&b, e), Ordering::Equal);
469		assert_eq!(127u32.tcmp(&b, e), Ordering::Equal);
470
471		assert_eq!(128u32.tcmp(&b, e), Ordering::Greater);
472		assert_eq!(102u32.tcmp(&b, e), Ordering::Less);
473	}
474
475	#[test]
476	fn epsilon_ord_works_with_small_epc() {
477		let b = 115u32;
478		// way less than 1 percent. threshold will be zero. Result should be same as normal ord.
479		let e = Perbill::from_parts(100) * b;
480
481		// [115 - 11,5 (103,5), 115 + 11,5 (126,5)] is all equal
482		assert_eq!(103u32.tcmp(&b, e), 103u32.cmp(&b));
483		assert_eq!(104u32.tcmp(&b, e), 104u32.cmp(&b));
484		assert_eq!(115u32.tcmp(&b, e), 115u32.cmp(&b));
485		assert_eq!(120u32.tcmp(&b, e), 120u32.cmp(&b));
486		assert_eq!(126u32.tcmp(&b, e), 126u32.cmp(&b));
487		assert_eq!(127u32.tcmp(&b, e), 127u32.cmp(&b));
488
489		assert_eq!(128u32.tcmp(&b, e), 128u32.cmp(&b));
490		assert_eq!(102u32.tcmp(&b, e), 102u32.cmp(&b));
491	}
492
493	#[test]
494	fn peru16_rational_does_not_overflow() {
495		// A historical example that will panic only for per_thing type that are created with
496		// maximum capacity of their type, e.g. PerU16.
497		let _ = PerU16::from_rational_approximation(17424870u32, 17424870);
498	}
499
500	#[test]
501	fn saturating_mul_works() {
502		assert_eq!(Saturating::saturating_mul(2, i32::min_value()), i32::min_value());
503		assert_eq!(Saturating::saturating_mul(2, i32::max_value()), i32::max_value());
504	}
505
506	#[test]
507	fn saturating_pow_works() {
508		assert_eq!(Saturating::saturating_pow(i32::min_value(), 0), 1);
509		assert_eq!(Saturating::saturating_pow(i32::max_value(), 0), 1);
510		assert_eq!(Saturating::saturating_pow(i32::min_value(), 3), i32::min_value());
511		assert_eq!(Saturating::saturating_pow(i32::min_value(), 2), i32::max_value());
512		assert_eq!(Saturating::saturating_pow(i32::max_value(), 2), i32::max_value());
513	}
514}