kanata_parser/cfg/
key_override.rs

1//! Contains code to handle global override keys.
2
3use anyhow::{Result, anyhow, bail};
4use rustc_hash::FxHashMap as HashMap;
5
6use crate::keys::*;
7
8use kanata_keyberon::key_code::KeyCode;
9use kanata_keyberon::layout::NORMAL_KEY_FLAG_CLEAR_ON_NEXT_ACTION;
10use kanata_keyberon::layout::NORMAL_KEY_FLAG_CLEAR_ON_NEXT_RELEASE;
11use kanata_keyberon::layout::State;
12
13/// Scratch space containing allocations used to process override information. Exists as an
14/// optimization to reuse allocations between iterations.
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub struct OverrideStates {
17    mods_pressed: u8,
18    oscs_to_remove: Vec<OsCode>,
19    oscs_to_add: Vec<OsCode>,
20}
21
22impl Default for OverrideStates {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl OverrideStates {
29    pub fn new() -> Self {
30        Self {
31            mods_pressed: 0,
32            oscs_to_add: Vec::new(),
33            oscs_to_remove: Vec::new(),
34        }
35    }
36
37    fn cleanup(&mut self) {
38        self.oscs_to_add.clear();
39        self.oscs_to_remove.clear();
40        self.mods_pressed = 0;
41    }
42
43    fn update(&mut self, osc: OsCode, overrides: &Overrides) {
44        if let Some(mod_mask) = mask_for_key(osc) {
45            self.mods_pressed |= mod_mask;
46        } else {
47            overrides.update_keys(
48                osc,
49                self.mods_pressed,
50                &mut self.oscs_to_add,
51                &mut self.oscs_to_remove,
52            );
53        }
54    }
55
56    fn is_key_overridden(&self, osc: OsCode) -> bool {
57        self.oscs_to_remove.contains(&osc)
58    }
59
60    fn add_overrides(&self, oscs: &mut Vec<KeyCode>) {
61        oscs.extend(self.oscs_to_add.iter().copied().map(KeyCode::from));
62    }
63
64    pub fn removed_oscs(&self) -> impl Iterator<Item = OsCode> + '_ {
65        self.oscs_to_remove.iter().copied()
66    }
67}
68
69/// A collection of global key overrides.
70#[derive(Debug, Clone, PartialEq, Eq)]
71pub struct Overrides {
72    overrides_by_osc: HashMap<OsCode, Vec<Override>>,
73}
74
75impl Overrides {
76    pub fn new(overrides: &[Override]) -> Self {
77        let mut overrides_by_osc: HashMap<OsCode, Vec<Override>> = HashMap::default();
78        for o in overrides.iter() {
79            overrides_by_osc
80                .entry(o.in_non_mod_osc)
81                .and_modify(|ovd| ovd.push(o.clone()))
82                .or_insert_with(|| vec![o.clone()]);
83        }
84        for ovds in overrides_by_osc.values_mut() {
85            ovds.shrink_to_fit();
86        }
87        overrides_by_osc.shrink_to_fit();
88        Self { overrides_by_osc }
89    }
90
91    pub fn override_keys(&self, kcs: &mut Vec<KeyCode>, states: &mut OverrideStates) {
92        if self.is_empty() {
93            return;
94        }
95        states.cleanup();
96        for kc in kcs.iter().copied() {
97            states.update(kc.into(), self);
98        }
99        kcs.retain(|kc| !states.is_key_overridden((*kc).into()));
100        states.add_overrides(kcs);
101    }
102
103    pub fn output_non_mods_for_input_non_mod(&self, in_osc: OsCode) -> Vec<OsCode> {
104        let mut ret = Vec::new();
105        if let Some(ovds) = self.overrides_by_osc.get(&in_osc) {
106            for out_osc in ovds.iter().map(|ovd| ovd.out_non_mod_osc) {
107                ret.push(out_osc);
108            }
109        }
110        ret
111    }
112
113    fn is_empty(&self) -> bool {
114        self.overrides_by_osc.is_empty()
115    }
116
117    fn update_keys(
118        &self,
119        active_osc: OsCode,
120        active_mod_mask: u8,
121        oscs_to_add: &mut Vec<OsCode>,
122        oscs_to_remove: &mut Vec<OsCode>,
123    ) {
124        let Some(ovds) = self.overrides_by_osc.get(&active_osc) else {
125            return;
126        };
127        let mut cur_chord_size = 0;
128        if let Some(ovd) = ovds
129            .iter()
130            .filter(|ovd| {
131                let mask = ovd.get_mod_mask();
132                if mask & active_mod_mask == mask {
133                    // keep only the longest matching prefix.
134                    let chord_size = ovd.in_mod_oscs.len() + 1;
135                    if chord_size <= cur_chord_size {
136                        false
137                    } else {
138                        cur_chord_size = chord_size;
139                        true
140                    }
141                } else {
142                    false
143                }
144            })
145            .next_back()
146        {
147            log::debug!("using override {ovd:?}");
148            ovd.add_override_keys(oscs_to_add);
149            ovd.add_removed_keys(oscs_to_remove);
150        }
151    }
152}
153
154/// A global key override.
155#[derive(Debug, Clone, PartialEq, Eq, Hash)]
156pub struct Override {
157    in_non_mod_osc: OsCode,
158    out_non_mod_osc: OsCode,
159    in_mod_oscs: Vec<OsCode>,
160    out_mod_oscs: Vec<OsCode>,
161}
162
163impl Override {
164    pub fn try_new(in_oscs: &[OsCode], out_oscs: &[OsCode]) -> Result<Self> {
165        let mut in_nmoscs = in_oscs
166            .iter()
167            .copied()
168            .filter(|osc| mask_for_key(*osc).is_none());
169        let in_non_mod_osc = in_nmoscs.next().ok_or_else(|| {
170            anyhow!("override must contain exactly one input non-modifier key; found none")
171        })?;
172        if in_nmoscs.next().is_some() {
173            bail!("override must contain exactly one input non-modifier key; found multiple");
174        }
175        let mut out_nmoscs = out_oscs
176            .iter()
177            .copied()
178            .filter(|osc| mask_for_key(*osc).is_none());
179        let out_non_mod_osc = out_nmoscs.next().ok_or_else(|| {
180            anyhow!("override must contain exactly one output non-modifier key; found none")
181        })?;
182        if out_nmoscs.next().is_some() {
183            bail!("override must contain exactly one output non-modifier key; found multiple");
184        }
185        let mut in_mod_oscs = in_oscs
186            .iter()
187            .copied()
188            .filter(|osc| mask_for_key(*osc).is_some())
189            .collect::<Vec<_>>();
190        let mut out_mod_oscs = out_oscs
191            .iter()
192            .copied()
193            .filter(|osc| mask_for_key(*osc).is_some())
194            .collect::<Vec<_>>();
195        in_mod_oscs.shrink_to_fit();
196        out_mod_oscs.shrink_to_fit();
197        Ok(Self {
198            in_non_mod_osc,
199            out_non_mod_osc,
200            in_mod_oscs,
201            out_mod_oscs,
202        })
203    }
204
205    fn get_mod_mask(&self) -> u8 {
206        let mut mask = 0;
207        for osc in self.in_mod_oscs.iter().copied() {
208            mask |= mask_for_key(osc).expect("mod only");
209        }
210        mask
211    }
212
213    fn add_override_keys(&self, oscs_to_add: &mut Vec<OsCode>) {
214        for osc in self.out_mod_oscs.iter().copied() {
215            if !oscs_to_add.contains(&osc) {
216                oscs_to_add.push(osc);
217            }
218        }
219        if !oscs_to_add.contains(&self.out_non_mod_osc) {
220            oscs_to_add.push(self.out_non_mod_osc);
221        }
222    }
223
224    fn add_removed_keys(&self, oscs_to_remove: &mut Vec<OsCode>) {
225        for osc in self.in_mod_oscs.iter().copied() {
226            if !oscs_to_remove.contains(&osc) {
227                oscs_to_remove.push(osc);
228            }
229        }
230        if !oscs_to_remove.contains(&self.in_non_mod_osc) {
231            oscs_to_remove.push(self.in_non_mod_osc);
232        }
233    }
234}
235
236fn mask_for_key(osc: OsCode) -> Option<u8> {
237    match osc {
238        OsCode::KEY_LEFTCTRL => Some(1 << 0),
239        OsCode::KEY_LEFTSHIFT => Some(1 << 1),
240        OsCode::KEY_LEFTALT => Some(1 << 2),
241        OsCode::KEY_LEFTMETA => Some(1 << 3),
242        OsCode::KEY_RIGHTCTRL => Some(1 << 4),
243        OsCode::KEY_RIGHTSHIFT => Some(1 << 5),
244        OsCode::KEY_RIGHTALT => Some(1 << 6),
245        OsCode::KEY_RIGHTMETA => Some(1 << 7),
246        _ => None,
247    }
248}
249
250/// For every `OsCode` marked for removal by overrides that is not a modifier,
251/// mark its state in the keyberon layout
252/// with `NORMAL_KEY_FLAG_CLEAR_ON_NEXT_ACTION` and `NORMAL_KEY_FLAG_CLEAR_ON_NEXT_RELEASE`
253/// so that it gets eagerly cleared, avoiding weird character outputs.
254pub fn mark_overridden_nonmodkeys_for_eager_erasure<T>(
255    override_states: &OverrideStates,
256    kb_states: &mut [State<T>],
257) {
258    for osc_to_mark in override_states
259        .removed_oscs()
260        .filter(|osc| !osc.is_modifier())
261    {
262        let kc: KeyCode = osc_to_mark.into();
263        for kbstate in kb_states.iter_mut() {
264            if let State::NormalKey {
265                mut flags,
266                keycode,
267                coord,
268            } = kbstate
269            {
270                if kc == *keycode {
271                    flags.0 |= NORMAL_KEY_FLAG_CLEAR_ON_NEXT_ACTION
272                        | NORMAL_KEY_FLAG_CLEAR_ON_NEXT_RELEASE;
273                    *kbstate = State::NormalKey {
274                        flags,
275                        keycode: *keycode,
276                        coord: *coord,
277                    };
278                }
279            }
280        }
281    }
282}