use hashbrown::HashMap;
use rand::Rng;
use rand_distr::{Distribution, weighted::WeightedAliasIndex};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::token::Token;
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct TokenDistribution {
dist: WeightedAliasIndex<u64>,
choices: Vec<Token>,
}
impl TokenDistribution {
pub fn builder() -> TokenDistributionBuilder {
TokenDistributionBuilder::new()
}
pub fn get_random_token(&self, rng: &mut impl Rng) -> &Token {
&self.choices[self.dist.sample(rng)]
}
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct TokenDistributionBuilder {
map: HashMap<String, u64>,
}
impl TokenDistributionBuilder {
pub fn new() -> Self {
Self {
map: HashMap::new(),
}
}
pub fn build(self) -> TokenDistribution {
let mut choices = Vec::with_capacity(self.map.len());
let mut occurances = Vec::with_capacity(self.map.len());
for (token, n) in self.map {
choices.push(token);
occurances.push(n);
}
TokenDistribution {
dist: WeightedAliasIndex::new(occurances)
.expect("failed to create weighted alias index"),
choices,
}
}
pub fn add_token(&mut self, token: &str) {
match self.map.get_mut(token) {
Some(n) => {
*n += 1;
}
None => {
self.map.insert(token.to_string(), 1);
}
}
}
}
impl Default for TokenDistributionBuilder {
fn default() -> Self {
Self::new()
}
}