freya_core/accessibility/
tree.rs

1use std::sync::atomic::{
2    AtomicU64,
3    Ordering,
4};
5
6use accesskit::{
7    Action,
8    Affine,
9    Node,
10    NodeId as AccessibilityId,
11    Rect,
12    Role,
13    TextDirection,
14    Tree,
15    TreeUpdate,
16};
17use freya_engine::prelude::{
18    Color,
19    Slant,
20    TextAlign,
21    TextDecoration,
22    TextDecorationStyle,
23};
24use freya_native_core::{
25    node::NodeType,
26    prelude::NodeImmutable,
27    tags::TagName,
28    NodeId,
29};
30use rustc_hash::{
31    FxHashMap,
32    FxHashSet,
33};
34use torin::{
35    prelude::LayoutNode,
36    torin::Torin,
37};
38
39use super::NodeAccessibility;
40use crate::{
41    dom::{
42        DioxusDOM,
43        DioxusNode,
44    },
45    states::{
46        AccessibilityNodeState,
47        FontStyleState,
48        StyleState,
49        TransformState,
50    },
51    values::{
52        Fill,
53        OverflowMode,
54    },
55};
56
57/// Strategy focusing an Accessibility Node.
58#[derive(PartialEq, Debug, Clone)]
59pub enum AccessibilityFocusStrategy {
60    Forward,
61    Backward,
62    Node(accesskit::NodeId),
63}
64
65#[derive(Default)]
66pub struct AccessibilityDirtyNodes {
67    pub requested_focus: Option<AccessibilityFocusStrategy>,
68    pub added_or_updated: FxHashSet<NodeId>,
69    pub removed: FxHashMap<NodeId, NodeId>,
70}
71
72impl AccessibilityDirtyNodes {
73    pub fn request_focus(&mut self, node_id: AccessibilityFocusStrategy) {
74        self.requested_focus = Some(node_id);
75    }
76
77    pub fn add_or_update(&mut self, node_id: NodeId) {
78        self.added_or_updated.insert(node_id);
79    }
80
81    pub fn remove(&mut self, node_id: NodeId, parent_id: NodeId) {
82        self.removed.insert(node_id, parent_id);
83    }
84
85    pub fn clear(&mut self) {
86        self.requested_focus.take();
87        self.added_or_updated.clear();
88        self.removed.clear();
89    }
90}
91
92pub struct AccessibilityGenerator {
93    counter: AtomicU64,
94}
95
96impl Default for AccessibilityGenerator {
97    fn default() -> Self {
98        Self {
99            counter: AtomicU64::new(1), // Must start at 1 because 0 is reserved for the Root
100        }
101    }
102}
103
104impl AccessibilityGenerator {
105    pub fn new_id(&self) -> u64 {
106        self.counter.fetch_add(1, Ordering::Relaxed)
107    }
108}
109
110pub const ACCESSIBILITY_ROOT_ID: AccessibilityId = AccessibilityId(0);
111
112pub struct AccessibilityTree {
113    pub map: FxHashMap<AccessibilityId, NodeId>,
114    // Current focused Accessibility Node.
115    pub focused_id: AccessibilityId,
116}
117
118impl AccessibilityTree {
119    pub fn new(focused_id: AccessibilityId) -> Self {
120        Self {
121            focused_id,
122            map: FxHashMap::default(),
123        }
124    }
125
126    pub fn focused_node_id(&self) -> Option<NodeId> {
127        self.map.get(&self.focused_id).cloned()
128    }
129
130    /// Initialize the Accessibility Tree
131    pub fn init(
132        &self,
133        rdom: &DioxusDOM,
134        layout: &Torin<NodeId>,
135        dirty_nodes: &mut AccessibilityDirtyNodes,
136    ) -> TreeUpdate {
137        dirty_nodes.clear();
138
139        let mut nodes = vec![];
140
141        rdom.traverse_depth_first_advanced(|node_ref| {
142            if !node_ref.node_type().is_element() {
143                return false;
144            }
145
146            let accessibility_id = node_ref.get_accessibility_id();
147            let layout_node = layout.get(node_ref.id());
148
149            // Layout nodes might not exist yet when the app is lauched
150            if let Some((accessibility_id, layout_node)) = accessibility_id.zip(layout_node) {
151                let node_accessibility_state = node_ref.get::<AccessibilityNodeState>().unwrap();
152                let accessibility_node =
153                    Self::create_node(&node_ref, layout_node, &node_accessibility_state);
154                nodes.push((accessibility_id, accessibility_node));
155            }
156
157            if let Some(tag) = node_ref.node_type().tag() {
158                if *tag == TagName::Paragraph || *tag == TagName::Label {
159                    return false;
160                }
161            }
162
163            true
164        });
165
166        #[cfg(debug_assertions)]
167        tracing::info!(
168            "Initialized the Accessibility Tree with {} nodes.",
169            nodes.len()
170        );
171
172        TreeUpdate {
173            nodes,
174            tree: Some(Tree::new(ACCESSIBILITY_ROOT_ID)),
175            focus: ACCESSIBILITY_ROOT_ID,
176        }
177    }
178
179    /// Process any pending Accessibility Tree update
180    pub fn process_updates(
181        &mut self,
182        rdom: &DioxusDOM,
183        layout: &Torin<NodeId>,
184        dirty_nodes: &mut AccessibilityDirtyNodes,
185    ) -> (TreeUpdate, NodeId) {
186        let requested_focus = dirty_nodes.requested_focus.take();
187        let removed_ids = dirty_nodes.removed.drain().collect::<FxHashMap<_, _>>();
188        let mut added_or_updated_ids = dirty_nodes
189            .added_or_updated
190            .drain()
191            .collect::<FxHashSet<_>>();
192
193        #[cfg(debug_assertions)]
194        if !removed_ids.is_empty() || !added_or_updated_ids.is_empty() {
195            tracing::info!(
196                "Updating the Accessibility Tree with {} removals and {} additions/modifications",
197                removed_ids.len(),
198                added_or_updated_ids.len()
199            );
200        }
201
202        // Remove all the removed nodes from the update list
203        for (node_id, _) in removed_ids.iter() {
204            added_or_updated_ids.remove(node_id);
205            self.map.retain(|_, id| id != node_id);
206        }
207
208        // Mark the parent of the removed nodes as updated
209        for (_, parent_id) in removed_ids.iter() {
210            if !removed_ids.contains_key(parent_id) {
211                added_or_updated_ids.insert(*parent_id);
212            }
213        }
214
215        // Mark the ancestors as modified
216        for node_id in added_or_updated_ids.clone() {
217            let node_ref = rdom.get(node_id).unwrap();
218            let node_ref_parent = node_ref.parent_id().unwrap_or(rdom.root_id());
219            added_or_updated_ids.insert(node_ref_parent);
220            self.map
221                .insert(node_ref.get_accessibility_id().unwrap(), node_id);
222        }
223
224        // Create the updated nodes
225        let mut nodes = Vec::new();
226        for node_id in added_or_updated_ids {
227            let node_ref = rdom.get(node_id).unwrap();
228            let node_accessibility_state = node_ref.get::<AccessibilityNodeState>();
229            let layout_node = layout.get(node_id);
230
231            if let Some((node_accessibility_state, layout_node)) =
232                node_accessibility_state.as_ref().zip(layout_node)
233            {
234                let accessibility_node =
235                    Self::create_node(&node_ref, layout_node, node_accessibility_state);
236                let accessibility_id = node_ref.get_accessibility_id().unwrap();
237
238                nodes.push((accessibility_id, accessibility_node));
239            }
240        }
241
242        // Focus the requested node id if there is one
243        if let Some(requested_focus) = requested_focus {
244            self.focus_node_with_strategy(requested_focus, rdom);
245        }
246
247        // Fallback the focused id to the root if the focused node no longer exists
248        if !self.map.contains_key(&self.focused_id) {
249            self.focused_id = ACCESSIBILITY_ROOT_ID;
250        }
251
252        let node_id = self.map.get(&self.focused_id).cloned().unwrap();
253
254        (
255            TreeUpdate {
256                nodes,
257                tree: Some(Tree::new(ACCESSIBILITY_ROOT_ID)),
258                focus: self.focused_id,
259            },
260            node_id,
261        )
262    }
263
264    /// Focus a Node given the strategy.
265    pub fn focus_node_with_strategy(
266        &mut self,
267        stragegy: AccessibilityFocusStrategy,
268        rdom: &DioxusDOM,
269    ) {
270        if let AccessibilityFocusStrategy::Node(id) = stragegy {
271            self.focused_id = id;
272            return;
273        }
274
275        let mut nodes = Vec::new();
276
277        rdom.traverse_depth_first_advanced(|node_ref| {
278            if !node_ref.node_type().is_element() {
279                return false;
280            }
281
282            let accessibility_id = node_ref.get_accessibility_id();
283
284            if let Some(accessibility_id) = accessibility_id {
285                let accessibility_state = node_ref.get::<AccessibilityNodeState>().unwrap();
286                if accessibility_state.a11y_focusable.is_enabled() {
287                    nodes.push(accessibility_id)
288                }
289            }
290
291            if let Some(tag) = node_ref.node_type().tag() {
292                if *tag == TagName::Paragraph || *tag == TagName::Label {
293                    return false;
294                }
295            }
296
297            true
298        });
299
300        let node_index = nodes
301            .iter()
302            .position(|accessibility_id| *accessibility_id == self.focused_id);
303
304        let target_node = if stragegy == AccessibilityFocusStrategy::Forward {
305            // Find the next Node
306            if let Some(node_index) = node_index {
307                if node_index == nodes.len() - 1 {
308                    nodes.first()
309                } else {
310                    nodes.get(node_index + 1)
311                }
312            } else {
313                nodes.first()
314            }
315        } else {
316            // Find the previous Node
317            if let Some(node_index) = node_index {
318                if node_index == 0 {
319                    nodes.last()
320                } else {
321                    nodes.get(node_index - 1)
322                }
323            } else {
324                nodes.last()
325            }
326        };
327
328        self.focused_id = target_node.copied().unwrap_or(ACCESSIBILITY_ROOT_ID);
329
330        #[cfg(debug_assertions)]
331        tracing::info!("Focused {:?} node.", self.focused_id);
332    }
333
334    /// Create an accessibility node
335    pub fn create_node(
336        node_ref: &DioxusNode,
337        layout_node: &LayoutNode,
338        node_accessibility: &AccessibilityNodeState,
339    ) -> Node {
340        let font_style_state = &*node_ref.get::<FontStyleState>().unwrap();
341        let style_state = &*node_ref.get::<StyleState>().unwrap();
342        let transform_state = &*node_ref.get::<TransformState>().unwrap();
343        let node_type = node_ref.node_type();
344
345        let mut builder = match node_type.tag() {
346            // Make the root accessibility node.
347            Some(&TagName::Root) => Node::new(Role::Window),
348
349            // All other node types will either don't have a builder (but don't support
350            // accessibility attributes like with `text`) or have their builder made for
351            // them already.
352            Some(_) => node_accessibility.builder.clone().unwrap(),
353
354            // Tag-less nodes can't have accessibility state
355            None => unreachable!(),
356        };
357
358        // Set children
359        let children = node_ref.get_accessibility_children();
360        builder.set_children(children);
361
362        // Set the area
363        let area = layout_node.area.to_f64();
364        builder.set_bounds(Rect {
365            x0: area.min_x(),
366            x1: area.max_x(),
367            y0: area.min_y(),
368            y1: area.max_y(),
369        });
370
371        if let NodeType::Element(node) = &*node_type {
372            if matches!(node.tag, TagName::Label | TagName::Paragraph) && builder.value().is_none()
373            {
374                if let Some(inner_text) = node_ref.get_inner_texts() {
375                    builder.set_value(inner_text);
376                }
377            }
378        }
379
380        // Set focusable action
381        // This will cause assistive technology to offer the user an option
382        // to focus the current element if it supports it.
383        if node_accessibility.a11y_focusable.is_enabled() {
384            builder.add_action(Action::Focus);
385        }
386
387        // Rotation transform
388        if let Some((_, rotation)) = transform_state
389            .rotations
390            .iter()
391            .find(|(id, _)| id == &node_ref.id())
392        {
393            let rotation = rotation.to_radians() as f64;
394            let (s, c) = rotation.sin_cos();
395            builder.set_transform(Affine::new([c, s, -s, c, 0.0, 0.0]));
396        }
397
398        // Clipping overflow
399        if style_state.overflow == OverflowMode::Clip {
400            builder.set_clips_children();
401        }
402
403        // Foreground/Background color
404        builder.set_foreground_color(skia_color_to_rgba_u32(font_style_state.color));
405        if let Fill::Color(color) = style_state.background {
406            builder.set_background_color(skia_color_to_rgba_u32(color));
407        }
408
409        // If the node is a block-level element in the layout, indicate that it will cause a linebreak.
410        if !node_type.is_text() {
411            if let NodeType::Element(node) = &*node_type {
412                // This should be impossible currently but i'm checking for it just in case.
413                // In the future, inline text spans should have their own own accessibility node,
414                // but that's not a concern yet.
415                if node.tag != TagName::Text {
416                    builder.set_is_line_breaking_object();
417                }
418            }
419        }
420
421        // Font size
422        builder.set_font_size(font_style_state.font_size as _);
423
424        // If the font family has changed since the parent node, then we inform accesskit of this change.
425        if let Some(parent_node) = node_ref.parent() {
426            if parent_node.get::<FontStyleState>().unwrap().font_family
427                != font_style_state.font_family
428            {
429                builder.set_font_family(font_style_state.font_family.join(", "));
430            }
431        } else {
432            // Element has no parent elements, so we set the initial font style.
433            builder.set_font_family(font_style_state.font_family.join(", "));
434        }
435
436        // Set bold flag for weights above 700
437        if font_style_state.font_weight > 700.into() {
438            builder.set_bold();
439        }
440
441        // Text alignment
442        builder.set_text_align(match font_style_state.text_align {
443            TextAlign::Center => accesskit::TextAlign::Center,
444            TextAlign::Justify => accesskit::TextAlign::Justify,
445            // TODO: change representation of `Start` and `End` once RTL text/writing modes are supported.
446            TextAlign::Left | TextAlign::Start => accesskit::TextAlign::Left,
447            TextAlign::Right | TextAlign::End => accesskit::TextAlign::Right,
448        });
449
450        // TODO: Adjust this once text direction support other than RTL is properly added
451        builder.set_text_direction(TextDirection::LeftToRight);
452
453        // Set italic property for italic/oblique font slants
454        match font_style_state.font_slant {
455            Slant::Italic | Slant::Oblique => builder.set_italic(),
456            _ => {}
457        }
458
459        // Text decoration
460        if font_style_state
461            .decoration
462            .ty
463            .contains(TextDecoration::LINE_THROUGH)
464        {
465            builder.set_strikethrough(skia_decoration_style_to_accesskit(
466                font_style_state.decoration.style,
467            ));
468        }
469        if font_style_state
470            .decoration
471            .ty
472            .contains(TextDecoration::UNDERLINE)
473        {
474            builder.set_underline(skia_decoration_style_to_accesskit(
475                font_style_state.decoration.style,
476            ));
477        }
478        if font_style_state
479            .decoration
480            .ty
481            .contains(TextDecoration::OVERLINE)
482        {
483            builder.set_overline(skia_decoration_style_to_accesskit(
484                font_style_state.decoration.style,
485            ));
486        }
487
488        builder
489    }
490}
491
492fn skia_decoration_style_to_accesskit(style: TextDecorationStyle) -> accesskit::TextDecoration {
493    match style {
494        TextDecorationStyle::Solid => accesskit::TextDecoration::Solid,
495        TextDecorationStyle::Dotted => accesskit::TextDecoration::Dotted,
496        TextDecorationStyle::Dashed => accesskit::TextDecoration::Dashed,
497        TextDecorationStyle::Double => accesskit::TextDecoration::Double,
498        TextDecorationStyle::Wavy => accesskit::TextDecoration::Wavy,
499    }
500}
501
502fn skia_color_to_rgba_u32(color: Color) -> u32 {
503    ((color.a() as u32) << 24)
504        | ((color.b() as u32) << 16)
505        | (((color.g() as u32) << 8) + (color.r() as u32))
506}