use itertools::Itertools;
use log::{debug, warn};
use crate::{ast, inheritance, errors::ParseError};
use std::{cmp::Reverse, collections::HashMap};
pub struct SortedResult {
pub topics: HashMap<String, Vec<ast::Trigger>>,
pub thats: HashMap<String, Vec<ast::Trigger>>,
pub subs: Vec<String>,
pub person: Vec<String>,
}
pub fn sort_triggers(brain: &ast::AST) -> Result<SortedResult, ParseError> {
let mut result = SortedResult{
topics: HashMap::new(),
thats: HashMap::new(),
subs: Vec::new(),
person: Vec::new(),
};
if brain.topics.len() == 0 {
return Err(ParseError::new(
"sort_triggers: no topics were found. Did you load any RiveScript code?",
));
}
warn!("Sorting triggers...");
for name in brain.topics.keys() {
let topic = brain.topics.get(name).unwrap();
debug!("Analyzing topic {}", name);
let all_triggers = inheritance::get_topic_triggers(brain, topic, false);
result.topics.insert(name.to_string(), sort_trigger_set(all_triggers.to_vec()));
let that_triggers = inheritance::get_topic_triggers(brain, topic, true);
result.thats.insert(name.to_string(), sort_trigger_set(that_triggers));
}
result.subs = sort_list(brain.subs.clone());
result.person = sort_list(brain.person.clone());
Ok(result)
}
fn sort_trigger_set(triggers: Vec<ast::Trigger>) -> Vec<ast::Trigger> {
let mut running: Vec<ast::Trigger> = Vec::new();
let mut prior: HashMap<isize, Vec<ast::Trigger>> = HashMap::new();
for trigger in triggers {
let mut weight: isize = 0;
match rivescript_core::regex::WEIGHT.captures(trigger.trigger.as_str()) {
Some(cap) => {
weight = cap.get(1).unwrap().as_str().parse::<isize>().unwrap_or(0);
},
None => (),
}
if !prior.contains_key(&weight) {
prior.insert(weight, Vec::new());
}
let mut vt: Vec<ast::Trigger> = prior.get(&weight).unwrap().to_vec();
vt.insert(vt.len(), trigger);
prior.insert(weight, vt);
}
for p in prior.keys().sorted().rev() {
warn!("Sorting triggers with weight: {}", p);
let mut inherits = -1;
let mut highest_inherits = -1;
let mut track: HashMap<isize, SortTracker> = HashMap::new();
track.insert(inherits, SortTracker::new());
for trig in prior.get(p).unwrap().to_vec() {
let mut pattern = trig.trigger.clone();
debug!("Looking at trigger: {pattern}");
match rivescript_core::regex::INHERITS.captures(pattern.as_str()) {
Some(cap) => {
inherits = cap.get(1).unwrap().as_str().parse::<isize>().unwrap_or(-1);
if inherits > highest_inherits {
highest_inherits = inherits;
}
debug!("Trigger belongs to a topic that inherits other topics. Level={inherits}");
pattern = rivescript_core::regex::INHERITS.replace_all(&pattern, "").to_string();
},
None => (),
}
let this_track = track.entry(inherits).or_insert_with(SortTracker::new);
if pattern.contains("_") {
let wc = word_count(&pattern, false);
debug!("Has a _ wildcard and {wc} words");
if wc > 0 {
let entries = this_track.alpha.entry(wc).or_insert_with(Vec::new);
entries.push(SortedTriggerEntry {
text: pattern,
pointer: trig,
});
} else {
this_track.under.push(SortedTriggerEntry {
text: pattern,
pointer: trig,
});
}
} else if pattern.contains("#") {
let wc = word_count(&pattern, false);
debug!("Has a # wildcard and {wc} words");
if wc > 0 {
let entries = this_track.number.entry(wc).or_insert_with(Vec::new);
entries.push(SortedTriggerEntry {
text: pattern,
pointer: trig,
});
} else {
this_track.pound.push(SortedTriggerEntry {
text: pattern,
pointer: trig,
});
}
} else if pattern.contains("*") {
let wc = word_count(&pattern, false);
debug!("Has a * wildcard and {wc} words");
if wc > 0 {
let entries = this_track.wild.entry(wc).or_insert_with(Vec::new);
entries.push(SortedTriggerEntry {
text: pattern,
pointer: trig,
});
} else {
this_track.star.push(SortedTriggerEntry {
text: pattern,
pointer: trig,
});
}
} else if pattern.contains("[") {
let wc = word_count(&pattern, false);
debug!("Has optionals with {wc} words");
let entries = this_track.option.entry(wc).or_insert_with(Vec::new);
entries.push(SortedTriggerEntry {
text: pattern,
pointer: trig,
});
} else {
let wc = word_count(&pattern, false);
debug!("Totally atomic trigger with {wc} words");
let entries = this_track.atomic.entry(wc).or_insert_with(Vec::new);
entries.push(SortedTriggerEntry{
text: pattern,
pointer: trig,
});
}
}
let mut track_sorted: Vec<isize> = Vec::new();
for k in track.keys() {
track_sorted.push(k.clone());
}
track_sorted.sort();
for ip in track_sorted {
let ip_track = track.entry(ip).or_insert_with(SortTracker::new);
sort_by_words(&mut running, &ip_track.atomic);
sort_by_words(&mut running, &ip_track.option);
sort_by_words(&mut running, &ip_track.alpha);
sort_by_words(&mut running, &ip_track.number);
sort_by_words(&mut running, &ip_track.wild);
sort_by_length(&mut running, &ip_track.under);
sort_by_length(&mut running, &ip_track.pound);
sort_by_length(&mut running, &ip_track.star);
}
}
running
}
#[derive(Debug)]
struct SortTracker {
atomic: HashMap<isize, Vec<SortedTriggerEntry>>, option: HashMap<isize, Vec<SortedTriggerEntry>>, alpha: HashMap<isize, Vec<SortedTriggerEntry>>, number: HashMap<isize, Vec<SortedTriggerEntry>>, wild: HashMap<isize, Vec<SortedTriggerEntry>>, pound: Vec<SortedTriggerEntry>, under: Vec<SortedTriggerEntry>, star: Vec<SortedTriggerEntry>, }
impl SortTracker {
pub fn new() -> Self {
Self {
atomic: HashMap::new(),
option: HashMap::new(),
alpha: HashMap::new(),
number: HashMap::new(),
wild: HashMap::new(),
pound: Vec::new(),
under: Vec::new(),
star: Vec::new(),
}
}
}
#[derive(Clone)]
#[derive(Debug)]
struct SortedTriggerEntry {
text: String,
pointer: ast::Trigger,
}
fn word_count(pattern: &str, all: bool) -> isize {
let words: Vec<&str>;
if all {
words = pattern.split(' ').collect();
} else {
words = pattern.split(&[' ', '*', '#', '_', '|']).collect();
}
let mut wc = 0;
for word in words {
if word.len() > 0 {
wc += 1;
}
}
wc
}
fn sort_list(dict: HashMap<String, String>) -> Vec<String> {
let mut track: HashMap<isize, Vec<&String>> = HashMap::new();
for phrase in dict.keys() {
let wc = word_count(phrase, true);
let entries = track.entry(wc).or_insert_with(Vec::new);
entries.push(phrase);
}
let distinct_counts = track.keys().unique().sorted().rev();
let mut sorted_patterns: Vec<String> = Vec::new();
debug!("distinct_counts: {:?}", distinct_counts);
for wc in distinct_counts {
let entries = track.get(wc).unwrap();
for entry in entries {
sorted_patterns.push(entry.to_string());
}
}
sorted_patterns.sort_by(|a, b| b.len().cmp(&a.len()));
sorted_patterns
}
fn sort_by_words(running: &mut Vec<ast::Trigger>, triggers: &HashMap<isize, Vec<SortedTriggerEntry>>) {
let mut sorted_wc: Vec<isize> = Vec::new();
for wc in triggers.keys() {
sorted_wc.push(wc.clone());
}
sorted_wc.sort_by_key(|k| Reverse(*k));
for wc in sorted_wc {
let mut sorted_patterns: Vec<String> = Vec::new();
let mut pattern_map: HashMap<String, Vec<&SortedTriggerEntry>> = HashMap::new();
let entries = triggers.get(&wc).unwrap();
for trig in entries {
sorted_patterns.push(trig.text.clone());
let entries = pattern_map.entry(trig.text.clone()).or_insert_with(Vec::new);
entries.push(trig);
}
sorted_patterns.sort_by(|a, b| b.len().cmp(&a.len()));
let mut distinct_pattern: HashMap<String, bool> = HashMap::new();
for pattern in sorted_patterns {
if distinct_pattern.contains_key(&pattern) {
continue;
}
distinct_pattern.insert(pattern.clone(), true);
let entries = pattern_map.get(&pattern).unwrap();
for entry in entries {
debug!("sort_by_words: wc={wc} pattern={pattern}");
running.push(entry.pointer.clone());
}
}
}
}
fn sort_by_length(running: &mut Vec<ast::Trigger>, triggers: &Vec<SortedTriggerEntry>) {
let mut sorted_patterns: Vec<String> = Vec::new();
let mut pattern_map: HashMap<String, Vec<SortedTriggerEntry>> = HashMap::new();
for trig in triggers {
sorted_patterns.push(trig.text.clone());
let entries = pattern_map.entry(trig.text.clone()).or_insert_with(Vec::new);
entries.push(trig.clone());
}
sorted_patterns.sort_by(|a, b| b.len().cmp(&a.len()));
for pattern in sorted_patterns {
let entries = pattern_map.get(&pattern).unwrap();
for entry in entries {
running.push(entry.pointer.clone());
}
}
}