use alloc::vec::Vec;
use bevy_app::{App, Plugin, Startup};
use bevy_ecs::{
component::Component,
entity::Entity,
hierarchy::{ChildOf, Children},
observer::On,
query::{With, Without},
system::{Commands, Query, Res, ResMut, SystemParam},
};
use bevy_input::{
keyboard::{KeyCode, KeyboardInput},
ButtonInput, ButtonState,
};
use bevy_window::{PrimaryWindow, Window};
use log::warn;
use thiserror::Error;
use crate::{AcquireFocus, FocusedInput, InputFocus, InputFocusVisible};
#[cfg(feature = "bevy_reflect")]
use {
bevy_ecs::prelude::ReflectComponent,
bevy_reflect::{prelude::*, Reflect},
};
#[derive(Debug, Default, Component, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
#[cfg_attr(
feature = "bevy_reflect",
derive(Reflect),
reflect(Debug, Default, Component, PartialEq, Clone)
)]
pub struct TabIndex(pub i32);
#[derive(Debug, Default, Component, Copy, Clone)]
#[cfg_attr(
feature = "bevy_reflect",
derive(Reflect),
reflect(Debug, Default, Component, Clone)
)]
pub struct TabGroup {
pub order: i32,
pub modal: bool,
}
impl TabGroup {
pub fn new(order: i32) -> Self {
Self {
order,
modal: false,
}
}
pub fn modal() -> Self {
Self {
order: 0,
modal: true,
}
}
}
#[derive(Clone, Copy)]
pub enum NavAction {
Next,
Previous,
First,
Last,
}
#[derive(Debug, Error, PartialEq, Eq, Clone)]
pub enum TabNavigationError {
#[error("No tab groups found")]
NoTabGroups,
#[error("No focusable entities found")]
NoFocusableEntities,
#[error("Failed to navigate to next focusable entity")]
FailedToNavigateToNextFocusableEntity,
#[error("No tab group found for currently focused entity {previous_focus}. Users will not be able to navigate back to this entity.")]
NoTabGroupForCurrentFocus {
previous_focus: Entity,
new_focus: Entity,
},
}
#[doc(hidden)]
#[derive(SystemParam)]
pub struct TabNavigation<'w, 's> {
tabgroup_query: Query<'w, 's, (Entity, &'static TabGroup, &'static Children)>,
tabindex_query: Query<
'w,
's,
(Entity, Option<&'static TabIndex>, Option<&'static Children>),
Without<TabGroup>,
>,
parent_query: Query<'w, 's, &'static ChildOf>,
}
impl TabNavigation<'_, '_> {
pub fn navigate(
&self,
focus: &InputFocus,
action: NavAction,
) -> Result<Entity, TabNavigationError> {
if self.tabgroup_query.is_empty() {
return Err(TabNavigationError::NoTabGroups);
}
let tabgroup = focus.0.and_then(|focus_ent| {
self.parent_query
.iter_ancestors(focus_ent)
.find_map(|entity| {
self.tabgroup_query
.get(entity)
.ok()
.map(|(_, tg, _)| (entity, tg))
})
});
self.navigate_internal(focus.0, action, tabgroup)
}
pub fn initialize(
&self,
parent: Entity,
action: NavAction,
) -> Result<Entity, TabNavigationError> {
if self.tabgroup_query.is_empty() {
return Err(TabNavigationError::NoTabGroups);
}
match self.tabgroup_query.get(parent) {
Ok(tabgroup) => self.navigate_internal(None, action, Some((parent, tabgroup.1))),
Err(_) => Err(TabNavigationError::NoTabGroups),
}
}
pub fn navigate_internal(
&self,
focus: Option<Entity>,
action: NavAction,
tabgroup: Option<(Entity, &TabGroup)>,
) -> Result<Entity, TabNavigationError> {
let navigation_result = self.navigate_in_group(tabgroup, focus, action);
match navigation_result {
Ok(entity) => {
if let Some(previous_focus) = focus
&& tabgroup.is_none()
{
Err(TabNavigationError::NoTabGroupForCurrentFocus {
previous_focus,
new_focus: entity,
})
} else {
Ok(entity)
}
}
Err(e) => Err(e),
}
}
fn navigate_in_group(
&self,
tabgroup: Option<(Entity, &TabGroup)>,
focus: Option<Entity>,
action: NavAction,
) -> Result<Entity, TabNavigationError> {
let mut focusable: Vec<(Entity, TabIndex, usize)> =
Vec::with_capacity(self.tabindex_query.iter().len());
match tabgroup {
Some((tg_entity, tg)) if tg.modal => {
if let Ok((_, _, children)) = self.tabgroup_query.get(tg_entity) {
for child in children.iter() {
self.gather_focusable(&mut focusable, *child, 0);
}
}
}
_ => {
let mut tab_groups: Vec<(Entity, TabGroup)> = self
.tabgroup_query
.iter()
.filter(|(_, tg, _)| !tg.modal)
.map(|(e, tg, _)| (e, *tg))
.collect();
tab_groups.sort_by_key(|(_, tg)| tg.order);
tab_groups
.iter()
.enumerate()
.for_each(|(idx, (tg_entity, _))| {
self.gather_focusable(&mut focusable, *tg_entity, idx);
});
}
}
if focusable.is_empty() {
return Err(TabNavigationError::NoFocusableEntities);
}
focusable.sort_by(|(_, a_tab_idx, a_group), (_, b_tab_idx, b_group)| {
if a_group == b_group {
a_tab_idx.cmp(b_tab_idx)
} else {
a_group.cmp(b_group)
}
});
let index = focusable.iter().position(|e| Some(e.0) == focus);
let count = focusable.len();
let next = match (index, action) {
(Some(idx), NavAction::Next) => (idx + 1).rem_euclid(count),
(Some(idx), NavAction::Previous) => (idx + count - 1).rem_euclid(count),
(None, NavAction::Next) | (_, NavAction::First) => 0,
(None, NavAction::Previous) | (_, NavAction::Last) => count - 1,
};
match focusable.get(next) {
Some((entity, _, _)) => Ok(*entity),
None => Err(TabNavigationError::FailedToNavigateToNextFocusableEntity),
}
}
fn gather_focusable(
&self,
out: &mut Vec<(Entity, TabIndex, usize)>,
parent: Entity,
tab_group_idx: usize,
) {
if let Ok((entity, tabindex, children)) = self.tabindex_query.get(parent) {
if let Some(tabindex) = tabindex {
if tabindex.0 >= 0 {
out.push((entity, *tabindex, tab_group_idx));
}
}
if let Some(children) = children {
for child in children.iter() {
if self.tabgroup_query.get(*child).is_err() {
self.gather_focusable(out, *child, tab_group_idx);
}
}
}
} else if let Ok((_, tabgroup, children)) = self.tabgroup_query.get(parent) {
if !tabgroup.modal {
for child in children.iter() {
self.gather_focusable(out, *child, tab_group_idx);
}
}
}
}
}
pub(crate) fn acquire_focus(
mut acquire_focus: On<AcquireFocus>,
focusable: Query<(), With<TabIndex>>,
windows: Query<(), With<Window>>,
mut focus: ResMut<InputFocus>,
) {
if focusable.contains(acquire_focus.focused_entity) {
acquire_focus.propagate(false);
if focus.0 != Some(acquire_focus.focused_entity) {
focus.0 = Some(acquire_focus.focused_entity);
}
} else if windows.contains(acquire_focus.focused_entity) {
acquire_focus.propagate(false);
if focus.0.is_some() {
focus.clear();
}
}
}
pub struct TabNavigationPlugin;
impl Plugin for TabNavigationPlugin {
fn build(&self, app: &mut App) {
app.add_systems(Startup, setup_tab_navigation);
app.add_observer(acquire_focus);
#[cfg(feature = "bevy_picking")]
app.add_observer(click_to_focus);
}
}
fn setup_tab_navigation(mut commands: Commands, window: Query<Entity, With<PrimaryWindow>>) {
for window in window.iter() {
commands.entity(window).observe(handle_tab_navigation);
}
}
#[cfg(feature = "bevy_picking")]
fn click_to_focus(
press: On<bevy_picking::events::Pointer<bevy_picking::events::Press>>,
mut focus_visible: ResMut<InputFocusVisible>,
windows: Query<Entity, With<PrimaryWindow>>,
mut commands: Commands,
) {
if press.entity == press.original_event_target() {
if focus_visible.0 {
focus_visible.0 = false;
}
if let Ok(window) = windows.single() {
commands.trigger(AcquireFocus {
focused_entity: press.entity,
window,
});
}
}
}
pub fn handle_tab_navigation(
mut event: On<FocusedInput<KeyboardInput>>,
nav: TabNavigation,
mut focus: ResMut<InputFocus>,
mut visible: ResMut<InputFocusVisible>,
keys: Res<ButtonInput<KeyCode>>,
) {
let key_event = &event.input;
if key_event.key_code == KeyCode::Tab
&& key_event.state == ButtonState::Pressed
&& !key_event.repeat
{
let maybe_next = nav.navigate(
&focus,
if keys.pressed(KeyCode::ShiftLeft) || keys.pressed(KeyCode::ShiftRight) {
NavAction::Previous
} else {
NavAction::Next
},
);
match maybe_next {
Ok(next) => {
event.propagate(false);
focus.set(next);
visible.0 = true;
}
Err(e) => {
warn!("Tab navigation error: {e}");
if let TabNavigationError::NoTabGroupForCurrentFocus { new_focus, .. } = e {
event.propagate(false);
focus.set(new_focus);
visible.0 = true;
}
}
}
}
}
#[cfg(test)]
mod tests {
use bevy_ecs::system::SystemState;
use super::*;
#[test]
fn test_tab_navigation() {
let mut app = App::new();
let world = app.world_mut();
let tab_group_entity = world.spawn(TabGroup::new(0)).id();
let tab_entity_1 = world.spawn((TabIndex(0), ChildOf(tab_group_entity))).id();
let tab_entity_2 = world.spawn((TabIndex(1), ChildOf(tab_group_entity))).id();
let mut system_state: SystemState<TabNavigation> = SystemState::new(world);
let tab_navigation = system_state.get(world);
assert_eq!(tab_navigation.tabgroup_query.iter().count(), 1);
assert!(tab_navigation.tabindex_query.iter().count() >= 2);
let next_entity =
tab_navigation.navigate(&InputFocus::from_entity(tab_entity_1), NavAction::Next);
assert_eq!(next_entity, Ok(tab_entity_2));
let prev_entity =
tab_navigation.navigate(&InputFocus::from_entity(tab_entity_2), NavAction::Previous);
assert_eq!(prev_entity, Ok(tab_entity_1));
let first_entity = tab_navigation.navigate(&InputFocus::default(), NavAction::First);
assert_eq!(first_entity, Ok(tab_entity_1));
let last_entity = tab_navigation.navigate(&InputFocus::default(), NavAction::Last);
assert_eq!(last_entity, Ok(tab_entity_2));
}
#[test]
fn test_tab_navigation_between_groups_is_sorted_by_group() {
let mut app = App::new();
let world = app.world_mut();
let tab_group_1 = world.spawn(TabGroup::new(0)).id();
let tab_entity_1 = world.spawn((TabIndex(0), ChildOf(tab_group_1))).id();
let tab_entity_2 = world.spawn((TabIndex(1), ChildOf(tab_group_1))).id();
let tab_group_2 = world.spawn(TabGroup::new(1)).id();
let tab_entity_3 = world.spawn((TabIndex(0), ChildOf(tab_group_2))).id();
let tab_entity_4 = world.spawn((TabIndex(1), ChildOf(tab_group_2))).id();
let mut system_state: SystemState<TabNavigation> = SystemState::new(world);
let tab_navigation = system_state.get(world);
assert_eq!(tab_navigation.tabgroup_query.iter().count(), 2);
assert!(tab_navigation.tabindex_query.iter().count() >= 4);
let next_entity =
tab_navigation.navigate(&InputFocus::from_entity(tab_entity_1), NavAction::Next);
assert_eq!(next_entity, Ok(tab_entity_2));
let prev_entity =
tab_navigation.navigate(&InputFocus::from_entity(tab_entity_2), NavAction::Previous);
assert_eq!(prev_entity, Ok(tab_entity_1));
let first_entity = tab_navigation.navigate(&InputFocus::default(), NavAction::First);
assert_eq!(first_entity, Ok(tab_entity_1));
let last_entity = tab_navigation.navigate(&InputFocus::default(), NavAction::Last);
assert_eq!(last_entity, Ok(tab_entity_4));
let next_from_end_of_group_entity =
tab_navigation.navigate(&InputFocus::from_entity(tab_entity_2), NavAction::Next);
assert_eq!(next_from_end_of_group_entity, Ok(tab_entity_3));
let prev_entity_from_start_of_group =
tab_navigation.navigate(&InputFocus::from_entity(tab_entity_3), NavAction::Previous);
assert_eq!(prev_entity_from_start_of_group, Ok(tab_entity_2));
}
}