Skip to main content

cranpose_ui/
focus_dispatch.rs

1//! Focus invalidation manager for Cranpose.
2//!
3//! This module implements focus invalidation servicing that mirrors Jetpack Compose's
4//! `FocusInvalidationManager`. When focus modifiers change, they mark nodes for
5//! reprocessing without forcing layout/draw passes.
6
7use cranpose_core::NodeId;
8use std::cell::RefCell;
9use std::collections::HashSet;
10
11/// Manages focus invalidations across the UI tree.
12///
13/// Similar to Kotlin's `FocusInvalidationManager`, this tracks which
14/// layout nodes need focus state reprocessing and provides hooks for
15/// the runtime to service those invalidations.
16struct FocusInvalidationManager {
17    dirty_nodes: HashSet<NodeId>,
18    is_processing: bool,
19    active_focus_target: Option<NodeId>,
20}
21
22impl FocusInvalidationManager {
23    fn new() -> Self {
24        Self {
25            dirty_nodes: HashSet::new(),
26            is_processing: false,
27            active_focus_target: None,
28        }
29    }
30
31    fn schedule_invalidation(&mut self, node_id: NodeId) {
32        self.dirty_nodes.insert(node_id);
33    }
34
35    fn has_pending_invalidation(&self) -> bool {
36        !self.dirty_nodes.is_empty()
37    }
38
39    fn set_active_focus_target(&mut self, node_id: Option<NodeId>) {
40        self.active_focus_target = node_id;
41    }
42
43    fn active_focus_target(&self) -> Option<NodeId> {
44        self.active_focus_target
45    }
46
47    fn take_pending_for_processing(&mut self) -> Option<Vec<NodeId>> {
48        if self.is_processing {
49            return None;
50        }
51
52        self.is_processing = true;
53        Some(self.dirty_nodes.drain().collect())
54    }
55
56    fn finish_processing<I>(&mut self, remaining: I)
57    where
58        I: IntoIterator<Item = NodeId>,
59    {
60        self.dirty_nodes.extend(remaining);
61        self.is_processing = false;
62    }
63
64    fn clear(&mut self) {
65        self.dirty_nodes.clear();
66    }
67}
68
69pub(crate) struct FocusInvalidationState {
70    manager: RefCell<FocusInvalidationManager>,
71}
72
73impl FocusInvalidationState {
74    pub(crate) fn new() -> Self {
75        Self {
76            manager: RefCell::new(FocusInvalidationManager::new()),
77        }
78    }
79
80    fn schedule_invalidation(&self, node_id: NodeId) {
81        self.manager.borrow_mut().schedule_invalidation(node_id);
82    }
83
84    fn has_pending_invalidation(&self) -> bool {
85        self.manager.borrow().has_pending_invalidation()
86    }
87
88    fn set_active_focus_target(&self, node_id: Option<NodeId>) {
89        self.manager.borrow_mut().set_active_focus_target(node_id);
90    }
91
92    fn active_focus_target(&self) -> Option<NodeId> {
93        self.manager.borrow().active_focus_target()
94    }
95
96    fn process_invalidations<F>(&self, processor: F)
97    where
98        F: FnMut(NodeId),
99    {
100        let Some(nodes) = self.manager.borrow_mut().take_pending_for_processing() else {
101            return;
102        };
103
104        self.process_pending_nodes(nodes, processor);
105    }
106
107    fn clear(&self) {
108        self.manager.borrow_mut().clear();
109    }
110
111    fn process_pending_nodes<F>(&self, nodes: Vec<NodeId>, mut processor: F)
112    where
113        F: FnMut(NodeId),
114    {
115        let mut remaining = nodes.into_iter();
116        let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
117            for node_id in remaining.by_ref() {
118                processor(node_id);
119            }
120        }));
121
122        self.manager.borrow_mut().finish_processing(remaining);
123
124        if let Err(payload) = result {
125            std::panic::resume_unwind(payload);
126        }
127    }
128}
129
130/// Schedules a focus invalidation for the specified node.
131///
132/// This is called automatically when focus modifiers invalidate
133/// and mirrors Kotlin's `FocusInvalidationManager.scheduleInvalidation`.
134pub fn schedule_focus_invalidation(node_id: NodeId) {
135    crate::render_state::with_focus_dispatch(|state| state.schedule_invalidation(node_id));
136}
137
138/// Returns true if any focus invalidations are pending.
139pub fn has_pending_focus_invalidations() -> bool {
140    crate::render_state::with_focus_dispatch(|state| state.has_pending_invalidation())
141}
142
143/// Sets the currently active focus target.
144///
145/// This mirrors Kotlin's `FocusOwner.activeFocusTargetNode` and allows
146/// the focus system to track which node currently has focus.
147pub fn set_active_focus_target(node_id: Option<NodeId>) {
148    crate::render_state::with_focus_dispatch(|state| state.set_active_focus_target(node_id));
149}
150
151/// Returns the currently active focus target, if any.
152pub fn active_focus_target() -> Option<NodeId> {
153    crate::render_state::with_focus_dispatch(|state| state.active_focus_target())
154}
155
156/// Processes all pending focus invalidations.
157///
158/// The host (e.g., app shell or layout engine) should call this after
159/// composition/layout to service focus invalidations without forcing
160/// measure/layout passes.
161pub fn process_focus_invalidations<F>(processor: F)
162where
163    F: FnMut(NodeId),
164{
165    crate::render_state::with_focus_dispatch(|state| state.process_invalidations(processor));
166}
167
168/// Clears all pending focus invalidations without processing them.
169pub fn clear_focus_invalidations() {
170    crate::render_state::with_focus_dispatch(|state| state.clear());
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    #[test]
178    fn schedule_and_process_invalidations() {
179        let _app_context = crate::render_state::app_context_test_scope();
180        clear_focus_invalidations();
181
182        let node1: NodeId = 1;
183        let node2: NodeId = 2;
184
185        schedule_focus_invalidation(node1);
186        schedule_focus_invalidation(node2);
187
188        assert!(has_pending_focus_invalidations());
189
190        let mut processed = Vec::new();
191        process_focus_invalidations(|node_id| {
192            processed.push(node_id);
193        });
194
195        assert_eq!(processed.len(), 2);
196        assert!(processed.contains(&node1));
197        assert!(processed.contains(&node2));
198        assert!(!has_pending_focus_invalidations());
199    }
200
201    #[test]
202    fn active_focus_target_tracking() {
203        let _app_context = crate::render_state::app_context_test_scope();
204        set_active_focus_target(None);
205        assert_eq!(active_focus_target(), None);
206
207        let node: NodeId = 42;
208        set_active_focus_target(Some(node));
209        assert_eq!(active_focus_target(), Some(node));
210
211        set_active_focus_target(None);
212        assert_eq!(active_focus_target(), None);
213    }
214
215    #[test]
216    fn duplicate_invalidations_deduplicated() {
217        let _app_context = crate::render_state::app_context_test_scope();
218        clear_focus_invalidations();
219
220        let node: NodeId = 42;
221        schedule_focus_invalidation(node);
222        schedule_focus_invalidation(node);
223        schedule_focus_invalidation(node);
224
225        let mut count = 0;
226        process_focus_invalidations(|_| {
227            count += 1;
228        });
229
230        assert_eq!(count, 1);
231    }
232
233    #[test]
234    fn process_invalidations_recovers_after_processor_panic() {
235        let _app_context = crate::render_state::app_context_test_scope();
236        clear_focus_invalidations();
237
238        schedule_focus_invalidation(1);
239        let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
240            process_focus_invalidations(|_| panic!("focus processor panic"));
241        }));
242        assert!(result.is_err());
243
244        schedule_focus_invalidation(2);
245        let mut processed = Vec::new();
246        process_focus_invalidations(|node_id| processed.push(node_id));
247
248        assert!(
249            processed.contains(&2),
250            "focus invalidation processing must not stay stuck after a processor panic"
251        );
252        assert!(!has_pending_focus_invalidations());
253    }
254
255    #[test]
256    fn process_invalidations_allows_processor_to_schedule_more_work() {
257        let _app_context = crate::render_state::app_context_test_scope();
258        clear_focus_invalidations();
259
260        schedule_focus_invalidation(1);
261        let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
262            process_focus_invalidations(|_| schedule_focus_invalidation(2));
263        }));
264        assert!(
265            result.is_ok(),
266            "focus processors must be able to enqueue follow-up invalidations"
267        );
268        assert!(has_pending_focus_invalidations());
269
270        let mut processed = Vec::new();
271        process_focus_invalidations(|node_id| processed.push(node_id));
272
273        assert_eq!(processed, vec![2]);
274        assert!(!has_pending_focus_invalidations());
275    }
276
277    #[test]
278    fn focus_state_is_scoped_by_app_context() {
279        let _app_context = crate::render_state::app_context_test_scope();
280        let first = crate::render_state::AppContext::new_with_density(1.0);
281        let second = crate::render_state::AppContext::new_with_density(1.0);
282
283        first.enter(|| {
284            clear_focus_invalidations();
285            schedule_focus_invalidation(7);
286            set_active_focus_target(Some(17));
287            assert!(has_pending_focus_invalidations());
288            assert_eq!(active_focus_target(), Some(17));
289        });
290
291        second.enter(|| {
292            clear_focus_invalidations();
293            assert!(!has_pending_focus_invalidations());
294            assert_eq!(active_focus_target(), None);
295            schedule_focus_invalidation(9);
296            set_active_focus_target(Some(19));
297        });
298
299        first.enter(|| {
300            let mut processed = Vec::new();
301            process_focus_invalidations(|node_id| processed.push(node_id));
302            assert_eq!(processed, vec![7]);
303            assert_eq!(active_focus_target(), Some(17));
304        });
305
306        second.enter(|| {
307            let mut processed = Vec::new();
308            process_focus_invalidations(|node_id| processed.push(node_id));
309            assert_eq!(processed, vec![9]);
310            assert_eq!(active_focus_target(), Some(19));
311        });
312    }
313}