use std::collections::HashMap;
use keymap_parser::node::{CharGroup, Key, Node};
#[derive(Debug)]
struct Trie<T> {
value: Option<T>,
exact: HashMap<Node, Trie<T>>,
groups: Vec<(Node, Trie<T>)>,
}
impl<T> Trie<T> {
fn new() -> Self {
Self {
value: None,
exact: HashMap::new(),
groups: Vec::new(),
}
}
}
#[derive(Debug)]
pub struct Matcher<T> {
root: Trie<T>,
}
impl<T> Default for Matcher<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> FromIterator<(Vec<Node>, T)> for Matcher<T> {
fn from_iter<I: IntoIterator<Item = (Vec<Node>, T)>>(iter: I) -> Self {
let mut matcher = Matcher::new();
for (pattern, value) in iter {
matcher.add(pattern, value);
}
matcher
}
}
impl<T> Matcher<T> {
pub fn new() -> Self {
Self { root: Trie::new() }
}
pub fn add(&mut self, pattern: Vec<Node>, value: T) {
let mut node = &mut self.root;
for input_node in pattern {
node = match input_node.key {
Key::Group(_) => {
if let Some(pos) = node.groups.iter().position(|(n, _)| n == &input_node) {
&mut node.groups[pos].1
} else {
node.groups.push((input_node, Trie::new()));
&mut node.groups.last_mut().unwrap().1
}
}
_ => node.exact.entry(input_node).or_insert_with(Trie::new),
};
}
node.value = Some(value);
}
pub fn get(&self, nodes: &[Node]) -> Option<&T> {
search(&self.root, nodes, 0)
}
}
fn search<'a, T>(node: &'a Trie<T>, nodes: &[Node], pos: usize) -> Option<&'a T> {
if pos == nodes.len() {
return node.value.as_ref();
}
let input_node = &nodes[pos];
if let Some(result) = node
.exact
.get(input_node)
.and_then(|child| search(child, nodes, pos + 1))
{
return Some(result);
}
if let Key::Char(ch) = input_node.key {
if let Some(result) = node.groups.iter().find_map(|(n, child)| match n.key {
Key::Group(group) if n.modifiers == input_node.modifiers && group.matches(ch) => {
search(child, nodes, pos + 1)
}
_ => None,
}) {
return Some(result);
}
}
node.groups.iter().find_map(|(n, child)| {
if matches!(n.key, Key::Group(CharGroup::Any)) {
search(child, nodes, pos + 1)
} else {
None
}
})
}
#[cfg(test)]
mod tests {
use keymap_parser::parse_seq;
use super::*;
fn matches(inputs: &[(&'static str, &'static str, bool)]) {
let items = inputs
.iter()
.enumerate()
.map(|(i, (keys, _, _))| (parse_seq(keys).unwrap(), i))
.collect::<Vec<_>>();
let matcher = Matcher::from_iter(items);
inputs.iter().enumerate().for_each(|(i, (_, v, pass))| {
let key = parse_seq(v).unwrap();
let result = matcher.get(&key);
if *pass {
assert_eq!(result, Some(i).as_ref(), "{key:?}");
} else {
assert_eq!(result, None);
}
});
}
#[test]
fn test_exact_nodes() {
matches(&[
("a", "a", true),
("ctrl-c", "ctrl-c", true),
("f12", "f12", true),
("f10", "f11", false),
("enter", "enter", true),
]);
}
#[test]
fn test_groups() {
matches(&[
("@upper", "A", true),
("@digit", "1", true),
("ctrl-@any", "ctrl-x", true),
("@any", "b", true),
("a", "a", true), ]);
}
#[test]
fn test_sequences() {
matches(&[
("a enter", "a enter", true),
("ctrl-@any shift-@upper", "ctrl-x shift-B", true),
]);
}
}