use std::fmt::{Display, Formatter};
use std::str;
use std::{fmt::Debug, str::FromStr};
use pom::Error;
use pom::parser::{end, is_a, one_of, sym, Parser};
use serde::{de, Deserialize, Deserializer};
use strum_macros::{AsRefStr, Display, EnumString};
const KEY_SEP: char = '-';
#[derive(Debug, Eq, Hash, PartialEq)]
pub(crate) struct Node {
pub modifiers: Modifiers,
pub key: Key,
}
impl Node {
fn new(modifiers: Modifiers, key: Key) -> Self {
Self { modifiers, key }
}
}
#[derive(Copy, Clone, Debug, Display, Eq, Hash, PartialEq, EnumString, AsRefStr)]
#[strum(serialize_all = "lowercase")]
pub(crate) enum Modifier {
None = 0b0000,
Alt = 0b0001,
Cmd = 0b0010,
Ctrl = 0b0100,
Shift = 0b1000,
}
pub(crate) type Modifiers = u8;
const MODIFIERS: [Modifier; 4] = [
Modifier::Alt,
Modifier::Cmd,
Modifier::Ctrl,
Modifier::Shift,
];
#[derive(Debug, Display, Eq, Hash, PartialEq, EnumString, AsRefStr)]
#[strum(serialize_all = "lowercase")]
pub(crate) enum Key {
BackTab,
Backspace,
Char(char),
#[strum(serialize = "del", serialize = "delete")]
Delete,
Down,
End,
Enter,
Esc,
Home,
F(u8),
Insert,
Left,
PageDown,
PageUp,
Right,
Space,
Tab,
Up,
}
pub(crate) fn parse(s: &str) -> Result<Node, Error> {
let input = s.chars().collect::<Vec<char>>();
let result = node().parse(&input);
result
}
fn node<'a>() -> Parser<'a, char, Node> {
combination() - end()
}
fn key<'a>() -> Parser<'a, char, Key> {
fn_key() | parse_enum::<Key>() | char()
}
fn char<'a>() -> Parser<'a, char, Key> {
is_a(ascii).map(Key::Char)
}
fn fn_key<'a>() -> Parser<'a, char, Key> {
sym('f')
* ((sym('1') * one_of("012")).map(|n| 10 + n as u8) | is_a(digit).map(|n| n as u8))
.map(|n| Key::F(n - 48))
}
fn modifier<'a>() -> Parser<'a, char, Modifier> {
parse_enum::<Modifier>() - sym(KEY_SEP).opt()
}
fn combination<'a>() -> Parser<'a, char, Node> {
(modifier().repeat(..4) + key()).map(|(m, key)| {
let mods = m.into_iter().map(|v| v as u8).sum();
Node::new(mods, key)
})
}
fn parse_enum<'a, T>() -> Parser<'a, char, T>
where
T: FromStr + 'static,
<T as FromStr>::Err: Debug,
{
let p = is_a(alpha).repeat(2..);
p.convert(|s| s.iter().collect::<String>().parse::<T>())
}
#[inline]
fn alpha(term: char) -> bool {
term.is_ascii_alphabetic()
}
#[inline]
fn ascii(term: char) -> bool {
term.is_ascii()
}
#[inline]
fn digit(term: char) -> bool {
term.is_ascii_digit()
}
impl<'s> Deserialize<'s> for Node {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'s>,
{
let s = String::deserialize(deserializer)?;
parse(&s).map_err(de::Error::custom)
}
}
impl Display for Node {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
MODIFIERS.iter().for_each(|m| {
if self.modifiers & *m as u8 != 0 {
write!(f, "{m}{}", KEY_SEP).unwrap();
}
});
match self.key {
Key::Char(char) => write!(f, "{char}"),
Key::F(n) => write!(f, "{}{n}", self.key),
_ => write!(f, "{}", self.key),
}
}
}
#[cfg(test)]
mod tests {
use pom::{parser::end, Error};
use serde::Deserialize;
use crate::parser::{parse_enum, Key, Modifier, Node};
use super::{fn_key, parse};
#[test]
fn test_parse() {
let err = |e| Err::<Node, Error>(e);
[
("alt-f", Ok(Node::new(Modifier::Alt as u8, Key::Char('f')))),
("space", Ok(Node::new(0, Key::Space))),
(
"delta",
err(Error::Mismatch {
message: "expect end of input, found: e".into(),
position: 1,
}),
),
(
"shift-a",
Ok(Node::new(Modifier::Shift as u8, Key::Char('a'))),
),
(
"shift-a-delete",
err(Error::Mismatch {
message: "expect end of input, found: -".into(),
position: 7,
}),
),
(
"al",
err(Error::Mismatch {
message: "expect end of input, found: l".into(),
position: 1,
}),
),
]
.map(|(input, result)| {
let output = parse(input);
assert_eq!(result, output);
});
}
#[test]
fn test_parse_fn_key() {
(0..=12).for_each(|n| {
let input = format!("f{n}").chars().collect::<Vec<char>>();
let result = (fn_key() - end()).parse(&input);
assert_eq!(Key::F(n), result.unwrap());
});
[13, 15].map(|n| {
let input: Vec<char> = format!("f{n}").chars().collect();
let result = (fn_key() - end()).parse(&input);
assert!(result.is_err());
});
}
#[test]
fn test_parse_enum() {
[("up", Key::Up), ("esc", Key::Esc), ("del", Key::Delete)].map(|(s, key)| {
let input: Vec<char> = s.chars().collect();
let r = parse_enum::<Key>().parse(&input);
assert_eq!(r.unwrap(), key);
});
}
#[test]
fn test_format() {
[
(Node::new(0, Key::F(3)), "f3"),
(Node::new(0, Key::Delete), "delete"),
(Node::new(0, Key::Space), "space"),
(Node::new(0, Key::Char('g')), "g"),
(Node::new(0, Key::Char('#')), "#"),
(Node::new(Modifier::Alt as u8, Key::Char('f')), "alt-f"),
(
Node::new(Modifier::Shift as u8 | Modifier::Cmd as u8, Key::Char('f')),
"cmd-shift-f",
),
]
.map(|(node, expected)| {
let result = format!("{}", node);
assert_eq!(expected, result);
});
}
#[test]
fn test_deserialize() {
use std::collections::HashMap;
#[derive(Deserialize, Debug)]
struct Test {
keys: HashMap<Node, String>,
}
let result: Test = toml::from_str(
r#"
[keys]
alt-d = "a"
cmd-shift-del = "b"
shift-cmd-del = "b" # this is the same as previous one
delete = "d"
"#,
)
.unwrap();
[
Node::new(Modifier::Alt as u8, Key::Char('d')),
Node::new(Modifier::Cmd as u8 | Modifier::Shift as u8, Key::Delete),
Node::new(0, Key::Delete),
]
.map(|n| {
let (key, _) = result.keys.get_key_value(&n).unwrap();
assert_eq!(key, &n);
});
}
}