Skip to main content

cranpose_ui/modifier/
pointer_input.rs

1use super::{inspector_metadata, Modifier, PointerEvent};
2use cranpose_core::hash::default;
3use cranpose_foundation::{
4    impl_pointer_input_node, DelegatableNode, ModifierNode, ModifierNodeContext,
5    ModifierNodeElement, NodeCapabilities, NodeState, PointerInputNode,
6};
7use cranpose_ui_graphics::Size;
8use futures_task::{waker, ArcWake};
9use std::any::TypeId;
10use std::cell::{Cell, RefCell};
11use std::collections::{HashMap, VecDeque};
12use std::fmt;
13use std::future::Future;
14use std::hash::{Hash, Hasher};
15use std::pin::Pin;
16use std::rc::Rc;
17use std::sync::atomic::{AtomicU64, Ordering};
18use std::sync::Arc;
19use std::task::{Context, Poll, Waker};
20
21impl Modifier {
22    pub fn pointer_input<K, F, Fut>(self, key: K, handler: F) -> Self
23    where
24        K: Hash + 'static,
25        F: Fn(PointerInputScope) -> Fut + 'static,
26        Fut: Future<Output = ()> + 'static,
27    {
28        let element =
29            PointerInputElement::new(vec![KeyToken::new(&key)], pointer_input_handler(handler));
30        let key_count = element.key_count();
31        let handler_id = element.handler_id();
32        self.then(
33            Self::with_element(element).with_inspector_metadata(inspector_metadata(
34                "pointerInput",
35                move |info| {
36                    info.add_property("keyCount", key_count.to_string());
37                    info.add_property("handlerId", handler_id.to_string());
38                },
39            )),
40        )
41    }
42}
43
44fn pointer_input_handler<F, Fut>(handler: F) -> PointerInputHandler
45where
46    F: Fn(PointerInputScope) -> Fut + 'static,
47    Fut: Future<Output = ()> + 'static,
48{
49    Rc::new(move |scope| Box::pin(handler(scope.clone())))
50}
51
52type PointerInputFuture = Pin<Box<dyn Future<Output = ()>>>;
53type PointerInputHandler = Rc<dyn Fn(PointerInputScope) -> PointerInputFuture>;
54
55thread_local! {
56    static POINTER_INPUT_TASKS: RefCell<HashMap<u64, Rc<PointerInputTaskInner>>> = RefCell::new(HashMap::new());
57}
58
59#[derive(Clone)]
60struct PointerInputElement {
61    keys: Vec<KeyToken>,
62    handler: PointerInputHandler,
63    handler_id: u64,
64}
65
66impl PointerInputElement {
67    fn new(keys: Vec<KeyToken>, handler: PointerInputHandler) -> Self {
68        static NEXT_HANDLER_ID: AtomicU64 = AtomicU64::new(1);
69        let handler_id = NEXT_HANDLER_ID.fetch_add(1, Ordering::Relaxed);
70        Self {
71            keys,
72            handler,
73            handler_id,
74        }
75    }
76
77    fn key_count(&self) -> usize {
78        self.keys.len()
79    }
80
81    fn handler_id(&self) -> u64 {
82        self.handler_id
83    }
84}
85
86impl fmt::Debug for PointerInputElement {
87    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88        f.debug_struct("PointerInputElement")
89            .field("keys", &self.keys)
90            .field("handler", &Rc::as_ptr(&self.handler))
91            .field("handler_id", &self.handler_id)
92            .finish()
93    }
94}
95
96impl PartialEq for PointerInputElement {
97    fn eq(&self, other: &Self) -> bool {
98        // Only compare keys, not handler_id. In Compose, elements are equal if their
99        // keys match, even if the handler closure is recreated on recomposition.
100        // This ensures nodes are reused instead of being dropped and recreated.
101        self.keys == other.keys
102    }
103}
104
105impl Eq for PointerInputElement {}
106
107impl Hash for PointerInputElement {
108    fn hash<H: Hasher>(&self, state: &mut H) {
109        // Only hash keys, not handler_id. This ensures stable hashing across
110        // recompositions when the closure is recreated but keys remain the same.
111        self.keys.hash(state);
112    }
113}
114
115impl ModifierNodeElement for PointerInputElement {
116    type Node = SuspendingPointerInputNode;
117
118    fn create(&self) -> Self::Node {
119        SuspendingPointerInputNode::new(self.keys.clone(), self.handler.clone())
120    }
121
122    fn update(&self, node: &mut Self::Node) {
123        node.update(self.keys.clone(), self.handler.clone());
124    }
125
126    fn capabilities(&self) -> NodeCapabilities {
127        NodeCapabilities::POINTER_INPUT
128    }
129}
130
131#[derive(Clone)]
132pub struct PointerInputScope {
133    state: Rc<PointerInputScopeState>,
134}
135
136impl PointerInputScope {
137    fn new(state: Rc<PointerInputScopeState>) -> Self {
138        Self { state }
139    }
140
141    pub fn size(&self) -> Size {
142        self.state.size.get()
143    }
144
145    pub async fn await_pointer_event_scope<R, F, Fut>(&self, block: F) -> R
146    where
147        F: FnOnce(AwaitPointerEventScope) -> Fut,
148        Fut: Future<Output = R>,
149    {
150        let scope = AwaitPointerEventScope {
151            state: self.state.clone(),
152        };
153        block(scope).await
154    }
155}
156
157#[derive(Clone)]
158pub struct AwaitPointerEventScope {
159    state: Rc<PointerInputScopeState>,
160}
161
162impl AwaitPointerEventScope {
163    pub fn size(&self) -> Size {
164        self.state.size.get()
165    }
166
167    pub async fn await_pointer_event(&self) -> PointerEvent {
168        NextPointerEvent {
169            state: self.state.clone(),
170        }
171        .await
172    }
173
174    pub async fn with_timeout_or_null<R, F, Fut>(&self, _time_millis: u64, block: F) -> Option<R>
175    where
176        F: FnOnce(&AwaitPointerEventScope) -> Fut,
177        Fut: Future<Output = R>,
178    {
179        Some(block(self).await)
180    }
181
182    pub async fn with_timeout<R, F, Fut>(&self, _time_millis: u64, block: F) -> R
183    where
184        F: FnOnce(&AwaitPointerEventScope) -> Fut,
185        Fut: Future<Output = R>,
186    {
187        block(self).await
188    }
189}
190
191struct NextPointerEvent {
192    state: Rc<PointerInputScopeState>,
193}
194
195impl Future for NextPointerEvent {
196    type Output = PointerEvent;
197
198    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
199        self.state.poll_event(cx)
200    }
201}
202
203struct PointerInputScopeState {
204    events: RefCell<VecDeque<PointerEvent>>,
205    waiting: RefCell<Option<Waker>>,
206    size: Cell<Size>,
207}
208
209impl PointerInputScopeState {
210    fn new() -> Self {
211        Self {
212            events: RefCell::new(VecDeque::new()),
213            waiting: RefCell::new(None),
214            size: Cell::new(Size {
215                width: 0.0,
216                height: 0.0,
217            }),
218        }
219    }
220
221    fn push_event(&self, event: PointerEvent) {
222        self.events.borrow_mut().push_back(event);
223        let waker = {
224            let mut waiting = self.waiting.borrow_mut();
225            waiting.take()
226        };
227        if let Some(waker) = waker {
228            waker.wake();
229        }
230    }
231
232    fn poll_event(&self, cx: &mut Context<'_>) -> Poll<PointerEvent> {
233        if let Some(event) = self.events.borrow_mut().pop_front() {
234            Poll::Ready(event)
235        } else {
236            self.waiting.replace(Some(cx.waker().clone()));
237            Poll::Pending
238        }
239    }
240}
241
242struct PointerEventDispatcher {
243    state: Rc<RefCell<Option<Rc<PointerInputScopeState>>>>,
244    handler: Rc<dyn Fn(PointerEvent)>,
245}
246
247impl PointerEventDispatcher {
248    fn new() -> Self {
249        let state = Rc::new(RefCell::new(None::<Rc<PointerInputScopeState>>));
250        let state_for_handler = state.clone();
251        let handler = Rc::new(move |event: PointerEvent| {
252            if let Some(inner) = state_for_handler.borrow().as_ref() {
253                inner.push_event(event);
254            }
255        });
256        Self { state, handler }
257    }
258
259    fn handler(&self) -> Rc<dyn Fn(PointerEvent)> {
260        self.handler.clone()
261    }
262
263    fn set_state(&self, state: Option<Rc<PointerInputScopeState>>) {
264        *self.state.borrow_mut() = state;
265    }
266}
267
268struct PointerInputTask {
269    id: u64,
270    inner: Rc<PointerInputTaskInner>,
271}
272
273impl PointerInputTask {
274    fn new(future: PointerInputFuture) -> Self {
275        static NEXT_TASK_ID: AtomicU64 = AtomicU64::new(1);
276        let id = NEXT_TASK_ID.fetch_add(1, Ordering::Relaxed);
277        let inner = Rc::new(PointerInputTaskInner::new(future));
278        POINTER_INPUT_TASKS.with(|registry| {
279            registry.borrow_mut().insert(id, inner.clone());
280        });
281        Self { id, inner }
282    }
283
284    fn poll(&self) {
285        self.inner.poll(self.id);
286    }
287
288    fn cancel(self) {
289        self.inner.cancel();
290        POINTER_INPUT_TASKS.with(|registry| {
291            registry.borrow_mut().remove(&self.id);
292        });
293    }
294}
295
296impl Drop for PointerInputTask {
297    fn drop(&mut self) {
298        self.inner.cancel();
299        POINTER_INPUT_TASKS.with(|registry| {
300            registry.borrow_mut().remove(&self.id);
301        });
302    }
303}
304
305struct PointerInputTaskInner {
306    future: RefCell<Option<PointerInputFuture>>,
307    is_polling: Cell<bool>,
308    needs_poll: Cell<bool>,
309}
310
311impl PointerInputTaskInner {
312    fn new(future: PointerInputFuture) -> Self {
313        Self {
314            future: RefCell::new(Some(future)),
315            is_polling: Cell::new(false),
316            needs_poll: Cell::new(false),
317        }
318    }
319
320    fn cancel(&self) {
321        self.future.borrow_mut().take();
322    }
323
324    fn request_poll(&self, task_id: u64) {
325        if self.is_polling.get() {
326            self.needs_poll.set(true);
327        } else {
328            self.poll(task_id);
329        }
330    }
331
332    fn poll(&self, task_id: u64) {
333        if self.is_polling.replace(true) {
334            self.needs_poll.set(true);
335            return;
336        }
337        loop {
338            self.needs_poll.set(false);
339            let waker = waker(Arc::new(PointerInputTaskWaker { task_id }));
340            let mut cx = Context::from_waker(&waker);
341            let mut future_slot = self.future.borrow_mut();
342            if let Some(future) = future_slot.as_mut() {
343                let poll_result = future.as_mut().poll(&mut cx);
344                if poll_result.is_ready() {
345                    future_slot.take();
346                }
347            }
348            if !self.needs_poll.get() {
349                break;
350            }
351        }
352        self.is_polling.set(false);
353    }
354}
355
356struct PointerInputTaskWaker {
357    task_id: u64,
358}
359
360impl ArcWake for PointerInputTaskWaker {
361    fn wake_by_ref(arc_self: &Arc<Self>) {
362        POINTER_INPUT_TASKS.with(|registry| {
363            if let Some(task) = registry.borrow().get(&arc_self.task_id).cloned() {
364                task.request_poll(arc_self.task_id);
365            }
366        });
367    }
368}
369
370pub struct SuspendingPointerInputNode {
371    keys: Vec<KeyToken>,
372    handler: PointerInputHandler,
373    dispatcher: PointerEventDispatcher,
374    task: Option<PointerInputTask>,
375    state: NodeState,
376}
377
378impl SuspendingPointerInputNode {
379    fn new(keys: Vec<KeyToken>, handler: PointerInputHandler) -> Self {
380        Self {
381            keys,
382            handler,
383            dispatcher: PointerEventDispatcher::new(),
384            task: None,
385            state: NodeState::new(),
386        }
387    }
388
389    fn update(&mut self, keys: Vec<KeyToken>, handler: PointerInputHandler) {
390        // Only restart if keys changed - not if handler Rc pointer changed.
391        // In Compose, closures are recreated every composition but the task should
392        // continue running as long as the keys are the same. This matches Jetpack
393        // Compose behavior where rememberUpdatedState keeps the task alive.
394        let should_restart = self.keys != keys;
395        self.keys = keys;
396        self.handler = handler; // Update handler even if not restarting
397        if should_restart {
398            self.restart();
399        }
400    }
401
402    fn restart(&mut self) {
403        self.cancel();
404        self.start();
405    }
406
407    fn start(&mut self) {
408        let state = Rc::new(PointerInputScopeState::new());
409        self.dispatcher.set_state(Some(state.clone()));
410        let scope = PointerInputScope::new(state);
411        let future = (self.handler)(scope);
412        let task = PointerInputTask::new(future);
413        task.poll();
414        self.task = Some(task);
415    }
416
417    fn cancel(&mut self) {
418        if let Some(task) = self.task.take() {
419            task.cancel();
420        }
421        self.dispatcher.set_state(None);
422    }
423}
424
425impl Drop for SuspendingPointerInputNode {
426    fn drop(&mut self) {
427        self.cancel();
428    }
429}
430
431impl ModifierNode for SuspendingPointerInputNode {
432    fn on_attach(&mut self, _context: &mut dyn ModifierNodeContext) {
433        self.start();
434    }
435
436    fn on_detach(&mut self) {
437        self.cancel();
438    }
439
440    fn on_reset(&mut self) {
441        // Don't restart on reset - only restart when keys/handler actually change
442        // (which is handled by update() method). Restarting here would kill the
443        // active task and lose its registered waker, preventing events from being delivered.
444    }
445
446    // Capability-driven implementation using helper macro
447    impl_pointer_input_node!();
448}
449
450impl DelegatableNode for SuspendingPointerInputNode {
451    fn node_state(&self) -> &NodeState {
452        &self.state
453    }
454}
455
456impl PointerInputNode for SuspendingPointerInputNode {
457    fn pointer_input_handler(&self) -> Option<Rc<dyn Fn(PointerEvent)>> {
458        Some(self.dispatcher.handler())
459    }
460}
461
462#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
463struct KeyToken {
464    type_id: TypeId,
465    hash: u64,
466}
467
468impl KeyToken {
469    fn new<T: Hash + 'static>(value: &T) -> Self {
470        let mut hasher = default::new();
471        value.hash(&mut hasher);
472        Self {
473            type_id: TypeId::of::<T>(),
474            hash: hasher.finish(),
475        }
476    }
477}