use super::map::Map;
use super::{Auction, Call, Hand, IllegalCall, Vulnerability};
use core::ops::{Index, IndexMut};
pub type Classifier = fn(Hand) -> super::array::Logits;
#[derive(Clone)]
pub struct Trie {
children: Map<Box<Self>>,
classify: Option<Classifier>,
}
impl Default for Trie {
fn default() -> Self {
Self::new()
}
}
impl Trie {
#[must_use]
pub const fn new() -> Self {
Self {
children: Map::new(),
classify: None,
}
}
#[must_use]
fn subtrie(&self, auction: &[Call]) -> Option<&Self> {
let mut node = self;
for &call in auction {
node = node.children.get(call)?;
}
Some(node)
}
#[must_use]
pub fn get(&self, auction: &[Call]) -> Option<&Classifier> {
self.subtrie(auction)
.and_then(|node| node.classify.as_ref())
}
#[must_use]
pub fn is_prefix(&self, auction: &[Call]) -> bool {
self.subtrie(auction).is_some()
}
#[must_use]
pub fn longest_prefix<'a>(&self, auction: &'a [Call]) -> Option<(&'a [Call], &Classifier)> {
let mut prefix = self.classify.as_ref().map(|f| (&[][..], f));
let mut node = self;
for (depth, &call) in auction.iter().enumerate() {
node = match node.children.get(call) {
Some(child) => child,
None => break,
};
if let Some(f) = node.classify.as_ref() {
prefix.replace((&auction[..=depth], f));
}
}
prefix
}
pub fn insert(&mut self, auction: &[Call], f: Classifier) -> Option<Classifier> {
let mut node = self;
for &call in auction {
node = node.children.entry(call).get_or_insert_with(Box::default);
}
node.classify.replace(f)
}
#[must_use]
pub fn iter(&'_ self) -> Suffixes<'_> {
self.suffixes(Auction::new())
}
#[must_use]
pub fn suffixes(&'_ self, auction: Auction) -> Suffixes<'_> {
Suffixes::new(self, auction)
}
#[must_use]
pub const fn common_prefixes(&'_ self, auction: Auction) -> CommonPrefixes<'_> {
CommonPrefixes::new(self, auction)
}
}
impl<'a> IntoIterator for &'a Trie {
type Item = (Box<[Call]>, Result<&'a Classifier, IllegalCall>);
type IntoIter = Suffixes<'a>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
#[derive(Clone, Copy)]
struct StackEntry<'a> {
depth: usize,
call: Call,
node: &'a Trie,
}
fn collect_children(node: &'_ Trie, depth: usize) -> impl Iterator<Item = StackEntry<'_>> {
node.children.iter().map(move |(call, child)| StackEntry {
depth,
call,
node: child,
})
}
#[derive(Clone)]
pub struct Suffixes<'a> {
stack: Vec<StackEntry<'a>>,
auction: Auction,
separator: usize,
value: Option<&'a Classifier>,
}
impl<'a> Suffixes<'a> {
#[must_use]
pub const fn empty() -> Self {
Self {
stack: Vec::new(),
auction: Auction::new(),
separator: 0,
value: None,
}
}
#[must_use]
pub fn new(trie: &'a Trie, auction: Auction) -> Self {
let Some(node) = trie.subtrie(&auction) else {
return Self::empty();
};
Self {
stack: collect_children(node, 0).collect(),
separator: auction.len(),
value: node.classify.as_ref(),
auction,
}
}
}
impl<'a> Iterator for Suffixes<'a> {
type Item = (Box<[Call]>, Result<&'a Classifier, IllegalCall>);
fn next(&mut self) -> Option<Self::Item> {
while self.value.is_none() {
let entry = self.stack.pop()?;
self.stack
.extend(collect_children(entry.node, entry.depth + 1));
self.value = entry.node.classify.as_ref();
self.auction.truncate(self.separator + entry.depth);
if let Err(e) = self.auction.force_push(entry.call) {
return Some((self.auction[self.separator..].into(), Err(e)));
}
}
Some((
self.auction[self.separator..].into(),
Ok(self.value.take()?),
))
}
}
#[derive(Clone)]
pub struct CommonPrefixes<'a> {
trie: &'a Trie,
query: Auction,
depth: usize,
value: Option<&'a Classifier>,
}
impl<'a> CommonPrefixes<'a> {
#[must_use]
pub const fn new(trie: &'a Trie, query: Auction) -> Self {
Self {
trie,
query,
depth: 0,
value: trie.classify.as_ref(),
}
}
}
impl<'a> Iterator for CommonPrefixes<'a> {
type Item = (Box<[Call]>, &'a Classifier);
fn next(&mut self) -> Option<Self::Item> {
while self.value.is_none() {
let &call = self.query.get(self.depth)?;
self.trie = self.trie.children.get(call)?;
self.value = self.trie.classify.as_ref();
self.depth += 1;
}
Some((self.query[..self.depth].into(), self.value.take()?))
}
}
impl Index<Vulnerability> for Trie {
type Output = Self;
fn index(&self, _: Vulnerability) -> &Self {
self
}
}
#[derive(Clone)]
pub struct Forest([Trie; 4]);
impl Index<Vulnerability> for Forest {
type Output = Trie;
fn index(&self, index: Vulnerability) -> &Trie {
&self.0[usize::from(index.bits())]
}
}
impl IndexMut<Vulnerability> for Forest {
fn index_mut(&mut self, index: Vulnerability) -> &mut Trie {
&mut self.0[usize::from(index.bits())]
}
}