1use std::ops::ControlFlow;
2
3use cast_trait::Cast;
4use internal_iterator::{InternalIterator, IntoInternalIterator};
5
6use crate::board::{Outcome, Player};
7use crate::pov::{NonPov, Pov, ScalarAbs};
8
9#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
11pub enum OutcomeWDL {
12 Win,
13 Draw,
14 Loss,
15}
16
17#[derive(Default, Debug, Copy, Clone, Eq, PartialEq, Hash)]
19pub struct WDL<V> {
20 pub win: V,
21 pub draw: V,
22 pub loss: V,
23}
24
25#[derive(Default, Debug, Copy, Clone, Eq, PartialEq, Hash)]
26pub struct WDLAbs<V> {
27 pub win_a: V,
28 pub draw: V,
29 pub win_b: V,
30}
31
32impl Outcome {
33 pub fn to_wdl_abs<V: num_traits::One + Default>(self) -> WDLAbs<V> {
35 let mut result = WDLAbs::default();
36 *match self {
37 Outcome::WonBy(Player::A) => &mut result.win_a,
38 Outcome::WonBy(Player::B) => &mut result.win_b,
39 Outcome::Draw => &mut result.draw,
40 } = V::one();
41 result
42 }
43
44 pub fn sign<V: num_traits::Zero + num_traits::One + std::ops::Neg<Output = V>>(self) -> ScalarAbs<V> {
46 match self {
47 Outcome::WonBy(Player::A) => ScalarAbs::new(V::one()),
48 Outcome::Draw => ScalarAbs::new(V::zero()),
49 Outcome::WonBy(Player::B) => ScalarAbs::new(-V::one()),
50 }
51 }
52}
53
54impl OutcomeWDL {
55 pub fn to_wdl<V: num_traits::One + Default>(self) -> WDL<V> {
57 let mut result = WDL::default();
58 *match self {
59 OutcomeWDL::Win => &mut result.win,
60 OutcomeWDL::Draw => &mut result.draw,
61 OutcomeWDL::Loss => &mut result.loss,
62 } = V::one();
63 result
64 }
65
66 pub fn sign<V: num_traits::Zero + num_traits::One + std::ops::Neg<Output = V>>(self) -> V {
68 match self {
69 OutcomeWDL::Win => V::one(),
70 OutcomeWDL::Draw => V::zero(),
71 OutcomeWDL::Loss => -V::one(),
72 }
73 }
74
75 pub fn un_pov(self, pov: Player) -> Outcome {
77 match self {
78 OutcomeWDL::Win => Outcome::WonBy(pov),
79 OutcomeWDL::Draw => Outcome::Draw,
80 OutcomeWDL::Loss => Outcome::WonBy(pov.other()),
81 }
82 }
83
84 pub fn best<I: IntoInternalIterator<Item = OutcomeWDL>>(children: I) -> OutcomeWDL {
87 Self::best_maybe(children.into_internal_iter().map(Some)).unwrap()
88 }
89
90 pub fn best_maybe<I: IntoInternalIterator<Item = Option<OutcomeWDL>>>(children: I) -> Option<OutcomeWDL> {
93 let mut any_unknown = false;
94 let mut all_known_are_loss = true;
95
96 let control = children.into_internal_iter().try_for_each(|child| {
97 match child {
98 None => {
99 any_unknown = true;
100 }
101 Some(OutcomeWDL::Win) => {
102 return ControlFlow::Break(());
103 }
104 Some(OutcomeWDL::Draw) => {
105 all_known_are_loss = false;
106 }
107 Some(OutcomeWDL::Loss) => {}
108 }
109
110 ControlFlow::Continue(())
111 });
112
113 if let ControlFlow::Break(()) = control {
114 Some(OutcomeWDL::Win)
115 } else if any_unknown {
116 None
117 } else if all_known_are_loss {
118 Some(OutcomeWDL::Loss)
119 } else {
120 Some(OutcomeWDL::Draw)
121 }
122 }
123}
124
125impl<V> NonPov for WDLAbs<V> {
126 type Output = WDL<V>;
127
128 fn pov(self, pov: Player) -> WDL<V> {
129 let (win, loss) = match pov {
130 Player::A => (self.win_a, self.win_b),
131 Player::B => (self.win_b, self.win_a),
132 };
133
134 WDL {
135 win,
136 draw: self.draw,
137 loss,
138 }
139 }
140}
141
142impl<V> Pov for WDL<V> {
143 type Output = WDLAbs<V>;
144
145 fn un_pov(self, pov: Player) -> Self::Output {
146 let (win_a, win_b) = match pov {
147 Player::A => (self.win, self.loss),
148 Player::B => (self.loss, self.win),
149 };
150
151 WDLAbs {
152 win_a,
153 draw: self.draw,
154 win_b,
155 }
156 }
157}
158
159impl<V> WDLAbs<V> {
160 pub fn new(win_a: V, draw: V, win_b: V) -> Self {
161 Self { win_a, draw, win_b }
162 }
163}
164
165impl<V> WDL<V> {
166 pub fn new(win: V, draw: V, loss: V) -> Self {
167 WDL { win, draw, loss }
168 }
169
170 pub fn to_slice(self) -> [V; 3] {
171 [self.win, self.draw, self.loss]
172 }
173}
174
175impl<V: num_traits::Float> WDL<V> {
176 pub fn nan() -> WDL<V> {
177 WDL {
178 win: V::nan(),
179 draw: V::nan(),
180 loss: V::nan(),
181 }
182 }
183
184 pub fn normalized(self) -> WDL<V> {
185 self / self.sum()
186 }
187}
188
189impl<V: num_traits::Float> WDLAbs<V> {
190 pub fn nan() -> WDLAbs<V> {
191 WDLAbs {
192 win_a: V::nan(),
193 draw: V::nan(),
194 win_b: V::nan(),
195 }
196 }
197}
198
199impl<V: num_traits::One + Default + PartialEq> WDLAbs<V> {
200 pub fn try_to_outcome(self) -> Option<Outcome> {
201 let outcomes = [Outcome::WonBy(Player::A), Outcome::Draw, Outcome::WonBy(Player::B)];
202 outcomes.iter().copied().find(|&o| o.to_wdl_abs() == self)
203 }
204}
205
206impl<V: num_traits::One + Default + PartialEq> WDL<V> {
207 pub fn try_to_outcome_wdl(self) -> Option<OutcomeWDL> {
208 let outcomes = [OutcomeWDL::Win, OutcomeWDL::Draw, OutcomeWDL::Loss];
209 outcomes.iter().copied().find(|&o| o.to_wdl() == self)
210 }
211}
212
213impl<V: Copy> WDL<V> {
214 pub fn cast<W>(self) -> WDL<W>
215 where
216 V: Cast<W>,
217 {
218 WDL {
219 win: self.win.cast(),
220 draw: self.draw.cast(),
221 loss: self.loss.cast(),
222 }
223 }
224}
225
226impl<V: Copy + std::ops::Sub<V, Output = V>> WDL<V> {
227 pub fn value(self) -> V {
228 self.win - self.loss
229 }
230}
231
232impl<V: Copy + std::ops::Sub<V, Output = V>> WDLAbs<V> {
233 pub fn value(self) -> ScalarAbs<V> {
234 ScalarAbs::new(self.win_a - self.win_b)
235 }
236}
237
238impl<V: Copy + std::ops::Add<V, Output = V>> WDL<V> {
239 pub fn sum(self) -> V {
240 self.win + self.draw + self.loss
241 }
242}
243
244impl<V: Copy + std::ops::Add<V, Output = V>> WDLAbs<V> {
245 pub fn sum(self) -> V {
246 self.win_a + self.draw + self.win_b
247 }
248}
249
250impl NonPov for Outcome {
251 type Output = OutcomeWDL;
252 fn pov(self, pov: Player) -> OutcomeWDL {
253 match self {
254 Outcome::WonBy(player) => {
255 if player == pov {
256 OutcomeWDL::Win
257 } else {
258 OutcomeWDL::Loss
259 }
260 }
261 Outcome::Draw => OutcomeWDL::Draw,
262 }
263 }
264}
265
266impl Pov for OutcomeWDL {
267 type Output = Outcome;
268 fn un_pov(self, pov: Player) -> Outcome {
269 match self {
270 OutcomeWDL::Win => Outcome::WonBy(pov),
271 OutcomeWDL::Draw => Outcome::Draw,
272 OutcomeWDL::Loss => Outcome::WonBy(pov.other()),
273 }
274 }
275}
276
277impl<V: std::ops::Add<V, Output = V>> std::ops::Add<WDL<V>> for WDL<V> {
278 type Output = WDL<V>;
279
280 fn add(self, rhs: WDL<V>) -> Self::Output {
281 WDL {
282 win: self.win + rhs.win,
283 draw: self.draw + rhs.draw,
284 loss: self.loss + rhs.loss,
285 }
286 }
287}
288
289impl<V: Copy + std::ops::Sub<V, Output = V>> std::ops::Sub<WDL<V>> for WDL<V> {
290 type Output = WDL<V>;
291
292 fn sub(self, rhs: WDL<V>) -> Self::Output {
293 WDL {
294 win: self.win - rhs.win,
295 draw: self.draw - rhs.draw,
296 loss: self.loss - rhs.loss,
297 }
298 }
299}
300
301impl<V: Copy + std::ops::Add<V, Output = V>> std::ops::AddAssign<WDL<V>> for WDL<V> {
302 fn add_assign(&mut self, rhs: WDL<V>) {
303 *self = *self + rhs;
304 }
305}
306
307impl<V: Copy + std::ops::Mul<V, Output = V>> std::ops::Mul<V> for WDL<V> {
308 type Output = WDL<V>;
309
310 fn mul(self, rhs: V) -> Self::Output {
311 WDL {
312 win: self.win * rhs,
313 draw: self.draw * rhs,
314 loss: self.loss * rhs,
315 }
316 }
317}
318
319impl<V: Copy + std::ops::Div<V, Output = V>> std::ops::Div<V> for WDL<V> {
320 type Output = WDL<V>;
321
322 fn div(self, rhs: V) -> Self::Output {
323 WDL {
324 win: self.win / rhs,
325 draw: self.draw / rhs,
326 loss: self.loss / rhs,
327 }
328 }
329}
330
331impl<V: Default + Copy + std::ops::Add<Output = V>> std::iter::Sum<Self> for WDL<V> {
332 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
333 iter.fold(Self::default(), |a, v| a + v)
334 }
335}
336
337impl<'a, V: Default + Copy + std::ops::Add<Output = V>> std::iter::Sum<&'a Self> for WDL<V> {
338 fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
339 iter.fold(Self::default(), |a, &v| a + v)
340 }
341}
342
343impl<V: std::ops::Add<V, Output = V>> std::ops::Add<WDLAbs<V>> for WDLAbs<V> {
344 type Output = WDLAbs<V>;
345
346 fn add(self, rhs: WDLAbs<V>) -> Self::Output {
347 WDLAbs {
348 win_a: self.win_a + rhs.win_a,
349 draw: self.draw + rhs.draw,
350 win_b: self.win_b + rhs.win_b,
351 }
352 }
353}
354
355impl<V: Copy + std::ops::Sub<V, Output = V>> std::ops::Sub<WDLAbs<V>> for WDLAbs<V> {
356 type Output = WDLAbs<V>;
357
358 fn sub(self, rhs: WDLAbs<V>) -> Self::Output {
359 WDLAbs {
360 win_a: self.win_a - rhs.win_a,
361 draw: self.draw - rhs.draw,
362 win_b: self.win_b - rhs.win_b,
363 }
364 }
365}
366
367impl<V: Copy + std::ops::Add<V, Output = V>> std::ops::AddAssign<WDLAbs<V>> for WDLAbs<V> {
368 fn add_assign(&mut self, rhs: WDLAbs<V>) {
369 *self = *self + rhs;
370 }
371}
372
373impl<V: Copy + std::ops::Mul<V, Output = V>> std::ops::Mul<V> for WDLAbs<V> {
374 type Output = WDLAbs<V>;
375
376 fn mul(self, rhs: V) -> Self::Output {
377 WDLAbs {
378 win_a: self.win_a * rhs,
379 draw: self.draw * rhs,
380 win_b: self.win_b * rhs,
381 }
382 }
383}
384
385impl<V: Copy + std::ops::Div<V, Output = V>> std::ops::Div<V> for WDLAbs<V> {
386 type Output = WDLAbs<V>;
387
388 fn div(self, rhs: V) -> Self::Output {
389 WDLAbs {
390 win_a: self.win_a / rhs,
391 draw: self.draw / rhs,
392 win_b: self.win_b / rhs,
393 }
394 }
395}
396
397impl<V: Default + Copy + std::ops::Add<Output = V>> std::iter::Sum for WDLAbs<V> {
398 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
399 iter.fold(Self::default(), |a, v| a + v)
400 }
401}