use std::collections::HashMap;
use crossterm::event::{KeyCode, KeyModifiers};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Keybind {
pub code: KeyCode,
pub modifiers: KeyModifiers,
}
impl Keybind {
pub fn new(code: KeyCode, modifiers: KeyModifiers) -> Self {
Self { code, modifiers }
}
pub fn key(code: KeyCode) -> Self {
Self::new(code, KeyModifiers::NONE)
}
pub fn ctrl(code: KeyCode) -> Self {
Self::new(code, KeyModifiers::CONTROL)
}
pub fn alt(code: KeyCode) -> Self {
Self::new(code, KeyModifiers::ALT)
}
pub fn display(&self) -> String {
let mut parts = Vec::new();
if self.modifiers.contains(KeyModifiers::CONTROL) {
parts.push("Ctrl".to_string());
}
if self.modifiers.contains(KeyModifiers::ALT) {
parts.push("Alt".to_string());
}
if self.modifiers.contains(KeyModifiers::SHIFT) {
parts.push("Shift".to_string());
}
parts.push(self.key_name());
parts.join("+")
}
fn key_name(&self) -> String {
match self.code {
KeyCode::Char(c) => c.to_uppercase().to_string(),
KeyCode::Enter => "Enter".to_string(),
KeyCode::Esc => "Esc".to_string(),
KeyCode::Tab => "Tab".to_string(),
KeyCode::Backspace => "Backspace".to_string(),
KeyCode::Up => "↑".to_string(),
KeyCode::Down => "↓".to_string(),
KeyCode::Left => "←".to_string(),
KeyCode::Right => "→".to_string(),
_ => format!("{:?}", self.code),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum KeyContext {
Global,
Status,
Commits,
Branches,
Stashes,
Diff,
Input,
}
pub struct KeybindRegistry {
bindings: HashMap<(KeyContext, Keybind), String>,
}
impl KeybindRegistry {
pub fn new() -> Self {
Self {
bindings: HashMap::new(),
}
}
pub fn bind(&mut self, context: KeyContext, keybind: Keybind, action: impl Into<String>) {
self.bindings.insert((context, keybind), action.into());
}
pub fn lookup(&self, context: KeyContext, keybind: &Keybind) -> Option<&str> {
self.bindings.get(&(context, keybind.clone()))
.or_else(|| self.bindings.get(&(KeyContext::Global, keybind.clone())))
.map(|s| s.as_str())
}
pub fn bindings_for(&self, context: KeyContext) -> Vec<(&Keybind, &str)> {
self.bindings
.iter()
.filter(|((ctx, _), _)| *ctx == context || *ctx == KeyContext::Global)
.map(|((_, kb), action)| (kb, action.as_str()))
.collect()
}
pub fn with_defaults() -> Self {
let mut r = Self::new();
r.bind(KeyContext::Global, Keybind::key(KeyCode::Char('q')), "Quit");
r.bind(KeyContext::Global, Keybind::key(KeyCode::Char('?')), "Help");
r.bind(KeyContext::Global, Keybind::key(KeyCode::Tab), "NextPane");
r.bind(KeyContext::Global, Keybind::ctrl(KeyCode::Char('r')), "Refresh");
r.bind(KeyContext::Status, Keybind::key(KeyCode::Char('s')), "Stage");
r.bind(KeyContext::Status, Keybind::key(KeyCode::Char('u')), "Unstage");
r.bind(KeyContext::Status, Keybind::key(KeyCode::Enter), "ShowDiff");
r.bind(KeyContext::Commits, Keybind::key(KeyCode::Char('c')), "Commit");
r.bind(KeyContext::Commits, Keybind::key(KeyCode::Char('a')), "Amend");
r
}
}
impl Default for KeybindRegistry {
fn default() -> Self {
Self::with_defaults()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_keybind_display() {
let kb = Keybind::ctrl(KeyCode::Char('s'));
assert_eq!(kb.display(), "Ctrl+S");
}
#[test]
fn test_lookup() {
let r = KeybindRegistry::default();
let action = r.lookup(KeyContext::Status, &Keybind::key(KeyCode::Char('s')));
assert_eq!(action, Some("Stage"));
let action = r.lookup(KeyContext::Status, &Keybind::key(KeyCode::Char('q')));
assert_eq!(action, Some("Quit"));
}
}