1use std::fmt;
2
3use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
4
5use serde::{
6 de::{self, Visitor},
7 Deserialize, Deserializer,
8};
9
10use crate::{command::Command, config::ConfigError};
11
12#[derive(Clone, Debug, PartialEq, Deserialize)]
13pub(crate) struct KeyBinding {
14 pub key: Key,
15 pub command: Command,
16}
17
18impl From<(Key, Command)> for KeyBinding {
19 fn from((key, command): (Key, Command)) -> Self {
20 Self::new(key, command)
21 }
22}
23
24impl KeyBinding {
25 pub const fn new(key: Key, command: Command) -> Self {
26 Self { key, command }
27 }
28}
29
30#[derive(Clone, Debug, Eq, Hash, PartialEq)]
31pub struct Keystroke {
32 pub code: KeyCode,
33 pub modifiers: KeyModifiers,
34}
35
36impl Keystroke {
37 pub const fn new(code: KeyCode, modifiers: KeyModifiers) -> Self {
38 Self { code, modifiers }
39 }
40}
41
42impl From<KeyEvent> for Keystroke {
43 fn from(value: KeyEvent) -> Self {
44 Self::from((value.code, value.modifiers))
45 }
46}
47
48impl From<KeyCode> for Keystroke {
49 fn from(code: KeyCode) -> Self {
50 Keystroke::from((code, KeyModifiers::NONE))
51 }
52}
53
54impl From<(KeyCode, KeyModifiers)> for Keystroke {
55 fn from((code, mut modifiers): (KeyCode, KeyModifiers)) -> Self {
56 let code = match code {
57 KeyCode::Char(ch) if ch.is_uppercase() => {
58 modifiers.insert(KeyModifiers::SHIFT);
59 code
60 }
61 KeyCode::Char(ch)
62 if modifiers.contains(KeyModifiers::SHIFT) && ch.is_ascii_lowercase() =>
63 {
64 KeyCode::Char(ch.to_ascii_uppercase())
66 }
67 _ => code,
68 };
69 Self { code, modifiers }
70 }
71}
72
73impl From<(char, KeyModifiers)> for Keystroke {
74 fn from((c, modifiers): (char, KeyModifiers)) -> Self {
75 Keystroke::from((KeyCode::Char(c), modifiers))
76 }
77}
78
79impl From<&KeyEvent> for Keystroke {
80 fn from(event: &KeyEvent) -> Self {
81 Self::from((event.code, event.modifiers))
82 }
83}
84
85impl fmt::Display for Keystroke {
86 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
87 let code = self.code.to_string().replace(" ", "_");
88
89 let modifiers = match self.code {
92 KeyCode::Char(ch) if ch.is_uppercase() => self.modifiers - KeyModifiers::SHIFT,
93 _ => self.modifiers,
94 };
95
96 if modifiers.is_empty() {
97 write!(f, "{code}")
98 } else {
99 write!(f, "{}-{code}", modifiers.to_string().to_ascii_lowercase())
100 }
101 }
102}
103
104#[derive(Clone, Debug, Eq, Hash, PartialEq)]
105pub enum Key {
106 Single(Keystroke),
107 Chord(Vec<Keystroke>),
108}
109
110impl fmt::Display for Key {
111 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
112 match self {
113 Key::Single(key) => key.fmt(f),
114 Key::Chord(keys) => keys.iter().try_for_each(|key| key.fmt(f)),
115 }
116 }
117}
118
119impl Key {
120 pub const CTRL_C: Key = Key::new(KeyCode::Char('c'), KeyModifiers::CONTROL);
121
122 pub const fn new(code: KeyCode, modifiers: KeyModifiers) -> Self {
123 Key::Single(Keystroke::new(code, modifiers))
124 }
125
126 pub fn chord(iter: impl IntoIterator<Item = Keystroke>) -> Self {
127 Key::Chord(iter.into_iter().collect())
128 }
129}
130
131impl From<KeyEvent> for Key {
132 fn from(value: KeyEvent) -> Self {
133 Self::Single(Keystroke::from(value))
134 }
135}
136
137impl From<KeyCode> for Key {
138 fn from(value: KeyCode) -> Self {
139 Self::Single(Keystroke::from(value))
140 }
141}
142
143impl From<(KeyCode, KeyModifiers)> for Key {
144 fn from(value: (KeyCode, KeyModifiers)) -> Self {
145 Self::Single(Keystroke::from(value))
146 }
147}
148
149impl From<char> for Key {
150 fn from(value: char) -> Self {
151 Self::from(KeyCode::Char(value))
152 }
153}
154
155impl From<(char, KeyModifiers)> for Key {
156 fn from(value: (char, KeyModifiers)) -> Self {
157 Self::Single(Keystroke::from(value))
158 }
159}
160
161impl From<Keystroke> for Key {
162 fn from(value: Keystroke) -> Self {
163 Self::Single(value)
164 }
165}
166
167impl FromIterator<Keystroke> for Key {
168 fn from_iter<T: IntoIterator<Item = Keystroke>>(iter: T) -> Self {
169 Key::chord(iter)
170 }
171}
172
173impl From<Vec<Keystroke>> for Key {
174 fn from(value: Vec<Keystroke>) -> Self {
175 Key::from_iter(value)
176 }
177}
178
179impl<'de> Deserialize<'de> for Key {
180 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
181 where
182 D: Deserializer<'de>,
183 {
184 deserializer.deserialize_str(KeyVisitor)
185 }
186}
187
188struct KeyVisitor;
189
190impl Visitor<'_> for KeyVisitor {
191 type Value = Key;
192
193 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
194 formatter.write_str("a single key (\"a\"), named key (\"esc\"), modified key (\"ctrl+x\"), or key sequence (\"gg\")")
195 }
196
197 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
198 where
199 E: de::Error,
200 {
201 let mut parts = value.split('+');
202 let code = parts
203 .next_back()
204 .ok_or(ConfigError::UnknownKeyCode(value.to_string()))
205 .map_err(de::Error::custom)?;
206
207 let mut modifiers = KeyModifiers::NONE;
208 for part in parts {
209 modifiers |= parse_modifiers(&part.to_lowercase()).map_err(de::Error::custom)?;
210 }
211
212 parse_key(code, modifiers).map_err(de::Error::custom)
213 }
214}
215
216fn parse_key(code: &str, modifiers: KeyModifiers) -> Result<Key, ConfigError> {
217 if code.is_empty() {
218 return Ok(Key::from((KeyCode::Null, modifiers)));
219 }
220
221 let key_code = match code {
222 "esc" => KeyCode::Esc,
223 "space" => KeyCode::Char(' '),
224 "backspace" => KeyCode::Backspace,
225 "backtab" => KeyCode::BackTab,
226 "delete" => KeyCode::Delete,
227 "down" => KeyCode::Down,
228 "end" => KeyCode::End,
229 "enter" => KeyCode::Enter,
230 "home" => KeyCode::Home,
231 "insert" => KeyCode::Insert,
232 "left" => KeyCode::Left,
233 "page_down" => KeyCode::PageDown,
234 "page_up" => KeyCode::PageUp,
235 "right" => KeyCode::Right,
236 "tab" => KeyCode::Tab,
237 "up" => KeyCode::Up,
238 c if c.chars().count() == 1 => c
240 .chars()
241 .next()
242 .map(KeyCode::Char)
243 .ok_or_else(|| ConfigError::UnknownKeyCode(c.to_string()))?,
244 c if c.starts_with('f') => c[1..]
246 .parse::<u8>()
247 .map(KeyCode::F)
248 .map_err(|_| ConfigError::UnknownKeyCode(c.to_string()))?,
249 c => {
251 return Ok(Key::chord(
252 c.chars().map(KeyCode::Char).map(Keystroke::from),
253 ))
254 }
255 };
256
257 Ok(Key::from((key_code, modifiers)))
258}
259
260fn parse_modifiers(modifiers: &str) -> Result<KeyModifiers, ConfigError> {
261 if modifiers.is_empty() {
262 return Ok(KeyModifiers::NONE);
263 }
264
265 match modifiers {
266 "alt" => Ok(KeyModifiers::ALT),
267 "ctrl" | "control" => Ok(KeyModifiers::CONTROL),
268 "hyper" => Ok(KeyModifiers::HYPER),
269 "meta" => Ok(KeyModifiers::META),
270 "shift" => Ok(KeyModifiers::SHIFT),
271 "super" => Ok(KeyModifiers::SUPER),
272 _ => Err(ConfigError::UnknownKeyModifiers(modifiers.to_string())),
273 }
274}
275
276impl de::Error for ConfigError {
277 fn custom<T>(msg: T) -> Self
278 where
279 T: fmt::Display,
280 {
281 ConfigError::InvalidKeybinding(msg.to_string())
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use ratatui::crossterm::event::{KeyCode, KeyModifiers};
288 use serde::de::IntoDeserializer;
289
290 use super::*;
291
292 fn key_from_str(s: &str) -> Result<Key, ConfigError> {
293 Key::deserialize(s.into_deserializer())
294 }
295
296 #[test]
297 fn test_named_keys() {
298 let cases = [
299 ("esc", Key::from(KeyCode::Esc)),
300 ("enter", Key::from(KeyCode::Enter)),
301 ("space", Key::from(KeyCode::Char(' '))),
302 ("backspace", Key::from(KeyCode::Backspace)),
303 ("backtab", Key::from(KeyCode::BackTab)),
304 ("delete", Key::from(KeyCode::Delete)),
305 ("tab", Key::from(KeyCode::Tab)),
306 ("up", Key::from(KeyCode::Up)),
307 ("down", Key::from(KeyCode::Down)),
308 ("left", Key::from(KeyCode::Left)),
309 ("right", Key::from(KeyCode::Right)),
310 ("home", Key::from(KeyCode::Home)),
311 ("end", Key::from(KeyCode::End)),
312 ("page_up", Key::from(KeyCode::PageUp)),
313 ("page_down", Key::from(KeyCode::PageDown)),
314 ("insert", Key::from(KeyCode::Insert)),
315 ];
316
317 cases.into_iter().for_each(|(input, expected)| {
318 assert_eq!(key_from_str(input).unwrap(), expected, "input: {input:?}");
319 });
320 }
321
322 #[test]
323 fn test_single_char_keys() {
324 let cases = [
325 ("a", Key::from('a')),
326 ("z", Key::from('z')),
327 ("A", Key::from('A')),
328 ("0", Key::from('0')),
329 ("?", Key::from('?')),
330 ("/", Key::from('/')),
331 (":", Key::from(':')),
332 ];
333
334 cases.into_iter().for_each(|(input, expected)| {
335 assert_eq!(key_from_str(input).unwrap(), expected, "input: {input:?}");
336 });
337 }
338
339 #[test]
340 fn test_function_keys() {
341 let cases = [
342 ("f1", Key::from(KeyCode::F(1))),
343 ("f5", Key::from(KeyCode::F(5))),
344 ("f12", Key::from(KeyCode::F(12))),
345 ];
346
347 cases.into_iter().for_each(|(input, expected)| {
348 assert_eq!(key_from_str(input).unwrap(), expected, "input: {input:?}");
349 });
350 }
351
352 #[test]
353 fn test_modified_keys() {
354 let cases = [
355 ("ctrl+c", Key::from(('c', KeyModifiers::CONTROL))),
356 ("control+c", Key::from(('c', KeyModifiers::CONTROL))),
357 ("alt+x", Key::from(('x', KeyModifiers::ALT))),
358 ("shift+a", Key::from(('a', KeyModifiers::SHIFT))),
359 (
360 "ctrl+shift+k",
361 Key::from((
362 KeyCode::Char('k'),
363 KeyModifiers::CONTROL | KeyModifiers::SHIFT,
364 )),
365 ),
366 (
367 "ctrl+enter",
368 Key::from((KeyCode::Enter, KeyModifiers::CONTROL)),
369 ),
370 ("alt+esc", Key::from((KeyCode::Esc, KeyModifiers::ALT))),
371 ("ctrl+f5", Key::from((KeyCode::F(5), KeyModifiers::CONTROL))),
372 ];
373
374 cases.into_iter().for_each(|(input, expected)| {
375 assert_eq!(key_from_str(input).unwrap(), expected, "input: {input:?}");
376 });
377 }
378
379 #[test]
380 fn test_key_sequences() {
381 let cases: &[(&str, &[Keystroke])] = &[
382 (
383 "gg",
384 &[
385 Keystroke::from(KeyCode::Char('g')),
386 Keystroke::from(KeyCode::Char('g')),
387 ],
388 ),
389 (
390 "gG",
391 &[
392 Keystroke::from(KeyCode::Char('g')),
393 Keystroke::from(KeyCode::Char('G')),
394 ],
395 ),
396 (
397 "crn",
398 &[
399 Keystroke::from(KeyCode::Char('c')),
400 Keystroke::from(KeyCode::Char('r')),
401 Keystroke::from(KeyCode::Char('n')),
402 ],
403 ),
404 ];
405
406 cases.iter().for_each(|(input, expected_keys)| {
407 let key = key_from_str(input).unwrap();
408 match key {
409 Key::Chord(keys) => assert_eq!(keys, *expected_keys, "input: {input:?}"),
410 Key::Single(_) => panic!("Expected sequence for {input:?}, got plain key"),
411 }
412 });
413 }
414
415 #[test]
416 fn test_invalid_keys() {
417 let cases = ["unknown_modifier+c", "badmod+x", "f999"];
418
419 cases.into_iter().for_each(|input| {
420 assert!(key_from_str(input).is_err(), "Expected error for {input:?}");
421 });
422 }
423
424 #[test]
425 fn test_keystroke_display() {
426 let cases = [
427 (Keystroke::new(KeyCode::Char('a'), KeyModifiers::NONE), "a"),
428 (
429 Keystroke::new(KeyCode::Char('c'), KeyModifiers::CONTROL),
430 "control-c",
431 ),
432 (Keystroke::from(KeyCode::Char('G')), "G"),
434 (
436 Keystroke::from((KeyCode::Char('G'), KeyModifiers::CONTROL)),
437 "control-G",
438 ),
439 ];
440
441 cases.into_iter().for_each(|(key, expected)| {
442 assert_eq!(key.to_string(), expected, "key: {key:?}");
443 });
444 }
445
446 #[test]
447 fn test_key_sequence_display() {
448 let keys = [
449 Keystroke::from(KeyCode::Char('g')),
450 Keystroke::from(KeyCode::Char('G')),
451 ];
452
453 assert_eq!(Key::chord(keys).to_string(), "gG");
454 }
455
456 #[test]
457 fn test_uppercase_implies_shift() {
458 let upper = key_from_str("G").unwrap();
460 assert_eq!(
461 upper,
462 Key::Single(Keystroke::new(KeyCode::Char('G'), KeyModifiers::SHIFT))
463 );
464
465 let seq = key_from_str("gG").unwrap();
467 assert_eq!(
468 seq,
469 Key::chord([
470 Keystroke::new(KeyCode::Char('g'), KeyModifiers::NONE),
471 Keystroke::new(KeyCode::Char('G'), KeyModifiers::SHIFT),
472 ])
473 );
474 }
475}