floating_ui_core/middleware/
flip.rs1use floating_ui_utils::{
2    Alignment, Axis, Placement, get_alignment, get_alignment_sides, get_expanded_placements,
3    get_opposite_axis_placements, get_opposite_placement, get_side, get_side_axis,
4};
5use serde::{Deserialize, Serialize};
6
7use crate::{
8    detect_overflow::{DetectOverflowOptions, detect_overflow},
9    middleware::arrow::{ARROW_NAME, ArrowData},
10    types::{
11        Derivable, DerivableFn, Middleware, MiddlewareReturn, MiddlewareState,
12        MiddlewareWithOptions, Reset, ResetValue,
13    },
14};
15
16pub const FLIP_NAME: &str = "flip";
18
19#[derive(Copy, Clone, Debug, Default, PartialEq)]
21pub enum FallbackStrategy {
22    #[default]
23    BestFit,
24    InitialPlacement,
25}
26
27#[derive(Clone, Debug, PartialEq)]
29pub struct FlipOptions<Element: Clone> {
30    pub detect_overflow: Option<DetectOverflowOptions<Element>>,
34
35    pub main_axis: Option<bool>,
39
40    pub cross_axis: Option<bool>,
44
45    pub fallback_placements: Option<Vec<Placement>>,
49
50    pub fallback_strategy: Option<FallbackStrategy>,
54
55    pub fallback_axis_side_direction: Option<Alignment>,
59
60    pub flip_alignment: Option<bool>,
64}
65
66impl<Element: Clone> FlipOptions<Element> {
67    pub fn detect_overflow(mut self, value: DetectOverflowOptions<Element>) -> Self {
69        self.detect_overflow = Some(value);
70        self
71    }
72
73    pub fn main_axis(mut self, value: bool) -> Self {
75        self.main_axis = Some(value);
76        self
77    }
78
79    pub fn cross_axis(mut self, value: bool) -> Self {
81        self.cross_axis = Some(value);
82        self
83    }
84
85    pub fn fallback_placements(mut self, value: Vec<Placement>) -> Self {
87        self.fallback_placements = Some(value);
88        self
89    }
90
91    pub fn fallback_strategy(mut self, value: FallbackStrategy) -> Self {
93        self.fallback_strategy = Some(value);
94        self
95    }
96
97    pub fn fallback_axis_side_direction(mut self, value: Alignment) -> Self {
99        self.fallback_axis_side_direction = Some(value);
100        self
101    }
102
103    pub fn flip_alignment(mut self, value: bool) -> Self {
105        self.flip_alignment = Some(value);
106        self
107    }
108}
109
110impl<Element: Clone> Default for FlipOptions<Element> {
111    fn default() -> Self {
112        Self {
113            detect_overflow: Default::default(),
114            main_axis: Default::default(),
115            cross_axis: Default::default(),
116            fallback_placements: Default::default(),
117            fallback_strategy: Default::default(),
118            fallback_axis_side_direction: Default::default(),
119            flip_alignment: Default::default(),
120        }
121    }
122}
123
124#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
126pub struct FlipDataOverflow {
127    pub placement: Placement,
128    pub overflows: Vec<f64>,
129}
130
131#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
133pub struct FlipData {
134    pub index: usize,
135    pub overflows: Vec<FlipDataOverflow>,
136}
137
138#[derive(PartialEq)]
145pub struct Flip<'a, Element: Clone + 'static, Window: Clone> {
146    options: Derivable<'a, Element, Window, FlipOptions<Element>>,
147}
148
149impl<'a, Element: Clone + 'static, Window: Clone> Flip<'a, Element, Window> {
150    pub fn new(options: FlipOptions<Element>) -> Self {
152        Flip {
153            options: options.into(),
154        }
155    }
156
157    pub fn new_derivable(options: Derivable<'a, Element, Window, FlipOptions<Element>>) -> Self {
159        Flip { options }
160    }
161
162    pub fn new_derivable_fn(
164        options: DerivableFn<'a, Element, Window, FlipOptions<Element>>,
165    ) -> Self {
166        Flip {
167            options: options.into(),
168        }
169    }
170}
171
172impl<Element: Clone + 'static, Window: Clone> Clone for Flip<'_, Element, Window> {
173    fn clone(&self) -> Self {
174        Self {
175            options: self.options.clone(),
176        }
177    }
178}
179
180impl<Element: Clone + PartialEq, Window: Clone + PartialEq> Middleware<Element, Window>
181    for Flip<'static, Element, Window>
182{
183    fn name(&self) -> &'static str {
184        FLIP_NAME
185    }
186
187    fn compute(&self, state: MiddlewareState<Element, Window>) -> MiddlewareReturn {
188        let options = self.options.evaluate(state.clone());
189
190        let MiddlewareState {
191            placement,
192            initial_placement,
193            middleware_data,
194            elements,
195            rects,
196            platform,
197            ..
198        } = state;
199
200        let data: FlipData = middleware_data.get_as(self.name()).unwrap_or(FlipData {
201            index: 0,
202            overflows: vec![],
203        });
204
205        let check_main_axis = options.main_axis.unwrap_or(true);
206        let check_cross_axis = options.cross_axis.unwrap_or(true);
207        let specified_fallback_placements = options.fallback_placements.clone();
208        let fallback_strategy = options.fallback_strategy.unwrap_or_default();
209        let fallback_axis_side_direction = options.fallback_axis_side_direction;
210        let flip_alignment = options.flip_alignment.unwrap_or(true);
211
212        let arrow_data: Option<ArrowData> = middleware_data.get_as(ARROW_NAME);
215        if arrow_data
216            .and_then(|arrow_data| arrow_data.alignment_offset)
217            .is_some()
218        {
219            return MiddlewareReturn {
220                x: None,
221                y: None,
222                data: None,
223                reset: None,
224            };
225        }
226
227        let side = get_side(placement);
228        let initial_side_axis = get_side_axis(initial_placement);
229        let is_base_placement = get_alignment(initial_placement).is_none();
230        let rtl = platform.is_rtl(elements.floating);
231
232        let has_specified_fallback_placements = specified_fallback_placements.is_some();
233        let mut placements =
234            specified_fallback_placements.unwrap_or(if is_base_placement || !flip_alignment {
235                vec![get_opposite_placement(initial_placement)]
236            } else {
237                get_expanded_placements(initial_placement)
238            });
239
240        let has_fallback_axis_side_direction = fallback_axis_side_direction.is_some();
241
242        if !has_specified_fallback_placements && has_fallback_axis_side_direction {
243            placements.append(&mut get_opposite_axis_placements(
244                initial_placement,
245                flip_alignment,
246                fallback_axis_side_direction,
247                rtl,
248            ));
249        }
250
251        placements.insert(0, initial_placement);
252
253        let overflow = detect_overflow(
254            MiddlewareState {
255                elements: elements.clone(),
256                ..state
257            },
258            options.detect_overflow.unwrap_or_default(),
259        );
260
261        let mut overflows: Vec<f64> = Vec::new();
262        let mut overflows_data = data.overflows;
263
264        if check_main_axis {
265            overflows.push(overflow.side(side));
266        }
267        if check_cross_axis {
268            let sides = get_alignment_sides(placement, rects, rtl);
269            overflows.push(overflow.side(sides.0));
270            overflows.push(overflow.side(sides.1));
271        }
272
273        overflows_data.push(FlipDataOverflow {
274            placement,
275            overflows: overflows.clone(),
276        });
277
278        if !overflows.into_iter().all(|side| side <= 0.0) {
280            let next_index = data.index + 1;
281            let next_placement = placements.get(next_index);
282
283            if let Some(next_placement) = next_placement {
284                return MiddlewareReturn {
286                    x: None,
287                    y: None,
288                    data: Some(
289                        serde_json::to_value(FlipData {
290                            index: next_index,
291                            overflows: overflows_data,
292                        })
293                        .expect("Data should be valid JSON."),
294                    ),
295                    reset: Some(Reset::Value(ResetValue {
296                        placement: Some(*next_placement),
297                        rects: None,
298                    })),
299                };
300            }
301
302            let mut reset_placement: Vec<&FlipDataOverflow> = overflows_data
304                .iter()
305                .filter(|overflow| overflow.overflows[0] <= 0.0)
306                .collect();
307            reset_placement.sort_by(|a, b| a.overflows[1].total_cmp(&b.overflows[1]));
308
309            let mut reset_placement = reset_placement.first().map(|overflow| overflow.placement);
310
311            if reset_placement.is_none() {
313                match fallback_strategy {
314                    FallbackStrategy::BestFit => {
315                        let mut placement: Vec<(Placement, f64)> = overflows_data
316                            .into_iter()
317                            .filter(|overflow| {
318                                if has_fallback_axis_side_direction {
319                                    let current_side_axis = get_side_axis(overflow.placement);
320
321                                    current_side_axis == initial_side_axis
323                                        || current_side_axis == Axis::Y
324                                } else {
325                                    true
326                                }
327                            })
328                            .map(|overflow| {
329                                (
330                                    overflow.placement,
331                                    overflow
332                                        .overflows
333                                        .into_iter()
334                                        .filter(|overflow| *overflow > 0.0)
335                                        .sum::<f64>(),
336                                )
337                            })
338                            .collect();
339                        placement.sort_by(|a, b| a.1.total_cmp(&b.1));
340
341                        let placement = placement.first().map(|v| v.0);
342                        if placement.is_some() {
343                            reset_placement = placement;
344                        }
345                    }
346                    FallbackStrategy::InitialPlacement => {
347                        reset_placement = Some(initial_placement);
348                    }
349                }
350            }
351
352            if placement != reset_placement.expect("Reset placement is not none.") {
353                return MiddlewareReturn {
354                    x: None,
355                    y: None,
356                    data: None,
357                    reset: Some(Reset::Value(ResetValue {
358                        placement: reset_placement,
359                        rects: None,
360                    })),
361                };
362            }
363        }
364
365        MiddlewareReturn {
366            x: None,
367            y: None,
368            data: None,
369            reset: None,
370        }
371    }
372}
373
374impl<Element: Clone, Window: Clone> MiddlewareWithOptions<Element, Window, FlipOptions<Element>>
375    for Flip<'_, Element, Window>
376{
377    fn options(&self) -> &Derivable<Element, Window, FlipOptions<Element>> {
378        &self.options
379    }
380}