kanata_parser/cfg/
key_override.rs1use anyhow::{Result, anyhow, bail};
4use rustc_hash::FxHashMap as HashMap;
5
6use crate::keys::*;
7
8use kanata_keyberon::key_code::KeyCode;
9use kanata_keyberon::layout::NORMAL_KEY_FLAG_CLEAR_ON_NEXT_ACTION;
10use kanata_keyberon::layout::NORMAL_KEY_FLAG_CLEAR_ON_NEXT_RELEASE;
11use kanata_keyberon::layout::State;
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub struct OverrideStates {
17 mods_pressed: u8,
18 oscs_to_remove: Vec<OsCode>,
19 oscs_to_add: Vec<OsCode>,
20}
21
22impl Default for OverrideStates {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl OverrideStates {
29 pub fn new() -> Self {
30 Self {
31 mods_pressed: 0,
32 oscs_to_add: Vec::new(),
33 oscs_to_remove: Vec::new(),
34 }
35 }
36
37 fn cleanup(&mut self) {
38 self.oscs_to_add.clear();
39 self.oscs_to_remove.clear();
40 self.mods_pressed = 0;
41 }
42
43 fn update(&mut self, osc: OsCode, overrides: &Overrides) {
44 if let Some(mod_mask) = mask_for_key(osc) {
45 self.mods_pressed |= mod_mask;
46 } else {
47 overrides.update_keys(
48 osc,
49 self.mods_pressed,
50 &mut self.oscs_to_add,
51 &mut self.oscs_to_remove,
52 );
53 }
54 }
55
56 fn is_key_overridden(&self, osc: OsCode) -> bool {
57 self.oscs_to_remove.contains(&osc)
58 }
59
60 fn add_overrides(&self, oscs: &mut Vec<KeyCode>) {
61 oscs.extend(self.oscs_to_add.iter().copied().map(KeyCode::from));
62 }
63
64 pub fn removed_oscs(&self) -> impl Iterator<Item = OsCode> + '_ {
65 self.oscs_to_remove.iter().copied()
66 }
67}
68
69#[derive(Debug, Clone, PartialEq, Eq)]
71pub struct Overrides {
72 overrides_by_osc: HashMap<OsCode, Vec<Override>>,
73}
74
75impl Overrides {
76 pub fn new(overrides: &[Override]) -> Self {
77 let mut overrides_by_osc: HashMap<OsCode, Vec<Override>> = HashMap::default();
78 for o in overrides.iter() {
79 overrides_by_osc
80 .entry(o.in_non_mod_osc)
81 .and_modify(|ovd| ovd.push(o.clone()))
82 .or_insert_with(|| vec![o.clone()]);
83 }
84 for ovds in overrides_by_osc.values_mut() {
85 ovds.shrink_to_fit();
86 }
87 overrides_by_osc.shrink_to_fit();
88 Self { overrides_by_osc }
89 }
90
91 pub fn override_keys(&self, kcs: &mut Vec<KeyCode>, states: &mut OverrideStates) {
92 if self.is_empty() {
93 return;
94 }
95 states.cleanup();
96 for kc in kcs.iter().copied() {
97 states.update(kc.into(), self);
98 }
99 kcs.retain(|kc| !states.is_key_overridden((*kc).into()));
100 states.add_overrides(kcs);
101 }
102
103 pub fn output_non_mods_for_input_non_mod(&self, in_osc: OsCode) -> Vec<OsCode> {
104 let mut ret = Vec::new();
105 if let Some(ovds) = self.overrides_by_osc.get(&in_osc) {
106 for out_osc in ovds.iter().map(|ovd| ovd.out_non_mod_osc) {
107 ret.push(out_osc);
108 }
109 }
110 ret
111 }
112
113 fn is_empty(&self) -> bool {
114 self.overrides_by_osc.is_empty()
115 }
116
117 fn update_keys(
118 &self,
119 active_osc: OsCode,
120 active_mod_mask: u8,
121 oscs_to_add: &mut Vec<OsCode>,
122 oscs_to_remove: &mut Vec<OsCode>,
123 ) {
124 let Some(ovds) = self.overrides_by_osc.get(&active_osc) else {
125 return;
126 };
127 let mut cur_chord_size = 0;
128 if let Some(ovd) = ovds
129 .iter()
130 .filter(|ovd| {
131 let mask = ovd.get_mod_mask();
132 if mask & active_mod_mask == mask {
133 let chord_size = ovd.in_mod_oscs.len() + 1;
135 if chord_size <= cur_chord_size {
136 false
137 } else {
138 cur_chord_size = chord_size;
139 true
140 }
141 } else {
142 false
143 }
144 })
145 .next_back()
146 {
147 log::debug!("using override {ovd:?}");
148 ovd.add_override_keys(oscs_to_add);
149 ovd.add_removed_keys(oscs_to_remove);
150 }
151 }
152}
153
154#[derive(Debug, Clone, PartialEq, Eq, Hash)]
156pub struct Override {
157 in_non_mod_osc: OsCode,
158 out_non_mod_osc: OsCode,
159 in_mod_oscs: Vec<OsCode>,
160 out_mod_oscs: Vec<OsCode>,
161}
162
163impl Override {
164 pub fn try_new(in_oscs: &[OsCode], out_oscs: &[OsCode]) -> Result<Self> {
165 let mut in_nmoscs = in_oscs
166 .iter()
167 .copied()
168 .filter(|osc| mask_for_key(*osc).is_none());
169 let in_non_mod_osc = in_nmoscs.next().ok_or_else(|| {
170 anyhow!("override must contain exactly one input non-modifier key; found none")
171 })?;
172 if in_nmoscs.next().is_some() {
173 bail!("override must contain exactly one input non-modifier key; found multiple");
174 }
175 let mut out_nmoscs = out_oscs
176 .iter()
177 .copied()
178 .filter(|osc| mask_for_key(*osc).is_none());
179 let out_non_mod_osc = out_nmoscs.next().ok_or_else(|| {
180 anyhow!("override must contain exactly one output non-modifier key; found none")
181 })?;
182 if out_nmoscs.next().is_some() {
183 bail!("override must contain exactly one output non-modifier key; found multiple");
184 }
185 let mut in_mod_oscs = in_oscs
186 .iter()
187 .copied()
188 .filter(|osc| mask_for_key(*osc).is_some())
189 .collect::<Vec<_>>();
190 let mut out_mod_oscs = out_oscs
191 .iter()
192 .copied()
193 .filter(|osc| mask_for_key(*osc).is_some())
194 .collect::<Vec<_>>();
195 in_mod_oscs.shrink_to_fit();
196 out_mod_oscs.shrink_to_fit();
197 Ok(Self {
198 in_non_mod_osc,
199 out_non_mod_osc,
200 in_mod_oscs,
201 out_mod_oscs,
202 })
203 }
204
205 fn get_mod_mask(&self) -> u8 {
206 let mut mask = 0;
207 for osc in self.in_mod_oscs.iter().copied() {
208 mask |= mask_for_key(osc).expect("mod only");
209 }
210 mask
211 }
212
213 fn add_override_keys(&self, oscs_to_add: &mut Vec<OsCode>) {
214 for osc in self.out_mod_oscs.iter().copied() {
215 if !oscs_to_add.contains(&osc) {
216 oscs_to_add.push(osc);
217 }
218 }
219 if !oscs_to_add.contains(&self.out_non_mod_osc) {
220 oscs_to_add.push(self.out_non_mod_osc);
221 }
222 }
223
224 fn add_removed_keys(&self, oscs_to_remove: &mut Vec<OsCode>) {
225 for osc in self.in_mod_oscs.iter().copied() {
226 if !oscs_to_remove.contains(&osc) {
227 oscs_to_remove.push(osc);
228 }
229 }
230 if !oscs_to_remove.contains(&self.in_non_mod_osc) {
231 oscs_to_remove.push(self.in_non_mod_osc);
232 }
233 }
234}
235
236fn mask_for_key(osc: OsCode) -> Option<u8> {
237 match osc {
238 OsCode::KEY_LEFTCTRL => Some(1 << 0),
239 OsCode::KEY_LEFTSHIFT => Some(1 << 1),
240 OsCode::KEY_LEFTALT => Some(1 << 2),
241 OsCode::KEY_LEFTMETA => Some(1 << 3),
242 OsCode::KEY_RIGHTCTRL => Some(1 << 4),
243 OsCode::KEY_RIGHTSHIFT => Some(1 << 5),
244 OsCode::KEY_RIGHTALT => Some(1 << 6),
245 OsCode::KEY_RIGHTMETA => Some(1 << 7),
246 _ => None,
247 }
248}
249
250pub fn mark_overridden_nonmodkeys_for_eager_erasure<T>(
255 override_states: &OverrideStates,
256 kb_states: &mut [State<T>],
257) {
258 for osc_to_mark in override_states
259 .removed_oscs()
260 .filter(|osc| !osc.is_modifier())
261 {
262 let kc: KeyCode = osc_to_mark.into();
263 for kbstate in kb_states.iter_mut() {
264 if let State::NormalKey {
265 mut flags,
266 keycode,
267 coord,
268 } = kbstate
269 {
270 if kc == *keycode {
271 flags.0 |= NORMAL_KEY_FLAG_CLEAR_ON_NEXT_ACTION
272 | NORMAL_KEY_FLAG_CLEAR_ON_NEXT_RELEASE;
273 *kbstate = State::NormalKey {
274 flags,
275 keycode: *keycode,
276 coord: *coord,
277 };
278 }
279 }
280 }
281 }
282}