floating_ui_core/middleware/
auto_placement.rs

1use floating_ui_utils::{
2    ALL_PLACEMENTS, Alignment, Placement, get_alignment, get_alignment_sides,
3    get_opposite_alignment_placement, get_side,
4};
5use serde::{Deserialize, Serialize};
6
7use crate::{
8    detect_overflow::{DetectOverflowOptions, detect_overflow},
9    types::{
10        Derivable, DerivableFn, Middleware, MiddlewareReturn, MiddlewareState,
11        MiddlewareWithOptions, Reset, ResetValue,
12    },
13};
14
15fn get_placement_list(
16    alignment: Option<Alignment>,
17    auto_alignment: bool,
18    allowed_placements: Vec<Placement>,
19) -> Vec<Placement> {
20    let allowed_placements_sorted_by_alignment: Vec<Placement> = match alignment {
21        Some(alignment) => {
22            let mut list = vec![];
23
24            list.append(
25                &mut allowed_placements
26                    .clone()
27                    .into_iter()
28                    .filter(|placement| get_alignment(*placement) == Some(alignment))
29                    .collect(),
30            );
31
32            list.append(
33                &mut allowed_placements
34                    .clone()
35                    .into_iter()
36                    .filter(|placement| get_alignment(*placement) != Some(alignment))
37                    .collect(),
38            );
39
40            list
41        }
42        None => allowed_placements
43            .into_iter()
44            .filter(|placement| get_alignment(*placement).is_none())
45            .collect(),
46    };
47
48    allowed_placements_sorted_by_alignment
49        .into_iter()
50        .filter(|placement| match alignment {
51            Some(alignment) => {
52                get_alignment(*placement) == Some(alignment)
53                    || (if auto_alignment {
54                        get_opposite_alignment_placement(*placement) != *placement
55                    } else {
56                        false
57                    })
58            }
59            None => true,
60        })
61        .collect()
62}
63
64/// Name of the [`AutoPlacement`] middleware.
65pub const AUTO_PLACEMENT_NAME: &str = "autoPlacement";
66
67/// Options for [`AutoPlacement`] middleware.
68#[derive(Clone, Debug, PartialEq)]
69pub struct AutoPlacementOptions<Element: Clone> {
70    /// Options for [`detect_overflow`].
71    ///
72    /// Defaults to [`DetectOverflowOptions::default`].
73    pub detect_overflow: Option<DetectOverflowOptions<Element>>,
74
75    /// The axis that runs along the alignment of the floating element. Determines whether to check for most space along this axis.
76    ///
77    /// Defaults to `false`.
78    pub cross_axis: Option<bool>,
79
80    /// Choose placements with a particular alignment.
81    ///
82    /// Defaults to [`Option::None`].
83    pub alignment: Option<Alignment>,
84
85    /// Whether to choose placements with the opposite alignment if the preferred alignment does not fit.
86    ///
87    /// Defaults to `true`.
88    pub auto_alignment: Option<bool>,
89
90    /// Which placements are allowed to be chosen. Placements must be within the [`alignment`][`Self::alignment`] option if explicitly set.
91    ///
92    /// Defaults to all possible placements.
93    pub allowed_placements: Option<Vec<Placement>>,
94}
95
96impl<Element: Clone> AutoPlacementOptions<Element> {
97    /// Set `detect_overflow` option.
98    pub fn detect_overflow(mut self, value: DetectOverflowOptions<Element>) -> Self {
99        self.detect_overflow = Some(value);
100        self
101    }
102
103    /// Set `cross_axis` option.
104    pub fn cross_axis(mut self, value: bool) -> Self {
105        self.cross_axis = Some(value);
106        self
107    }
108
109    /// Set `alignment` option.
110    pub fn alignment(mut self, value: Alignment) -> Self {
111        self.alignment = Some(value);
112        self
113    }
114
115    /// Set `auto_alignment` option.
116    pub fn auto_alignment(mut self, value: bool) -> Self {
117        self.auto_alignment = Some(value);
118        self
119    }
120
121    /// Set `alignment` option.
122    pub fn allowed_placements(mut self, value: Vec<Placement>) -> Self {
123        self.allowed_placements = Some(value);
124        self
125    }
126}
127
128impl<Element: Clone> Default for AutoPlacementOptions<Element> {
129    fn default() -> Self {
130        Self {
131            detect_overflow: Default::default(),
132            cross_axis: Default::default(),
133            alignment: Default::default(),
134            auto_alignment: Default::default(),
135            allowed_placements: Default::default(),
136        }
137    }
138}
139
140/// An overflow stored in [`AutoPlacementData`].
141#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
142pub struct AutoPlacementDataOverflow {
143    pub placement: Placement,
144    pub overflows: Vec<f64>,
145}
146
147/// Data stored by [`AutoPlacement`] middleware.
148#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
149pub struct AutoPlacementData {
150    pub index: usize,
151    pub overflows: Vec<AutoPlacementDataOverflow>,
152}
153
154/// Auto placement middleware.
155///
156/// Optimizes the visibility of the floating element by choosing the placement that has the most space available automatically, without needing to specify a preferred placement.
157/// Alternative to [`Flip`][`crate::middleware::Flip`].
158///
159/// See [the Rust Floating UI book](https://floating-ui.rustforweb.org/middleware/auto-placement.html) for more documentation.
160#[derive(PartialEq)]
161pub struct AutoPlacement<'a, Element: Clone + 'static, Window: Clone> {
162    options: Derivable<'a, Element, Window, AutoPlacementOptions<Element>>,
163}
164
165impl<Element: Clone + 'static, Window: Clone> Clone for AutoPlacement<'_, Element, Window> {
166    fn clone(&self) -> Self {
167        Self {
168            options: self.options.clone(),
169        }
170    }
171}
172
173impl<'a, Element: Clone + 'static, Window: Clone> AutoPlacement<'a, Element, Window> {
174    /// Constructs a new instance of this middleware.
175    pub fn new(options: AutoPlacementOptions<Element>) -> Self {
176        AutoPlacement {
177            options: options.into(),
178        }
179    }
180
181    /// Constructs a new instance of this middleware with derivable options.
182    pub fn new_derivable(
183        options: Derivable<'a, Element, Window, AutoPlacementOptions<Element>>,
184    ) -> Self {
185        AutoPlacement { options }
186    }
187
188    /// Constructs a new instance of this middleware with derivable options function.
189    pub fn new_derivable_fn(
190        options: DerivableFn<'a, Element, Window, AutoPlacementOptions<Element>>,
191    ) -> Self {
192        AutoPlacement {
193            options: options.into(),
194        }
195    }
196}
197
198impl<Element: Clone + PartialEq, Window: Clone + PartialEq> Middleware<Element, Window>
199    for AutoPlacement<'static, Element, Window>
200{
201    fn name(&self) -> &'static str {
202        AUTO_PLACEMENT_NAME
203    }
204
205    fn compute(&self, state: MiddlewareState<Element, Window>) -> MiddlewareReturn {
206        let options = self.options.evaluate(state.clone());
207
208        let MiddlewareState {
209            rects,
210            middleware_data,
211            placement,
212            platform,
213            elements,
214            ..
215        } = state;
216
217        let data: AutoPlacementData =
218            middleware_data
219                .get_as(self.name())
220                .unwrap_or(AutoPlacementData {
221                    index: 0,
222                    overflows: vec![],
223                });
224
225        let cross_axis = options.cross_axis.unwrap_or(false);
226        let alignment = options.alignment;
227        let has_allowed_placements = options.allowed_placements.is_some();
228        let allowed_placements = options
229            .allowed_placements
230            .unwrap_or(Vec::from(ALL_PLACEMENTS));
231        let auto_alignment = options.auto_alignment.unwrap_or(true);
232
233        let placements: Vec<Placement> = if alignment.is_some() || !has_allowed_placements {
234            get_placement_list(alignment, auto_alignment, allowed_placements)
235        } else {
236            allowed_placements
237        };
238
239        let overflow = detect_overflow(
240            MiddlewareState {
241                elements: elements.clone(),
242                ..state
243            },
244            options.detect_overflow.unwrap_or_default(),
245        );
246
247        let current_index = data.index;
248        let current_placement = placements.get(current_index);
249
250        if let Some(current_placement) = current_placement {
251            let current_placement = *current_placement;
252
253            let alignment_sides =
254                get_alignment_sides(current_placement, rects, platform.is_rtl(elements.floating));
255
256            // Make `compute_coords` start from the right place.
257            if placement != current_placement {
258                return MiddlewareReturn {
259                    x: None,
260                    y: None,
261                    data: None,
262                    reset: Some(Reset::Value(ResetValue {
263                        placement: Some(placements[0]),
264                        rects: None,
265                    })),
266                };
267            }
268
269            let current_overflows = vec![
270                overflow.side(get_side(current_placement)),
271                overflow.side(alignment_sides.0),
272                overflow.side(alignment_sides.1),
273            ];
274
275            let mut all_overflows = data.overflows.clone();
276            all_overflows.push(AutoPlacementDataOverflow {
277                placement,
278                overflows: current_overflows,
279            });
280
281            let next_placement = placements.get(current_index + 1);
282
283            // There are more placements to check.
284            if let Some(next_placement) = next_placement {
285                return MiddlewareReturn {
286                    x: None,
287                    y: None,
288                    data: Some(
289                        serde_json::to_value(AutoPlacementData {
290                            index: current_index + 1,
291                            overflows: all_overflows.clone(),
292                        })
293                        .expect("Data should be valid JSON."),
294                    ),
295                    reset: Some(Reset::Value(ResetValue {
296                        placement: Some(*next_placement),
297                        rects: None,
298                    })),
299                };
300            }
301
302            let mut placements_sorted_by_most_space: Vec<_> = all_overflows
303                .clone()
304                .into_iter()
305                .map(|overflow| {
306                    let alignment = get_alignment(overflow.placement);
307
308                    (
309                        overflow.placement,
310                        if alignment.is_some() && cross_axis {
311                            // Check along the main axis and main cross axis side.
312                            overflow.overflows[0..2].iter().sum()
313                        } else {
314                            // Check only the main axis.
315                            overflow.overflows[0]
316                        },
317                        overflow.overflows,
318                    )
319                })
320                .collect();
321
322            placements_sorted_by_most_space.sort_by(|a, b| a.1.total_cmp(&b.1));
323
324            let placements_that_fit_on_each_side: Vec<_> = placements_sorted_by_most_space
325                .clone()
326                .into_iter()
327                .filter(|overflow| {
328                    // Aligned placements should not check their opposite cross axis side.
329                    overflow.2[0..match get_alignment(overflow.0) {
330                        Some(_) => 2,
331                        None => 3,
332                    }]
333                        .iter()
334                        .all(|v| *v <= 0.0)
335                })
336                .collect();
337
338            let reset_placement = placements_that_fit_on_each_side
339                .first()
340                .map(|v| v.0)
341                .unwrap_or(placements_sorted_by_most_space[0].0);
342
343            if reset_placement != placement {
344                return MiddlewareReturn {
345                    x: None,
346                    y: None,
347                    data: Some(
348                        serde_json::to_value(AutoPlacementData {
349                            index: current_index + 1,
350                            overflows: all_overflows,
351                        })
352                        .expect("Data should be valid JSON."),
353                    ),
354                    reset: Some(Reset::Value(ResetValue {
355                        placement: Some(reset_placement),
356                        rects: None,
357                    })),
358                };
359            }
360        }
361
362        MiddlewareReturn {
363            x: None,
364            y: None,
365            data: None,
366            reset: None,
367        }
368    }
369}
370
371impl<Element: Clone, Window: Clone>
372    MiddlewareWithOptions<Element, Window, AutoPlacementOptions<Element>>
373    for AutoPlacement<'_, Element, Window>
374{
375    fn options(&self) -> &Derivable<'_, Element, Window, AutoPlacementOptions<Element>> {
376        &self.options
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383
384    #[test]
385    fn test_base_placement() {
386        assert_eq!(
387            get_placement_list(
388                None,
389                false,
390                vec![
391                    Placement::Top,
392                    Placement::Bottom,
393                    Placement::Left,
394                    Placement::Right,
395                    Placement::TopStart,
396                    Placement::RightEnd,
397                ]
398            ),
399            vec![
400                Placement::Top,
401                Placement::Bottom,
402                Placement::Left,
403                Placement::Right,
404            ]
405        )
406    }
407
408    #[test]
409    fn test_start_alignment_without_auto_alignment() {
410        assert_eq!(
411            get_placement_list(
412                Some(Alignment::Start),
413                false,
414                vec![
415                    Placement::Top,
416                    Placement::Bottom,
417                    Placement::Left,
418                    Placement::Right,
419                    Placement::TopStart,
420                    Placement::RightEnd,
421                    Placement::LeftStart,
422                ]
423            ),
424            vec![Placement::TopStart, Placement::LeftStart]
425        )
426    }
427
428    #[test]
429    fn test_start_alignment_with_auto_alignment() {
430        assert_eq!(
431            get_placement_list(
432                Some(Alignment::Start),
433                true,
434                vec![
435                    Placement::Top,
436                    Placement::Bottom,
437                    Placement::Left,
438                    Placement::Right,
439                    Placement::TopStart,
440                    Placement::RightEnd,
441                    Placement::LeftStart,
442                ]
443            ),
444            vec![
445                Placement::TopStart,
446                Placement::LeftStart,
447                Placement::RightEnd,
448            ]
449        )
450    }
451
452    #[test]
453    fn test_end_alignment_without_auto_alignment() {
454        assert_eq!(
455            get_placement_list(
456                Some(Alignment::End),
457                false,
458                vec![
459                    Placement::Top,
460                    Placement::Bottom,
461                    Placement::Left,
462                    Placement::Right,
463                    Placement::TopStart,
464                    Placement::RightEnd,
465                    Placement::LeftStart,
466                ]
467            ),
468            vec![Placement::RightEnd,]
469        )
470    }
471
472    #[test]
473    fn test_end_alignment_with_auto_alignment() {
474        assert_eq!(
475            get_placement_list(
476                Some(Alignment::End),
477                true,
478                vec![
479                    Placement::Top,
480                    Placement::Bottom,
481                    Placement::Left,
482                    Placement::Right,
483                    Placement::TopStart,
484                    Placement::RightEnd,
485                    Placement::LeftStart,
486                ]
487            ),
488            vec![
489                Placement::RightEnd,
490                Placement::TopStart,
491                Placement::LeftStart
492            ]
493        )
494    }
495}