use crate::VKey;
use windows::Win32::UI::Input::KeyboardAndMouse::GetAsyncKeyState;
#[derive(Copy, Clone)]
pub struct KeyboardState {
pub flags: [u128; 2],
}
impl KeyboardState {
pub fn new() -> KeyboardState {
KeyboardState { flags: [0, 0] }
}
pub fn keydown(&mut self, key: u16) {
let index = (key / 128) as usize;
let position = key % 128;
self.flags[index] |= 1 << position;
match key {
_ if key == VKey::LShift.to_vk_code() => self.keydown(VKey::Shift.to_vk_code()),
_ if key == VKey::RShift.to_vk_code() => self.keydown(VKey::Shift.to_vk_code()),
_ if key == VKey::LControl.to_vk_code() => self.keydown(VKey::Control.to_vk_code()),
_ if key == VKey::RControl.to_vk_code() => self.keydown(VKey::Control.to_vk_code()),
_ if key == VKey::LMenu.to_vk_code() => self.keydown(VKey::Menu.to_vk_code()),
_ if key == VKey::RMenu.to_vk_code() => self.keydown(VKey::Menu.to_vk_code()),
_ => {}
}
}
pub fn keyup(&mut self, key: u16) {
let index = (key / 128) as usize;
let position = key % 128;
self.flags[index] &= !(1 << position);
if (key == VKey::LShift.to_vk_code() || key == VKey::RShift.to_vk_code())
&& !self.is_down(VKey::LShift.to_vk_code())
&& !self.is_down(VKey::RShift.to_vk_code())
{
self.keyup(VKey::Shift.to_vk_code());
} else if (key == VKey::LControl.to_vk_code() || key == VKey::RControl.to_vk_code())
&& !self.is_down(VKey::LControl.to_vk_code())
&& !self.is_down(VKey::RControl.to_vk_code())
{
self.keyup(VKey::Control.to_vk_code());
} else if (key == VKey::LMenu.to_vk_code() || key == VKey::RMenu.to_vk_code())
&& !self.is_down(VKey::LMenu.to_vk_code())
&& !self.is_down(VKey::RMenu.to_vk_code())
{
self.keyup(VKey::Menu.to_vk_code());
}
}
pub fn is_down(&self, key: u16) -> bool {
let index = (key / 128) as usize;
let position = key % 128;
(self.flags[index] & (1 << position)) != 0
}
pub fn sync(&mut self) {
for vk_code in 0..128 {
let mask = 1u128 << vk_code;
if self.flags[0] & mask != 0 && !Self::get_async_key_state(vk_code) {
self.keyup(vk_code);
}
if self.flags[1] & mask != 0 && !Self::get_async_key_state(vk_code + 128) {
self.keyup(vk_code + 128);
}
}
}
pub fn clear(&mut self) {
self.flags = [0, 0];
}
pub fn get_async_key_state(key: u16) -> bool {
unsafe { (GetAsyncKeyState(key.into()) & -0x8000) != 0 }
}
}
impl std::fmt::Debug for KeyboardState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut keys = Vec::new();
for vk_code in 0..128 {
match vk_code {
_ if vk_code == VKey::Shift.to_vk_code() => continue,
_ if vk_code == VKey::Control.to_vk_code() => continue,
_ if vk_code == VKey::Menu.to_vk_code() => continue,
_ => {}
}
let mask = 1u128 << vk_code;
if self.flags[0] & mask != 0 {
keys.push(VKey::from_vk_code(vk_code));
}
if self.flags[1] & mask != 0 {
keys.push(VKey::from_vk_code(vk_code + 128));
}
}
f.debug_struct("KeyboardState")
.field("Keys", &keys)
.finish()
}
}
impl Default for KeyboardState {
fn default() -> Self {
Self::new()
}
}
impl PartialEq for KeyboardState {
fn eq(&self, other: &KeyboardState) -> bool {
self.flags == other.flags
}
}
impl Eq for KeyboardState {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_keyboard_state() {
let keyboard = KeyboardState::new();
assert_eq!(
keyboard.flags,
[0, 0],
"New KeyboardState should have all flags cleared"
);
}
#[test]
fn test_keydown() {
let mut keyboard = KeyboardState::new();
keyboard.keydown(65);
assert_eq!(keyboard.flags[0], 1 << (65 % 128), "Key 65 should be set");
keyboard.keydown(129);
assert_eq!(keyboard.flags[1], 1 << (129 % 128), "Key 129 should be set");
}
#[test]
fn test_keyup() {
let mut keyboard = KeyboardState::new();
keyboard.keydown(65); keyboard.keyup(65); assert_eq!(keyboard.flags[0], 0, "Key 65 should be cleared");
keyboard.keydown(129); keyboard.keyup(129); assert_eq!(keyboard.flags[1], 0, "Key 129 should be cleared");
}
#[test]
fn test_clear() {
let mut keyboard = KeyboardState::new();
keyboard.keydown(65); keyboard.keydown(129); keyboard.clear(); assert_eq!(
keyboard.flags,
[0, 0],
"KeyboardState should be cleared after clear()"
);
}
#[test]
fn test_clone() {
let mut keyboard = KeyboardState::new();
keyboard.keydown(65); let cloned_keyboard = keyboard.clone();
assert_eq!(
keyboard, cloned_keyboard,
"Cloned KeyboardState should be equal to the original"
);
keyboard.keydown(129);
assert_ne!(
keyboard, cloned_keyboard,
"Cloned KeyboardState should not reflect changes to the original"
);
}
#[test]
fn test_equality() {
let mut keyboard1 = KeyboardState::new();
let mut keyboard2 = KeyboardState::new();
assert_eq!(
keyboard1, keyboard2,
"Two empty KeyboardState instances should be equal"
);
keyboard1.keydown(65);
assert_ne!(
keyboard1, keyboard2,
"KeyboardState instances with different flags should not be equal"
);
keyboard2.keydown(65);
assert_eq!(
keyboard1, keyboard2,
"KeyboardState instances with the same flags should be equal"
);
}
#[test]
fn test_multiple_keys() {
let mut keyboard = KeyboardState::new();
keyboard.keydown(65);
keyboard.keydown(70);
keyboard.keydown(129);
assert_eq!(true, keyboard.is_down(65), "Key 65 should be set");
assert_eq!(true, keyboard.is_down(70), "Key 70 should be set");
assert_eq!(true, keyboard.is_down(129), "Key 129 should be set");
keyboard.keyup(65);
keyboard.keyup(70);
assert_eq!(keyboard.flags[0], 0, "Key 65 should be cleared");
assert_eq!(keyboard.flags[0], 0, "Key 70 should be cleared");
assert_eq!(
keyboard.flags[1],
1 << (129 % 128),
"Key 129 should remain set"
);
}
}