floating_ui_core/middleware/
arrow.rs

1use floating_ui_utils::{
2    Axis, Coords, OwnedElementOrWindow, Padding, Side, clamp, get_alignment, get_alignment_axis,
3    get_axis_length, get_padding_object,
4};
5use serde::{Deserialize, Serialize};
6
7use crate::types::{
8    Derivable, DerivableFn, Middleware, MiddlewareReturn, MiddlewareState, MiddlewareWithOptions,
9};
10
11/// Name of the [`Arrow`] middleware.
12pub const ARROW_NAME: &str = "arrow";
13
14/// Options for [`Arrow`].
15#[derive(Clone, Debug, PartialEq)]
16pub struct ArrowOptions<Element: Clone> {
17    /// The arrow element to be positioned.
18    pub element: Element,
19
20    /// The padding between the arrow element and the floating element edges.
21    /// Useful when the floating element has rounded corners.
22    ///
23    /// Defaults to `0` on all sides.
24    pub padding: Option<Padding>,
25}
26
27impl<Element: Clone> ArrowOptions<Element> {
28    pub fn new(element: Element) -> Self {
29        ArrowOptions {
30            element,
31            padding: None,
32        }
33    }
34
35    /// Set `element` option.
36    pub fn element(mut self, value: Element) -> Self {
37        self.element = value;
38        self
39    }
40
41    /// Set `padding` option.
42    pub fn padding(mut self, value: Padding) -> Self {
43        self.padding = Some(value);
44        self
45    }
46}
47
48/// Data stored by [`Arrow`] middleware.
49#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
50pub struct ArrowData {
51    pub x: Option<f64>,
52    pub y: Option<f64>,
53    pub center_offset: f64,
54    pub alignment_offset: Option<f64>,
55}
56
57/// Arrow middleware.
58///
59/// Provides data to position an inner element of the floating element so that it appears centered to the reference element.
60///
61/// See [the Rust Floating UI book](https://floating-ui.rustforweb.org/middleware/arrow.html) for more documentation.
62#[derive(PartialEq)]
63pub struct Arrow<'a, Element: Clone + 'static, Window: Clone> {
64    options: Derivable<'a, Element, Window, ArrowOptions<Element>>,
65}
66
67impl<'a, Element: Clone + 'static, Window: Clone> Arrow<'a, Element, Window> {
68    /// Constructs a new instance of this middleware.
69    pub fn new(options: ArrowOptions<Element>) -> Self {
70        Arrow {
71            options: options.into(),
72        }
73    }
74
75    /// Constructs a new instance of this middleware with derivable options.
76    pub fn new_derivable(options: Derivable<'a, Element, Window, ArrowOptions<Element>>) -> Self {
77        Arrow { options }
78    }
79
80    /// Constructs a new instance of this middleware with derivable options function.
81    pub fn new_derivable_fn(
82        options: DerivableFn<'a, Element, Window, ArrowOptions<Element>>,
83    ) -> Self {
84        Arrow {
85            options: options.into(),
86        }
87    }
88}
89
90impl<Element: Clone + 'static, Window: Clone> Clone for Arrow<'_, Element, Window> {
91    fn clone(&self) -> Self {
92        Self {
93            options: self.options.clone(),
94        }
95    }
96}
97
98impl<Element: Clone + PartialEq, Window: Clone + PartialEq> Middleware<Element, Window>
99    for Arrow<'static, Element, Window>
100{
101    fn name(&self) -> &'static str {
102        ARROW_NAME
103    }
104
105    fn compute(&self, state: MiddlewareState<Element, Window>) -> MiddlewareReturn {
106        let options = self.options.evaluate(state.clone());
107
108        let MiddlewareState {
109            x,
110            y,
111            placement,
112            middleware_data,
113            elements,
114            rects,
115            platform,
116            ..
117        } = state;
118
119        let data: Option<ArrowData> = middleware_data.get_as(self.name());
120
121        let padding_object = get_padding_object(options.padding.unwrap_or(Padding::All(0.0)));
122        let coords = Coords { x, y };
123        let axis = get_alignment_axis(placement);
124        let length = get_axis_length(axis);
125        let arrow_dimensions = platform.get_dimensions(&options.element);
126        let min_prop = match axis {
127            Axis::X => Side::Left,
128            Axis::Y => Side::Top,
129        };
130        let max_prop = match axis {
131            Axis::X => Side::Right,
132            Axis::Y => Side::Bottom,
133        };
134
135        let start_diff = coords.axis(axis) - rects.reference.axis(axis);
136        let end_diff = rects.reference.length(length) + rects.reference.axis(axis)
137            - coords.axis(axis)
138            - rects.floating.length(length);
139
140        let arrow_offset_parent = platform.get_offset_parent(&options.element);
141        let client_size = arrow_offset_parent
142            .and_then(|arrow_offset_parent| match arrow_offset_parent {
143                OwnedElementOrWindow::Element(element) => {
144                    platform.get_client_length(&element, length)
145                }
146                OwnedElementOrWindow::Window(_) => {
147                    platform.get_client_length(elements.floating, length)
148                }
149            })
150            .unwrap_or(rects.floating.length(length));
151
152        let center_to_reference = end_diff / 2.0 - start_diff / 2.0;
153
154        // If the padding is large enough that it causes the arrow to no longer be centered, modify the padding so that it is centered.
155        let largest_possible_padding =
156            client_size / 2.0 - arrow_dimensions.length(length) / 2.0 - 1.0;
157        let min_padding = padding_object.side(min_prop).min(largest_possible_padding);
158        let max_padding = padding_object.side(max_prop).min(largest_possible_padding);
159
160        // Make sure the arrow doesn't overflow the floating element if the center point is outside the floating element's bounds.
161        let min = min_padding;
162        let max = client_size - arrow_dimensions.length(length) - max_padding;
163        let center =
164            client_size / 2.0 - arrow_dimensions.length(length) / 2.0 + center_to_reference;
165        let offset = clamp(min, center, max);
166
167        // If the reference is small enough that the arrow's padding causes it to to point to nothing for an aligned placement, adjust the offset of the floating element itself.
168        // To ensure `shift()` continues to take action, a single reset is performed when this is true.
169        let should_add_offset = data.is_none()
170            && get_alignment(placement).is_some()
171            && center != offset
172            && rects.reference.length(length) / 2.0
173                - (if center < min {
174                    min_padding
175                } else {
176                    max_padding
177                })
178                - arrow_dimensions.length(length) / 2.0
179                < 0.0;
180        let alignment_offset = if should_add_offset {
181            if center < min {
182                center - min
183            } else {
184                center - max
185            }
186        } else {
187            0.0
188        };
189
190        MiddlewareReturn {
191            x: match axis {
192                Axis::X => Some(coords.axis(axis) + alignment_offset),
193                Axis::Y => None,
194            },
195            y: match axis {
196                Axis::X => None,
197                Axis::Y => Some(coords.axis(axis) + alignment_offset),
198            },
199            data: Some(
200                serde_json::to_value(ArrowData {
201                    x: match axis {
202                        Axis::X => Some(offset),
203                        Axis::Y => None,
204                    },
205                    y: match axis {
206                        Axis::X => None,
207                        Axis::Y => Some(offset),
208                    },
209                    center_offset: center - offset - alignment_offset,
210                    alignment_offset: should_add_offset.then_some(alignment_offset),
211                })
212                .expect("Data should be valid JSON."),
213            ),
214            reset: None,
215        }
216    }
217}
218
219impl<Element: Clone, Window: Clone> MiddlewareWithOptions<Element, Window, ArrowOptions<Element>>
220    for Arrow<'_, Element, Window>
221{
222    fn options(&self) -> &Derivable<'_, Element, Window, ArrowOptions<Element>> {
223        &self.options
224    }
225}