use super::{Adjective, Noun, Verb, Word};
use rand::seq::IndexedRandom;
use serde::{Deserialize, Serialize};
pub trait DictionaryWordType: Word + 'static {
fn get_words(dictionary: &Dictionary) -> &Vec<Self>;
fn get_words_mut(dictionary: &mut Dictionary) -> &mut Vec<Self>;
}
impl DictionaryWordType for Noun {
fn get_words(dictionary: &Dictionary) -> &Vec<Self> {
&dictionary.nouns
}
fn get_words_mut(dictionary: &mut Dictionary) -> &mut Vec<Self> {
&mut dictionary.nouns
}
}
impl DictionaryWordType for Verb {
fn get_words(dictionary: &Dictionary) -> &Vec<Self> {
&dictionary.verbs
}
fn get_words_mut(dictionary: &mut Dictionary) -> &mut Vec<Self> {
&mut dictionary.verbs
}
}
impl DictionaryWordType for Adjective {
fn get_words(dictionary: &Dictionary) -> &Vec<Self> {
&dictionary.adjectives
}
fn get_words_mut(dictionary: &mut Dictionary) -> &mut Vec<Self> {
&mut dictionary.adjectives
}
}
pub trait DictionarySampling {
fn choose<'a, T: DictionaryWordType + 'static>(
&'a self,
rng: &mut impl rand::Rng,
) -> Option<&'a T>;
fn choose_filtered<T: DictionaryWordType, F>(
&self,
filter: F,
rng: &mut impl rand::Rng,
) -> Option<&T>
where
F: Fn(&T) -> bool;
}
#[derive(Default, Clone, Serialize, Deserialize, Debug)]
pub struct Dictionary {
nouns: Vec<Noun>,
verbs: Vec<Verb>,
adjectives: Vec<Adjective>,
}
#[derive(Default, Clone, Debug)]
pub struct DictionaryStack {
dictionaries: Vec<Dictionary>,
}
impl Dictionary {
pub fn new() -> Self {
Self::default()
}
pub fn add_word<T: DictionaryWordType>(&mut self, word: T) {
T::get_words_mut(self).push(word);
}
pub fn add_words<T: DictionaryWordType>(&mut self, words: Vec<T>) {
T::get_words_mut(self).extend(words);
}
pub fn get_all<T: DictionaryWordType>(&self) -> Vec<&T> {
T::get_words(self).iter().collect()
}
pub fn get_filtered<T: DictionaryWordType, F>(&self, filter: F) -> Vec<&T>
where
F: Fn(&T) -> bool,
{
T::get_words(self)
.iter()
.filter(|&word| filter(word))
.collect()
}
pub fn join(&mut self, other: &Dictionary) {
self.nouns.extend(other.nouns.iter().cloned());
self.verbs.extend(other.verbs.iter().cloned());
self.adjectives.extend(other.adjectives.iter().cloned());
}
pub fn combine(a: &Dictionary, b: &Dictionary) -> Dictionary {
let mut combined = a.clone();
combined.join(b);
combined
}
}
impl DictionarySampling for Dictionary {
fn choose<'a, T: DictionaryWordType + 'static>(
&'a self,
rng: &mut impl rand::Rng,
) -> Option<&'a T> {
T::get_words(self).choose(rng)
}
fn choose_filtered<T: DictionaryWordType, F>(
&self,
filter: F,
rng: &mut impl rand::Rng,
) -> Option<&T>
where
F: Fn(&T) -> bool,
{
let filtered_words: Vec<&T> = T::get_words(self)
.iter()
.filter(|&word| filter(word))
.collect();
filtered_words.choose(rng).map(|&word| word)
}
}
impl DictionaryStack {
pub fn new() -> Self {
Self {
dictionaries: Vec::new(),
}
}
pub fn with_dictionary(self, dictionary: Dictionary) -> Self {
let mut new_set = self;
new_set.dictionaries.push(dictionary);
new_set
}
}
impl From<Dictionary> for DictionaryStack {
fn from(dictionary: Dictionary) -> Self {
Self {
dictionaries: vec![dictionary],
}
}
}
impl From<Vec<Dictionary>> for DictionaryStack {
fn from(dictionaries: Vec<Dictionary>) -> Self {
Self { dictionaries }
}
}
impl DictionarySampling for DictionaryStack {
fn choose<'a, T: DictionaryWordType + 'static>(
&'a self,
rng: &mut impl rand::Rng,
) -> Option<&'a T> {
for dictionary in self.dictionaries.iter() {
if let Some(word) = dictionary.choose::<T>(rng) {
return Some(word);
}
}
None
}
fn choose_filtered<T: DictionaryWordType, F>(
&self,
filter: F,
rng: &mut impl rand::Rng,
) -> Option<&T>
where
F: Fn(&T) -> bool,
{
for dictionary in self.dictionaries.iter() {
if let Some(word) = dictionary.choose_filtered::<T, _>(&filter, rng) {
return Some(word);
}
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::language::Noun;
use crate::language::Verb;
#[test]
fn dictionary_test() {
let mut dict = Dictionary::new();
dict.add_word(Noun::new_proper("bilbo"));
dict.add_word(Noun::new_common("ring"));
let nouns = dict.get_all::<Noun>();
assert_eq!(nouns.len(), 2);
assert_eq!(nouns[0].as_ref(), "Bilbo");
assert_eq!(nouns[1].as_ref(), "ring");
}
#[test]
fn dictionary_filter_test() {
let mut dict = Dictionary::new();
dict.add_word(Noun::new_proper("bilbo"));
dict.add_word(Noun::new_common("ring"));
dict.add_word(Noun::new_collective("fellowship"));
let common_nouns = dict.get_filtered::<Noun, _>(|n| n.is_common());
assert_eq!(common_nouns.len(), 1);
assert_eq!(common_nouns[0].as_ref(), "ring");
let proper_nouns = dict.get_filtered::<Noun, _>(|n| n.is_proper());
assert_eq!(proper_nouns.len(), 1);
assert_eq!(proper_nouns[0].as_ref(), "Bilbo");
let collective_nouns = dict.get_filtered::<Noun, _>(|n| n.is_collective());
assert_eq!(collective_nouns.len(), 1);
assert_eq!(collective_nouns[0].as_ref(), "fellowship");
}
#[test]
fn dictionary_random_choice_test() {
let mut dict = Dictionary::new();
dict.add_word(Noun::new_proper("Aragorn"));
dict.add_word(Noun::new_common("king"));
dict.add_word(Verb::new_regular("walk"));
let mut rng = rand::rng();
let random_noun = dict.choose::<Noun>(&mut rng);
assert!(random_noun.is_some());
let random_proper_noun = dict.choose_filtered::<Noun, _>(|n| n.is_proper(), &mut rng);
assert!(random_proper_noun.is_some());
assert_eq!(random_proper_noun.unwrap().as_ref(), "Aragorn");
}
#[test]
fn dictionary_clone_test() {
let mut original_dict = Dictionary::new();
original_dict.add_word(Noun::new_proper("Frodo"));
original_dict.add_word(Verb::new_regular("run"));
original_dict.add_word(Adjective::new_regular("brave"));
let cloned_dict = original_dict.clone();
let original_nouns = original_dict.get_all::<Noun>();
let cloned_nouns = cloned_dict.get_all::<Noun>();
assert_eq!(original_nouns.len(), cloned_nouns.len());
assert_eq!(original_nouns[0].as_ref(), cloned_nouns[0].as_ref());
let original_verbs = original_dict.get_all::<Verb>();
let cloned_verbs = cloned_dict.get_all::<Verb>();
assert_eq!(original_verbs.len(), cloned_verbs.len());
assert_eq!(original_verbs[0].as_ref(), cloned_verbs[0].as_ref());
let original_adjectives = original_dict.get_all::<Adjective>();
let cloned_adjectives = cloned_dict.get_all::<Adjective>();
assert_eq!(original_adjectives.len(), cloned_adjectives.len());
assert_eq!(
original_adjectives[0].as_ref(),
cloned_adjectives[0].as_ref()
);
}
#[test]
fn dictionary_join_test() {
let mut dict1 = Dictionary::new();
dict1.add_word(Noun::new_proper("Frodo"));
dict1.add_word(Verb::new_regular("carry"));
let mut dict2 = Dictionary::new();
dict2.add_word(Noun::new_common("ring"));
dict2.add_word(Adjective::new_regular("heavy"));
dict1.join(&dict2);
let nouns = dict1.get_all::<Noun>();
assert_eq!(nouns.len(), 2);
assert_eq!(nouns[0].as_ref(), "Frodo");
assert_eq!(nouns[1].as_ref(), "ring");
assert_eq!(dict1.get_all::<Verb>().len(), 1);
assert_eq!(dict1.get_all::<Adjective>().len(), 1);
}
#[test]
fn dictionary_stack_test() {
let mut rng = rand::rng();
let mut dict1 = Dictionary::new();
dict1.add_word(Verb::new_regular("run"));
let mut dict2 = Dictionary::new();
dict2.add_word(Noun::new_proper("Sauron"));
dict2.add_word(Verb::new_regular("sit"));
let stack = DictionaryStack::from(vec![dict1, dict2]);
let noun = stack.choose::<Noun>(&mut rng);
assert!(noun.is_some(), "Should find a noun in the second dictionary");
assert_eq!(noun.unwrap().as_ref(), "Sauron");
let verb = stack.choose::<Verb>(&mut rng);
assert!(verb.is_some(), "Should find a verb in the first dictionary");
assert_eq!(verb.unwrap().as_ref(), "run");
}
}