use std::collections::BTreeMap;
use std::sync::Arc;
use crate::{BindingGroup, Key, Keymap, NodeResult};
pub type CatchAllHandler<K, A> = Arc<dyn Fn(K) -> Option<A> + Send + Sync>;
#[derive(Clone)]
pub struct WhichKeyState<K, S, A, C>
where
K: Key,
{
pub active: bool,
pub current_sequence: Vec<K>,
scope: S,
keymap: Keymap<K, S, A, C>,
cached_bindings: Vec<BindingGroup<K>>,
catch_all_handlers: BTreeMap<S, CatchAllHandler<K, A>>,
}
impl<K, S, A, C> std::fmt::Debug for WhichKeyState<K, S, A, C>
where
K: Key + std::fmt::Debug,
S: std::fmt::Debug,
A: std::fmt::Debug,
C: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WhichKeyState")
.field("active", &self.active)
.field("current_sequence", &self.current_sequence)
.field("scope", &self.scope)
.field("keymap", &self.keymap)
.field("cached_bindings", &self.cached_bindings)
.finish_non_exhaustive()
}
}
impl<K, S, A, C> WhichKeyState<K, S, A, C>
where
K: Key,
S: Clone,
{
#[must_use]
pub fn scope(&self) -> &S {
&self.scope
}
pub fn toggle(&mut self) {
self.active = !self.active;
if self.active {
self.current_sequence.clear();
}
}
pub fn dismiss(&mut self) {
self.active = false;
self.current_sequence.clear();
}
#[must_use]
pub fn is_pending(&self) -> bool {
!self.current_sequence.is_empty()
}
#[must_use]
pub fn keymap(&self) -> &Keymap<K, S, A, C> {
&self.keymap
}
}
impl<K, S, A, C> WhichKeyState<K, S, A, C>
where
K: Key + Clone + PartialEq,
S: Clone + Ord + PartialEq + Send + Sync,
A: Clone + Send + Sync,
C: Clone + std::fmt::Display,
{
#[must_use]
pub fn new(keymap: Keymap<K, S, A, C>, scope: S) -> Self {
let cached_bindings = keymap.bindings_for_scope(scope.clone());
let catch_all_handlers = keymap.catch_all_handlers().clone();
Self {
active: false,
current_sequence: Vec::new(),
scope,
keymap,
cached_bindings,
catch_all_handlers,
}
}
pub fn set_scope(&mut self, scope: S) {
self.scope = scope.clone();
self.cached_bindings = self.keymap.bindings_for_scope(scope);
}
pub fn handle_key(&mut self, key: K) -> Option<A> {
if key.is_backspace() {
self.current_sequence.pop();
if self.current_sequence.is_empty() {
self.dismiss();
}
return None;
}
self.current_sequence.push(key.clone());
match self.keymap.navigate(&self.current_sequence, &self.scope) {
Some(NodeResult::Branch { .. }) => {
self.active = true;
None
}
Some(NodeResult::Leaf { action }) => {
self.active = false;
self.current_sequence.clear();
Some(action)
}
None => {
if let Some(handler) = self.catch_all_handlers.get(&self.scope) {
let action = handler(key);
self.dismiss();
action
} else {
self.dismiss();
None
}
}
}
}
#[must_use]
pub fn current_bindings(&self) -> Vec<BindingGroup<K>> {
if self.current_sequence.is_empty() {
self.cached_bindings.clone()
} else {
self.keymap
.children_at_path(&self.current_sequence)
.map(|children| {
vec![BindingGroup {
category: String::new(),
bindings: children,
}]
})
.unwrap_or_default()
}
}
#[must_use]
pub fn format_path(&self) -> String {
self.current_sequence
.iter()
.map(super::key::Key::display)
.collect::<Vec<_>>()
.join(" > ")
}
}
#[cfg(test)]
mod tests {
#![allow(dead_code)]
use derive_more::Display;
use super::*;
#[derive(Display, Debug, Clone, Copy, PartialEq, Eq)]
enum TestCategory {
General,
}
#[derive(Debug, Clone, PartialEq)]
enum TestAction {
Quit,
Save,
}
impl std::fmt::Display for TestAction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TestAction::Quit => write!(f, "quit"),
TestAction::Save => write!(f, "save"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
enum TestScope {
Global,
Insert,
}
#[derive(Debug, Clone, PartialEq, Eq)]
enum TestKey {
Char(char),
Backspace,
}
impl Key for TestKey {
fn display(&self) -> String {
match self {
TestKey::Char(c) => c.to_string(),
TestKey::Backspace => "BS".to_string(),
}
}
fn is_backspace(&self) -> bool {
matches!(self, TestKey::Backspace)
}
fn from_char(c: char) -> Option<Self> {
Some(TestKey::Char(c))
}
fn space() -> Self {
TestKey::Char(' ')
}
}
fn create_test_keymap() -> Keymap<TestKey, TestScope, TestAction, TestCategory> {
Keymap::new()
}
#[test]
fn new_creates_inactive_state() {
let keymap = create_test_keymap();
let state = WhichKeyState::new(keymap, TestScope::Global);
assert!(!state.active);
assert!(state.current_sequence.is_empty());
}
#[test]
fn toggle_activates_inactive_state() {
let keymap = create_test_keymap();
let mut state = WhichKeyState::new(keymap, TestScope::Global);
assert!(!state.active);
state.toggle();
assert!(state.active);
}
#[test]
fn toggle_deactivates_active_state() {
let keymap = create_test_keymap();
let mut state = WhichKeyState::new(keymap, TestScope::Global);
state.active = true;
state.toggle();
assert!(!state.active);
}
#[test]
fn dismiss_clears_state() {
let keymap = create_test_keymap();
let mut state = WhichKeyState::new(keymap, TestScope::Global);
state.active = true;
state.current_sequence.push(TestKey::Char('a'));
state.dismiss();
assert!(!state.active);
assert!(state.current_sequence.is_empty());
}
#[test]
fn is_pending_returns_true_when_keys_present() {
let keymap = create_test_keymap();
let mut state = WhichKeyState::new(keymap, TestScope::Global);
state.current_sequence.push(TestKey::Char('a'));
assert!(state.is_pending());
}
#[test]
fn format_path_joins_keys() {
let keymap = create_test_keymap();
let mut state = WhichKeyState::new(keymap, TestScope::Global);
state.current_sequence.push(TestKey::Char('a'));
state.current_sequence.push(TestKey::Char('b'));
assert_eq!(state.format_path(), "a > b");
}
#[test]
fn set_scope_updates_scope() {
let keymap = create_test_keymap();
let mut state = WhichKeyState::new(keymap, TestScope::Global);
state.set_scope(TestScope::Insert);
assert_eq!(*state.scope(), TestScope::Insert);
}
use crate::test_utils::state_with_binding_and_sequence;
#[test]
fn leaf_action_clears_sequence() {
let mut state = state_with_binding_and_sequence(
"qw",
TestAction::Quit,
TestCategory::General,
TestScope::Global,
&[],
);
state.handle_key(TestKey::Char('q'));
let result = state.handle_key(TestKey::Char('w'));
assert!(result.is_some());
assert!(!state.active);
assert!(state.current_sequence.is_empty());
assert_eq!(state.format_path(), "");
}
#[test]
fn backspace_dismisses_when_single_key_in_sequence() {
let mut state = state_with_binding_and_sequence(
"qw",
TestAction::Quit,
TestCategory::General,
TestScope::Global,
&[TestKey::Char('q')],
);
state.handle_key(TestKey::Backspace);
assert!(!state.active);
assert!(state.current_sequence.is_empty());
}
#[test]
fn catch_all_returns_action_for_unmatched_key() {
let mut keymap = create_test_keymap();
keymap.register_catch_all(TestScope::Global, |key| {
if let TestKey::Char(_c) = key {
Some(TestAction::Save)
} else {
None
}
});
let mut state = WhichKeyState::new(keymap, TestScope::Global);
let result = state.handle_key(TestKey::Char('x'));
assert!(result.is_some());
assert_eq!(result, Some(TestAction::Save));
}
#[test]
fn catch_all_returns_none_dismisses() {
let mut keymap = create_test_keymap();
keymap.register_catch_all(TestScope::Global, |_key| None);
let mut state = WhichKeyState::new(keymap, TestScope::Global);
state.active = true;
let result = state.handle_key(TestKey::Char('x'));
assert!(result.is_none());
assert!(!state.active);
}
#[test]
fn no_catch_all_dismisses_on_unmatched() {
let keymap = create_test_keymap();
let mut state = WhichKeyState::new(keymap, TestScope::Global);
state.active = true;
let result = state.handle_key(TestKey::Char('x'));
assert!(result.is_none());
assert!(!state.active);
}
#[test]
fn catch_all_only_applies_to_matching_scope() {
let mut keymap = create_test_keymap();
keymap.register_catch_all(TestScope::Insert, |_key| Some(TestAction::Save));
let mut state = WhichKeyState::new(keymap, TestScope::Global);
state.active = true;
let result = state.handle_key(TestKey::Char('x'));
assert!(result.is_none());
assert!(!state.active);
}
#[test]
fn handle_key_with_custom_leader_triggers_action() {
let mut keymap: Keymap<TestKey, TestScope, TestAction, TestCategory> =
Keymap::new().with_leader(TestKey::Char('a'));
keymap.bind(
"<leader>gg",
TestAction::Quit,
TestCategory::General,
TestScope::Global,
);
let mut state = WhichKeyState::new(keymap, TestScope::Global);
let result = state.handle_key(TestKey::Char('a'));
assert!(state.active);
assert!(result.is_none());
let result = state.handle_key(TestKey::Char('g'));
assert!(state.active);
assert!(result.is_none());
let result = state.handle_key(TestKey::Char('g'));
assert_eq!(result, Some(TestAction::Quit));
}
}