floating_ui_core/middleware/
shift.rs

1use std::fmt::Debug;
2
3use dyn_derive::dyn_trait;
4use floating_ui_utils::{Axis, Coords, Side, clamp, get_opposite_axis, get_side_axis};
5use serde::{Deserialize, Serialize};
6
7use crate::{
8    detect_overflow::{DetectOverflowOptions, detect_overflow},
9    middleware::{OFFSET_NAME, OffsetData},
10    types::{
11        Derivable, DerivableFn, Middleware, MiddlewareReturn, MiddlewareState,
12        MiddlewareWithOptions,
13    },
14};
15
16/// Name of the [`Shift`] middleware.
17pub const SHIFT_NAME: &str = "shift";
18
19/// Limiter used by [`Shift`] middleware. Limits the shifting done in order to prevent detachment.
20#[dyn_trait]
21pub trait Limiter<Element: Clone + 'static, Window: Clone + 'static>: Clone + PartialEq {
22    fn compute(&self, state: MiddlewareState<Element, Window>) -> Coords;
23}
24
25/// Options for [`Shift`] middleware.
26#[derive(Clone, PartialEq)]
27pub struct ShiftOptions<Element: Clone + 'static, Window: Clone + 'static> {
28    /// Options for [`detect_overflow`].
29    ///
30    /// Defaults to [`DetectOverflowOptions::default`].
31    pub detect_overflow: Option<DetectOverflowOptions<Element>>,
32
33    /// The axis that runs along the alignment of the floating element. Determines whether overflow along this axis is checked to perform shifting.
34    ///
35    /// Defaults to `true`.
36    pub main_axis: Option<bool>,
37
38    /// The axis that runs along the side of the floating element. Determines whether overflow along this axis is checked to perform shifting.
39    ///
40    /// Defaults to `false`.
41    pub cross_axis: Option<bool>,
42
43    /// Accepts a limiter that limits the shifting done in order to prevent detachment.
44    ///
45    /// Defaults to [`DefaultLimiter`].
46    pub limiter: Option<Box<dyn Limiter<Element, Window>>>,
47}
48
49impl<Element: Clone, Window: Clone> ShiftOptions<Element, Window> {
50    /// Set `detect_overflow` option.
51    pub fn detect_overflow(mut self, value: DetectOverflowOptions<Element>) -> Self {
52        self.detect_overflow = Some(value);
53        self
54    }
55
56    /// Set `main_axis` option.
57    pub fn main_axis(mut self, value: bool) -> Self {
58        self.main_axis = Some(value);
59        self
60    }
61
62    /// Set `cross_axis` option.
63    pub fn cross_axis(mut self, value: bool) -> Self {
64        self.cross_axis = Some(value);
65        self
66    }
67
68    /// Set `limiter` option.
69    pub fn limiter(mut self, value: Box<dyn Limiter<Element, Window>>) -> Self {
70        self.limiter = Some(value);
71        self
72    }
73}
74
75impl<Element: Clone, Window: Clone> Default for ShiftOptions<Element, Window> {
76    fn default() -> Self {
77        Self {
78            detect_overflow: Default::default(),
79            main_axis: Default::default(),
80            cross_axis: Default::default(),
81            limiter: Default::default(),
82        }
83    }
84}
85
86/// Enabled sides stored in [`ShiftData`].
87#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
88pub struct ShiftDataEnabled {
89    pub x: bool,
90    pub y: bool,
91}
92
93impl ShiftDataEnabled {
94    pub fn set_axis(mut self, axis: Axis, enabled: bool) -> Self {
95        match axis {
96            Axis::X => {
97                self.x = enabled;
98            }
99            Axis::Y => {
100                self.y = enabled;
101            }
102        }
103        self
104    }
105}
106
107/// Data stored by [`Shift`] middleware.
108#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
109pub struct ShiftData {
110    pub x: f64,
111    pub y: f64,
112    pub enabled: ShiftDataEnabled,
113}
114
115/// Shift middleware.
116///
117/// Optimizes the visibility of the floating element by shifting it in order to keep it in view when it will overflow the clipping boundary.
118///
119/// See [the Rust Floating UI book](https://floating-ui.rustforweb.org/middleware/shift.html) for more documentation.
120#[derive(PartialEq)]
121pub struct Shift<'a, Element: Clone + 'static, Window: Clone + 'static> {
122    options: Derivable<'a, Element, Window, ShiftOptions<Element, Window>>,
123}
124
125impl<'a, Element: Clone, Window: Clone> Shift<'a, Element, Window> {
126    /// Constructs a new instance of this middleware.
127    pub fn new(options: ShiftOptions<Element, Window>) -> Self {
128        Shift {
129            options: options.into(),
130        }
131    }
132
133    /// Constructs a new instance of this middleware with derivable options.
134    pub fn new_derivable(
135        options: Derivable<'a, Element, Window, ShiftOptions<Element, Window>>,
136    ) -> Self {
137        Shift { options }
138    }
139
140    /// Constructs a new instance of this middleware with derivable options function.
141    pub fn new_derivable_fn(
142        options: DerivableFn<'a, Element, Window, ShiftOptions<Element, Window>>,
143    ) -> Self {
144        Shift {
145            options: options.into(),
146        }
147    }
148}
149
150impl<Element: Clone, Window: Clone> Clone for Shift<'_, Element, Window> {
151    fn clone(&self) -> Self {
152        Self {
153            options: self.options.clone(),
154        }
155    }
156}
157
158impl<Element: Clone + PartialEq + 'static, Window: Clone + PartialEq + 'static>
159    Middleware<Element, Window> for Shift<'static, Element, Window>
160{
161    fn name(&self) -> &'static str {
162        SHIFT_NAME
163    }
164
165    fn compute(&self, state: MiddlewareState<Element, Window>) -> MiddlewareReturn {
166        let options = self.options.evaluate(state.clone());
167
168        let MiddlewareState {
169            x, y, placement, ..
170        } = state;
171
172        let check_main_axis = options.main_axis.unwrap_or(true);
173        let check_cross_axis = options.cross_axis.unwrap_or(false);
174        #[allow(clippy::unwrap_or_default)]
175        let limiter = options.limiter.unwrap_or(Box::<DefaultLimiter>::default());
176
177        let coords = Coords { x, y };
178        let overflow = detect_overflow(
179            MiddlewareState {
180                elements: state.elements.clone(),
181                ..state
182            },
183            options.detect_overflow.unwrap_or_default(),
184        );
185        let cross_axis = get_side_axis(placement);
186        let main_axis = get_opposite_axis(cross_axis);
187
188        let mut main_axis_coord = coords.axis(main_axis);
189        let mut cross_axis_coord = coords.axis(cross_axis);
190
191        if check_main_axis {
192            let min_side = match main_axis {
193                Axis::X => Side::Left,
194                Axis::Y => Side::Top,
195            };
196            let max_side = match main_axis {
197                Axis::X => Side::Right,
198                Axis::Y => Side::Bottom,
199            };
200            let min = main_axis_coord + overflow.side(min_side);
201            let max = main_axis_coord - overflow.side(max_side);
202
203            main_axis_coord = clamp(min, main_axis_coord, max);
204        }
205
206        if check_cross_axis {
207            let min_side = match cross_axis {
208                Axis::X => Side::Left,
209                Axis::Y => Side::Top,
210            };
211            let max_side = match cross_axis {
212                Axis::X => Side::Right,
213                Axis::Y => Side::Bottom,
214            };
215            let min = cross_axis_coord + overflow.side(min_side);
216            let max = cross_axis_coord - overflow.side(max_side);
217
218            cross_axis_coord = clamp(min, cross_axis_coord, max);
219        }
220
221        let limited_coords = limiter.compute(MiddlewareState {
222            x: match main_axis {
223                Axis::X => main_axis_coord,
224                Axis::Y => cross_axis_coord,
225            },
226            y: match main_axis {
227                Axis::X => cross_axis_coord,
228                Axis::Y => main_axis_coord,
229            },
230            ..state
231        });
232
233        MiddlewareReturn {
234            x: Some(limited_coords.x),
235            y: Some(limited_coords.y),
236            data: Some(
237                serde_json::to_value(ShiftData {
238                    x: limited_coords.x - x,
239                    y: limited_coords.y - y,
240                    enabled: ShiftDataEnabled::default()
241                        .set_axis(main_axis, check_main_axis)
242                        .set_axis(cross_axis, check_cross_axis),
243                })
244                .expect("Data should be valid JSON."),
245            ),
246            reset: None,
247        }
248    }
249}
250
251impl<Element: Clone, Window: Clone>
252    MiddlewareWithOptions<Element, Window, ShiftOptions<Element, Window>>
253    for Shift<'_, Element, Window>
254{
255    fn options(&self) -> &Derivable<'_, Element, Window, ShiftOptions<Element, Window>> {
256        &self.options
257    }
258}
259
260/// Default [`Limiter`], which doesn't limit shifting.
261#[derive(Clone, Debug, Default, PartialEq)]
262pub struct DefaultLimiter;
263
264impl<Element: Clone + 'static, Window: Clone + 'static> Limiter<Element, Window>
265    for DefaultLimiter
266{
267    fn compute(&self, state: MiddlewareState<Element, Window>) -> Coords {
268        Coords {
269            x: state.x,
270            y: state.y,
271        }
272    }
273}
274
275/// Axes configuration for [`LimitShiftOffset`].
276#[derive(Clone, Default, Debug, PartialEq)]
277pub struct LimitShiftOffsetValues {
278    pub main_axis: Option<f64>,
279
280    pub cross_axis: Option<f64>,
281}
282
283impl LimitShiftOffsetValues {
284    /// Set `main_axis` option.
285    pub fn main_axis(mut self, value: f64) -> Self {
286        self.main_axis = Some(value);
287        self
288    }
289
290    /// Set `cross_axis` option.
291    pub fn cross_axis(mut self, value: f64) -> Self {
292        self.cross_axis = Some(value);
293        self
294    }
295}
296
297/// Offset configuration for [`LimitShiftOptions`].
298#[derive(Clone, Debug, PartialEq)]
299pub enum LimitShiftOffset {
300    Value(f64),
301    Values(LimitShiftOffsetValues),
302}
303
304impl Default for LimitShiftOffset {
305    fn default() -> Self {
306        LimitShiftOffset::Value(0.0)
307    }
308}
309
310/// Options for [`LimitShift`] limiter.
311#[derive(Clone, PartialEq)]
312pub struct LimitShiftOptions<'a, Element: Clone + 'static, Window: Clone> {
313    pub offset: Option<Derivable<'a, Element, Window, LimitShiftOffset>>,
314
315    pub main_axis: Option<bool>,
316
317    pub cross_axis: Option<bool>,
318}
319
320impl<'a, Element: Clone, Window: Clone> LimitShiftOptions<'a, Element, Window> {
321    /// Set `offset` option.
322    pub fn offset(mut self, value: LimitShiftOffset) -> Self {
323        self.offset = Some(value.into());
324        self
325    }
326
327    /// Set `offset` option with derivable offset.
328    pub fn offset_derivable(
329        mut self,
330        value: Derivable<'a, Element, Window, LimitShiftOffset>,
331    ) -> Self {
332        self.offset = Some(value);
333        self
334    }
335
336    /// Set `offset` option with derivable offset function.
337    pub fn offset_derivable_fn(
338        mut self,
339        value: DerivableFn<'a, Element, Window, LimitShiftOffset>,
340    ) -> Self {
341        self.offset = Some(value.into());
342        self
343    }
344
345    /// Set `main_axis` option.
346    pub fn main_axis(mut self, value: bool) -> Self {
347        self.main_axis = Some(value);
348        self
349    }
350
351    /// Set `cross_axis` option.
352    pub fn cross_axis(mut self, value: bool) -> Self {
353        self.cross_axis = Some(value);
354        self
355    }
356}
357
358impl<Element: Clone + 'static, Window: Clone> Default for LimitShiftOptions<'_, Element, Window> {
359    fn default() -> Self {
360        Self {
361            offset: Default::default(),
362            main_axis: Default::default(),
363            cross_axis: Default::default(),
364        }
365    }
366}
367
368/// Built-in [`Limiter`], that will stop [`Shift`] at a certain point.
369#[derive(Clone, Default, PartialEq)]
370pub struct LimitShift<'a, Element: Clone + 'static, Window: Clone> {
371    options: LimitShiftOptions<'a, Element, Window>,
372}
373
374impl<'a, Element: Clone, Window: Clone> LimitShift<'a, Element, Window> {
375    pub fn new(options: LimitShiftOptions<'a, Element, Window>) -> Self {
376        LimitShift { options }
377    }
378}
379
380impl<Element: Clone + PartialEq, Window: Clone + PartialEq> Limiter<Element, Window>
381    for LimitShift<'static, Element, Window>
382{
383    fn compute(&self, state: MiddlewareState<Element, Window>) -> Coords {
384        let MiddlewareState {
385            x,
386            y,
387            placement,
388            rects,
389            middleware_data,
390            ..
391        } = state;
392
393        let offset = self
394            .options
395            .offset
396            .clone()
397            .unwrap_or(Derivable::Value(LimitShiftOffset::default()));
398        let check_main_axis = self.options.main_axis.unwrap_or(true);
399        let check_cross_axis = self.options.cross_axis.unwrap_or(true);
400
401        let coords = Coords { x, y };
402        let cross_axis = get_side_axis(placement);
403        let main_axis = get_opposite_axis(cross_axis);
404
405        let mut main_axis_coord = coords.axis(main_axis);
406        let mut cross_axis_coord = coords.axis(cross_axis);
407
408        let raw_offset = offset.evaluate(state.clone());
409        let (computed_main_axis, computed_cross_axis) = match raw_offset {
410            LimitShiftOffset::Value(value) => (value, 0.0),
411            LimitShiftOffset::Values(values) => (
412                values.main_axis.unwrap_or(0.0),
413                values.cross_axis.unwrap_or(0.0),
414            ),
415        };
416
417        if check_main_axis {
418            let len = main_axis.length();
419            let limit_min =
420                rects.reference.axis(main_axis) - rects.floating.length(len) + computed_main_axis;
421            let limit_max =
422                rects.reference.axis(main_axis) + rects.reference.length(len) - computed_main_axis;
423
424            main_axis_coord = clamp(limit_min, main_axis_coord, limit_max);
425        }
426
427        if check_cross_axis {
428            let len = main_axis.length();
429            let is_origin_side = match placement.side() {
430                Side::Top | Side::Left => true,
431                Side::Bottom | Side::Right => false,
432            };
433
434            let data: Option<OffsetData> = middleware_data.get_as(OFFSET_NAME);
435            let data_cross_axis = data.map_or(0.0, |data| data.diff_coords.axis(cross_axis));
436
437            let limit_min = rects.reference.axis(cross_axis) - rects.floating.length(len)
438                + if is_origin_side { data_cross_axis } else { 0.0 }
439                + if is_origin_side {
440                    0.0
441                } else {
442                    computed_cross_axis
443                };
444            let limit_max = rects.reference.axis(cross_axis)
445                + rects.reference.length(len)
446                + if is_origin_side { 0.0 } else { data_cross_axis }
447                - if is_origin_side {
448                    computed_cross_axis
449                } else {
450                    0.0
451                };
452
453            cross_axis_coord = clamp(limit_min, cross_axis_coord, limit_max);
454        }
455
456        Coords {
457            x: match main_axis {
458                Axis::X => main_axis_coord,
459                Axis::Y => cross_axis_coord,
460            },
461            y: match main_axis {
462                Axis::X => cross_axis_coord,
463                Axis::Y => main_axis_coord,
464            },
465        }
466    }
467}