1use std::{collections::HashMap, fmt::Display};
2
3use action_shortcuts::ActionShortcuts;
4use itertools::Itertools;
5use key_combo::{KeyCombo, KeyModifiers};
6use key_strike::KeyStrike;
7use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers as CKeyMods};
8use serde::{de::Visitor, ser::SerializeMap, Deserialize, Serialize};
9
10pub mod action_shortcuts;
11pub mod key_combo;
12pub mod key_strike;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct KeyBindings {
16 bindings: HashMap<KeyCombo, ActionShortcuts>,
17}
18
19impl Serialize for KeyBindings {
20 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
21 where
22 S: serde::Serializer,
23 {
24 let kb_map = self.to_hashmap();
25 let mut map = serializer.serialize_map(Some(kb_map.len()))?;
26 for (k, v) in kb_map
27 .iter()
28 .sorted_by_key(|(action, _combo)| action.to_owned())
29 {
30 map.serialize_entry(&k, &v)?;
31 }
32 map.end()
33 }
34}
35
36struct DeserializeKeyBindingsVisitor;
37impl<'de> Visitor<'de> for DeserializeKeyBindingsVisitor {
38 type Value = KeyBindings;
39
40 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
41 formatter.write_str("A valid path with `/` separators, no need of starting `/`")
42 }
43 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
44 where
45 A: serde::de::MapAccess<'de>,
46 {
47 let mut bindings: HashMap<ActionShortcuts, Vec<KeyCombo>> =
48 HashMap::with_capacity(map.size_hint().unwrap_or(0));
49 while let Some((key, value)) = map.next_entry()? {
51 bindings.insert(key, value);
52 }
53 Ok(KeyBindings::from_hashmap(bindings))
54 }
55}
56
57impl<'de> Deserialize<'de> for KeyBindings {
58 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
59 where
60 D: serde::Deserializer<'de>,
61 {
62 deserializer.deserialize_map(DeserializeKeyBindingsVisitor)
63 }
64}
65
66impl Display for KeyBindings {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 let mut bindings: Vec<(ActionShortcuts, Vec<KeyCombo>)> = vec![];
69 for (key, value) in &self.bindings {
70 if let Some((_, combos)) = bindings
71 .iter_mut()
72 .find(|(shortcut, _combos)| shortcut.eq(value))
73 {
74 combos.push(key.to_owned());
75 combos.sort();
76 } else {
77 bindings.push((value.to_owned(), vec![key.to_owned()]));
78 }
79 }
80
81 bindings.sort_by_key(|(a, _v)| a.to_owned());
82 for (key, value) in &bindings {
83 writeln!(
84 f,
85 "{}: {}",
86 key,
87 value
88 .iter()
89 .map(|kc| kc.to_string())
90 .collect::<Vec<String>>()
91 .join(", ")
92 )?;
93 }
94
95 Ok(())
96 }
97}
98
99impl KeyBindings {
100 pub fn empty() -> Self {
101 KeyBindings {
102 bindings: HashMap::default(),
103 }
104 }
105
106 pub fn batch_add(&mut self) -> KeyBindBatch<'_> {
107 KeyBindBatch {
108 bindings: self,
109 modifiers: KeyModifiers::default(),
110 }
111 }
112
113 pub fn get_action(&self, combo: &KeyCombo) -> Option<ActionShortcuts> {
114 let bind = self.bindings.get(combo).map(|a| a.to_owned());
115 bind
116 }
117
118 pub fn first_combo_for(&self, action: &ActionShortcuts) -> Option<String> {
120 self.bindings
121 .iter()
122 .find(|(_, a)| *a == action)
123 .map(|(combo, _)| combo.to_string())
124 }
125
126 pub fn to_hashmap(&self) -> HashMap<ActionShortcuts, Vec<KeyCombo>> {
127 let mut bindings: HashMap<ActionShortcuts, Vec<KeyCombo>> = HashMap::new();
128 for (combo, action) in &self.bindings {
129 let entry = bindings.entry(action.to_owned()).or_default();
130 entry.push(combo.to_owned());
131 entry.sort();
132 }
133 bindings
134 }
135
136 pub fn from_hashmap(bindings: HashMap<ActionShortcuts, Vec<KeyCombo>>) -> KeyBindings {
137 let mut kb = KeyBindings::empty();
138 for (action, combos) in &bindings {
139 log::debug!("from_hashmap: action={} combos={:?}", action, combos);
140 }
141 for (action, combos) in bindings {
142 for combo in combos {
143 let valid = combo.is_valid_binding();
144 log::debug!("from_hashmap: combo='{}' key={:?} modifiers={:?} valid={}", combo, combo.key, combo.modifiers, valid);
145 if valid {
146 kb.bindings.insert(combo.to_owned(), action.to_owned());
147 } else {
148 log::warn!(
149 "Skipping invalid key combo '{}' for action '{}': \
150 only ctrl/alt (with optional shift) + a letter (a-z), or bare F1–F12 are supported",
151 combo,
152 action
153 );
154 }
155 }
156 }
157 kb
158 }
159}
160
161pub struct KeyBindBatch<'k> {
162 bindings: &'k mut KeyBindings,
163 modifiers: KeyModifiers,
164}
165
166impl<'k> KeyBindBatch<'k> {
167 pub fn with_shift(mut self) -> Self {
168 self.modifiers.with_shift();
169 self
170 }
171 pub fn with_ctrl(mut self) -> Self {
172 self.modifiers.with_ctrl();
173 self
174 }
175 pub fn with_alt(mut self) -> Self {
176 self.modifiers.with_alt();
177 self
178 }
179 pub fn with_meta(mut self) -> Self {
181 self.modifiers.with_meta_cmd();
182 self
183 }
184 pub fn with_cmd(mut self) -> Self {
185 self.modifiers.with_meta_cmd();
186 self
187 }
188 pub fn add(self, key: KeyStrike, action: ActionShortcuts) -> KeyBindBatch<'k> {
189 self.bindings
190 .bindings
191 .insert(KeyCombo::new(self.modifiers, key), action);
192 self
193 }
194}
195
196pub fn key_event_to_combo(event: &KeyEvent) -> Option<KeyCombo> {
201 let mut implied_ctrl = false;
205 let key = match event.code {
206 KeyCode::Char(c) => {
207 let c = if c as u8 >= 1 && c as u8 <= 26 {
208 implied_ctrl = true;
209 (c as u8 + b'a' - 1) as char
210 } else {
211 c
212 };
213 match c.to_ascii_lowercase() {
214 'a' => KeyStrike::KeyA,
215 'b' => KeyStrike::KeyB,
216 'c' => KeyStrike::KeyC,
217 'd' => KeyStrike::KeyD,
218 'e' => KeyStrike::KeyE,
219 'f' => KeyStrike::KeyF,
220 'g' => KeyStrike::KeyG,
221 'h' => KeyStrike::KeyH,
222 'i' => KeyStrike::KeyI,
223 'j' => KeyStrike::KeyJ,
224 'k' => KeyStrike::KeyK,
225 'l' => KeyStrike::KeyL,
226 'm' => KeyStrike::KeyM,
227 'n' => KeyStrike::KeyN,
228 'o' => KeyStrike::KeyO,
229 'p' => KeyStrike::KeyP,
230 'q' => KeyStrike::KeyQ,
231 'r' => KeyStrike::KeyR,
232 's' => KeyStrike::KeyS,
233 't' => KeyStrike::KeyT,
234 'u' => KeyStrike::KeyU,
235 'v' => KeyStrike::KeyV,
236 'w' => KeyStrike::KeyW,
237 'x' => KeyStrike::KeyX,
238 'y' => KeyStrike::KeyY,
239 'z' => KeyStrike::KeyZ,
240 '0' => KeyStrike::Digit0,
241 '1' => KeyStrike::Digit1,
242 '2' => KeyStrike::Digit2,
243 '3' => KeyStrike::Digit3,
244 '4' => KeyStrike::Digit4,
245 '5' => KeyStrike::Digit5,
246 '6' => KeyStrike::Digit6,
247 '7' => KeyStrike::Digit7,
248 '8' => KeyStrike::Digit8,
249 '9' => KeyStrike::Digit9,
250 ',' => KeyStrike::Comma,
251 '.' => KeyStrike::Period,
252 '/' => KeyStrike::Slash,
253 ';' => KeyStrike::Semicolon,
254 '\'' => KeyStrike::Quote,
255 '[' => KeyStrike::BracketLeft,
256 ']' => KeyStrike::BracketRight,
257 '\\' => KeyStrike::Backslash,
258 '`' => KeyStrike::Backquote,
259 '-' => KeyStrike::Minus,
260 '=' => KeyStrike::Equal,
261 _ => return None,
262 }},
263 KeyCode::Enter => KeyStrike::Enter,
264 KeyCode::Backspace => KeyStrike::Backspace,
265 KeyCode::Tab | KeyCode::BackTab => KeyStrike::Tab,
266 KeyCode::Esc => KeyStrike::Escape,
267 KeyCode::Up => KeyStrike::ArrowUp,
268 KeyCode::Down => KeyStrike::ArrowDown,
269 KeyCode::Left => KeyStrike::ArrowLeft,
270 KeyCode::Right => KeyStrike::ArrowRight,
271 KeyCode::Home => KeyStrike::Home,
272 KeyCode::End => KeyStrike::End,
273 KeyCode::PageUp => KeyStrike::PageUp,
274 KeyCode::PageDown => KeyStrike::PageDown,
275 KeyCode::Delete => KeyStrike::Delete,
276 KeyCode::Insert => KeyStrike::Insert,
277 KeyCode::F(n) => match n {
278 1 => KeyStrike::F1,
279 2 => KeyStrike::F2,
280 3 => KeyStrike::F3,
281 4 => KeyStrike::F4,
282 5 => KeyStrike::F5,
283 6 => KeyStrike::F6,
284 7 => KeyStrike::F7,
285 8 => KeyStrike::F8,
286 9 => KeyStrike::F9,
287 10 => KeyStrike::F10,
288 11 => KeyStrike::F11,
289 12 => KeyStrike::F12,
290 _ => return None,
291 },
292 _ => return None,
293 };
294
295 let mut modifiers = KeyModifiers::default();
296 if implied_ctrl || event.modifiers.contains(CKeyMods::CONTROL) {
297 modifiers.with_ctrl();
298 }
299 if event.modifiers.contains(CKeyMods::SHIFT) || matches!(event.code, KeyCode::BackTab) {
301 modifiers.with_shift();
302 }
303 if event.modifiers.contains(CKeyMods::ALT) {
304 modifiers.with_alt();
305 }
306 if event.modifiers.contains(CKeyMods::SUPER) || event.modifiers.contains(CKeyMods::META) {
307 modifiers.with_meta_cmd();
308 }
309
310 Some(KeyCombo::new(modifiers, key))
311}
312
313#[cfg(test)]
314mod tests {
315 use super::{
316 action_shortcuts::{ActionShortcuts, TextAction},
317 key_strike::KeyStrike,
318 KeyBindings,
319 };
320
321 #[test]
322 fn serialize_key_binding() {
323 let mut km = KeyBindings::empty();
324 km.batch_add()
325 .with_ctrl()
326 .add(KeyStrike::KeyN, ActionShortcuts::TogglePreview)
327 .add(KeyStrike::KeyH, ActionShortcuts::Text(TextAction::Bold))
328 .with_alt()
329 .add(
330 KeyStrike::KeyL,
331 ActionShortcuts::Text(TextAction::Header(2)),
332 );
333 let km_str = toml::to_string(&km).unwrap();
334
335 let expected = r#"TogglePreview = ["ctrl&N"]
336TextEditor-Bold = ["ctrl&H"]
337TextEditor-Header2 = ["ctrl+alt&L"]
338"#
339 .to_string();
340 assert_eq!(expected, km_str);
341 }
342
343 #[test]
344 fn serialize_key_binding_double_assignment() {
345 let mut km = KeyBindings::empty();
346 km.batch_add()
347 .with_ctrl()
348 .add(KeyStrike::KeyN, ActionShortcuts::TogglePreview)
349 .add(KeyStrike::KeyH, ActionShortcuts::Text(TextAction::Bold))
350 .with_alt()
351 .add(KeyStrike::KeyL, ActionShortcuts::Text(TextAction::Bold));
352 let km_str = toml::to_string(&km).unwrap();
353
354 let expected = r#"TogglePreview = ["ctrl&N"]
355TextEditor-Bold = ["ctrl&H", "ctrl+alt&L"]
356"#
357 .to_string();
358 assert_eq!(expected, km_str);
359 }
360
361 #[test]
362 fn deserialize_key_binding_double_assignment() {
363 let mut expected_km = KeyBindings::empty();
364 expected_km
365 .batch_add()
366 .with_ctrl()
367 .add(KeyStrike::KeyN, ActionShortcuts::TogglePreview)
368 .add(KeyStrike::KeyH, ActionShortcuts::Text(TextAction::Bold))
369 .with_alt()
370 .add(KeyStrike::KeyL, ActionShortcuts::Text(TextAction::Bold));
371
372 let km_str = r#"TogglePreview = ["ctrl & N"]
373TextEditor-Bold = ["ctrl & H", "ctrl+alt & L"]
374"#
375 .to_string();
376
377 let km = toml::from_str(&km_str).unwrap();
378
379 assert_eq!(expected_km, km);
380 }
381}