floating_ui_core/middleware/
arrow.rs1use floating_ui_utils::{
2 Axis, Coords, OwnedElementOrWindow, Padding, Side, clamp, get_alignment, get_alignment_axis,
3 get_axis_length, get_padding_object,
4};
5use serde::{Deserialize, Serialize};
6
7use crate::types::{
8 Derivable, DerivableFn, Middleware, MiddlewareReturn, MiddlewareState, MiddlewareWithOptions,
9};
10
11pub const ARROW_NAME: &str = "arrow";
13
14#[derive(Clone, Debug, PartialEq)]
16pub struct ArrowOptions<Element: Clone> {
17 pub element: Element,
19
20 pub padding: Option<Padding>,
25}
26
27impl<Element: Clone> ArrowOptions<Element> {
28 pub fn new(element: Element) -> Self {
29 ArrowOptions {
30 element,
31 padding: None,
32 }
33 }
34
35 pub fn element(mut self, value: Element) -> Self {
37 self.element = value;
38 self
39 }
40
41 pub fn padding(mut self, value: Padding) -> Self {
43 self.padding = Some(value);
44 self
45 }
46}
47
48#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
50pub struct ArrowData {
51 pub x: Option<f64>,
52 pub y: Option<f64>,
53 pub center_offset: f64,
54 pub alignment_offset: Option<f64>,
55}
56
57#[derive(PartialEq)]
63pub struct Arrow<'a, Element: Clone + 'static, Window: Clone> {
64 options: Derivable<'a, Element, Window, ArrowOptions<Element>>,
65}
66
67impl<'a, Element: Clone + 'static, Window: Clone> Arrow<'a, Element, Window> {
68 pub fn new(options: ArrowOptions<Element>) -> Self {
70 Arrow {
71 options: options.into(),
72 }
73 }
74
75 pub fn new_derivable(options: Derivable<'a, Element, Window, ArrowOptions<Element>>) -> Self {
77 Arrow { options }
78 }
79
80 pub fn new_derivable_fn(
82 options: DerivableFn<'a, Element, Window, ArrowOptions<Element>>,
83 ) -> Self {
84 Arrow {
85 options: options.into(),
86 }
87 }
88}
89
90impl<Element: Clone + 'static, Window: Clone> Clone for Arrow<'_, Element, Window> {
91 fn clone(&self) -> Self {
92 Self {
93 options: self.options.clone(),
94 }
95 }
96}
97
98impl<Element: Clone + PartialEq, Window: Clone + PartialEq> Middleware<Element, Window>
99 for Arrow<'static, Element, Window>
100{
101 fn name(&self) -> &'static str {
102 ARROW_NAME
103 }
104
105 fn compute(&self, state: MiddlewareState<Element, Window>) -> MiddlewareReturn {
106 let options = self.options.evaluate(state.clone());
107
108 let MiddlewareState {
109 x,
110 y,
111 placement,
112 middleware_data,
113 elements,
114 rects,
115 platform,
116 ..
117 } = state;
118
119 let data: Option<ArrowData> = middleware_data.get_as(self.name());
120
121 let padding_object = get_padding_object(options.padding.unwrap_or(Padding::All(0.0)));
122 let coords = Coords { x, y };
123 let axis = get_alignment_axis(placement);
124 let length = get_axis_length(axis);
125 let arrow_dimensions = platform.get_dimensions(&options.element);
126 let min_prop = match axis {
127 Axis::X => Side::Left,
128 Axis::Y => Side::Top,
129 };
130 let max_prop = match axis {
131 Axis::X => Side::Right,
132 Axis::Y => Side::Bottom,
133 };
134
135 let start_diff = coords.axis(axis) - rects.reference.axis(axis);
136 let end_diff = rects.reference.length(length) + rects.reference.axis(axis)
137 - coords.axis(axis)
138 - rects.floating.length(length);
139
140 let arrow_offset_parent = platform.get_offset_parent(&options.element);
141 let client_size = arrow_offset_parent
142 .and_then(|arrow_offset_parent| match arrow_offset_parent {
143 OwnedElementOrWindow::Element(element) => {
144 platform.get_client_length(&element, length)
145 }
146 OwnedElementOrWindow::Window(_) => {
147 platform.get_client_length(elements.floating, length)
148 }
149 })
150 .unwrap_or(rects.floating.length(length));
151
152 let center_to_reference = end_diff / 2.0 - start_diff / 2.0;
153
154 let largest_possible_padding =
156 client_size / 2.0 - arrow_dimensions.length(length) / 2.0 - 1.0;
157 let min_padding = padding_object.side(min_prop).min(largest_possible_padding);
158 let max_padding = padding_object.side(max_prop).min(largest_possible_padding);
159
160 let min = min_padding;
162 let max = client_size - arrow_dimensions.length(length) - max_padding;
163 let center =
164 client_size / 2.0 - arrow_dimensions.length(length) / 2.0 + center_to_reference;
165 let offset = clamp(min, center, max);
166
167 let should_add_offset = data.is_none()
170 && get_alignment(placement).is_some()
171 && center != offset
172 && rects.reference.length(length) / 2.0
173 - (if center < min {
174 min_padding
175 } else {
176 max_padding
177 })
178 - arrow_dimensions.length(length) / 2.0
179 < 0.0;
180 let alignment_offset = if should_add_offset {
181 if center < min {
182 center - min
183 } else {
184 center - max
185 }
186 } else {
187 0.0
188 };
189
190 MiddlewareReturn {
191 x: match axis {
192 Axis::X => Some(coords.axis(axis) + alignment_offset),
193 Axis::Y => None,
194 },
195 y: match axis {
196 Axis::X => None,
197 Axis::Y => Some(coords.axis(axis) + alignment_offset),
198 },
199 data: Some(
200 serde_json::to_value(ArrowData {
201 x: match axis {
202 Axis::X => Some(offset),
203 Axis::Y => None,
204 },
205 y: match axis {
206 Axis::X => None,
207 Axis::Y => Some(offset),
208 },
209 center_offset: center - offset - alignment_offset,
210 alignment_offset: should_add_offset.then_some(alignment_offset),
211 })
212 .expect("Data should be valid JSON."),
213 ),
214 reset: None,
215 }
216 }
217}
218
219impl<Element: Clone, Window: Clone> MiddlewareWithOptions<Element, Window, ArrowOptions<Element>>
220 for Arrow<'_, Element, Window>
221{
222 fn options(&self) -> &Derivable<'_, Element, Window, ArrowOptions<Element>> {
223 &self.options
224 }
225}