use std::any::TypeId;
use std::collections::HashMap;
use std::hash::Hash;
use solverforge_core::score::Score;
#[cfg(test)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum ExistsStorageKind {
Hashed,
IndexedUsize,
}
pub(super) struct HashedExistsState<K, Sc> {
a_indices_by_key: HashMap<K, Vec<usize>>,
a_score_totals_by_key: HashMap<K, Sc>,
b_key_counts: HashMap<K, usize>,
}
impl<K, Sc> Default for HashedExistsState<K, Sc>
where
Sc: Score,
{
fn default() -> Self {
Self {
a_indices_by_key: HashMap::new(),
a_score_totals_by_key: HashMap::new(),
b_key_counts: HashMap::new(),
}
}
}
#[derive(Default)]
pub(super) struct IndexedUsizeExistsState<Sc> {
a_indices_by_key: Vec<Vec<usize>>,
a_score_totals_by_key: Vec<Sc>,
b_key_counts: Vec<usize>,
}
pub(super) enum ExistsKeyState<K, Sc>
where
Sc: Score,
{
Hashed(HashedExistsState<K, Sc>),
IndexedUsize(IndexedUsizeExistsState<Sc>),
}
#[derive(Debug, Clone, Copy)]
pub(super) struct MovedAIndex {
pub(super) idx: usize,
pub(super) bucket_pos: usize,
}
impl<K, Sc> ExistsKeyState<K, Sc>
where
K: Eq + Hash + Clone + 'static,
Sc: Score,
{
pub(super) fn new() -> Self {
if TypeId::of::<K>() == TypeId::of::<usize>() {
Self::IndexedUsize(IndexedUsizeExistsState::default())
} else {
Self::Hashed(HashedExistsState::default())
}
}
#[cfg(test)]
pub(super) fn storage_kind(&self) -> ExistsStorageKind {
match self {
Self::Hashed(_) => ExistsStorageKind::Hashed,
Self::IndexedUsize(_) => ExistsStorageKind::IndexedUsize,
}
}
pub(super) fn clear_a_buckets(&mut self) {
match self {
Self::Hashed(state) => {
state.a_indices_by_key.clear();
state.a_score_totals_by_key.clear();
}
Self::IndexedUsize(state) => {
state.a_indices_by_key.clear();
state.a_score_totals_by_key.clear();
}
}
}
pub(super) fn clear_b_counts(&mut self) {
match self {
Self::Hashed(state) => state.b_key_counts.clear(),
Self::IndexedUsize(state) => state.b_key_counts.clear(),
}
}
pub(super) fn insert_a_index(&mut self, key: K, idx: usize) -> usize {
match self {
Self::Hashed(state) => {
let bucket = state.a_indices_by_key.entry(key).or_default();
let bucket_pos = bucket.len();
bucket.push(idx);
bucket_pos
}
Self::IndexedUsize(state) => {
let key = usize_key(&key);
if state.a_indices_by_key.len() <= key {
state.a_indices_by_key.resize_with(key + 1, Vec::new);
}
let bucket = &mut state.a_indices_by_key[key];
let bucket_pos = bucket.len();
bucket.push(idx);
bucket_pos
}
}
}
pub(super) fn add_a_score(&mut self, key: &K, score: Sc) {
match self {
Self::Hashed(state) => {
let total = state
.a_score_totals_by_key
.entry(key.clone())
.or_insert_with(Sc::zero);
*total = *total + score;
}
Self::IndexedUsize(state) => {
let key = usize_key(key);
if state.a_score_totals_by_key.len() <= key {
state.a_score_totals_by_key.resize(key + 1, Sc::zero());
}
state.a_score_totals_by_key[key] = state.a_score_totals_by_key[key] + score;
}
}
}
pub(super) fn subtract_a_score(&mut self, key: &K, score: Sc) {
match self {
Self::Hashed(state) => {
if let Some(total) = state.a_score_totals_by_key.get_mut(key) {
*total = *total - score;
}
}
Self::IndexedUsize(state) => {
let key = usize_key(key);
if let Some(total) = state.a_score_totals_by_key.get_mut(key) {
*total = *total - score;
}
}
}
}
pub(super) fn a_score_total(&self, key: &K) -> Sc {
match self {
Self::Hashed(state) => state
.a_score_totals_by_key
.get(key)
.copied()
.unwrap_or_else(Sc::zero),
Self::IndexedUsize(state) => state
.a_score_totals_by_key
.get(usize_key(key))
.copied()
.unwrap_or_else(Sc::zero),
}
}
pub(super) fn remove_a_index(
&mut self,
key: &K,
idx: usize,
bucket_pos: usize,
) -> Option<MovedAIndex> {
match self {
Self::Hashed(state) => {
let mut remove_key = false;
let mut moved = None;
if let Some(bucket) = state.a_indices_by_key.get_mut(key) {
let removed = bucket.swap_remove(bucket_pos);
debug_assert_eq!(removed, idx);
if bucket_pos < bucket.len() {
moved = Some(MovedAIndex {
idx: bucket[bucket_pos],
bucket_pos,
});
}
remove_key = bucket.is_empty();
}
if remove_key {
state.a_indices_by_key.remove(key);
}
moved
}
Self::IndexedUsize(state) => {
let key = usize_key(key);
let bucket = state.a_indices_by_key.get_mut(key)?;
let removed = bucket.swap_remove(bucket_pos);
debug_assert_eq!(removed, idx);
if bucket_pos < bucket.len() {
Some(MovedAIndex {
idx: bucket[bucket_pos],
bucket_pos,
})
} else {
None
}
}
}
}
pub(super) fn increment_b_count(&mut self, key: &K, count: usize) {
match self {
Self::Hashed(state) => {
*state.b_key_counts.entry(key.clone()).or_insert(0) += count;
}
Self::IndexedUsize(state) => {
let key = usize_key(key);
if state.b_key_counts.len() <= key {
state.b_key_counts.resize(key + 1, 0);
}
state.b_key_counts[key] += count;
}
}
}
pub(super) fn decrement_b_count(&mut self, key: &K, count: usize) {
match self {
Self::Hashed(state) => {
let mut remove_key = false;
if let Some(entry) = state.b_key_counts.get_mut(key) {
*entry = entry.saturating_sub(count);
remove_key = *entry == 0;
}
if remove_key {
state.b_key_counts.remove(key);
}
}
Self::IndexedUsize(state) => {
let key = usize_key(key);
if let Some(entry) = state.b_key_counts.get_mut(key) {
*entry = entry.saturating_sub(count);
}
}
}
}
pub(super) fn b_count(&self, key: &K) -> usize {
match self {
Self::Hashed(state) => state.b_key_counts.get(key).copied().unwrap_or(0),
Self::IndexedUsize(state) => {
state.b_key_counts.get(usize_key(key)).copied().unwrap_or(0)
}
}
}
}
#[inline]
fn usize_key<K: 'static>(key: &K) -> usize {
debug_assert_eq!(TypeId::of::<K>(), TypeId::of::<usize>());
unsafe { *(key as *const K).cast::<usize>() }
}