Skip to main content

basalt_tui/config/
key_binding.rs

1use std::fmt;
2
3use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
4
5use serde::{
6    de::{self, Visitor},
7    Deserialize, Deserializer,
8};
9
10use crate::{command::Command, config::ConfigError};
11
12#[derive(Clone, Debug, PartialEq, Deserialize)]
13pub(crate) struct KeyBinding {
14    pub key: Key,
15    pub command: Command,
16}
17
18impl From<(Key, Command)> for KeyBinding {
19    fn from((key, command): (Key, Command)) -> Self {
20        Self::new(key, command)
21    }
22}
23
24impl KeyBinding {
25    pub const fn new(key: Key, command: Command) -> Self {
26        Self { key, command }
27    }
28}
29
30#[derive(Clone, Debug, Eq, Hash, PartialEq)]
31pub struct Keystroke {
32    pub code: KeyCode,
33    pub modifiers: KeyModifiers,
34}
35
36impl Keystroke {
37    pub const fn new(code: KeyCode, modifiers: KeyModifiers) -> Self {
38        Self { code, modifiers }
39    }
40}
41
42impl From<KeyEvent> for Keystroke {
43    fn from(value: KeyEvent) -> Self {
44        Self::from((value.code, value.modifiers))
45    }
46}
47
48impl From<KeyCode> for Keystroke {
49    fn from(code: KeyCode) -> Self {
50        Keystroke::from((code, KeyModifiers::NONE))
51    }
52}
53
54impl From<(KeyCode, KeyModifiers)> for Keystroke {
55    fn from((code, mut modifiers): (KeyCode, KeyModifiers)) -> Self {
56        let code = match code {
57            KeyCode::Char(ch) if ch.is_uppercase() => {
58                modifiers.insert(KeyModifiers::SHIFT);
59                code
60            }
61            KeyCode::Char(ch)
62                if modifiers.contains(KeyModifiers::SHIFT) && ch.is_ascii_lowercase() =>
63            {
64                // Normalize lowercase+SHIFT to uppercase
65                KeyCode::Char(ch.to_ascii_uppercase())
66            }
67            _ => code,
68        };
69        Self { code, modifiers }
70    }
71}
72
73impl From<(char, KeyModifiers)> for Keystroke {
74    fn from((c, modifiers): (char, KeyModifiers)) -> Self {
75        Keystroke::from((KeyCode::Char(c), modifiers))
76    }
77}
78
79impl From<&KeyEvent> for Keystroke {
80    fn from(event: &KeyEvent) -> Self {
81        Self::from((event.code, event.modifiers))
82    }
83}
84
85impl fmt::Display for Keystroke {
86    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
87        let code = self.code.to_string().replace(" ", "_");
88
89        // Uppercase chars carry SHIFT implicitly — strip it from the display
90        // so the string representation stays canonical (e.g. "G" not "shift-G")
91        let modifiers = match self.code {
92            KeyCode::Char(ch) if ch.is_uppercase() => self.modifiers - KeyModifiers::SHIFT,
93            _ => self.modifiers,
94        };
95
96        if modifiers.is_empty() {
97            write!(f, "{code}")
98        } else {
99            write!(f, "{}-{code}", modifiers.to_string().to_ascii_lowercase())
100        }
101    }
102}
103
104#[derive(Clone, Debug, Eq, Hash, PartialEq)]
105pub enum Key {
106    Single(Keystroke),
107    Chord(Vec<Keystroke>),
108}
109
110impl fmt::Display for Key {
111    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
112        match self {
113            Key::Single(key) => key.fmt(f),
114            Key::Chord(keys) => keys.iter().try_for_each(|key| key.fmt(f)),
115        }
116    }
117}
118
119impl Key {
120    pub const CTRL_C: Key = Key::new(KeyCode::Char('c'), KeyModifiers::CONTROL);
121
122    pub const fn new(code: KeyCode, modifiers: KeyModifiers) -> Self {
123        Key::Single(Keystroke::new(code, modifiers))
124    }
125
126    pub fn chord(iter: impl IntoIterator<Item = Keystroke>) -> Self {
127        Key::Chord(iter.into_iter().collect())
128    }
129}
130
131impl From<KeyEvent> for Key {
132    fn from(value: KeyEvent) -> Self {
133        Self::Single(Keystroke::from(value))
134    }
135}
136
137impl From<KeyCode> for Key {
138    fn from(value: KeyCode) -> Self {
139        Self::Single(Keystroke::from(value))
140    }
141}
142
143impl From<(KeyCode, KeyModifiers)> for Key {
144    fn from(value: (KeyCode, KeyModifiers)) -> Self {
145        Self::Single(Keystroke::from(value))
146    }
147}
148
149impl From<char> for Key {
150    fn from(value: char) -> Self {
151        Self::from(KeyCode::Char(value))
152    }
153}
154
155impl From<(char, KeyModifiers)> for Key {
156    fn from(value: (char, KeyModifiers)) -> Self {
157        Self::Single(Keystroke::from(value))
158    }
159}
160
161impl From<Keystroke> for Key {
162    fn from(value: Keystroke) -> Self {
163        Self::Single(value)
164    }
165}
166
167impl FromIterator<Keystroke> for Key {
168    fn from_iter<T: IntoIterator<Item = Keystroke>>(iter: T) -> Self {
169        Key::chord(iter)
170    }
171}
172
173impl From<Vec<Keystroke>> for Key {
174    fn from(value: Vec<Keystroke>) -> Self {
175        Key::from_iter(value)
176    }
177}
178
179impl<'de> Deserialize<'de> for Key {
180    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
181    where
182        D: Deserializer<'de>,
183    {
184        deserializer.deserialize_str(KeyVisitor)
185    }
186}
187
188struct KeyVisitor;
189
190impl Visitor<'_> for KeyVisitor {
191    type Value = Key;
192
193    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
194        formatter.write_str("a single key (\"a\"), named key (\"esc\"), modified key (\"ctrl+x\"), or key sequence (\"gg\")")
195    }
196
197    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
198    where
199        E: de::Error,
200    {
201        let mut parts = value.split('+');
202        let code = parts
203            .next_back()
204            .ok_or(ConfigError::UnknownKeyCode(value.to_string()))
205            .map_err(de::Error::custom)?;
206
207        let mut modifiers = KeyModifiers::NONE;
208        for part in parts {
209            modifiers |= parse_modifiers(&part.to_lowercase()).map_err(de::Error::custom)?;
210        }
211
212        parse_key(code, modifiers).map_err(de::Error::custom)
213    }
214}
215
216fn parse_key(code: &str, modifiers: KeyModifiers) -> Result<Key, ConfigError> {
217    if code.is_empty() {
218        return Ok(Key::from((KeyCode::Null, modifiers)));
219    }
220
221    let key_code = match code {
222        "esc" => KeyCode::Esc,
223        "space" => KeyCode::Char(' '),
224        "backspace" => KeyCode::Backspace,
225        "backtab" => KeyCode::BackTab,
226        "delete" => KeyCode::Delete,
227        "down" => KeyCode::Down,
228        "end" => KeyCode::End,
229        "enter" => KeyCode::Enter,
230        "home" => KeyCode::Home,
231        "insert" => KeyCode::Insert,
232        "left" => KeyCode::Left,
233        "page_down" => KeyCode::PageDown,
234        "page_up" => KeyCode::PageUp,
235        "right" => KeyCode::Right,
236        "tab" => KeyCode::Tab,
237        "up" => KeyCode::Up,
238        // Single char — uppercase SHIFT is handled by Keystroke::from
239        c if c.chars().count() == 1 => c
240            .chars()
241            .next()
242            .map(KeyCode::Char)
243            .ok_or_else(|| ConfigError::UnknownKeyCode(c.to_string()))?,
244        // F-n keys
245        c if c.starts_with('f') => c[1..]
246            .parse::<u8>()
247            .map(KeyCode::F)
248            .map_err(|_| ConfigError::UnknownKeyCode(c.to_string()))?,
249        // Multi-char sequence like "gG" or "ciw" — uppercase SHIFT via Keystroke::from
250        c => {
251            return Ok(Key::chord(
252                c.chars().map(KeyCode::Char).map(Keystroke::from),
253            ))
254        }
255    };
256
257    Ok(Key::from((key_code, modifiers)))
258}
259
260fn parse_modifiers(modifiers: &str) -> Result<KeyModifiers, ConfigError> {
261    if modifiers.is_empty() {
262        return Ok(KeyModifiers::NONE);
263    }
264
265    match modifiers {
266        "alt" => Ok(KeyModifiers::ALT),
267        "ctrl" | "control" => Ok(KeyModifiers::CONTROL),
268        "hyper" => Ok(KeyModifiers::HYPER),
269        "meta" => Ok(KeyModifiers::META),
270        "shift" => Ok(KeyModifiers::SHIFT),
271        "super" => Ok(KeyModifiers::SUPER),
272        _ => Err(ConfigError::UnknownKeyModifiers(modifiers.to_string())),
273    }
274}
275
276impl de::Error for ConfigError {
277    fn custom<T>(msg: T) -> Self
278    where
279        T: fmt::Display,
280    {
281        ConfigError::InvalidKeybinding(msg.to_string())
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use ratatui::crossterm::event::{KeyCode, KeyModifiers};
288    use serde::de::IntoDeserializer;
289
290    use super::*;
291
292    fn key_from_str(s: &str) -> Result<Key, ConfigError> {
293        Key::deserialize(s.into_deserializer())
294    }
295
296    #[test]
297    fn test_named_keys() {
298        let cases = [
299            ("esc", Key::from(KeyCode::Esc)),
300            ("enter", Key::from(KeyCode::Enter)),
301            ("space", Key::from(KeyCode::Char(' '))),
302            ("backspace", Key::from(KeyCode::Backspace)),
303            ("backtab", Key::from(KeyCode::BackTab)),
304            ("delete", Key::from(KeyCode::Delete)),
305            ("tab", Key::from(KeyCode::Tab)),
306            ("up", Key::from(KeyCode::Up)),
307            ("down", Key::from(KeyCode::Down)),
308            ("left", Key::from(KeyCode::Left)),
309            ("right", Key::from(KeyCode::Right)),
310            ("home", Key::from(KeyCode::Home)),
311            ("end", Key::from(KeyCode::End)),
312            ("page_up", Key::from(KeyCode::PageUp)),
313            ("page_down", Key::from(KeyCode::PageDown)),
314            ("insert", Key::from(KeyCode::Insert)),
315        ];
316
317        cases.into_iter().for_each(|(input, expected)| {
318            assert_eq!(key_from_str(input).unwrap(), expected, "input: {input:?}");
319        });
320    }
321
322    #[test]
323    fn test_single_char_keys() {
324        let cases = [
325            ("a", Key::from('a')),
326            ("z", Key::from('z')),
327            ("A", Key::from('A')),
328            ("0", Key::from('0')),
329            ("?", Key::from('?')),
330            ("/", Key::from('/')),
331            (":", Key::from(':')),
332        ];
333
334        cases.into_iter().for_each(|(input, expected)| {
335            assert_eq!(key_from_str(input).unwrap(), expected, "input: {input:?}");
336        });
337    }
338
339    #[test]
340    fn test_function_keys() {
341        let cases = [
342            ("f1", Key::from(KeyCode::F(1))),
343            ("f5", Key::from(KeyCode::F(5))),
344            ("f12", Key::from(KeyCode::F(12))),
345        ];
346
347        cases.into_iter().for_each(|(input, expected)| {
348            assert_eq!(key_from_str(input).unwrap(), expected, "input: {input:?}");
349        });
350    }
351
352    #[test]
353    fn test_modified_keys() {
354        let cases = [
355            ("ctrl+c", Key::from(('c', KeyModifiers::CONTROL))),
356            ("control+c", Key::from(('c', KeyModifiers::CONTROL))),
357            ("alt+x", Key::from(('x', KeyModifiers::ALT))),
358            ("shift+a", Key::from(('a', KeyModifiers::SHIFT))),
359            (
360                "ctrl+shift+k",
361                Key::from((
362                    KeyCode::Char('k'),
363                    KeyModifiers::CONTROL | KeyModifiers::SHIFT,
364                )),
365            ),
366            (
367                "ctrl+enter",
368                Key::from((KeyCode::Enter, KeyModifiers::CONTROL)),
369            ),
370            ("alt+esc", Key::from((KeyCode::Esc, KeyModifiers::ALT))),
371            ("ctrl+f5", Key::from((KeyCode::F(5), KeyModifiers::CONTROL))),
372        ];
373
374        cases.into_iter().for_each(|(input, expected)| {
375            assert_eq!(key_from_str(input).unwrap(), expected, "input: {input:?}");
376        });
377    }
378
379    #[test]
380    fn test_key_sequences() {
381        let cases: &[(&str, &[Keystroke])] = &[
382            (
383                "gg",
384                &[
385                    Keystroke::from(KeyCode::Char('g')),
386                    Keystroke::from(KeyCode::Char('g')),
387                ],
388            ),
389            (
390                "gG",
391                &[
392                    Keystroke::from(KeyCode::Char('g')),
393                    Keystroke::from(KeyCode::Char('G')),
394                ],
395            ),
396            (
397                "crn",
398                &[
399                    Keystroke::from(KeyCode::Char('c')),
400                    Keystroke::from(KeyCode::Char('r')),
401                    Keystroke::from(KeyCode::Char('n')),
402                ],
403            ),
404        ];
405
406        cases.iter().for_each(|(input, expected_keys)| {
407            let key = key_from_str(input).unwrap();
408            match key {
409                Key::Chord(keys) => assert_eq!(keys, *expected_keys, "input: {input:?}"),
410                Key::Single(_) => panic!("Expected sequence for {input:?}, got plain key"),
411            }
412        });
413    }
414
415    #[test]
416    fn test_invalid_keys() {
417        let cases = ["unknown_modifier+c", "badmod+x", "f999"];
418
419        cases.into_iter().for_each(|input| {
420            assert!(key_from_str(input).is_err(), "Expected error for {input:?}");
421        });
422    }
423
424    #[test]
425    fn test_keystroke_display() {
426        let cases = [
427            (Keystroke::new(KeyCode::Char('a'), KeyModifiers::NONE), "a"),
428            (
429                Keystroke::new(KeyCode::Char('c'), KeyModifiers::CONTROL),
430                "control-c",
431            ),
432            // Uppercase char: SHIFT is implicit, not shown in display
433            (Keystroke::from(KeyCode::Char('G')), "G"),
434            // Uppercase char with additional modifier
435            (
436                Keystroke::from((KeyCode::Char('G'), KeyModifiers::CONTROL)),
437                "control-G",
438            ),
439        ];
440
441        cases.into_iter().for_each(|(key, expected)| {
442            assert_eq!(key.to_string(), expected, "key: {key:?}");
443        });
444    }
445
446    #[test]
447    fn test_key_sequence_display() {
448        let keys = [
449            Keystroke::from(KeyCode::Char('g')),
450            Keystroke::from(KeyCode::Char('G')),
451        ];
452
453        assert_eq!(Key::chord(keys).to_string(), "gG");
454    }
455
456    #[test]
457    fn test_uppercase_implies_shift() {
458        // Parsing "G" should give the same result as "shift+g" would — SHIFT in modifiers
459        let upper = key_from_str("G").unwrap();
460        assert_eq!(
461            upper,
462            Key::Single(Keystroke::new(KeyCode::Char('G'), KeyModifiers::SHIFT))
463        );
464
465        // Sequence "gG" — second key carries SHIFT
466        let seq = key_from_str("gG").unwrap();
467        assert_eq!(
468            seq,
469            Key::chord([
470                Keystroke::new(KeyCode::Char('g'), KeyModifiers::NONE),
471                Keystroke::new(KeyCode::Char('G'), KeyModifiers::SHIFT),
472            ])
473        );
474    }
475}