beekeeper/hive/
weighted.rs

1//! Weighted value used for task submission with the `local-batch` feature.
2use num::ToPrimitive;
3use std::ops::Deref;
4
5/// Wraps a value of type `T` and an associated weight.
6#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)]
7pub struct Weighted<T> {
8    value: T,
9    weight: u32,
10}
11
12impl<T> Weighted<T> {
13    /// Creates a new `Weighted` instance with the given value and weight.
14    pub fn new<P: ToPrimitive>(value: T, weight: P) -> Self {
15        Self {
16            value,
17            weight: weight.to_u32().unwrap(),
18        }
19    }
20
21    /// Creates a new `Weighted` instance with the given value and weight obtained from calling the
22    /// given function on `value`.
23    pub fn from_fn<F>(value: T, f: F) -> Self
24    where
25        F: FnOnce(&T) -> u32,
26    {
27        let weight = f(&value);
28        Self::new(value, weight)
29    }
30
31    /// Creates a new `Weighted` instance with the given value and weight obtained by converting
32    /// the value into a `u32`.
33    pub fn from_identity(value: T) -> Self
34    where
35        T: ToPrimitive + Clone,
36    {
37        let weight = value.clone().to_u32().unwrap();
38        Self::new(value, weight)
39    }
40
41    /// Returns the weight associated with this `Weighted` value.
42    pub fn weight(&self) -> u32 {
43        self.weight
44    }
45
46    /// Returns the value and weight as a tuple.
47    pub fn into_parts(self) -> (T, u32) {
48        (self.value, self.weight)
49    }
50}
51
52impl<T> Deref for Weighted<T> {
53    type Target = T;
54
55    fn deref(&self) -> &Self::Target {
56        &self.value
57    }
58}
59
60impl<T> From<T> for Weighted<T> {
61    fn from(value: T) -> Self {
62        Self::new(value, 0)
63    }
64}
65
66impl<T, P: ToPrimitive> From<(T, P)> for Weighted<T> {
67    fn from((value, weight): (T, P)) -> Self {
68        Self::new(value, weight)
69    }
70}
71
72/// Extends `IntoIterator` to add methods to convert any iterator into an iterator over `Weighted`
73/// items.
74pub trait WeightedIteratorExt: IntoIterator + Sized {
75    /// Converts this iterator over (T, P) items into an iterator over `Weighted<T>` items with
76    /// weights set to `P::into_u32()`.
77    fn into_weighted<T, P>(self) -> impl Iterator<Item = Weighted<T>>
78    where
79        P: ToPrimitive,
80        Self: IntoIterator<Item = (T, P)>,
81    {
82        self.into_iter()
83            .map(|(value, weight)| Weighted::new(value, weight))
84    }
85
86    /// Converts this iterator into an iterator over `Weighted<Self::Item>` with weights set to 0.
87    fn into_default_weighted(self) -> impl Iterator<Item = Weighted<Self::Item>> {
88        self.into_iter().map(Into::into)
89    }
90
91    /// Converts this iterator into an iterator over `Weighted<Self::Item>` with weights set to
92    /// `weight`.
93    fn into_const_weighted(self, weight: u32) -> impl Iterator<Item = Weighted<Self::Item>> {
94        self.into_iter()
95            .map(move |item| Weighted::new(item, weight))
96    }
97
98    /// Converts this iterator into an iterator over `Weighted<Self::Item>` with weights set to
99    /// `item.clone().into_u32()`.
100    fn into_identity_weighted(self) -> impl Iterator<Item = Weighted<Self::Item>>
101    where
102        Self::Item: ToPrimitive + Clone,
103    {
104        self.into_iter().map(Weighted::from_identity)
105    }
106
107    /// Zips this iterator with `weights` and converts each tuple into a `Weighted<Self::Item>`
108    /// with the weight set to the corresponding value from `weights`.
109    fn into_weighted_zip<P, W>(self, weights: W) -> impl Iterator<Item = Weighted<Self::Item>>
110    where
111        P: ToPrimitive + Clone + Default,
112        W: IntoIterator<Item = P>,
113        W::IntoIter: 'static,
114    {
115        self.into_iter()
116            .zip(weights.into_iter().chain(std::iter::repeat(P::default())))
117            .map(Into::into)
118    }
119
120    /// Converts this interator into an iterator over `Weighted<Self::Item>` with weights set to
121    /// the result of calling `f` on each item.
122    fn into_weighted_with<F>(self, f: F) -> impl Iterator<Item = Weighted<Self::Item>>
123    where
124        F: Fn(&Self::Item) -> u32,
125    {
126        self.into_iter().map(move |item| {
127            let weight = f(&item);
128            Weighted::new(item, weight)
129        })
130    }
131
132    /// Converts this `ExactSizeIterator` over (T, P) items into an `ExactSizeIterator` over
133    /// `Weighted<T>` items with weights set to `P::into_u32()`.
134    fn into_weighted_exact<T>(self) -> impl ExactSizeIterator<Item = Weighted<T>>
135    where
136        Self: IntoIterator<Item = (T, u32)>,
137        Self::IntoIter: ExactSizeIterator + 'static,
138    {
139        self.into_iter()
140            .map(|(value, weight)| Weighted::new(value, weight))
141    }
142
143    /// Converts this `ExactSizeIterator` into an `ExactSizeIterator` over `Weighted<Self::Item>`
144    /// with weights set to 0.
145    fn into_default_weighted_exact(self) -> impl ExactSizeIterator<Item = Weighted<Self::Item>>
146    where
147        Self::IntoIter: ExactSizeIterator + 'static,
148    {
149        self.into_iter().map(Into::into)
150    }
151
152    /// Converts this `ExactSizeIterator` into an `ExactSizeIterator` over `Weighted<Self::Item>`
153    /// with weights set to `weight`.
154    fn into_const_weighted_exact(
155        self,
156        weight: u32,
157    ) -> impl ExactSizeIterator<Item = Weighted<Self::Item>>
158    where
159        Self::IntoIter: ExactSizeIterator + 'static,
160    {
161        self.into_iter()
162            .map(move |item| Weighted::new(item, weight))
163    }
164
165    /// Converts this `ExactSizeIterator` into an `ExactSizeIterator` over `Weighted<Self::Item>`
166    /// with weights set to `item.clone().into_u32()`.
167    fn into_identity_weighted_exact(self) -> impl ExactSizeIterator<Item = Weighted<Self::Item>>
168    where
169        Self::Item: ToPrimitive + Clone,
170        Self::IntoIter: ExactSizeIterator + 'static,
171    {
172        self.into_iter().map(Weighted::from_identity)
173    }
174
175    /// Converts this `ExactSizeIterator` into an `ExactSizeIterator` over `Weighted<Self::Item>`
176    /// with weights set to the result of calling `f` on each item.
177    fn into_weighted_exact_with<F>(
178        self,
179        f: F,
180    ) -> impl ExactSizeIterator<Item = Weighted<Self::Item>>
181    where
182        Self::IntoIter: ExactSizeIterator + 'static,
183        F: Fn(&Self::Item) -> u32,
184    {
185        self.into_iter().map(move |item| {
186            let weight = f(&item);
187            Weighted::new(item, weight)
188        })
189    }
190}
191
192impl<T: IntoIterator> WeightedIteratorExt for T {}
193
194#[cfg(test)]
195#[cfg_attr(coverage_nightly, coverage(off))]
196mod tests {
197    use super::*;
198
199    #[test]
200    fn test_new() {
201        let weighted = Weighted::new(42, 10);
202        assert_eq!(*weighted, 42);
203        assert_eq!(weighted.weight(), 10);
204        assert_eq!(weighted.into_parts(), (42, 10));
205    }
206
207    #[test]
208    fn test_from_fn() {
209        let weighted = Weighted::from_fn(42, |x| x * 2);
210        assert_eq!(*weighted, 42);
211        assert_eq!(weighted.weight(), 84);
212    }
213
214    #[test]
215    fn test_from_identity() {
216        let weighted = Weighted::from_identity(42);
217        assert_eq!(*weighted, 42);
218        assert_eq!(weighted.weight(), 42);
219    }
220
221    #[test]
222    fn test_from_unweighted() {
223        let weighted = Weighted::from(42);
224        assert_eq!(*weighted, 42);
225        assert_eq!(weighted.weight(), 0);
226    }
227
228    #[test]
229    fn test_from_tuple() {
230        let weighted: Weighted<usize> = Weighted::from((42, 10));
231        assert_eq!(*weighted, 42);
232        assert_eq!(weighted.weight(), 10);
233        assert_eq!(weighted.into_parts(), (42, 10));
234    }
235
236    #[test]
237    fn test_into_weighted() {
238        (0..10)
239            .map(|i| (i, i))
240            .into_weighted()
241            .for_each(|weighted| assert_eq!(weighted.weight(), weighted.value));
242    }
243
244    #[test]
245    fn test_into_default_weighted() {
246        (0..10)
247            .into_default_weighted()
248            .for_each(|weighted| assert_eq!(weighted.weight(), 0));
249    }
250
251    #[test]
252    fn test_into_identity_weighted() {
253        (0..10)
254            .into_identity_weighted()
255            .for_each(|weighted| assert_eq!(weighted.weight(), weighted.value));
256    }
257
258    #[test]
259    fn test_into_const_weighted() {
260        (0..10)
261            .into_const_weighted(5)
262            .for_each(|weighted| assert_eq!(weighted.weight(), 5));
263    }
264
265    #[test]
266    fn test_into_weighted_zip() {
267        (0..10)
268            .into_weighted_zip(10..20)
269            .for_each(|weighted| assert_eq!(weighted.weight(), weighted.value + 10));
270    }
271
272    #[test]
273    fn test_into_weighted_with() {
274        (0..10)
275            .into_weighted_with(|i| i * 2)
276            .for_each(|weighted| assert_eq!(weighted.weight(), weighted.value * 2));
277    }
278
279    #[test]
280    fn test_into_weighted_exact() {
281        (0..10)
282            .map(|i| (i, i))
283            .into_weighted_exact()
284            .for_each(|weighted| assert_eq!(weighted.weight(), weighted.value));
285    }
286
287    #[test]
288    fn test_into_default_weighted_exact() {
289        (0..10)
290            .into_default_weighted_exact()
291            .for_each(|weighted| assert_eq!(weighted.weight(), 0));
292    }
293
294    #[test]
295    fn test_into_identity_weighted_exact() {
296        (0..10)
297            .into_identity_weighted_exact()
298            .for_each(|weighted| assert_eq!(weighted.weight(), weighted.value));
299    }
300
301    #[test]
302    fn test_into_const_weighted_exact() {
303        (0..10)
304            .into_const_weighted_exact(5)
305            .for_each(|weighted| assert_eq!(weighted.weight(), 5));
306    }
307
308    #[test]
309    fn test_into_weighted_exact_with() {
310        (0..10)
311            .into_weighted_exact_with(|i| i * 2)
312            .for_each(|weighted| assert_eq!(weighted.weight(), weighted.value * 2));
313    }
314}