use super::{Bid, Call, Strain};
use core::convert::Infallible;
use core::iter::Enumerate;
use core::mem::MaybeUninit;
use core::ops::{Deref, DerefMut, Index, IndexMut, Range, RangeFrom, RangeFull, RangeInclusive};
use dds_bridge::Level;
const CALL_VARIANTS: usize = 3 + 7 * 5;
const fn encode_call(call: Call) -> usize {
match call {
Call::Pass => 0,
Call::Double => 1,
Call::Redouble => 2,
Call::Bid(bid) => encode_bid(bid),
}
}
const fn encode_bid(bid: Bid) -> usize {
3 + (bid.level.get() as usize - 1) * 5 + bid.strain as usize
}
const _: () = {
let mut calls = [Call::Pass; CALL_VARIANTS];
let mut level: u8 = 1;
let mut strain = 0;
while level <= 7 {
while strain <= 4 {
let bid = Bid {
level: Level::new(level),
strain: Strain::ASC[strain],
};
calls[encode_call(Call::Bid(bid))] = Call::Bid(bid);
strain += 1;
}
strain = 0;
level += 1;
}
assert!(encode_call(Call::Pass) == 0);
assert!(encode_call(Call::Double) == 1);
assert!(encode_call(Call::Redouble) == 2);
let mut index = 3;
while index < CALL_VARIANTS {
assert!(matches!(calls[index], Call::Bid(_)));
index += 1;
}
};
const fn decode_call(index: usize) -> Call {
match index {
0 => Call::Pass,
1 => Call::Double,
2 => Call::Redouble,
3..CALL_VARIANTS => {
let code = index - 3 + 5;
let (level, strain) = (code / 5, code % 5);
Call::Bid(super::Bid {
#[allow(clippy::cast_possible_truncation)]
level: Level::new(level as u8),
strain: super::Strain::ASC[strain],
})
}
_ => unreachable!(),
}
}
const _: () = {
let mut id = 0;
while id < CALL_VARIANTS {
let call = decode_call(id);
assert!(encode_call(call) == id);
id += 1;
}
};
#[test]
fn test_encode_special_calls() {
assert_eq!(encode_call(Call::Pass), 0);
assert_eq!(encode_call(Call::Double), 1);
assert_eq!(encode_call(Call::Redouble), 2);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Array<T>([T; CALL_VARIANTS]);
pub type Values<'a, T> = core::slice::Iter<'a, T>;
pub type ValuesMut<'a, T> = core::slice::IterMut<'a, T>;
pub type IntoValues<T> = core::array::IntoIter<T, CALL_VARIANTS>;
impl<T> Array<T> {
#[must_use]
pub fn from_fn(mut f: impl FnMut(Call) -> T) -> Self {
Self(core::array::from_fn(|index| f(decode_call(index))))
}
#[must_use]
#[inline]
pub const fn get(&self, call: Call) -> &T {
&self.0[encode_call(call)]
}
#[must_use]
#[inline]
pub const fn get_mut(&mut self, call: Call) -> &mut T {
&mut self.0[encode_call(call)]
}
pub const fn each_ref(&self) -> Array<&T> {
Array(self.0.each_ref())
}
pub const fn each_mut(&mut self) -> Array<&mut T> {
Array(self.0.each_mut())
}
pub fn iter(&self) -> Iter<'_, T> {
self.into_iter()
}
pub fn iter_mut(&mut self) -> IterMut<'_, T> {
self.into_iter()
}
pub fn try_map<U, E>(self, mut f: impl FnMut(Call, T) -> Result<U, E>) -> Result<Array<U>, E> {
let mut result = [const { MaybeUninit::uninit() }; CALL_VARIANTS];
for (index, value) in self.0.into_iter().enumerate() {
match f(decode_call(index), value) {
Ok(u) => result[index] = MaybeUninit::new(u),
Err(e) => {
unsafe { result[..index].assume_init_drop() };
return Err(e);
}
}
}
Ok(Array(unsafe { core::mem::transmute_copy(&result) }))
}
#[allow(clippy::missing_panics_doc)]
pub fn map<U>(self, mut f: impl FnMut(Call, T) -> U) -> Array<U> {
self.try_map::<_, Infallible>(|call, value| Ok(f(call, value)))
.unwrap()
}
pub fn values(&self) -> Values<'_, T> {
self.0.iter()
}
pub fn values_mut(&mut self) -> ValuesMut<'_, T> {
self.0.iter_mut()
}
pub fn into_values(self) -> IntoValues<T> {
self.0.into_iter()
}
}
impl<T> Array<Option<T>> {
#[must_use]
pub const fn new() -> Self {
Self([const { None }; CALL_VARIANTS])
}
}
impl<T: Clone> Array<T> {
#[must_use]
pub fn repeat(value: T) -> Self {
Self(core::array::repeat(value))
}
}
impl<T> Index<Call> for Array<T> {
type Output = T;
fn index(&self, call: Call) -> &Self::Output {
self.get(call)
}
}
impl<T> IndexMut<Call> for Array<T> {
fn index_mut(&mut self, call: Call) -> &mut Self::Output {
self.get_mut(call)
}
}
impl<T> Index<RangeFull> for Array<T> {
type Output = [T];
fn index(&self, _: RangeFull) -> &Self::Output {
&self.0
}
}
impl<T> IndexMut<RangeFull> for Array<T> {
fn index_mut(&mut self, _: RangeFull) -> &mut Self::Output {
&mut self.0
}
}
impl<T> Index<Range<Bid>> for Array<T> {
type Output = [T];
fn index(&self, range: Range<Bid>) -> &Self::Output {
let start = encode_bid(range.start);
let end = encode_bid(range.end);
&self.0[start..end]
}
}
impl<T> IndexMut<Range<Bid>> for Array<T> {
fn index_mut(&mut self, range: Range<Bid>) -> &mut Self::Output {
let start = encode_bid(range.start);
let end = encode_bid(range.end);
&mut self.0[start..end]
}
}
impl<T> Index<RangeFrom<Bid>> for Array<T> {
type Output = [T];
fn index(&self, range: RangeFrom<Bid>) -> &Self::Output {
let start = encode_bid(range.start);
&self.0[start..]
}
}
impl<T> IndexMut<RangeFrom<Bid>> for Array<T> {
fn index_mut(&mut self, range: RangeFrom<Bid>) -> &mut Self::Output {
let start = encode_bid(range.start);
&mut self.0[start..]
}
}
impl<T> Index<RangeInclusive<Bid>> for Array<T> {
type Output = [T];
fn index(&self, range: RangeInclusive<Bid>) -> &Self::Output {
let start = encode_bid(*range.start());
let end = encode_bid(*range.end());
&self.0[start..=end]
}
}
impl<T> IndexMut<RangeInclusive<Bid>> for Array<T> {
fn index_mut(&mut self, range: RangeInclusive<Bid>) -> &mut Self::Output {
let start = encode_bid(*range.start());
let end = encode_bid(*range.end());
&mut self.0[start..=end]
}
}
impl<T: Default> Default for Array<T> {
fn default() -> Self {
Self::from_fn(|_| T::default())
}
}
pub type Iter<'a, T> = core::iter::Map<Enumerate<Values<'a, T>>, fn((usize, &T)) -> (Call, &T)>;
pub type IterMut<'a, T> =
core::iter::Map<Enumerate<ValuesMut<'a, T>>, fn((usize, &mut T)) -> (Call, &mut T)>;
pub type IntoIter<T> = core::iter::Map<Enumerate<IntoValues<T>>, fn((usize, T)) -> (Call, T)>;
impl<'a, T> IntoIterator for &'a Array<T> {
type Item = (Call, &'a T);
type IntoIter = Iter<'a, T>;
fn into_iter(self) -> Self::IntoIter {
self.values()
.enumerate()
.map(|(index, entry)| (decode_call(index), entry))
}
}
impl<'a, T> IntoIterator for &'a mut Array<T> {
type Item = (Call, &'a mut T);
type IntoIter = IterMut<'a, T>;
fn into_iter(self) -> Self::IntoIter {
self.values_mut()
.enumerate()
.map(|(index, entry)| (decode_call(index), entry))
}
}
impl<T> IntoIterator for Array<T> {
type Item = (Call, T);
type IntoIter = IntoIter<T>;
fn into_iter(self) -> Self::IntoIter {
self.into_values()
.enumerate()
.map(|(index, entry)| (decode_call(index), entry))
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Logits(pub Array<f32>);
impl Logits {
#[must_use]
pub const fn new() -> Self {
Self(Array([f32::NEG_INFINITY; CALL_VARIANTS]))
}
#[must_use]
pub fn softmax(self) -> Option<Array<f32>> {
let max = self.into_values().fold(f32::NEG_INFINITY, f32::max);
(max > f32::NEG_INFINITY).then(|| {
let exp: [_; CALL_VARIANTS] = core::array::from_fn(|i| (self.0.0[i] - max).exp());
let sum: f32 = exp.iter().sum();
Array(core::array::from_fn(|i| exp[i] / sum))
})
}
}
impl Default for Logits {
fn default() -> Self {
Self::new()
}
}
impl Deref for Logits {
type Target = Array<f32>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for Logits {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}