use std::collections::HashMap;
use std::fmt;
use std::sync::Mutex;
use itertools::{iproduct, join};
use linked_hash_map::LinkedHashMap;
#[derive(Clone, Debug)]
pub struct GrammarRule<'symbol> {
left_symbol: &'symbol str,
right_symbol: &'symbol str,
}
pub trait Grammar<'grammar> {
fn convert(&self) -> Vec<GrammarRule<'grammar>>;
}
pub trait WordBank {
fn lookup(&self, word: &str) -> &str;
}
pub struct CYK<'rules, W> {
grammar_rules: Vec<GrammarRule<'rules>>,
word_bank: W,
}
#[derive(Clone, Debug, Hash, PartialEq)]
struct MatrixIndicator {
x: usize,
y: usize,
}
impl fmt::Display for MatrixIndicator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "({}, {})", self.x, self.y)
}
}
impl Eq for MatrixIndicator {}
#[derive(Clone, Debug)]
pub struct MatrixResult {
map: HashMap<MatrixIndicator, String>,
final_res: Option<String>,
num_words: usize,
}
impl MatrixResult {
fn new() -> Self {
Self {
map: HashMap::new(),
final_res: None,
num_words: 0
}
}
pub fn get_final(&self) -> Option<String> {
self.final_res.clone()
}
fn set_final(&mut self, final_res: String) {
self.final_res = Some(final_res);
}
fn insert(&mut self, mi: MatrixIndicator, res: String) {
self.map.insert(mi, res);
}
fn set_num_words(&mut self, size: usize) {
self.num_words = size;
}
pub fn get_num_words(&self) -> usize {
self.num_words
}
}
impl fmt::Display for MatrixResult {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.num_words == 0 {
return write!(f, "No result caluclated.");
}
let mut output = Vec::new();
output.push("LN# 1:".to_owned());
for position in 0..self.num_words {
output.push(format!("\t({}, {}):", position+1, position+1));
let entry = self.map.get(&MatrixIndicator{ x: position, y: position }).expect("Can not be empty");
output.push(format!("{}", entry));
}
output.push("\n".to_owned());
let mut line_num = 2;
for ln in 1..self.num_words {
output.push(format!("LN# {}:", line_num));
for x in 0..self.num_words-ln {
let entry = self.map.get(&MatrixIndicator{ x, y: x+ln }).expect("Can not be empty");
if entry == "" {
continue;
}
output.push(format!("\t({}, {}):", x, x+ln));
output.push(format!("{}", entry));
}
output.push("\n".to_owned());
line_num += 1;
}
write!(f, "{}", join(&output, ""))
}
}
lazy_static::lazy_static! {
static ref MEMO: Mutex<LinkedHashMap<&'static str, MatrixResult>> = Mutex::new(LinkedHashMap::with_capacity(101));
}
fn vec_production(str1: &str, str2: &str) -> Vec<String> {
iproduct!(
str1.split(" ").collect::<Vec<&str>>(),
str2.split(" ").collect::<Vec<&str>>())
.map(|vals| {
join(&[vals.0, vals.1], " ")
})
.collect::<Vec<String>>()
}
impl<'grammar, W> CYK<'grammar, W> where
W: WordBank {
pub fn new<G>(rules: G, word_bank: W) -> Self where
G: Grammar<'grammar> {
Self {
grammar_rules: rules.convert(),
word_bank
}
}
fn find_terminal_assign(&self, terminal: &str) -> String {
let mut res = Vec::new();
for grammar in &self.grammar_rules {
for rule in grammar.right_symbol.split(" | ").collect::<Vec<&str>>() {
if rule == terminal {
res.push(grammar.left_symbol.clone());
}
}
}
join(res, " ")
}
pub fn parse<'word>(&self, input: &'word str) -> MatrixResult {
let mut result: MatrixResult = MatrixResult::new();
let words = input.split_whitespace().collect::<Vec<&str>>();
let num_words = words.len();
result.set_num_words(num_words);
for (pos, word) in words.iter().enumerate() {
let terminal = self.word_bank.lookup(word);
result.insert(MatrixIndicator{ x: pos, y: pos }, self.find_terminal_assign(terminal));
}
for l in 1..=num_words {
for i in 0..(num_words - l) {
let j = i + l;
let mut targets: Vec<String> = Vec::new();
for k in 1..=j {
let empty = String::from("");
let fv = result.map.get(&MatrixIndicator{ x: i, y: i+k-1 }).unwrap_or(&empty);
let sv = result.map.get(&MatrixIndicator{ x: i+k, y: j }).unwrap_or(&empty);
let mut products = vec_production(fv, sv);
targets.append(&mut products);
}
let mut res = String::from("");
for target in targets {
let target_symbol = self.find_terminal_assign(target.as_str());
if !res.contains(&target_symbol) {
res = match res.as_str() {
"" => target_symbol,
_ => join(&[res, target_symbol], " ")
};
}
}
result.insert(MatrixIndicator{ x: i, y: j }, res);
}
}
let final_result = result.map.get(&MatrixIndicator{ x: 0, y: num_words - 1 }).expect("Can not be empty").to_owned();
result.set_final(final_result);
return result;
}
pub fn memoized_parse<'word>(&self, input: &'static str) -> MatrixResult {
if MEMO.lock().expect("Memo should not be NONE.").contains_key(input) {
return MEMO.lock().expect("Memo should not be NONE.").get(input).expect("Should never be none.").clone();
}
let res = self.parse(input);
MEMO.lock().expect("Memo should not be NONE.").insert(input, res.clone());
if MEMO.lock().expect("Memo should not be NONE.").len() > 100 {
MEMO.lock().expect("Memo should not be NONE.").pop_back();
}
res
}
}
#[cfg(test)]
mod tests {
use super::*;
struct G {}
impl<'grammar> Grammar<'grammar> for G {
fn convert(&self) -> Vec<GrammarRule<'grammar>> {
let mut rules = Vec::new();
rules.push(GrammarRule{ left_symbol: "ActionSentence", right_symbol: "Verb Noun | Verb NounClause | ActionSentence PrepClause" });
rules.push(GrammarRule{ left_symbol: "DescriptiveSentence", right_symbol: "Noun Verb Adjective" });
rules.push(GrammarRule{ left_symbol: "NounClause", right_symbol: "Count ANoun | Adjective Noun"});
rules.push(GrammarRule{ left_symbol: "PrepClause", right_symbol: "Prep Noun" });
rules.push(GrammarRule{ left_symbol: "ANoun", right_symbol: "Adjective Noun" });
rules.push(GrammarRule{ left_symbol: "Adjective", right_symbol: "adjective" });
rules.push(GrammarRule{ left_symbol: "Prep", right_symbol: "prep" });
rules.push(GrammarRule{ left_symbol: "Verb", right_symbol: "verb" });
rules.push(GrammarRule{ left_symbol: "Noun", right_symbol: "noun" });
rules.push(GrammarRule{ left_symbol: "Count", right_symbol: "definiteArticle | indefiniteArticle | number" });
rules
}
}
struct WB {}
impl WordBank for WB {
fn lookup(&self, word: &str) -> &str {
match word {
"examine" => "verb",
"sword" => "noun",
"rusty" => "adjective",
"google" => "verb",
"is" => "verb",
"cool" => "adjective",
"from" => "prep",
"apple" => "noun",
"take" => "verb",
"table" => "noun",
_ => "fuck"
}
}
}
#[test]
fn basic_test() {
let g = G{};
let wb = WB{};
let input = "examine rusty sword";
let cyk: CYK<WB> = CYK::new(g, wb);
let res = cyk.parse(input);
assert_eq!(Some("ActionSentence".to_owned()), res.get_final());
}
#[test]
fn double_meaning_test() {
let g = G{};
let wb = WB{};
let input = "google sword";
let cyk: CYK<WB> = CYK::new(g, wb);
let res = cyk.parse(input);
assert_eq!(Some("ActionSentence".to_owned()), res.get_final());
}
#[test]
fn complicated_test() {
let g = G{};
let wb = WB{};
let input = "take apple from table";
let cyk: CYK<WB> = CYK::new(g, wb);
let res = cyk.parse(input);
assert_eq!(Some("ActionSentence".to_owned()), res.get_final());
}
}