floating_ui_leptos/
use_floating.rs

1use std::{
2    ops::Deref,
3    rc::Rc,
4    sync::{Arc, Mutex},
5};
6
7use floating_ui_dom::{
8    ComputePositionConfig, MiddlewareData, OwnedElementOrVirtual, Placement, Strategy,
9    VirtualElement, compute_position,
10};
11use leptos::{html::ElementType, prelude::*};
12use leptos_node_ref::AnyNodeRef;
13use send_wrapper::SendWrapper;
14use web_sys::wasm_bindgen::{JsCast, JsValue};
15
16use crate::{
17    types::{FloatingStyles, UseFloatingOptions, UseFloatingReturn, WhileElementsMountedCleanupFn},
18    utils::{get_dpr::get_dpr, round_by_dpr::round_by_dpr},
19};
20
21pub struct Virtual;
22
23impl ElementType for Virtual {
24    type Output = JsValue;
25
26    const TAG: &'static str = "virtual";
27    const SELF_CLOSING: bool = true;
28    const ESCAPE_CHILDREN: bool = true;
29    const NAMESPACE: Option<&'static str> = None;
30
31    fn tag(&self) -> &str {
32        Self::TAG
33    }
34}
35
36#[derive(Clone)]
37pub enum VirtualElementOrNodeRef {
38    VirtualElement(SendWrapper<Box<dyn VirtualElement<web_sys::Element>>>),
39    NodeRef(AnyNodeRef),
40}
41
42impl VirtualElementOrNodeRef {
43    pub fn get(&self) -> Option<OwnedElementOrVirtual> {
44        match self {
45            VirtualElementOrNodeRef::VirtualElement(virtual_element) => {
46                Some((**virtual_element).clone().into())
47            }
48            VirtualElementOrNodeRef::NodeRef(node_ref) => node_ref
49                .get()
50                .and_then(|element| element.dyn_into::<web_sys::Element>().ok())
51                .map(|element| element.into()),
52        }
53    }
54
55    pub fn get_untracked(&self) -> Option<OwnedElementOrVirtual> {
56        match self {
57            VirtualElementOrNodeRef::VirtualElement(virtual_element) => {
58                Some((**virtual_element).clone().into())
59            }
60            VirtualElementOrNodeRef::NodeRef(node_ref) => node_ref
61                .get_untracked()
62                .and_then(|element| element.dyn_into::<web_sys::Element>().ok())
63                .map(|element| element.into()),
64        }
65    }
66}
67
68// impl<E: ElementType> Clone for VirtualElementOrNodeRef<E> {
69//     fn clone(&self) -> Self {
70//         match self {
71//             Self::VirtualElement(virtual_element) => Self::VirtualElement(virtual_element.clone()),
72//             Self::NodeRef(node_ref) => Self::NodeRef(*node_ref),
73//         }
74//     }
75// }
76
77impl From<Box<dyn VirtualElement<web_sys::Element>>> for VirtualElementOrNodeRef {
78    fn from(value: Box<dyn VirtualElement<web_sys::Element>>) -> Self {
79        VirtualElementOrNodeRef::VirtualElement(SendWrapper::new(value))
80    }
81}
82
83impl From<AnyNodeRef> for VirtualElementOrNodeRef {
84    fn from(value: AnyNodeRef) -> Self {
85        VirtualElementOrNodeRef::NodeRef(value)
86    }
87}
88
89#[derive(Clone, Copy)]
90pub struct Reference(MaybeProp<VirtualElementOrNodeRef>);
91
92impl Deref for Reference {
93    type Target = MaybeProp<VirtualElementOrNodeRef>;
94
95    fn deref(&self) -> &Self::Target {
96        &self.0
97    }
98}
99
100impl From<MaybeProp<VirtualElementOrNodeRef>> for Reference {
101    fn from(value: MaybeProp<VirtualElementOrNodeRef>) -> Self {
102        Reference(value)
103    }
104}
105
106impl From<Memo<VirtualElementOrNodeRef>> for Reference {
107    fn from(value: Memo<VirtualElementOrNodeRef>) -> Self {
108        Reference(value.into())
109    }
110}
111
112impl From<ReadSignal<VirtualElementOrNodeRef>> for Reference {
113    fn from(value: ReadSignal<VirtualElementOrNodeRef>) -> Self {
114        Reference(value.into())
115    }
116}
117
118impl From<RwSignal<VirtualElementOrNodeRef>> for Reference {
119    fn from(value: RwSignal<VirtualElementOrNodeRef>) -> Self {
120        Reference(value.into())
121    }
122}
123
124impl From<Signal<VirtualElementOrNodeRef>> for Reference {
125    fn from(value: Signal<VirtualElementOrNodeRef>) -> Self {
126        Reference(value.into())
127    }
128}
129
130impl From<VirtualElementOrNodeRef> for Reference {
131    fn from(value: VirtualElementOrNodeRef) -> Self {
132        Reference(value.into())
133    }
134}
135
136impl From<Box<dyn VirtualElement<web_sys::Element>>> for Reference {
137    fn from(value: Box<dyn VirtualElement<web_sys::Element>>) -> Self {
138        Reference(VirtualElementOrNodeRef::from(value).into())
139    }
140}
141
142impl From<AnyNodeRef> for Reference {
143    fn from(value: AnyNodeRef) -> Self {
144        Reference(VirtualElementOrNodeRef::from(value).into())
145    }
146}
147
148/// Computes the `x` and `y` coordinates that will place the floating element next to a reference element.
149pub fn use_floating<R: Into<Reference>>(
150    reference: R,
151    floating: AnyNodeRef,
152    options: UseFloatingOptions,
153) -> UseFloatingReturn {
154    let reference: Reference = reference.into();
155
156    let open_option = Signal::derive(move || options.open.get().unwrap_or(true));
157    let placement_option_untracked = move || {
158        options
159            .placement
160            .get_untracked()
161            .unwrap_or(Placement::Bottom)
162    };
163    let strategy_option_untracked = move || {
164        options
165            .strategy
166            .get_untracked()
167            .unwrap_or(Strategy::Absolute)
168    };
169    let middleware_option_untracked = move || options.middleware.get_untracked();
170    let transform_option = move || options.transform.get().unwrap_or(true);
171    let while_elements_mounted_untracked = move || options.while_elements_mounted.get_untracked();
172
173    let (x, set_x) = signal(0.0);
174    let (y, set_y) = signal(0.0);
175    let (strategy, set_strategy) = signal(strategy_option_untracked());
176    let (placement, set_placement) = signal(placement_option_untracked());
177    let (middleware_data, set_middleware_data) = signal(MiddlewareData::default());
178    let (is_positioned, set_is_positioned) = signal(false);
179    let floating_styles = Memo::new(move |_| {
180        let initial_styles = FloatingStyles {
181            position: strategy.get(),
182            top: "0".to_owned(),
183            left: "0".to_owned(),
184            transform: None,
185            will_change: None,
186        };
187
188        match floating
189            .get()
190            .and_then(|floating| floating.dyn_into::<web_sys::Element>().ok())
191        {
192            Some(floating_element) => {
193                let x_val = round_by_dpr(&floating_element, x.get());
194                let y_val = round_by_dpr(&floating_element, y.get());
195
196                if transform_option() {
197                    FloatingStyles {
198                        transform: Some(format!("translate({x_val}px, {y_val}px)")),
199                        will_change: (get_dpr(&floating_element) >= 1.5)
200                            .then_some("transform".to_owned()),
201                        ..initial_styles
202                    }
203                } else {
204                    FloatingStyles {
205                        left: format!("{x_val}px"),
206                        top: format!("{y_val}px"),
207                        ..initial_styles
208                    }
209                }
210            }
211            _ => initial_styles,
212        }
213    });
214
215    let update = Rc::new({
216        move || {
217            if let Some(reference) = reference.get_untracked()
218                && let Some(reference_element) = reference.get_untracked()
219                && let Some(floating_element) = floating
220                    .get_untracked()
221                    .and_then(|floating| floating.dyn_into::<web_sys::Element>().ok())
222            {
223                let config = ComputePositionConfig {
224                    placement: Some(placement_option_untracked()),
225                    strategy: Some(strategy_option_untracked()),
226                    middleware: middleware_option_untracked()
227                        .map(|middleware| middleware.deref().clone()),
228                };
229
230                let open = open_option.get_untracked();
231
232                let position =
233                    compute_position((&reference_element).into(), &floating_element, config);
234                set_x.set(position.x);
235                set_y.set(position.y);
236                set_strategy.set(position.strategy);
237                set_placement.set(position.placement);
238                set_middleware_data.set(position.middleware_data);
239                // The floating element's position may be recomputed while it's closed
240                // but still mounted (such as when transitioning out). To ensure
241                // `is_positioned` will be `false` initially on the next open,
242                // avoid setting it to `true` when `open === false` (must be specified).
243                set_is_positioned.set(open);
244            }
245        }
246    });
247
248    let while_elements_mounted_cleanup: Arc<
249        Mutex<Option<SendWrapper<WhileElementsMountedCleanupFn>>>,
250    > = Arc::new(Mutex::new(None));
251
252    let cleanup = Arc::new({
253        let while_elements_mounted_cleanup = while_elements_mounted_cleanup.clone();
254
255        move || {
256            if let Some(while_elements_mounted_cleanup) = while_elements_mounted_cleanup
257                .lock()
258                .expect("Lock should be acquired.")
259                .as_ref()
260            {
261                while_elements_mounted_cleanup();
262            }
263        }
264    });
265
266    let attach = Rc::new({
267        let update = update.clone();
268        let cleanup = cleanup.clone();
269        let while_elements_mounted_cleanup = while_elements_mounted_cleanup.clone();
270
271        move || {
272            cleanup();
273
274            match while_elements_mounted_untracked() {
275                Some(while_elements_mounted) => {
276                    if let Some(reference) = reference.get_untracked()
277                        && let Some(reference_element) = reference.get_untracked()
278                        && let Some(floating_element) = floating
279                            .get_untracked()
280                            .and_then(|floating| floating.dyn_into::<web_sys::Element>().ok())
281                    {
282                        *while_elements_mounted_cleanup
283                            .lock()
284                            .expect("Lock should be acquired.") =
285                            Some(SendWrapper::new(while_elements_mounted(
286                                (&reference_element).into(),
287                                &floating_element,
288                                update.clone(),
289                            )));
290                    }
291                }
292                _ => {
293                    update();
294                }
295            }
296        }
297    });
298
299    let reset = move || {
300        if !open_option.get_untracked() {
301            set_is_positioned.set(false);
302        }
303    };
304
305    Effect::new({
306        let attach = attach.clone();
307
308        move |_| {
309            if let Some(reference) = reference.get() {
310                match reference {
311                    VirtualElementOrNodeRef::VirtualElement(_) => {
312                        attach();
313                    }
314                    VirtualElementOrNodeRef::NodeRef(reference) => {
315                        if reference
316                            .get()
317                            .and_then(|reference| reference.dyn_into::<web_sys::Element>().ok())
318                            .is_some()
319                        {
320                            attach();
321                        }
322                    }
323                }
324            }
325        }
326    });
327
328    Effect::new({
329        let attach = attach.clone();
330
331        move |_| {
332            if floating
333                .get()
334                .and_then(|floating| floating.dyn_into::<web_sys::Element>().ok())
335                .is_some()
336            {
337                attach();
338            }
339        }
340    });
341
342    Effect::new(move |_| {
343        reset();
344    });
345
346    _ = Effect::watch(
347        move || open_option.get(),
348        {
349            let update = update.clone();
350
351            move |_, _, _| {
352                update();
353            }
354        },
355        false,
356    );
357    _ = Effect::watch(
358        move || options.placement.get(),
359        {
360            let update = update.clone();
361
362            move |_, _, _| {
363                update();
364            }
365        },
366        false,
367    );
368    _ = Effect::watch(
369        move || options.strategy.get(),
370        {
371            let update = update.clone();
372
373            move |_, _, _| {
374                update();
375            }
376        },
377        false,
378    );
379    _ = Effect::watch(
380        move || options.middleware.get(),
381        {
382            let update = update.clone();
383
384            move |_, _, _| {
385                update();
386            }
387        },
388        false,
389    );
390    _ = Effect::watch(
391        move || options.while_elements_mounted.get(),
392        move |_, _, _| {
393            attach();
394        },
395        false,
396    );
397
398    on_cleanup(move || {
399        cleanup();
400    });
401
402    UseFloatingReturn {
403        x: x.into(),
404        y: y.into(),
405        placement: placement.into(),
406        strategy: strategy.into(),
407        middleware_data: middleware_data.into(),
408        is_positioned: is_positioned.into(),
409        floating_styles: floating_styles.into(),
410        update: SendWrapper::new(update.clone()),
411    }
412}
413
414#[cfg(target_arch = "wasm32")]
415#[cfg(test)]
416mod tests {
417    use leptos::prelude::*;
418    use leptos_node_ref::AnyNodeRef;
419    use wasm_bindgen_test::*;
420
421    use super::*;
422
423    wasm_bindgen_test_configure!(run_in_browser);
424
425    #[wasm_bindgen_test]
426    fn updates_is_positioned_when_position_is_computed() {
427        #[component]
428        fn Component() -> impl IntoView {
429            let reference = AnyNodeRef::new();
430            let floating = AnyNodeRef::new();
431            let UseFloatingReturn { is_positioned, .. } =
432                use_floating(reference, floating, UseFloatingOptions::default());
433
434            view! {
435                <div node_ref=reference />
436                <div node_ref=floating />
437                <div id="test-is-positioned">{is_positioned}</div>
438            }
439        }
440
441        mount_to_body(Component);
442
443        // assert_eq!(
444        //     document
445        //         .get_element_by_id("test-is-positioned")
446        //         .and_then(|element| element.text_content()),
447        //     Some("true".to_owned())
448        // );
449    }
450}