Skip to main content

rgpui_component/
focus_trap.rs

1use rgpui::{
2    AnyElement, App, Bounds, Element, ElementId, FocusHandle, Global, GlobalElementId,
3    InteractiveElement, Interactivity, IntoElement, LayoutId, ParentElement, Pixels,
4    StatefulInteractiveElement, StyleRefinement, Styled, WeakFocusHandle, Window,
5};
6use std::collections::HashMap;
7
8/// Initialize the focus trap manager as a global
9pub(crate) fn init(cx: &mut App) {
10    cx.set_global(FocusTrapManager::new());
11}
12
13/// An extension trait to add `focus_trap` functionality to interactive elements.
14pub trait FocusTrapElement: InteractiveElement + Sized {
15    /// Enable focus trap for this element.
16    ///
17    /// When enabled, focus will automatically cycle within this container
18    /// instead of escaping to parent elements. This is useful for modal dialogs,
19    /// sheets, and other overlay components.
20    ///
21    /// The focus trap works by:
22    /// 1. Registering this element as a focus trap container
23    /// 2. When Tab/Shift-Tab is pressed, Root intercepts the event
24    /// 3. If focus would leave the container, it cycles back to the beginning/end
25    ///
26    /// # Example
27    ///
28    /// ```ignore
29    /// v_flex()
30    ///     .child(Button::new("btn1").label("Button 1"))
31    ///     .child(Button::new("btn2").label("Button 2"))
32    ///     .child(Button::new("btn3").label("Button 3"))
33    ///     .focus_trap("trap1", &self.container_focus_handle)
34    /// // Pressing Tab will cycle: btn1 -> btn2 -> btn3 -> btn1
35    /// // Focus will not escape to elements outside this container
36    /// ```
37    ///
38    /// See also: <https://github.com/focus-trap/focus-trap-react>
39    fn focus_trap(
40        self,
41        id: impl Into<ElementId>,
42        focus_handle: &FocusHandle,
43    ) -> FocusTrapContainer<Self>
44    where
45        Self: ParentElement + Styled + Element + 'static,
46    {
47        FocusTrapContainer::new(id, focus_handle.clone(), self)
48    }
49}
50impl<T: InteractiveElement + Sized> FocusTrapElement for T {}
51
52/// Global state to manage all focus trap containers
53pub(crate) struct FocusTrapManager {
54    /// Map from container element ID to its focus trap info
55    traps: HashMap<GlobalElementId, WeakFocusHandle>,
56}
57
58impl Global for FocusTrapManager {}
59
60impl FocusTrapManager {
61    /// Create a new focus trap manager
62    fn new() -> Self {
63        Self {
64            traps: HashMap::new(),
65        }
66    }
67
68    pub(crate) fn global(cx: &App) -> &Self {
69        cx.global::<FocusTrapManager>()
70    }
71
72    fn global_mut(cx: &mut App) -> &mut Self {
73        cx.global_mut::<FocusTrapManager>()
74    }
75
76    /// Register a focus trap container
77    fn register_trap(id: &GlobalElementId, container_handle: WeakFocusHandle, cx: &mut App) {
78        let this = Self::global_mut(cx);
79        this.traps.insert(id.clone(), container_handle);
80        this.cleanup();
81    }
82
83    /// Find which focus trap contains the currently focused element
84    pub(crate) fn find_active_trap(window: &Window, cx: &App) -> Option<FocusHandle> {
85        for (_id, container_handle) in Self::global(cx).traps.iter() {
86            let Some(container) = container_handle.upgrade() else {
87                continue;
88            };
89
90            if container.contains_focused(window, cx) {
91                return Some(container.clone());
92            }
93        }
94        None
95    }
96
97    /// Cleanup any traps with dropped handles
98    fn cleanup(&mut self) {
99        self.traps.retain(|_, handle| handle.upgrade().is_some());
100    }
101}
102
103impl Default for FocusTrapManager {
104    fn default() -> Self {
105        Self::new()
106    }
107}
108
109/// A wrapper element that implements focus trap behavior.
110///
111/// This element wraps another element and registers it as a focus trap container.
112/// Focus will automatically cycle within the container when Tab/Shift-Tab is pressed.
113pub struct FocusTrapContainer<E: InteractiveElement + ParentElement + Styled + Element> {
114    id: ElementId,
115    focus_handle: FocusHandle,
116    base: E,
117}
118
119impl<E: InteractiveElement + ParentElement + Styled + Element> FocusTrapContainer<E> {
120    pub(crate) fn new(id: impl Into<ElementId>, focus_handle: FocusHandle, child: E) -> Self {
121        Self {
122            id: id.into(),
123            base: child.track_focus(&focus_handle),
124            focus_handle,
125        }
126    }
127}
128
129impl<E: InteractiveElement + ParentElement + Styled + Element> IntoElement
130    for FocusTrapContainer<E>
131{
132    type Element = Self;
133
134    fn into_element(self) -> Self::Element {
135        self
136    }
137}
138impl<E: InteractiveElement + ParentElement + Styled + Element> ParentElement
139    for FocusTrapContainer<E>
140{
141    fn extend(&mut self, elements: impl IntoIterator<Item = AnyElement>) {
142        self.base.extend(elements);
143    }
144}
145impl<E: InteractiveElement + ParentElement + Styled + Element> InteractiveElement
146    for FocusTrapContainer<E>
147{
148    fn interactivity(&mut self) -> &mut Interactivity {
149        self.base.interactivity()
150    }
151}
152impl<E: InteractiveElement + ParentElement + Styled + Element> StatefulInteractiveElement
153    for FocusTrapContainer<E>
154{
155}
156impl<E: InteractiveElement + ParentElement + Styled + Element> Styled for FocusTrapContainer<E> {
157    fn style(&mut self) -> &mut StyleRefinement {
158        self.base.style()
159    }
160}
161
162impl<E: InteractiveElement + ParentElement + Styled + Element + 'static> Element
163    for FocusTrapContainer<E>
164{
165    type RequestLayoutState = E::RequestLayoutState;
166    type PrepaintState = E::PrepaintState;
167
168    fn id(&self) -> Option<ElementId> {
169        Some(self.id.clone())
170    }
171
172    fn source_location(&self) -> Option<&'static std::panic::Location<'static>> {
173        None
174    }
175
176    fn request_layout(
177        &mut self,
178        global_id: Option<&rgpui::GlobalElementId>,
179        _inspector_id: Option<&rgpui::InspectorElementId>,
180        window: &mut Window,
181        cx: &mut App,
182    ) -> (LayoutId, Self::RequestLayoutState) {
183        // Register this focus trap with the manager
184        FocusTrapManager::register_trap(global_id.unwrap(), self.focus_handle.downgrade(), cx);
185
186        self.base.request_layout(global_id, None, window, cx)
187    }
188
189    fn prepaint(
190        &mut self,
191        global_id: Option<&rgpui::GlobalElementId>,
192        inspector_id: Option<&rgpui::InspectorElementId>,
193        bounds: Bounds<Pixels>,
194        request_layout: &mut Self::RequestLayoutState,
195        window: &mut Window,
196        cx: &mut App,
197    ) -> Self::PrepaintState {
198        self.base
199            .prepaint(global_id, inspector_id, bounds, request_layout, window, cx)
200    }
201
202    fn paint(
203        &mut self,
204        global_id: Option<&rgpui::GlobalElementId>,
205        inspector_id: Option<&rgpui::InspectorElementId>,
206        bounds: Bounds<Pixels>,
207        request_layout: &mut Self::RequestLayoutState,
208        prepaint: &mut Self::PrepaintState,
209        window: &mut Window,
210        cx: &mut App,
211    ) {
212        self.base.paint(
213            global_id,
214            inspector_id,
215            bounds,
216            request_layout,
217            prepaint,
218            window,
219            cx,
220        )
221    }
222}