1use std::{collections::HashMap, fmt::Display};
2
3use action_shortcuts::ActionShortcuts;
4use itertools::Itertools;
5use key_combo::{KeyCombo, KeyModifiers};
6use key_strike::KeyStrike;
7use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers as CKeyMods};
8use serde::{Deserialize, Serialize, de::Visitor, ser::SerializeMap};
9
10pub mod action_shortcuts;
11pub mod key_combo;
12pub mod key_strike;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub struct KeyBindings {
16 bindings: HashMap<KeyCombo, ActionShortcuts>,
17}
18
19impl Serialize for KeyBindings {
20 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
21 where
22 S: serde::Serializer,
23 {
24 let kb_map = self.to_hashmap();
25 let mut map = serializer.serialize_map(Some(kb_map.len()))?;
26 for (k, v) in kb_map
27 .iter()
28 .sorted_by_key(|(action, _combo)| action.to_owned())
29 {
30 map.serialize_entry(&k, &v)?;
31 }
32 map.end()
33 }
34}
35
36struct DeserializeKeyBindingsVisitor;
37impl<'de> Visitor<'de> for DeserializeKeyBindingsVisitor {
38 type Value = KeyBindings;
39
40 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
41 formatter.write_str("A valid path with `/` separators, no need of starting `/`")
42 }
43 fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
44 where
45 A: serde::de::MapAccess<'de>,
46 {
47 let mut bindings: HashMap<ActionShortcuts, Vec<KeyCombo>> =
48 HashMap::with_capacity(map.size_hint().unwrap_or(0));
49 while let Some((key, value)) = map.next_entry()? {
51 bindings.insert(key, value);
52 }
53 Ok(KeyBindings::from_hashmap(bindings))
54 }
55}
56
57impl<'de> Deserialize<'de> for KeyBindings {
58 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
59 where
60 D: serde::Deserializer<'de>,
61 {
62 deserializer.deserialize_map(DeserializeKeyBindingsVisitor)
63 }
64}
65
66impl Display for KeyBindings {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 let mut bindings: Vec<(ActionShortcuts, Vec<KeyCombo>)> = vec![];
69 for (key, value) in &self.bindings {
70 if let Some((_, combos)) = bindings
71 .iter_mut()
72 .find(|(shortcut, _combos)| shortcut.eq(value))
73 {
74 combos.push(key.to_owned());
75 combos.sort();
76 } else {
77 bindings.push((value.to_owned(), vec![key.to_owned()]));
78 }
79 }
80
81 bindings.sort_by_key(|(a, _v)| a.to_owned());
82 for (key, value) in &bindings {
83 writeln!(
84 f,
85 "{}: {}",
86 key,
87 value
88 .iter()
89 .map(|kc| kc.to_string())
90 .collect::<Vec<String>>()
91 .join(", ")
92 )?;
93 }
94
95 Ok(())
96 }
97}
98
99impl KeyBindings {
100 pub fn empty() -> Self {
101 KeyBindings {
102 bindings: HashMap::default(),
103 }
104 }
105
106 pub fn batch_add(&mut self) -> KeyBindBatch<'_> {
107 KeyBindBatch {
108 bindings: self,
109 modifiers: KeyModifiers::default(),
110 }
111 }
112
113 pub fn get_action(&self, combo: &KeyCombo) -> Option<ActionShortcuts> {
114 self.bindings.get(combo).map(|a| a.to_owned())
115 }
116
117 pub fn first_combo_for(&self, action: &ActionShortcuts) -> Option<String> {
119 self.bindings
120 .iter()
121 .find(|(_, a)| *a == action)
122 .map(|(combo, _)| combo.to_string())
123 }
124
125 pub fn to_hashmap(&self) -> HashMap<ActionShortcuts, Vec<KeyCombo>> {
126 let mut bindings: HashMap<ActionShortcuts, Vec<KeyCombo>> = HashMap::new();
127 for (combo, action) in &self.bindings {
128 let entry = bindings.entry(action.to_owned()).or_default();
129 entry.push(combo.to_owned());
130 entry.sort();
131 }
132 bindings
133 }
134
135 pub fn from_hashmap(bindings: HashMap<ActionShortcuts, Vec<KeyCombo>>) -> KeyBindings {
136 let mut kb = KeyBindings::empty();
137 for (action, combos) in &bindings {
138 tracing::debug!("from_hashmap: action={} combos={:?}", action, combos);
139 }
140 for (action, combos) in bindings {
141 for combo in combos {
142 let valid = combo.is_valid_binding();
143 tracing::debug!(
144 "from_hashmap: combo='{}' key={:?} modifiers={:?} valid={}",
145 combo,
146 combo.key,
147 combo.modifiers,
148 valid
149 );
150 if valid {
151 kb.bindings.insert(combo.to_owned(), action.to_owned());
152 } else {
153 tracing::warn!(
154 "Skipping invalid key combo '{}' for action '{}': \
155 only ctrl/alt (with optional shift) + a letter (a-z), or bare F1–F12 are supported",
156 combo,
157 action
158 );
159 }
160 }
161 }
162 kb
163 }
164}
165
166pub struct KeyBindBatch<'k> {
167 bindings: &'k mut KeyBindings,
168 modifiers: KeyModifiers,
169}
170
171impl<'k> KeyBindBatch<'k> {
172 pub fn with_shift(mut self) -> Self {
173 self.modifiers.with_shift();
174 self
175 }
176 pub fn with_ctrl(mut self) -> Self {
177 self.modifiers.with_ctrl();
178 self
179 }
180 pub fn with_alt(mut self) -> Self {
181 self.modifiers.with_alt();
182 self
183 }
184 pub fn with_meta(mut self) -> Self {
186 self.modifiers.with_meta_cmd();
187 self
188 }
189 pub fn with_cmd(mut self) -> Self {
190 self.modifiers.with_meta_cmd();
191 self
192 }
193 pub fn add(self, key: KeyStrike, action: ActionShortcuts) -> KeyBindBatch<'k> {
194 self.bindings
195 .bindings
196 .insert(KeyCombo::new(self.modifiers, key), action);
197 self
198 }
199}
200
201pub fn key_event_to_combo(event: &KeyEvent) -> Option<KeyCombo> {
206 let mut implied_ctrl = false;
210 let key = match event.code {
211 KeyCode::Char(c) => {
212 let c = if c as u8 >= 1 && c as u8 <= 26 {
213 implied_ctrl = true;
214 (c as u8 + b'a' - 1) as char
215 } else {
216 c
217 };
218 match c.to_ascii_lowercase() {
219 'a' => KeyStrike::KeyA,
220 'b' => KeyStrike::KeyB,
221 'c' => KeyStrike::KeyC,
222 'd' => KeyStrike::KeyD,
223 'e' => KeyStrike::KeyE,
224 'f' => KeyStrike::KeyF,
225 'g' => KeyStrike::KeyG,
226 'h' => KeyStrike::KeyH,
227 'i' => KeyStrike::KeyI,
228 'j' => KeyStrike::KeyJ,
229 'k' => KeyStrike::KeyK,
230 'l' => KeyStrike::KeyL,
231 'm' => KeyStrike::KeyM,
232 'n' => KeyStrike::KeyN,
233 'o' => KeyStrike::KeyO,
234 'p' => KeyStrike::KeyP,
235 'q' => KeyStrike::KeyQ,
236 'r' => KeyStrike::KeyR,
237 's' => KeyStrike::KeyS,
238 't' => KeyStrike::KeyT,
239 'u' => KeyStrike::KeyU,
240 'v' => KeyStrike::KeyV,
241 'w' => KeyStrike::KeyW,
242 'x' => KeyStrike::KeyX,
243 'y' => KeyStrike::KeyY,
244 'z' => KeyStrike::KeyZ,
245 '0' => KeyStrike::Digit0,
246 '1' => KeyStrike::Digit1,
247 '2' => KeyStrike::Digit2,
248 '3' => KeyStrike::Digit3,
249 '4' => KeyStrike::Digit4,
250 '5' => KeyStrike::Digit5,
251 '6' => KeyStrike::Digit6,
252 '7' => KeyStrike::Digit7,
253 '8' => KeyStrike::Digit8,
254 '9' => KeyStrike::Digit9,
255 ',' => KeyStrike::Comma,
256 '.' => KeyStrike::Period,
257 '/' => KeyStrike::Slash,
258 ';' => KeyStrike::Semicolon,
259 '\'' => KeyStrike::Quote,
260 '[' => KeyStrike::BracketLeft,
261 ']' => KeyStrike::BracketRight,
262 '\\' => KeyStrike::Backslash,
263 '`' => KeyStrike::Backquote,
264 '-' => KeyStrike::Minus,
265 '=' => KeyStrike::Equal,
266 _ => return None,
267 }
268 }
269 KeyCode::Enter => KeyStrike::Enter,
270 KeyCode::Backspace => KeyStrike::Backspace,
271 KeyCode::Tab | KeyCode::BackTab => KeyStrike::Tab,
272 KeyCode::Esc => KeyStrike::Escape,
273 KeyCode::Up => KeyStrike::ArrowUp,
274 KeyCode::Down => KeyStrike::ArrowDown,
275 KeyCode::Left => KeyStrike::ArrowLeft,
276 KeyCode::Right => KeyStrike::ArrowRight,
277 KeyCode::Home => KeyStrike::Home,
278 KeyCode::End => KeyStrike::End,
279 KeyCode::PageUp => KeyStrike::PageUp,
280 KeyCode::PageDown => KeyStrike::PageDown,
281 KeyCode::Delete => KeyStrike::Delete,
282 KeyCode::Insert => KeyStrike::Insert,
283 KeyCode::F(n) => match n {
284 1 => KeyStrike::F1,
285 2 => KeyStrike::F2,
286 3 => KeyStrike::F3,
287 4 => KeyStrike::F4,
288 5 => KeyStrike::F5,
289 6 => KeyStrike::F6,
290 7 => KeyStrike::F7,
291 8 => KeyStrike::F8,
292 9 => KeyStrike::F9,
293 10 => KeyStrike::F10,
294 11 => KeyStrike::F11,
295 12 => KeyStrike::F12,
296 _ => return None,
297 },
298 _ => return None,
299 };
300
301 let mut modifiers = KeyModifiers::default();
302 if implied_ctrl || event.modifiers.contains(CKeyMods::CONTROL) {
303 modifiers.with_ctrl();
304 }
305 if event.modifiers.contains(CKeyMods::SHIFT) || matches!(event.code, KeyCode::BackTab) {
307 modifiers.with_shift();
308 }
309 if event.modifiers.contains(CKeyMods::ALT) {
310 modifiers.with_alt();
311 }
312 if event.modifiers.contains(CKeyMods::SUPER) || event.modifiers.contains(CKeyMods::META) {
313 modifiers.with_meta_cmd();
314 }
315
316 Some(KeyCombo::new(modifiers, key))
317}
318
319#[cfg(test)]
320mod tests {
321 use super::{
322 KeyBindings,
323 action_shortcuts::{ActionShortcuts, TextAction},
324 key_strike::KeyStrike,
325 };
326
327 #[test]
328 fn serialize_key_binding() {
329 let mut km = KeyBindings::empty();
330 km.batch_add()
331 .with_ctrl()
332 .add(KeyStrike::KeyN, ActionShortcuts::TogglePreview)
333 .add(KeyStrike::KeyH, ActionShortcuts::Text(TextAction::Bold))
334 .with_alt()
335 .add(
336 KeyStrike::KeyL,
337 ActionShortcuts::Text(TextAction::Header(2)),
338 );
339 let km_str = toml::to_string(&km).unwrap();
340
341 let expected = r#"TogglePreview = ["ctrl&N"]
342TextEditor-Bold = ["ctrl&H"]
343TextEditor-Header2 = ["ctrl+alt&L"]
344"#
345 .to_string();
346 assert_eq!(expected, km_str);
347 }
348
349 #[test]
350 fn serialize_key_binding_double_assignment() {
351 let mut km = KeyBindings::empty();
352 km.batch_add()
353 .with_ctrl()
354 .add(KeyStrike::KeyN, ActionShortcuts::TogglePreview)
355 .add(KeyStrike::KeyH, ActionShortcuts::Text(TextAction::Bold))
356 .with_alt()
357 .add(KeyStrike::KeyL, ActionShortcuts::Text(TextAction::Bold));
358 let km_str = toml::to_string(&km).unwrap();
359
360 let expected = r#"TogglePreview = ["ctrl&N"]
361TextEditor-Bold = ["ctrl&H", "ctrl+alt&L"]
362"#
363 .to_string();
364 assert_eq!(expected, km_str);
365 }
366
367 #[test]
368 fn deserialize_key_binding_double_assignment() {
369 let mut expected_km = KeyBindings::empty();
370 expected_km
371 .batch_add()
372 .with_ctrl()
373 .add(KeyStrike::KeyN, ActionShortcuts::TogglePreview)
374 .add(KeyStrike::KeyH, ActionShortcuts::Text(TextAction::Bold))
375 .with_alt()
376 .add(KeyStrike::KeyL, ActionShortcuts::Text(TextAction::Bold));
377
378 let km_str = r#"TogglePreview = ["ctrl & N"]
379TextEditor-Bold = ["ctrl & H", "ctrl+alt & L"]
380"#
381 .to_string();
382
383 let km = toml::from_str(&km_str).unwrap();
384
385 assert_eq!(expected_km, km);
386 }
387}