use super::{Call, Hand, Map, RelativeVulnerability};
use core::fmt;
use core::iter::FusedIterator;
use core::ops::{Index, IndexMut};
use std::sync::Arc;
pub trait Classifier: Send + Sync {
fn classify(
&self,
hand: Hand,
vul: RelativeVulnerability,
prefixes: CommonPrefixes<'_, '_>,
) -> super::array::Logits;
}
impl fmt::Debug for dyn Classifier {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Classifier({:p})", &self)
}
}
impl<F> Classifier for F
where
F: Fn(Hand, RelativeVulnerability) -> super::array::Logits + Send + Sync,
{
fn classify(
&self,
hand: Hand,
vul: RelativeVulnerability,
_: CommonPrefixes<'_, '_>,
) -> super::array::Logits {
self(hand, vul)
}
}
#[derive(Debug, Clone)]
pub struct Trie {
children: Map<Box<Self>>,
classify: Option<Arc<dyn Classifier>>,
}
impl Default for Trie {
fn default() -> Self {
Self::new()
}
}
impl Trie {
#[must_use]
pub const fn new() -> Self {
Self {
children: Map::new(),
classify: None,
}
}
#[must_use]
fn subtrie(&self, auction: &[Call]) -> Option<&Self> {
let mut node = self;
for &call in auction {
node = node.children.get(call)?;
}
Some(node)
}
#[must_use]
pub fn get(&self, auction: &[Call]) -> Option<&dyn Classifier> {
self.subtrie(auction)
.and_then(|node| node.classify.as_deref())
}
#[must_use]
pub fn is_prefix(&self, auction: &[Call]) -> bool {
self.subtrie(auction).is_some()
}
#[must_use]
pub fn longest_prefix<'a>(&self, auction: &'a [Call]) -> Option<(&'a [Call], &dyn Classifier)> {
let mut prefix = self.classify.as_deref().map(|f| (&[][..], f));
let mut node = self;
for (depth, &call) in auction.iter().enumerate() {
node = match node.children.get(call) {
Some(child) => child,
None => break,
};
if let Some(f) = node.classify.as_deref() {
prefix.replace((&auction[..=depth], f));
}
}
prefix
}
pub fn insert(
&mut self,
auction: &[Call],
f: impl Classifier + 'static,
) -> Option<Arc<dyn Classifier>> {
let mut node = self;
for &call in auction {
node = node.children.entry(call).get_or_insert_with(Box::default);
}
node.classify.replace(Arc::new(f))
}
#[must_use]
pub fn iter(&'_ self) -> Suffixes<'_> {
self.suffixes(&[])
}
#[must_use]
pub fn suffixes(&self, auction: &[Call]) -> Suffixes<'_> {
Suffixes::new(self, auction)
}
#[must_use]
pub fn common_prefixes<'q>(&self, query: &'q [Call]) -> CommonPrefixes<'_, 'q> {
CommonPrefixes::new(self, query)
}
}
impl<'a> IntoIterator for &'a Trie {
type Item = (Box<[Call]>, &'a dyn Classifier);
type IntoIter = Suffixes<'a>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
#[derive(Debug, Clone, Copy)]
struct StackEntry<'a> {
depth: usize,
call: Call,
node: &'a Trie,
}
fn collect_children(node: &'_ Trie, depth: usize) -> impl Iterator<Item = StackEntry<'_>> {
node.children.iter().map(move |(call, child)| StackEntry {
depth,
call,
node: child,
})
}
#[derive(Clone)]
pub struct Suffixes<'a> {
stack: Vec<StackEntry<'a>>,
auction: Vec<Call>,
separator: usize,
value: Option<&'a dyn Classifier>,
}
impl<'a> Suffixes<'a> {
#[must_use]
pub const fn empty() -> Self {
Self {
stack: Vec::new(),
auction: Vec::new(),
separator: 0,
value: None,
}
}
#[must_use]
pub fn new(trie: &'a Trie, auction: &[Call]) -> Self {
let Some(node) = trie.subtrie(auction) else {
return Self::empty();
};
Self {
stack: collect_children(node, 0).collect(),
separator: auction.len(),
value: node.classify.as_deref(),
auction: auction.to_vec(),
}
}
}
impl fmt::Debug for Suffixes<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Suffixes")
.field("auction", &self.auction)
.field("separator", &self.separator)
.field("pending", &self.stack.len())
.field("has_value", &self.value.is_some())
.finish()
}
}
impl<'a> Iterator for Suffixes<'a> {
type Item = (Box<[Call]>, &'a dyn Classifier);
fn next(&mut self) -> Option<Self::Item> {
while self.value.is_none() {
let entry = self.stack.pop()?;
self.stack
.extend(collect_children(entry.node, entry.depth + 1));
self.value = entry.node.classify.as_deref();
self.auction.truncate(self.separator + entry.depth);
self.auction.push(entry.call);
}
Some((self.auction[self.separator..].into(), self.value.take()?))
}
}
impl FusedIterator for Suffixes<'_> {}
#[derive(Clone)]
pub struct CommonPrefixes<'trie, 'q> {
trie: &'trie Trie,
query: &'q [Call],
depth: usize,
value: Option<&'trie dyn Classifier>,
}
impl<'trie, 'q> CommonPrefixes<'trie, 'q> {
#[must_use]
pub fn new(trie: &'trie Trie, query: &'q [Call]) -> Self {
Self {
trie,
query,
depth: 0,
value: trie.classify.as_deref(),
}
}
}
impl fmt::Debug for CommonPrefixes<'_, '_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CommonPrefixes")
.field("query", &self.query)
.field("depth", &self.depth)
.field("has_value", &self.value.is_some())
.finish()
}
}
impl<'trie, 'q> Iterator for CommonPrefixes<'trie, 'q> {
type Item = (&'q [Call], &'trie dyn Classifier);
fn next(&mut self) -> Option<Self::Item> {
while self.value.is_none() {
let &call = self.query.get(self.depth)?;
self.trie = self.trie.children.get(call)?;
self.value = self.trie.classify.as_deref();
self.depth += 1;
}
Some((&self.query[..self.depth], self.value.take()?))
}
}
impl FusedIterator for CommonPrefixes<'_, '_> {}
#[derive(Clone, Debug)]
pub struct Forest([Trie; 4]);
impl Forest {
#[must_use]
pub const fn new() -> Self {
Self([Trie::new(), Trie::new(), Trie::new(), Trie::new()])
}
#[must_use]
pub fn from_fn(mut f: impl FnMut(RelativeVulnerability) -> Trie) -> Self {
Self([
f(RelativeVulnerability::NONE),
f(RelativeVulnerability::WE),
f(RelativeVulnerability::THEY),
f(RelativeVulnerability::ALL),
])
}
}
impl Default for Forest {
fn default() -> Self {
Self::new()
}
}
impl Index<RelativeVulnerability> for Forest {
type Output = Trie;
fn index(&self, index: RelativeVulnerability) -> &Trie {
&self.0[usize::from(index.bits())]
}
}
impl IndexMut<RelativeVulnerability> for Forest {
fn index_mut(&mut self, index: RelativeVulnerability) -> &mut Trie {
&mut self.0[usize::from(index.bits())]
}
}