zi/component/
bindings.rs

1use smallvec::{smallvec, SmallVec};
2use std::{
3    any::{Any, TypeId},
4    borrow::Cow,
5    collections::hash_map::HashMap,
6    fmt,
7    marker::PhantomData,
8};
9
10use super::{Component, DynamicMessage};
11use crate::terminal::Key;
12
13#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
14pub struct CommandId(usize);
15
16#[derive(Clone, Debug, PartialEq)]
17pub enum NamedBindingQuery {
18    Match(Cow<'static, str>),
19    PrefixOf(SmallVec<[Cow<'static, str>; 4]>),
20}
21
22impl NamedBindingQuery {
23    pub fn new(keymap: &Keymap, query: &BindingQuery) -> Self {
24        match query {
25            BindingQuery::Match(command_id) => Self::Match(keymap.names[command_id.0].clone()),
26            BindingQuery::PrefixOf(commands) => Self::PrefixOf(
27                commands
28                    .iter()
29                    .map(|command_id| keymap.names[command_id.0].clone())
30                    .collect(),
31            ),
32        }
33    }
34}
35
36#[derive(Clone, Debug, PartialEq)]
37pub enum BindingQuery {
38    Match(CommandId),
39    PrefixOf(SmallVec<[CommandId; 4]>),
40}
41
42impl BindingQuery {
43    pub fn matches(&self) -> Option<CommandId> {
44        match self {
45            Self::Match(command_id) => Some(*command_id),
46            _ => None,
47        }
48    }
49
50    pub fn prefix_of(&self) -> Option<&[CommandId]> {
51        match self {
52            Self::PrefixOf(commands) => Some(commands),
53            _ => None,
54        }
55    }
56}
57
58#[derive(Debug, Default)]
59pub struct Keymap {
60    names: Vec<Cow<'static, str>>,
61    keymap: HashMap<KeyPattern, BindingQuery>,
62}
63
64impl Keymap {
65    pub fn new() -> Self {
66        Self::default()
67    }
68
69    pub fn name(&self, command_id: &CommandId) -> &str {
70        &self.names[command_id.0]
71    }
72
73    pub fn is_empty(&self) -> bool {
74        self.keymap.is_empty()
75    }
76
77    pub fn add(
78        &mut self,
79        name: impl Into<Cow<'static, str>>,
80        pattern: impl Into<KeyPattern>,
81    ) -> CommandId {
82        let command_id = self.add_command(name).0;
83        self.bind_command(command_id, pattern);
84        command_id
85    }
86
87    pub fn add_command(&mut self, name: impl Into<Cow<'static, str>>) -> (CommandId, bool) {
88        let name = name.into();
89        let (command_id, is_new_command) = self
90            .names
91            .iter()
92            .enumerate()
93            .find(|(_index, existing)| **existing == name)
94            .map(|(index, _)| (CommandId(index), false))
95            .unwrap_or_else(|| (CommandId(self.names.len()), true));
96        if is_new_command {
97            self.names.push(name);
98        }
99        (command_id, is_new_command)
100    }
101
102    pub fn bind_command(&mut self, command_id: CommandId, pattern: impl Into<KeyPattern>) {
103        let name = &self.names[command_id.0];
104        let pattern = pattern.into();
105
106        // Add `BindingQuery::PrefixOf` entries for all prefixes of the key sequence
107        if let Some(keys) = pattern.keys() {
108            for prefix_len in 0..keys.len() {
109                let prefix = KeyPattern::Keys(keys.iter().copied().take(prefix_len).collect());
110                self.keymap
111                    .entry(prefix.clone())
112                    .and_modify(|entry| match entry {
113                        BindingQuery::Match(other_command_id) => panic_on_overlapping_key_bindings(
114                            &pattern,
115                            name,
116                            &prefix,
117                            &self.names[other_command_id.0],
118                        ),
119                        BindingQuery::PrefixOf(prefix_of) => {
120                            prefix_of.push(command_id);
121                        }
122                    })
123                    .or_insert_with(|| BindingQuery::PrefixOf(smallvec![command_id]));
124            }
125        }
126
127        // Add a `BindingQuery::Match` for the full key sequence
128        self.keymap
129            .entry(pattern.clone())
130            .and_modify(|entry| match entry {
131                BindingQuery::Match(other_command_id) => panic_on_overlapping_key_bindings(
132                    &pattern,
133                    name,
134                    &pattern,
135                    &self.names[other_command_id.0],
136                ),
137                BindingQuery::PrefixOf(prefix_of) => panic_on_overlapping_key_bindings(
138                    &pattern,
139                    name,
140                    &pattern,
141                    &self.names[prefix_of[0].0],
142                ),
143            })
144            .or_insert_with(|| BindingQuery::Match(command_id));
145    }
146
147    pub fn check_sequence(&self, keys: &[Key]) -> Option<&BindingQuery> {
148        let pattern: KeyPattern = keys.iter().copied().into();
149        self.keymap
150            .get(&pattern)
151            .or_else(|| match keys {
152                &[Key::Char(_)] => self.keymap.get(&KeyPattern::AnyCharacter),
153                _ => None,
154            })
155            .or_else(|| match keys {
156                &[_, key] | &[key] => self.keymap.get(&KeyPattern::EndsWith([key])),
157                _ => None,
158            })
159    }
160}
161
162#[allow(clippy::type_complexity)]
163struct DynamicCommandFn(Box<dyn Fn(&dyn Any, &[Key]) -> Option<DynamicMessage>>);
164
165impl fmt::Debug for DynamicCommandFn {
166    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
167        write!(formatter, "CommandFn@{:?})", &self.0 as *const _)
168    }
169}
170
171#[derive(Debug)]
172pub(crate) struct DynamicBindings {
173    keymap: Keymap,
174    commands: Vec<DynamicCommandFn>,
175    focused: bool,
176    notify: bool,
177    type_id: TypeId,
178}
179
180impl DynamicBindings {
181    pub fn new<ComponentT: Component>() -> Self {
182        Self {
183            keymap: Keymap::new(),
184            commands: Vec::new(),
185            focused: false,
186            notify: false,
187            type_id: TypeId::of::<ComponentT>(),
188        }
189    }
190
191    #[inline]
192    pub fn keymap(&self) -> &Keymap {
193        &self.keymap
194    }
195
196    #[inline]
197    pub fn set_focus(&mut self, focused: bool) {
198        self.focused = focused;
199    }
200
201    #[inline]
202    pub fn focused(&self) -> bool {
203        self.focused
204    }
205
206    #[inline]
207    pub fn set_notify(&mut self, notify: bool) {
208        self.notify = notify;
209    }
210
211    #[inline]
212    pub fn notify(&self) -> bool {
213        self.notify
214    }
215
216    pub fn add<ComponentT: Component, const VARIANT: usize>(
217        &mut self,
218        name: impl Into<Cow<'static, str>>,
219        keys: impl Into<KeyPattern>,
220        command_fn: impl CommandFn<ComponentT, VARIANT> + 'static,
221    ) -> CommandId {
222        let command_id = self.add_command(name, command_fn);
223        self.bind_command(command_id, keys);
224        command_id
225    }
226
227    pub fn add_command<ComponentT: Component, const VARIANT: usize>(
228        &mut self,
229        name: impl Into<Cow<'static, str>>,
230        command_fn: impl CommandFn<ComponentT, VARIANT> + 'static,
231    ) -> CommandId {
232        assert_eq!(self.type_id, TypeId::of::<ComponentT>());
233
234        let (command_id, is_new_command) = self.keymap.add_command(name);
235        let dyn_command_fn = DynamicCommandFn(Box::new(move |erased: &dyn Any, keys: &[Key]| {
236            let component = erased
237                .downcast_ref()
238                .expect("Incorrect `Component` type when downcasting");
239            command_fn
240                .call(component, keys)
241                .map(|message| DynamicMessage(Box::new(message)))
242        }));
243        if is_new_command {
244            self.commands.push(dyn_command_fn);
245        } else {
246            self.commands[command_id.0] = dyn_command_fn;
247        }
248
249        command_id
250    }
251
252    pub fn bind_command(&mut self, command_id: CommandId, keys: impl Into<KeyPattern>) {
253        self.keymap.bind_command(command_id, keys);
254    }
255
256    pub fn execute_command<ComponentT: Component>(
257        &self,
258        component: &ComponentT,
259        id: CommandId,
260        keys: &[Key],
261    ) -> Option<DynamicMessage> {
262        assert_eq!(self.type_id, TypeId::of::<ComponentT>());
263
264        (self.commands[id.0].0)(component, keys)
265    }
266
267    pub fn typed<ComponentT: Component>(
268        &mut self,
269        callback: impl FnOnce(&mut Bindings<ComponentT>),
270    ) {
271        assert_eq!(self.type_id, TypeId::of::<ComponentT>());
272
273        let mut bindings = Self::new::<ComponentT>();
274        std::mem::swap(self, &mut bindings);
275        let mut typed = Bindings::<ComponentT>::new(bindings);
276        callback(&mut typed);
277        std::mem::swap(self, &mut typed.bindings);
278    }
279}
280
281#[derive(Debug)]
282pub struct Bindings<ComponentT> {
283    bindings: DynamicBindings,
284    _component: PhantomData<fn() -> ComponentT>,
285}
286
287impl<ComponentT: Component> Bindings<ComponentT> {
288    fn new(bindings: DynamicBindings) -> Self {
289        Self {
290            bindings,
291            _component: PhantomData,
292        }
293    }
294
295    #[inline]
296    pub fn is_empty(&self) -> bool {
297        self.bindings.keymap.is_empty()
298    }
299
300    #[inline]
301    pub fn set_focus(&mut self, focused: bool) {
302        self.bindings.set_focus(focused)
303    }
304
305    #[inline]
306    pub fn focused(&self) -> bool {
307        self.bindings.focused()
308    }
309
310    #[inline]
311    pub fn set_notify(&mut self, notify: bool) {
312        self.bindings.set_notify(notify)
313    }
314
315    #[inline]
316    pub fn notify(&self) -> bool {
317        self.bindings.notify()
318    }
319
320    #[inline]
321    pub fn add<const VARIANT: usize>(
322        &mut self,
323        name: impl Into<Cow<'static, str>>,
324        keys: impl Into<KeyPattern>,
325        command_fn: impl CommandFn<ComponentT, VARIANT> + 'static,
326    ) {
327        self.bindings.add(name, keys, command_fn);
328    }
329
330    #[inline]
331    pub fn command<const VARIANT: usize>(
332        &mut self,
333        name: impl Into<Cow<'static, str>>,
334        command_fn: impl CommandFn<ComponentT, VARIANT> + 'static,
335    ) -> BindingBuilder<ComponentT> {
336        let command_id = self.bindings.add_command(name, command_fn);
337        BindingBuilder {
338            wrapped: self,
339            command_id,
340        }
341    }
342}
343
344#[derive(Debug)]
345pub struct BindingBuilder<'a, ComponentT> {
346    wrapped: &'a mut Bindings<ComponentT>,
347    command_id: CommandId,
348}
349
350impl<ComponentT: Component> BindingBuilder<'_, ComponentT> {
351    pub fn with(self, keys: impl Into<KeyPattern>) -> Self {
352        self.wrapped.bindings.bind_command(self.command_id, keys);
353        self
354    }
355}
356
357#[derive(Debug, Clone, PartialEq, Eq, Hash)]
358pub enum KeyPattern {
359    AnyCharacter,
360    EndsWith([Key; 1]),
361    Keys(SmallVec<[Key; 8]>),
362}
363
364impl KeyPattern {
365    fn keys(&self) -> Option<&[Key]> {
366        match self {
367            Self::AnyCharacter => None,
368            Self::EndsWith(key) => Some(key.as_slice()),
369            Self::Keys(keys) => Some(keys.as_slice()),
370        }
371    }
372}
373
374impl<IterT: IntoIterator<Item = Key>> From<IterT> for KeyPattern {
375    fn from(keys: IterT) -> Self {
376        Self::Keys(keys.into_iter().collect())
377    }
378}
379
380impl std::fmt::Display for KeyPattern {
381    fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::result::Result<(), std::fmt::Error> {
382        match self {
383            Self::AnyCharacter => {
384                write!(formatter, "Char(*)")
385            }
386            Self::Keys(keys) => KeySequenceSlice(keys.as_slice()).fmt(formatter),
387            Self::EndsWith(keys) => KeySequenceSlice(keys.as_slice()).fmt(formatter),
388        }
389    }
390}
391
392#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
393pub struct AnyCharacter;
394
395impl From<AnyCharacter> for KeyPattern {
396    fn from(_: AnyCharacter) -> Self {
397        Self::AnyCharacter
398    }
399}
400
401#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
402pub struct EndsWith(pub Key);
403
404impl From<EndsWith> for KeyPattern {
405    fn from(ends_with: EndsWith) -> Self {
406        Self::EndsWith([ends_with.0])
407    }
408}
409
410pub trait CommandFn<ComponentT: Component, const VARIANT: usize> {
411    fn call(&self, component: &ComponentT, keys: &[Key]) -> Option<ComponentT::Message>;
412}
413
414// Specializations for callbacks that take either a component or slice with keys
415// and return an option
416impl<ComponentT, FnT> CommandFn<ComponentT, 0> for FnT
417where
418    ComponentT: Component,
419    FnT: Fn(&ComponentT, &[Key]) -> Option<ComponentT::Message> + 'static,
420{
421    fn call(&self, component: &ComponentT, keys: &[Key]) -> Option<ComponentT::Message> {
422        (self)(component, keys)
423    }
424}
425
426impl<ComponentT, FnT> CommandFn<ComponentT, 1> for FnT
427where
428    ComponentT: Component,
429    FnT: Fn(&ComponentT) -> Option<ComponentT::Message> + 'static,
430{
431    #[inline]
432    fn call(&self, component: &ComponentT, _keys: &[Key]) -> Option<ComponentT::Message> {
433        (self)(component)
434    }
435}
436
437impl<ComponentT, FnT> CommandFn<ComponentT, 2> for FnT
438where
439    ComponentT: Component,
440    FnT: Fn(&[Key]) -> Option<ComponentT::Message> + 'static,
441{
442    #[inline]
443    fn call(&self, _component: &ComponentT, keys: &[Key]) -> Option<ComponentT::Message> {
444        (self)(keys)
445    }
446}
447
448// Specializations for callbacks that take a component and optionally a slice with keys
449impl<ComponentT, FnT> CommandFn<ComponentT, 3> for FnT
450where
451    ComponentT: Component,
452    FnT: Fn(&ComponentT, &[Key]) + 'static,
453{
454    #[inline]
455    fn call(&self, component: &ComponentT, keys: &[Key]) -> Option<ComponentT::Message> {
456        (self)(component, keys);
457        None
458    }
459}
460
461impl<ComponentT, FnT> CommandFn<ComponentT, 4> for FnT
462where
463    ComponentT: Component,
464    FnT: Fn(&ComponentT) + 'static,
465{
466    #[inline]
467    fn call(&self, component: &ComponentT, _keys: &[Key]) -> Option<ComponentT::Message> {
468        (self)(component);
469        None
470    }
471}
472
473// Specialization for callbacks that take no parameters and return a message
474impl<ComponentT, FnT> CommandFn<ComponentT, 5> for FnT
475where
476    ComponentT: Component,
477    FnT: Fn() -> ComponentT::Message + 'static,
478{
479    #[inline]
480    fn call(&self, _component: &ComponentT, _keys: &[Key]) -> Option<ComponentT::Message> {
481        Some((self)())
482    }
483}
484
485#[derive(Debug, Clone, PartialEq, Eq)]
486pub struct KeySequenceSlice<'a>(&'a [Key]);
487
488impl<'a> From<&'a [Key]> for KeySequenceSlice<'a> {
489    fn from(keys: &'a [Key]) -> Self {
490        Self(keys)
491    }
492}
493
494impl<'a> std::fmt::Display for KeySequenceSlice<'a> {
495    fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::result::Result<(), std::fmt::Error> {
496        for (index, key) in self.0.iter().enumerate() {
497            match key {
498                Key::Char(' ') => write!(formatter, "SPC")?,
499                Key::Char('\n') => write!(formatter, "RET")?,
500                Key::Char('\t') => write!(formatter, "TAB")?,
501                Key::Char(char) => write!(formatter, "{}", char)?,
502                Key::Ctrl(char) => write!(formatter, "C-{}", char)?,
503                Key::Alt(char) => write!(formatter, "A-{}", char)?,
504                Key::F(number) => write!(formatter, "F{}", number)?,
505                Key::Esc => write!(formatter, "ESC")?,
506                key => write!(formatter, "{:?}", key)?,
507            }
508            if index < self.0.len().saturating_sub(1) {
509                write!(formatter, " ")?;
510            }
511        }
512        Ok(())
513    }
514}
515
516fn panic_on_overlapping_key_bindings(
517    new_pattern: &KeyPattern,
518    new_name: &str,
519    existing_pattern: &KeyPattern,
520    existing_name: &str,
521) -> ! {
522    panic!(
523        "Binding `{}` for `{}` is ambiguous as it overlaps with binding `{}` for command `{}`",
524        new_pattern, new_name, existing_pattern, existing_name,
525    );
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531    use crate::prelude::*;
532    use smallvec::smallvec;
533    use std::{cell::RefCell, rc::Rc};
534
535    struct Empty;
536
537    impl Component for Empty {
538        type Message = ();
539        type Properties = ();
540
541        fn create(_: Self::Properties, _: Rect, _: ComponentLink<Self>) -> Self {
542            Self
543        }
544
545        fn view(&self) -> Layout {
546            Canvas::new(Size::new(10, 10)).into()
547        }
548    }
549
550    #[test]
551    fn keymap_alternative_binding_for_same_command() {
552        let mut keymap = Keymap::new();
553        let right_id = keymap.add("right", [Key::Right]);
554        let left_id = keymap.add("left", [Key::Left]);
555        assert_ne!(left_id, right_id);
556        let alternate_left_id = keymap.add("left", [Key::Ctrl('b')]);
557        assert_eq!(left_id, alternate_left_id);
558    }
559
560    #[test]
561    fn controller_one_command_end_to_end() {
562        let called = Rc::new(RefCell::new(false));
563
564        // Create a controller with one registered command
565        let mut controller = DynamicBindings::new::<Empty>();
566        let test_command_id = controller.add("test-command", [Key::Ctrl('x'), Key::Ctrl('f')], {
567            let called = Rc::clone(&called);
568            move |_: &Empty| {
569                *called.borrow_mut() = true;
570                None
571            }
572        });
573
574        // Check no key sequence is a prefix of test-command
575        assert_eq!(
576            controller.keymap().check_sequence(&[]),
577            Some(&BindingQuery::PrefixOf(smallvec![test_command_id]))
578        );
579        // Check C-x is a prefix of test-command
580        assert_eq!(
581            controller.keymap().check_sequence(&[Key::Ctrl('x')]),
582            Some(&BindingQuery::PrefixOf(smallvec![test_command_id]))
583        );
584        // Check C-x C-f is a match for test-command
585        assert_eq!(
586            controller
587                .keymap()
588                .check_sequence(&[Key::Ctrl('x'), Key::Ctrl('f')]),
589            Some(&BindingQuery::Match(test_command_id))
590        );
591
592        // Check C-f doesn't match any command
593        assert_eq!(controller.keymap().check_sequence(&[Key::Ctrl('f')]), None);
594        // Check C-x C-x doesn't match any command
595        assert_eq!(
596            controller
597                .keymap()
598                .check_sequence(&[Key::Ctrl('x'), Key::Ctrl('x')]),
599            None
600        );
601
602        controller.execute_command(&Empty, test_command_id, &[]);
603        assert!(*called.borrow(), "set-controller wasn't called");
604    }
605}