board_game/
wdl.rs

1use std::ops::ControlFlow;
2
3use cast_trait::Cast;
4use internal_iterator::{InternalIterator, IntoInternalIterator};
5
6use crate::board::{Outcome, Player};
7use crate::pov::{NonPov, Pov, ScalarAbs};
8
9/// The outcome of a game from the POV of a certain player. Usually obtained using [Outcome::pov].
10#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
11pub enum OutcomeWDL {
12    Win,
13    Draw,
14    Loss,
15}
16
17/// A collection of [win, draw, loss] values.
18#[derive(Default, Debug, Copy, Clone, Eq, PartialEq, Hash)]
19pub struct WDL<V> {
20    pub win: V,
21    pub draw: V,
22    pub loss: V,
23}
24
25#[derive(Default, Debug, Copy, Clone, Eq, PartialEq, Hash)]
26pub struct WDLAbs<V> {
27    pub win_a: V,
28    pub draw: V,
29    pub win_b: V,
30}
31
32impl Outcome {
33    /// Convert this to a [WDLAbs] with a one at the correct place and zero otherwise.
34    pub fn to_wdl_abs<V: num_traits::One + Default>(self) -> WDLAbs<V> {
35        let mut result = WDLAbs::default();
36        *match self {
37            Outcome::WonBy(Player::A) => &mut result.win_a,
38            Outcome::WonBy(Player::B) => &mut result.win_b,
39            Outcome::Draw => &mut result.draw,
40        } = V::one();
41        result
42    }
43
44    /// Convert a win (for a) to `1`, draw to `0` and loss (for a) to `-1`.
45    pub fn sign<V: num_traits::Zero + num_traits::One + std::ops::Neg<Output = V>>(self) -> ScalarAbs<V> {
46        match self {
47            Outcome::WonBy(Player::A) => ScalarAbs::new(V::one()),
48            Outcome::Draw => ScalarAbs::new(V::zero()),
49            Outcome::WonBy(Player::B) => ScalarAbs::new(-V::one()),
50        }
51    }
52}
53
54impl OutcomeWDL {
55    /// Convert this to a [WDL] with a one at the correct place and zero otherwise.
56    pub fn to_wdl<V: num_traits::One + Default>(self) -> WDL<V> {
57        let mut result = WDL::default();
58        *match self {
59            OutcomeWDL::Win => &mut result.win,
60            OutcomeWDL::Draw => &mut result.draw,
61            OutcomeWDL::Loss => &mut result.loss,
62        } = V::one();
63        result
64    }
65
66    /// Convert a win to `1`, draw to `0` and loss to `-1`.
67    pub fn sign<V: num_traits::Zero + num_traits::One + std::ops::Neg<Output = V>>(self) -> V {
68        match self {
69            OutcomeWDL::Win => V::one(),
70            OutcomeWDL::Draw => V::zero(),
71            OutcomeWDL::Loss => -V::one(),
72        }
73    }
74
75    /// The reverse of [Outcome::pov].
76    pub fn un_pov(self, pov: Player) -> Outcome {
77        match self {
78            OutcomeWDL::Win => Outcome::WonBy(pov),
79            OutcomeWDL::Draw => Outcome::Draw,
80            OutcomeWDL::Loss => Outcome::WonBy(pov.other()),
81        }
82    }
83
84    /// Pick the best possible outcome, assuming `Win > Draw > Loss`.
85    /// Make sure to flip the child values as appropriate, this function assumes everything is form the parent POV.
86    pub fn best<I: IntoInternalIterator<Item = OutcomeWDL>>(children: I) -> OutcomeWDL {
87        Self::best_maybe(children.into_internal_iter().map(Some)).unwrap()
88    }
89
90    /// Pick the best possible outcome, assuming `Some(Win) > None > Some(Draw) > Some(Loss)`.
91    /// Make sure to flip the child values as appropriate, this function assumes everything is form the parent POV.
92    pub fn best_maybe<I: IntoInternalIterator<Item = Option<OutcomeWDL>>>(children: I) -> Option<OutcomeWDL> {
93        let mut any_unknown = false;
94        let mut all_known_are_loss = true;
95
96        let control = children.into_internal_iter().try_for_each(|child| {
97            match child {
98                None => {
99                    any_unknown = true;
100                }
101                Some(OutcomeWDL::Win) => {
102                    return ControlFlow::Break(());
103                }
104                Some(OutcomeWDL::Draw) => {
105                    all_known_are_loss = false;
106                }
107                Some(OutcomeWDL::Loss) => {}
108            }
109
110            ControlFlow::Continue(())
111        });
112
113        if let ControlFlow::Break(()) = control {
114            Some(OutcomeWDL::Win)
115        } else if any_unknown {
116            None
117        } else if all_known_are_loss {
118            Some(OutcomeWDL::Loss)
119        } else {
120            Some(OutcomeWDL::Draw)
121        }
122    }
123}
124
125impl<V> NonPov for WDLAbs<V> {
126    type Output = WDL<V>;
127
128    fn pov(self, pov: Player) -> WDL<V> {
129        let (win, loss) = match pov {
130            Player::A => (self.win_a, self.win_b),
131            Player::B => (self.win_b, self.win_a),
132        };
133
134        WDL {
135            win,
136            draw: self.draw,
137            loss,
138        }
139    }
140}
141
142impl<V> Pov for WDL<V> {
143    type Output = WDLAbs<V>;
144
145    fn un_pov(self, pov: Player) -> Self::Output {
146        let (win_a, win_b) = match pov {
147            Player::A => (self.win, self.loss),
148            Player::B => (self.loss, self.win),
149        };
150
151        WDLAbs {
152            win_a,
153            draw: self.draw,
154            win_b,
155        }
156    }
157}
158
159impl<V> WDLAbs<V> {
160    pub fn new(win_a: V, draw: V, win_b: V) -> Self {
161        Self { win_a, draw, win_b }
162    }
163}
164
165impl<V> WDL<V> {
166    pub fn new(win: V, draw: V, loss: V) -> Self {
167        WDL { win, draw, loss }
168    }
169
170    pub fn to_slice(self) -> [V; 3] {
171        [self.win, self.draw, self.loss]
172    }
173}
174
175impl<V: num_traits::Float> WDL<V> {
176    pub fn nan() -> WDL<V> {
177        WDL {
178            win: V::nan(),
179            draw: V::nan(),
180            loss: V::nan(),
181        }
182    }
183
184    pub fn normalized(self) -> WDL<V> {
185        self / self.sum()
186    }
187}
188
189impl<V: num_traits::Float> WDLAbs<V> {
190    pub fn nan() -> WDLAbs<V> {
191        WDLAbs {
192            win_a: V::nan(),
193            draw: V::nan(),
194            win_b: V::nan(),
195        }
196    }
197}
198
199impl<V: num_traits::One + Default + PartialEq> WDLAbs<V> {
200    pub fn try_to_outcome(self) -> Option<Outcome> {
201        let outcomes = [Outcome::WonBy(Player::A), Outcome::Draw, Outcome::WonBy(Player::B)];
202        outcomes.iter().copied().find(|&o| o.to_wdl_abs() == self)
203    }
204}
205
206impl<V: num_traits::One + Default + PartialEq> WDL<V> {
207    pub fn try_to_outcome_wdl(self) -> Option<OutcomeWDL> {
208        let outcomes = [OutcomeWDL::Win, OutcomeWDL::Draw, OutcomeWDL::Loss];
209        outcomes.iter().copied().find(|&o| o.to_wdl() == self)
210    }
211}
212
213impl<V: Copy> WDL<V> {
214    pub fn cast<W>(self) -> WDL<W>
215    where
216        V: Cast<W>,
217    {
218        WDL {
219            win: self.win.cast(),
220            draw: self.draw.cast(),
221            loss: self.loss.cast(),
222        }
223    }
224}
225
226impl<V: Copy + std::ops::Sub<V, Output = V>> WDL<V> {
227    pub fn value(self) -> V {
228        self.win - self.loss
229    }
230}
231
232impl<V: Copy + std::ops::Sub<V, Output = V>> WDLAbs<V> {
233    pub fn value(self) -> ScalarAbs<V> {
234        ScalarAbs::new(self.win_a - self.win_b)
235    }
236}
237
238impl<V: Copy + std::ops::Add<V, Output = V>> WDL<V> {
239    pub fn sum(self) -> V {
240        self.win + self.draw + self.loss
241    }
242}
243
244impl<V: Copy + std::ops::Add<V, Output = V>> WDLAbs<V> {
245    pub fn sum(self) -> V {
246        self.win_a + self.draw + self.win_b
247    }
248}
249
250impl NonPov for Outcome {
251    type Output = OutcomeWDL;
252    fn pov(self, pov: Player) -> OutcomeWDL {
253        match self {
254            Outcome::WonBy(player) => {
255                if player == pov {
256                    OutcomeWDL::Win
257                } else {
258                    OutcomeWDL::Loss
259                }
260            }
261            Outcome::Draw => OutcomeWDL::Draw,
262        }
263    }
264}
265
266impl Pov for OutcomeWDL {
267    type Output = Outcome;
268    fn un_pov(self, pov: Player) -> Outcome {
269        match self {
270            OutcomeWDL::Win => Outcome::WonBy(pov),
271            OutcomeWDL::Draw => Outcome::Draw,
272            OutcomeWDL::Loss => Outcome::WonBy(pov.other()),
273        }
274    }
275}
276
277impl<V: std::ops::Add<V, Output = V>> std::ops::Add<WDL<V>> for WDL<V> {
278    type Output = WDL<V>;
279
280    fn add(self, rhs: WDL<V>) -> Self::Output {
281        WDL {
282            win: self.win + rhs.win,
283            draw: self.draw + rhs.draw,
284            loss: self.loss + rhs.loss,
285        }
286    }
287}
288
289impl<V: Copy + std::ops::Sub<V, Output = V>> std::ops::Sub<WDL<V>> for WDL<V> {
290    type Output = WDL<V>;
291
292    fn sub(self, rhs: WDL<V>) -> Self::Output {
293        WDL {
294            win: self.win - rhs.win,
295            draw: self.draw - rhs.draw,
296            loss: self.loss - rhs.loss,
297        }
298    }
299}
300
301impl<V: Copy + std::ops::Add<V, Output = V>> std::ops::AddAssign<WDL<V>> for WDL<V> {
302    fn add_assign(&mut self, rhs: WDL<V>) {
303        *self = *self + rhs;
304    }
305}
306
307impl<V: Copy + std::ops::Mul<V, Output = V>> std::ops::Mul<V> for WDL<V> {
308    type Output = WDL<V>;
309
310    fn mul(self, rhs: V) -> Self::Output {
311        WDL {
312            win: self.win * rhs,
313            draw: self.draw * rhs,
314            loss: self.loss * rhs,
315        }
316    }
317}
318
319impl<V: Copy + std::ops::Div<V, Output = V>> std::ops::Div<V> for WDL<V> {
320    type Output = WDL<V>;
321
322    fn div(self, rhs: V) -> Self::Output {
323        WDL {
324            win: self.win / rhs,
325            draw: self.draw / rhs,
326            loss: self.loss / rhs,
327        }
328    }
329}
330
331impl<V: Default + Copy + std::ops::Add<Output = V>> std::iter::Sum<Self> for WDL<V> {
332    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
333        iter.fold(Self::default(), |a, v| a + v)
334    }
335}
336
337impl<'a, V: Default + Copy + std::ops::Add<Output = V>> std::iter::Sum<&'a Self> for WDL<V> {
338    fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
339        iter.fold(Self::default(), |a, &v| a + v)
340    }
341}
342
343impl<V: std::ops::Add<V, Output = V>> std::ops::Add<WDLAbs<V>> for WDLAbs<V> {
344    type Output = WDLAbs<V>;
345
346    fn add(self, rhs: WDLAbs<V>) -> Self::Output {
347        WDLAbs {
348            win_a: self.win_a + rhs.win_a,
349            draw: self.draw + rhs.draw,
350            win_b: self.win_b + rhs.win_b,
351        }
352    }
353}
354
355impl<V: Copy + std::ops::Sub<V, Output = V>> std::ops::Sub<WDLAbs<V>> for WDLAbs<V> {
356    type Output = WDLAbs<V>;
357
358    fn sub(self, rhs: WDLAbs<V>) -> Self::Output {
359        WDLAbs {
360            win_a: self.win_a - rhs.win_a,
361            draw: self.draw - rhs.draw,
362            win_b: self.win_b - rhs.win_b,
363        }
364    }
365}
366
367impl<V: Copy + std::ops::Add<V, Output = V>> std::ops::AddAssign<WDLAbs<V>> for WDLAbs<V> {
368    fn add_assign(&mut self, rhs: WDLAbs<V>) {
369        *self = *self + rhs;
370    }
371}
372
373impl<V: Copy + std::ops::Mul<V, Output = V>> std::ops::Mul<V> for WDLAbs<V> {
374    type Output = WDLAbs<V>;
375
376    fn mul(self, rhs: V) -> Self::Output {
377        WDLAbs {
378            win_a: self.win_a * rhs,
379            draw: self.draw * rhs,
380            win_b: self.win_b * rhs,
381        }
382    }
383}
384
385impl<V: Copy + std::ops::Div<V, Output = V>> std::ops::Div<V> for WDLAbs<V> {
386    type Output = WDLAbs<V>;
387
388    fn div(self, rhs: V) -> Self::Output {
389        WDLAbs {
390            win_a: self.win_a / rhs,
391            draw: self.draw / rhs,
392            win_b: self.win_b / rhs,
393        }
394    }
395}
396
397impl<V: Default + Copy + std::ops::Add<Output = V>> std::iter::Sum for WDLAbs<V> {
398    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
399        iter.fold(Self::default(), |a, v| a + v)
400    }
401}