1use crate::env::ScrollStateMap;
2use crate::ui::custom_render::downcast_render_object;
3use fission_diagnostics::prelude as diag;
4use fission_ir::{CoreIR, LayoutOp, Op, WidgetId};
5use fission_layout::{LayoutPoint, LayoutSnapshot};
6use glam::{Mat4, Vec4};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum FocusDirection {
10 Up,
11 Down,
12 Left,
13 Right,
14}
15
16pub fn hit_test(
17 ir: &CoreIR,
18 layout: &LayoutSnapshot,
19 scroll_map: &ScrollStateMap,
20 point: LayoutPoint,
21) -> Option<WidgetId> {
22 hit_test_internal(ir, layout, Some(scroll_map), point)
23}
24
25pub fn hit_test_with_scroll(
26 ir: &CoreIR,
27 layout: &LayoutSnapshot,
28 scroll_map: &ScrollStateMap,
29 point: LayoutPoint,
30) -> Option<WidgetId> {
31 hit_test_internal(ir, layout, Some(scroll_map), point)
32}
33
34fn hit_test_internal(
35 ir: &CoreIR,
36 layout: &LayoutSnapshot,
37 scroll_map: Option<&ScrollStateMap>,
38 point: LayoutPoint,
39) -> Option<WidgetId> {
40 let result = ir
41 .root
42 .and_then(|root| hit_test_recursive(root, ir, layout, scroll_map, point));
43
44 if let Some(id) = result {
45 diag::emit(
46 diag::DiagCategory::Input,
47 diag::DiagLevel::Debug,
48 diag::DiagEventKind::InputEvent {
49 kind: "hit_test_result".into(),
50 target: Some(id.as_u128()),
51 position: Some((point.x, point.y)),
52 },
53 );
54 }
55 result
56}
57
58fn hit_test_recursive(
59 node_id: WidgetId,
60 ir: &CoreIR,
61 layout: &LayoutSnapshot,
62 scroll_map: Option<&ScrollStateMap>,
63 point: LayoutPoint,
64) -> Option<WidgetId> {
65 let node = ir.nodes.get(&node_id)?;
66 let geom = layout.get_node_geometry(node_id)?;
67
68 let is_clip_container = matches!(
69 node.op,
70 Op::Layout(LayoutOp::Clip { .. }) | Op::Layout(LayoutOp::Scroll { .. })
71 );
72
73 if is_clip_container && !geom.rect.contains(point) {
74 return None;
75 }
76
77 let mut child_point = point;
78
79 if let (Some(map), Op::Layout(LayoutOp::Scroll { direction, .. })) = (scroll_map, &node.op) {
80 let offset = map.get_offset(node_id);
81 match direction {
82 fission_ir::FlexDirection::Column => {
83 child_point.y += offset;
84 }
85 fission_ir::FlexDirection::Row => {
86 child_point.x += offset;
87 }
88 }
89 }
90
91 if let Op::Layout(LayoutOp::Transform { transform }) = &node.op {
92 let mat = Mat4::from_cols_array(transform);
93 let inv = mat.inverse();
94 let local_x = point.x - geom.rect.origin.x;
95 let local_y = point.y - geom.rect.origin.y;
96 let p = Vec4::new(local_x, local_y, 0.0, 1.0);
97 let transformed = inv * p;
98 child_point = LayoutPoint::new(
99 transformed.x + geom.rect.origin.x,
100 transformed.y + geom.rect.origin.y,
101 );
102 }
103
104 for child_id in node.children.iter().rev() {
105 if let Some(hit) = hit_test_recursive(*child_id, ir, layout, scroll_map, child_point) {
106 return Some(hit);
107 }
108 }
109
110 if geom.rect.contains(point) {
114 if let Some(any_ro) = ir.custom_render_objects.get(&node_id) {
115 if let Some(render_obj) = downcast_render_object(any_ro) {
116 let local_point =
117 LayoutPoint::new(point.x - geom.rect.origin.x, point.y - geom.rect.origin.y);
118 let result = render_obj.hit_test(local_point, geom.rect);
119 if result.hit {
120 return Some(node_id);
121 }
122 }
123 }
124 }
125
126 let mut current_is_hit = false;
127 if geom.rect.contains(point) {
128 match &node.op {
129 Op::Layout(LayoutOp::Scroll { .. }) | Op::Layout(LayoutOp::Embed { .. }) => {
130 current_is_hit = true;
131 }
132 Op::Semantics(semantics) => {
133 if !semantics.actions.entries.is_empty()
134 || semantics.focusable
135 || semantics.draggable
136 || semantics.scrollable_x
137 || semantics.scrollable_y
138 {
139 current_is_hit = true;
140 }
141 }
142 _ => {}
143 }
144 }
145
146 if current_is_hit {
147 Some(node_id)
148 } else {
149 None
150 }
151}
152
153pub fn find_next_focus_node(
154 ir: &CoreIR,
155 current: Option<WidgetId>,
156 reverse: bool,
157) -> Option<WidgetId> {
158 let (current_scope_id, current_is_barrier) = if let Some(id) = current {
160 let scope = find_parent_scope(id, ir);
161 let mut is_barrier = false;
162 if let Some(sid) = scope {
163 if let Some(node) = ir.nodes.get(&sid) {
164 if let Op::Semantics(s) = &node.op {
165 is_barrier = s.is_focus_barrier;
166 }
167 }
168 }
169 (scope, is_barrier)
170 } else {
171 (None, false)
172 };
173
174 let nodes_in_scope = if current_is_barrier {
175 let scope_id = current_scope_id.unwrap();
176 let mut list = Vec::new();
177 if let Some(node) = ir.nodes.get(&scope_id) {
179 for child in &node.children {
180 collect_focusable_nodes(*child, ir, &mut list, true, 0);
181 }
182 }
183 sort_focusable_nodes(ir, list)
184 } else {
185 get_all_focusable_nodes(ir)
186 };
187
188 if nodes_in_scope.is_empty() {
189 return None;
190 }
191
192 let idx = if let Some(curr_id) = current {
193 nodes_in_scope.iter().position(|id| *id == curr_id)
194 } else {
195 None
196 };
197
198 match idx {
199 Some(i) => {
200 if reverse {
201 if i == 0 {
202 Some(nodes_in_scope[nodes_in_scope.len() - 1])
203 } else {
204 Some(nodes_in_scope[i - 1])
205 }
206 } else if i == nodes_in_scope.len() - 1 {
207 Some(nodes_in_scope[0])
208 } else {
209 Some(nodes_in_scope[i + 1])
210 }
211 }
212 None => {
213 if reverse {
214 Some(nodes_in_scope[nodes_in_scope.len() - 1])
215 } else {
216 Some(nodes_in_scope[0])
217 }
218 }
219 }
220}
221
222pub fn get_all_focusable_nodes(ir: &CoreIR) -> Vec<WidgetId> {
223 let mut list = Vec::new();
224 if let Some(root) = ir.root {
225 collect_focusable_nodes(root, ir, &mut list, false, 0);
226 }
227 sort_focusable_nodes(ir, list)
228}
229
230fn sort_focusable_nodes(ir: &CoreIR, mut list: Vec<(WidgetId, usize)>) -> Vec<WidgetId> {
231 list.sort_by(|(id_a, order_a), (id_b, order_b)| {
232 let idx_a = ir.nodes.get(id_a).and_then(|n| {
233 if let Op::Semantics(s) = &n.op {
234 s.focus_index
235 } else {
236 None
237 }
238 });
239 let idx_b = ir.nodes.get(id_b).and_then(|n| {
240 if let Op::Semantics(s) = &n.op {
241 s.focus_index
242 } else {
243 None
244 }
245 });
246
247 match (idx_a, idx_b) {
248 (Some(a), Some(b)) => a.cmp(&b).then(order_a.cmp(order_b)),
249 (Some(_), None) => std::cmp::Ordering::Less,
250 (None, Some(_)) => std::cmp::Ordering::Greater,
251 (None, None) => order_a.cmp(order_b),
252 }
253 });
254 list.into_iter().map(|(id, _)| id).collect()
255}
256
257fn collect_focusable_nodes(
258 node_id: WidgetId,
259 ir: &CoreIR,
260 list: &mut Vec<(WidgetId, usize)>,
261 stop_at_barriers: bool,
262 mut order: usize,
263) {
264 if let Some(node) = ir.nodes.get(&node_id) {
265 let mut is_barrier = false;
266 if let Op::Semantics(s) = &node.op {
267 if s.focusable && !s.disabled {
268 list.push((node_id, order));
269 order += 1;
270 }
271 is_barrier = s.is_focus_barrier;
272 }
273
274 if stop_at_barriers && is_barrier {
275 return;
276 }
277
278 let mut children = node.children.clone();
279 children.sort_by_key(|cid| {
281 ir.nodes
282 .get(cid)
283 .and_then(|n| {
284 if let Op::Semantics(s) = &n.op {
285 s.focus_index
286 } else {
287 None
288 }
289 })
290 .unwrap_or(i32::MAX)
291 });
292
293 for child in children {
294 collect_focusable_nodes(child, ir, list, stop_at_barriers, order);
295 order = list.last().map(|(_, o)| *o + 1).unwrap_or(order);
296 }
297 }
298}
299
300fn find_parent_scope(node_id: WidgetId, ir: &CoreIR) -> Option<WidgetId> {
301 let mut curr = ir.nodes.get(&node_id)?.parent;
302 while let Some(pid) = curr {
303 if let Some(node) = ir.nodes.get(&pid) {
304 if let Op::Semantics(s) = &node.op {
305 if s.is_focus_scope {
306 return Some(pid);
307 }
308 }
309 curr = node.parent;
310 } else {
311 break;
312 }
313 }
314 None
315}
316
317pub fn find_neighbor_focus_node(
318 ir: &CoreIR,
319 layout: &LayoutSnapshot,
320 current: WidgetId,
321 direction: FocusDirection,
322) -> Option<WidgetId> {
323 let current_rect = layout.get_node_rect(current)?;
324 let focusable_nodes = get_all_focusable_nodes(ir);
325
326 let mut best_candidate = None;
327 let mut best_dist = f32::INFINITY;
328
329 let (cx, cy) = (
330 current_rect.x() + current_rect.width() / 2.0,
331 current_rect.y() + current_rect.height() / 2.0,
332 );
333
334 for node_id in focusable_nodes {
335 if node_id == current {
336 continue;
337 }
338 let rect = match layout.get_node_rect(node_id) {
339 Some(r) => r,
340 None => continue,
341 };
342
343 let (nx, ny) = (
344 rect.x() + rect.width() / 2.0,
345 rect.y() + rect.height() / 2.0,
346 );
347
348 let is_in_dir = match direction {
349 FocusDirection::Up => ny < cy && (nx - cx).abs() < (ny - cy).abs(),
350 FocusDirection::Down => ny > cy && (nx - cx).abs() < (ny - cy).abs(),
351 FocusDirection::Left => nx < cx && (ny - cy).abs() < (nx - cx).abs(),
352 FocusDirection::Right => nx > cx && (ny - cy).abs() < (nx - cx).abs(),
353 };
354
355 if is_in_dir {
356 let dist = (nx - cx).powi(2) + (ny - cy).powi(2);
357 if dist < best_dist {
358 best_dist = dist;
359 best_candidate = Some(node_id);
360 }
361 }
362 }
363
364 best_candidate
365}