1use std::fmt::Debug;
2
3use dyn_derive::dyn_trait;
4use floating_ui_utils::{Axis, Coords, Side, clamp, get_opposite_axis, get_side_axis};
5use serde::{Deserialize, Serialize};
6
7use crate::{
8 detect_overflow::{DetectOverflowOptions, detect_overflow},
9 middleware::{OFFSET_NAME, OffsetData},
10 types::{
11 Derivable, DerivableFn, Middleware, MiddlewareReturn, MiddlewareState,
12 MiddlewareWithOptions,
13 },
14};
15
16pub const SHIFT_NAME: &str = "shift";
18
19#[dyn_trait]
21pub trait Limiter<Element: Clone + 'static, Window: Clone + 'static>: Clone + PartialEq {
22 fn compute(&self, state: MiddlewareState<Element, Window>) -> Coords;
23}
24
25#[derive(Clone, PartialEq)]
27pub struct ShiftOptions<Element: Clone + 'static, Window: Clone + 'static> {
28 pub detect_overflow: Option<DetectOverflowOptions<Element>>,
32
33 pub main_axis: Option<bool>,
37
38 pub cross_axis: Option<bool>,
42
43 pub limiter: Option<Box<dyn Limiter<Element, Window>>>,
47}
48
49impl<Element: Clone, Window: Clone> ShiftOptions<Element, Window> {
50 pub fn detect_overflow(mut self, value: DetectOverflowOptions<Element>) -> Self {
52 self.detect_overflow = Some(value);
53 self
54 }
55
56 pub fn main_axis(mut self, value: bool) -> Self {
58 self.main_axis = Some(value);
59 self
60 }
61
62 pub fn cross_axis(mut self, value: bool) -> Self {
64 self.cross_axis = Some(value);
65 self
66 }
67
68 pub fn limiter(mut self, value: Box<dyn Limiter<Element, Window>>) -> Self {
70 self.limiter = Some(value);
71 self
72 }
73}
74
75impl<Element: Clone, Window: Clone> Default for ShiftOptions<Element, Window> {
76 fn default() -> Self {
77 Self {
78 detect_overflow: Default::default(),
79 main_axis: Default::default(),
80 cross_axis: Default::default(),
81 limiter: Default::default(),
82 }
83 }
84}
85
86#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
88pub struct ShiftDataEnabled {
89 pub x: bool,
90 pub y: bool,
91}
92
93impl ShiftDataEnabled {
94 pub fn set_axis(mut self, axis: Axis, enabled: bool) -> Self {
95 match axis {
96 Axis::X => {
97 self.x = enabled;
98 }
99 Axis::Y => {
100 self.y = enabled;
101 }
102 }
103 self
104 }
105}
106
107#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
109pub struct ShiftData {
110 pub x: f64,
111 pub y: f64,
112 pub enabled: ShiftDataEnabled,
113}
114
115#[derive(PartialEq)]
121pub struct Shift<'a, Element: Clone + 'static, Window: Clone + 'static> {
122 options: Derivable<'a, Element, Window, ShiftOptions<Element, Window>>,
123}
124
125impl<'a, Element: Clone, Window: Clone> Shift<'a, Element, Window> {
126 pub fn new(options: ShiftOptions<Element, Window>) -> Self {
128 Shift {
129 options: options.into(),
130 }
131 }
132
133 pub fn new_derivable(
135 options: Derivable<'a, Element, Window, ShiftOptions<Element, Window>>,
136 ) -> Self {
137 Shift { options }
138 }
139
140 pub fn new_derivable_fn(
142 options: DerivableFn<'a, Element, Window, ShiftOptions<Element, Window>>,
143 ) -> Self {
144 Shift {
145 options: options.into(),
146 }
147 }
148}
149
150impl<Element: Clone, Window: Clone> Clone for Shift<'_, Element, Window> {
151 fn clone(&self) -> Self {
152 Self {
153 options: self.options.clone(),
154 }
155 }
156}
157
158impl<Element: Clone + PartialEq + 'static, Window: Clone + PartialEq + 'static>
159 Middleware<Element, Window> for Shift<'static, Element, Window>
160{
161 fn name(&self) -> &'static str {
162 SHIFT_NAME
163 }
164
165 fn compute(&self, state: MiddlewareState<Element, Window>) -> MiddlewareReturn {
166 let options = self.options.evaluate(state.clone());
167
168 let MiddlewareState {
169 x, y, placement, ..
170 } = state;
171
172 let check_main_axis = options.main_axis.unwrap_or(true);
173 let check_cross_axis = options.cross_axis.unwrap_or(false);
174 #[allow(clippy::unwrap_or_default)]
175 let limiter = options.limiter.unwrap_or(Box::<DefaultLimiter>::default());
176
177 let coords = Coords { x, y };
178 let overflow = detect_overflow(
179 MiddlewareState {
180 elements: state.elements.clone(),
181 ..state
182 },
183 options.detect_overflow.unwrap_or_default(),
184 );
185 let cross_axis = get_side_axis(placement);
186 let main_axis = get_opposite_axis(cross_axis);
187
188 let mut main_axis_coord = coords.axis(main_axis);
189 let mut cross_axis_coord = coords.axis(cross_axis);
190
191 if check_main_axis {
192 let min_side = match main_axis {
193 Axis::X => Side::Left,
194 Axis::Y => Side::Top,
195 };
196 let max_side = match main_axis {
197 Axis::X => Side::Right,
198 Axis::Y => Side::Bottom,
199 };
200 let min = main_axis_coord + overflow.side(min_side);
201 let max = main_axis_coord - overflow.side(max_side);
202
203 main_axis_coord = clamp(min, main_axis_coord, max);
204 }
205
206 if check_cross_axis {
207 let min_side = match cross_axis {
208 Axis::X => Side::Left,
209 Axis::Y => Side::Top,
210 };
211 let max_side = match cross_axis {
212 Axis::X => Side::Right,
213 Axis::Y => Side::Bottom,
214 };
215 let min = cross_axis_coord + overflow.side(min_side);
216 let max = cross_axis_coord - overflow.side(max_side);
217
218 cross_axis_coord = clamp(min, cross_axis_coord, max);
219 }
220
221 let limited_coords = limiter.compute(MiddlewareState {
222 x: match main_axis {
223 Axis::X => main_axis_coord,
224 Axis::Y => cross_axis_coord,
225 },
226 y: match main_axis {
227 Axis::X => cross_axis_coord,
228 Axis::Y => main_axis_coord,
229 },
230 ..state
231 });
232
233 MiddlewareReturn {
234 x: Some(limited_coords.x),
235 y: Some(limited_coords.y),
236 data: Some(
237 serde_json::to_value(ShiftData {
238 x: limited_coords.x - x,
239 y: limited_coords.y - y,
240 enabled: ShiftDataEnabled::default()
241 .set_axis(main_axis, check_main_axis)
242 .set_axis(cross_axis, check_cross_axis),
243 })
244 .expect("Data should be valid JSON."),
245 ),
246 reset: None,
247 }
248 }
249}
250
251impl<Element: Clone, Window: Clone>
252 MiddlewareWithOptions<Element, Window, ShiftOptions<Element, Window>>
253 for Shift<'_, Element, Window>
254{
255 fn options(&self) -> &Derivable<'_, Element, Window, ShiftOptions<Element, Window>> {
256 &self.options
257 }
258}
259
260#[derive(Clone, Debug, Default, PartialEq)]
262pub struct DefaultLimiter;
263
264impl<Element: Clone + 'static, Window: Clone + 'static> Limiter<Element, Window>
265 for DefaultLimiter
266{
267 fn compute(&self, state: MiddlewareState<Element, Window>) -> Coords {
268 Coords {
269 x: state.x,
270 y: state.y,
271 }
272 }
273}
274
275#[derive(Clone, Default, Debug, PartialEq)]
277pub struct LimitShiftOffsetValues {
278 pub main_axis: Option<f64>,
279
280 pub cross_axis: Option<f64>,
281}
282
283impl LimitShiftOffsetValues {
284 pub fn main_axis(mut self, value: f64) -> Self {
286 self.main_axis = Some(value);
287 self
288 }
289
290 pub fn cross_axis(mut self, value: f64) -> Self {
292 self.cross_axis = Some(value);
293 self
294 }
295}
296
297#[derive(Clone, Debug, PartialEq)]
299pub enum LimitShiftOffset {
300 Value(f64),
301 Values(LimitShiftOffsetValues),
302}
303
304impl Default for LimitShiftOffset {
305 fn default() -> Self {
306 LimitShiftOffset::Value(0.0)
307 }
308}
309
310#[derive(Clone, PartialEq)]
312pub struct LimitShiftOptions<'a, Element: Clone + 'static, Window: Clone> {
313 pub offset: Option<Derivable<'a, Element, Window, LimitShiftOffset>>,
314
315 pub main_axis: Option<bool>,
316
317 pub cross_axis: Option<bool>,
318}
319
320impl<'a, Element: Clone, Window: Clone> LimitShiftOptions<'a, Element, Window> {
321 pub fn offset(mut self, value: LimitShiftOffset) -> Self {
323 self.offset = Some(value.into());
324 self
325 }
326
327 pub fn offset_derivable(
329 mut self,
330 value: Derivable<'a, Element, Window, LimitShiftOffset>,
331 ) -> Self {
332 self.offset = Some(value);
333 self
334 }
335
336 pub fn offset_derivable_fn(
338 mut self,
339 value: DerivableFn<'a, Element, Window, LimitShiftOffset>,
340 ) -> Self {
341 self.offset = Some(value.into());
342 self
343 }
344
345 pub fn main_axis(mut self, value: bool) -> Self {
347 self.main_axis = Some(value);
348 self
349 }
350
351 pub fn cross_axis(mut self, value: bool) -> Self {
353 self.cross_axis = Some(value);
354 self
355 }
356}
357
358impl<Element: Clone + 'static, Window: Clone> Default for LimitShiftOptions<'_, Element, Window> {
359 fn default() -> Self {
360 Self {
361 offset: Default::default(),
362 main_axis: Default::default(),
363 cross_axis: Default::default(),
364 }
365 }
366}
367
368#[derive(Clone, Default, PartialEq)]
370pub struct LimitShift<'a, Element: Clone + 'static, Window: Clone> {
371 options: LimitShiftOptions<'a, Element, Window>,
372}
373
374impl<'a, Element: Clone, Window: Clone> LimitShift<'a, Element, Window> {
375 pub fn new(options: LimitShiftOptions<'a, Element, Window>) -> Self {
376 LimitShift { options }
377 }
378}
379
380impl<Element: Clone + PartialEq, Window: Clone + PartialEq> Limiter<Element, Window>
381 for LimitShift<'static, Element, Window>
382{
383 fn compute(&self, state: MiddlewareState<Element, Window>) -> Coords {
384 let MiddlewareState {
385 x,
386 y,
387 placement,
388 rects,
389 middleware_data,
390 ..
391 } = state;
392
393 let offset = self
394 .options
395 .offset
396 .clone()
397 .unwrap_or(Derivable::Value(LimitShiftOffset::default()));
398 let check_main_axis = self.options.main_axis.unwrap_or(true);
399 let check_cross_axis = self.options.cross_axis.unwrap_or(true);
400
401 let coords = Coords { x, y };
402 let cross_axis = get_side_axis(placement);
403 let main_axis = get_opposite_axis(cross_axis);
404
405 let mut main_axis_coord = coords.axis(main_axis);
406 let mut cross_axis_coord = coords.axis(cross_axis);
407
408 let raw_offset = offset.evaluate(state.clone());
409 let (computed_main_axis, computed_cross_axis) = match raw_offset {
410 LimitShiftOffset::Value(value) => (value, 0.0),
411 LimitShiftOffset::Values(values) => (
412 values.main_axis.unwrap_or(0.0),
413 values.cross_axis.unwrap_or(0.0),
414 ),
415 };
416
417 if check_main_axis {
418 let len = main_axis.length();
419 let limit_min =
420 rects.reference.axis(main_axis) - rects.floating.length(len) + computed_main_axis;
421 let limit_max =
422 rects.reference.axis(main_axis) + rects.reference.length(len) - computed_main_axis;
423
424 main_axis_coord = clamp(limit_min, main_axis_coord, limit_max);
425 }
426
427 if check_cross_axis {
428 let len = main_axis.length();
429 let is_origin_side = match placement.side() {
430 Side::Top | Side::Left => true,
431 Side::Bottom | Side::Right => false,
432 };
433
434 let data: Option<OffsetData> = middleware_data.get_as(OFFSET_NAME);
435 let data_cross_axis = data.map_or(0.0, |data| data.diff_coords.axis(cross_axis));
436
437 let limit_min = rects.reference.axis(cross_axis) - rects.floating.length(len)
438 + if is_origin_side { data_cross_axis } else { 0.0 }
439 + if is_origin_side {
440 0.0
441 } else {
442 computed_cross_axis
443 };
444 let limit_max = rects.reference.axis(cross_axis)
445 + rects.reference.length(len)
446 + if is_origin_side { 0.0 } else { data_cross_axis }
447 - if is_origin_side {
448 computed_cross_axis
449 } else {
450 0.0
451 };
452
453 cross_axis_coord = clamp(limit_min, cross_axis_coord, limit_max);
454 }
455
456 Coords {
457 x: match main_axis {
458 Axis::X => main_axis_coord,
459 Axis::Y => cross_axis_coord,
460 },
461 y: match main_axis {
462 Axis::X => cross_axis_coord,
463 Axis::Y => main_axis_coord,
464 },
465 }
466 }
467}