Skip to main content

cranpose_ui/
scroll.rs

1//! Scroll state and node implementation for cranpose.
2//!
3//! This module provides the core scrolling components:
4//! - `ScrollState`: Holds scroll position and provides scroll control methods
5//! - `ScrollNode`: Layout modifier that applies scroll offset to content
6//! - `ScrollElement`: Element for creating ScrollNode instances
7//!
8//! The actual `Modifier.horizontal_scroll()` and `Modifier.vertical_scroll()`
9//! extension methods are defined in `modifier/scroll.rs`.
10
11use cranpose_core::{
12    current_runtime_handle, ownedMutableStateOf, NodeId, OwnedMutableState, RuntimeHandle,
13};
14use cranpose_foundation::{
15    Constraints, DelegatableNode, LayoutModifierNode, Measurable, ModifierNode,
16    ModifierNodeContext, ModifierNodeElement, NodeCapabilities, NodeState,
17};
18use cranpose_ui_graphics::Size;
19use cranpose_ui_layout::LayoutModifierMeasureResult;
20use std::cell::{Cell, RefCell};
21use std::collections::HashMap;
22use std::hash::{DefaultHasher, Hash, Hasher};
23use std::rc::{Rc, Weak};
24use std::sync::atomic::{AtomicU64, Ordering};
25
26static NEXT_SCROLL_STATE_ID: AtomicU64 = AtomicU64::new(1);
27const SCROLL_MOTION_ACTIVE_FRAME_COUNT: u8 = 6;
28
29/// State object for scroll position tracking.
30///
31/// Holds the current scroll offset and provides methods to programmatically
32/// control scrolling. Can be created with `rememberScrollState()`.
33///
34/// This is a pure scroll model - it does NOT store ephemeral gesture/pointer state.
35/// Gesture state is managed locally in the scroll modifier.
36#[derive(Clone)]
37pub struct ScrollState {
38    inner: Rc<ScrollStateInner>,
39}
40
41pub(crate) struct ScrollStateInner {
42    /// Unique ID for debugging
43    id: u64,
44    /// Current scroll offset in pixels.
45    /// Uses MutableState<f32> for reactivity - Composables can observe this value.
46    /// Layout reads use get_non_reactive() to avoid triggering recomposition.
47    value: OwnedMutableState<f32>,
48    /// Maximum scroll value (content_size - viewport_size)
49    /// Using RefCell instead of MutableState to avoid snapshot isolation issues
50    max_value: RefCell<f32>,
51    /// Callbacks to invalidate layout when scroll value changes
52    /// Using HashMap to allow multiple listeners (e.g. real node + clones)
53    invalidate_callbacks: RefCell<std::collections::HashMap<u64, Box<dyn Fn()>>>,
54    /// Tracks whether we need to invalidate once a callback is registered.
55    pending_invalidation: Cell<bool>,
56}
57
58impl ScrollState {
59    /// Creates a new ScrollState with the given initial scroll position.
60    pub fn new(initial: f32) -> Self {
61        let id = NEXT_SCROLL_STATE_ID.fetch_add(1, Ordering::Relaxed);
62
63        Self {
64            inner: Rc::new(ScrollStateInner {
65                id,
66                value: ownedMutableStateOf(initial),
67                max_value: RefCell::new(0.0),
68                invalidate_callbacks: RefCell::new(std::collections::HashMap::new()),
69                pending_invalidation: Cell::new(false),
70            }),
71        }
72    }
73
74    /// Get the unique ID of this ScrollState
75    pub fn id(&self) -> u64 {
76        self.inner.id
77    }
78
79    /// Gets the current scroll position in pixels (reactive - triggers recomposition).
80    ///
81    /// Use this in Composable functions when you want UI to update on scroll.
82    /// Example: `Text("Scroll position: ${scrollState.value()}")`
83    pub fn value(&self) -> f32 {
84        self.inner.value.with(|v| *v)
85    }
86
87    /// Gets the current scroll position in pixels (non-reactive).
88    ///
89    /// Use this in layout/measure phase to avoid triggering recomposition.
90    /// This is called internally by ScrollNode::measure().
91    pub fn value_non_reactive(&self) -> f32 {
92        self.inner.value.get_non_reactive()
93    }
94
95    /// Gets the maximum scroll value.
96    pub fn max_value(&self) -> f32 {
97        *self.inner.max_value.borrow()
98    }
99
100    /// Scrolls by the given delta, clamping to valid range [0, max_value].
101    /// Returns the actual amount scrolled.
102    pub fn dispatch_raw_delta(&self, delta: f32) -> f32 {
103        let current = self.value();
104        let max = self.max_value();
105        let new_value = (current + delta).clamp(0.0, max);
106        let actual_delta = new_value - current;
107
108        if actual_delta.abs() > 0.001 {
109            // Use MutableState::set which triggers snapshot observers for reactive updates
110            self.inner.value.set(new_value);
111
112            // Trigger layout invalidation callbacks
113            let callbacks = self.inner.invalidate_callbacks.borrow();
114            if callbacks.is_empty() {
115                // Defer invalidation until a node registers a callback.
116                self.inner.pending_invalidation.set(true);
117            } else {
118                for callback in callbacks.values() {
119                    callback();
120                }
121            }
122        }
123
124        actual_delta
125    }
126
127    /// Sets the maximum scroll value (internal use by ScrollNode).
128    pub(crate) fn set_max_value(&self, max: f32) {
129        *self.inner.max_value.borrow_mut() = max;
130    }
131
132    /// Scrolls to the given position immediately.
133    pub fn scroll_to(&self, position: f32) {
134        let max = self.max_value();
135        let clamped = position.clamp(0.0, max);
136
137        self.inner.value.set(clamped);
138
139        // Trigger layout invalidation callbacks
140        let callbacks = self.inner.invalidate_callbacks.borrow();
141        if callbacks.is_empty() {
142            self.inner.pending_invalidation.set(true);
143        } else {
144            for callback in callbacks.values() {
145                callback();
146            }
147        }
148    }
149
150    /// Adds an invalidation callback and returns its ID
151    pub(crate) fn add_invalidate_callback(&self, callback: Box<dyn Fn()>) -> u64 {
152        static NEXT_CALLBACK_ID: std::sync::atomic::AtomicU64 =
153            std::sync::atomic::AtomicU64::new(1);
154        let id = NEXT_CALLBACK_ID.fetch_add(1, Ordering::Relaxed);
155        self.inner
156            .invalidate_callbacks
157            .borrow_mut()
158            .insert(id, callback);
159        if self.inner.pending_invalidation.replace(false) {
160            if let Some(callback) = self.inner.invalidate_callbacks.borrow().get(&id) {
161                callback();
162            }
163        }
164        id
165    }
166
167    /// Removes an invalidation callback by ID
168    pub(crate) fn remove_invalidate_callback(&self, id: u64) {
169        self.inner.invalidate_callbacks.borrow_mut().remove(&id);
170    }
171}
172
173#[derive(Clone)]
174pub(crate) struct ScrollMotionContext {
175    inner: Rc<ScrollMotionContextInner>,
176}
177
178#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
179pub(crate) enum ScrollMotionContextKey {
180    ScrollState {
181        state_id: u64,
182        is_vertical: bool,
183        reverse_scrolling: bool,
184    },
185    LazyList {
186        state_identity: usize,
187        is_vertical: bool,
188        reverse_scrolling: bool,
189    },
190}
191
192struct ScrollMotionContextInner {
193    active: Cell<bool>,
194    generation: Cell<u64>,
195    invalidate_callbacks: RefCell<std::collections::HashMap<u64, Box<dyn Fn()>>>,
196    pending_invalidation: Cell<bool>,
197}
198
199thread_local! {
200    static SCROLL_MOTION_CONTEXTS: RefCell<HashMap<ScrollMotionContextKey, Weak<ScrollMotionContextInner>>> =
201        RefCell::new(HashMap::new());
202}
203
204pub(crate) fn scroll_motion_context_for_key(key: ScrollMotionContextKey) -> ScrollMotionContext {
205    SCROLL_MOTION_CONTEXTS.with(|contexts| {
206        let mut contexts = contexts.borrow_mut();
207        if let Some(inner) = contexts.get(&key).and_then(Weak::upgrade) {
208            return ScrollMotionContext { inner };
209        }
210
211        let context = ScrollMotionContext::new();
212        contexts.insert(key, Rc::downgrade(&context.inner));
213        contexts.retain(|_, weak| weak.strong_count() > 0);
214        context
215    })
216}
217
218impl ScrollMotionContext {
219    pub(crate) fn new() -> Self {
220        Self {
221            inner: Rc::new(ScrollMotionContextInner {
222                active: Cell::new(false),
223                generation: Cell::new(0),
224                invalidate_callbacks: RefCell::new(std::collections::HashMap::new()),
225                pending_invalidation: Cell::new(false),
226            }),
227        }
228    }
229
230    pub(crate) fn is_active(&self) -> bool {
231        self.inner.active.get()
232    }
233
234    pub(crate) fn ptr_eq(&self, other: &Self) -> bool {
235        Rc::ptr_eq(&self.inner, &other.inner)
236    }
237
238    pub(crate) fn stable_key(&self) -> usize {
239        Rc::as_ptr(&self.inner) as usize
240    }
241
242    pub(crate) fn set_active(&self, active: bool) {
243        if self.inner.active.replace(active) != active {
244            self.bump_generation();
245            self.invalidate();
246        }
247    }
248
249    pub(crate) fn activate_for_next_frame(&self) {
250        let was_active = self.inner.active.replace(true);
251        let generation = self.bump_generation();
252        if !was_active {
253            self.invalidate();
254        }
255        if let Some(runtime) = current_runtime_handle() {
256            self.schedule_clear_after_frames(runtime, generation, SCROLL_MOTION_ACTIVE_FRAME_COUNT);
257        } else {
258            self.clear_if_generation(generation);
259        }
260    }
261
262    pub(crate) fn add_invalidate_callback(&self, callback: Box<dyn Fn()>) -> u64 {
263        static NEXT_CALLBACK_ID: AtomicU64 = AtomicU64::new(1);
264        let id = NEXT_CALLBACK_ID.fetch_add(1, Ordering::Relaxed);
265        self.inner
266            .invalidate_callbacks
267            .borrow_mut()
268            .insert(id, callback);
269        if self.inner.pending_invalidation.replace(false) {
270            if let Some(callback) = self.inner.invalidate_callbacks.borrow().get(&id) {
271                callback();
272            }
273        }
274        id
275    }
276
277    pub(crate) fn remove_invalidate_callback(&self, id: u64) {
278        self.inner.invalidate_callbacks.borrow_mut().remove(&id);
279    }
280
281    fn bump_generation(&self) -> u64 {
282        let next = self.inner.generation.get().wrapping_add(1);
283        self.inner.generation.set(next);
284        next
285    }
286
287    fn clear_if_generation(&self, generation: u64) {
288        if self.inner.generation.get() == generation {
289            self.set_active(false);
290        }
291    }
292
293    fn schedule_clear_after_frames(
294        &self,
295        runtime: RuntimeHandle,
296        generation: u64,
297        frames_remaining: u8,
298    ) {
299        let state = self.clone();
300        let runtime_for_next = runtime.clone();
301        let _ = runtime.register_frame_callback(move |_| {
302            if state.inner.generation.get() != generation {
303                return;
304            }
305            if frames_remaining <= 1 {
306                state.clear_if_generation(generation);
307            } else {
308                state.schedule_clear_after_frames(
309                    runtime_for_next,
310                    generation,
311                    frames_remaining - 1,
312                );
313            }
314        });
315        runtime.schedule();
316    }
317
318    fn invalidate(&self) {
319        let callbacks = self.inner.invalidate_callbacks.borrow();
320        if callbacks.is_empty() {
321            self.inner.pending_invalidation.set(true);
322            return;
323        }
324        for callback in callbacks.values() {
325            callback();
326        }
327    }
328}
329
330/// Element for creating a ScrollNode.
331#[derive(Clone)]
332pub struct ScrollElement {
333    state: ScrollState,
334    is_vertical: bool,
335    reverse_scrolling: bool,
336}
337
338impl ScrollElement {
339    pub fn new(state: ScrollState, is_vertical: bool, reverse_scrolling: bool) -> Self {
340        Self {
341            state,
342            is_vertical,
343            reverse_scrolling,
344        }
345    }
346}
347
348impl std::fmt::Debug for ScrollElement {
349    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350        f.debug_struct("ScrollElement")
351            .field("is_vertical", &self.is_vertical)
352            .field("reverse_scrolling", &self.reverse_scrolling)
353            .finish()
354    }
355}
356
357impl PartialEq for ScrollElement {
358    fn eq(&self, other: &Self) -> bool {
359        // ScrollStates are equal if they point to the same underlying state
360        Rc::ptr_eq(&self.state.inner, &other.state.inner)
361            && self.is_vertical == other.is_vertical
362            && self.reverse_scrolling == other.reverse_scrolling
363    }
364}
365
366impl Eq for ScrollElement {}
367
368impl Hash for ScrollElement {
369    fn hash<H: Hasher>(&self, state: &mut H) {
370        (Rc::as_ptr(&self.state.inner) as usize).hash(state);
371        self.is_vertical.hash(state);
372        self.reverse_scrolling.hash(state);
373    }
374}
375
376impl ModifierNodeElement for ScrollElement {
377    type Node = ScrollNode;
378
379    fn create(&self) -> Self::Node {
380        // println!("ScrollElement::create");
381        ScrollNode::new(self.state.clone(), self.is_vertical, self.reverse_scrolling)
382    }
383
384    fn key(&self) -> Option<u64> {
385        let mut hasher = DefaultHasher::new();
386        self.state.id().hash(&mut hasher);
387        self.reverse_scrolling.hash(&mut hasher);
388        self.is_vertical.hash(&mut hasher);
389        Some(hasher.finish())
390    }
391
392    fn update(&self, node: &mut Self::Node) {
393        let needs_invalidation = !Rc::ptr_eq(&node.state.inner, &self.state.inner)
394            || node.is_vertical != self.is_vertical
395            || node.reverse_scrolling != self.reverse_scrolling;
396
397        if needs_invalidation {
398            node.state = self.state.clone();
399            node.is_vertical = self.is_vertical;
400            node.reverse_scrolling = self.reverse_scrolling;
401        }
402    }
403
404    fn capabilities(&self) -> NodeCapabilities {
405        NodeCapabilities::LAYOUT
406    }
407}
408
409/// ScrollNode layout modifier that physically moves content based on scroll position.
410/// This is the component that actually reads ScrollState and applies the visual offset.
411pub struct ScrollNode {
412    state: ScrollState,
413    is_vertical: bool,
414    reverse_scrolling: bool,
415    node_state: NodeState,
416    /// ID of the invalidation callback registered with ScrollState
417    invalidation_callback_id: Option<u64>,
418    /// We capture the NodeId when attached to ensure correct invalidation scope
419    node_id: Option<NodeId>,
420}
421
422impl ScrollNode {
423    pub fn new(state: ScrollState, is_vertical: bool, reverse_scrolling: bool) -> Self {
424        Self {
425            state,
426            is_vertical,
427            reverse_scrolling,
428            node_state: NodeState::default(),
429            invalidation_callback_id: None,
430            node_id: None,
431        }
432    }
433
434    /// Returns a reference to the ScrollState.
435    pub fn state(&self) -> &ScrollState {
436        &self.state
437    }
438}
439
440impl DelegatableNode for ScrollNode {
441    fn node_state(&self) -> &NodeState {
442        &self.node_state
443    }
444}
445
446impl ModifierNode for ScrollNode {
447    fn on_attach(&mut self, context: &mut dyn ModifierNodeContext) {
448        // Set up the invalidation callback to trigger layout when scroll state changes.
449        // We capture the node_id directly from the context, avoiding any global registry.
450
451        let node_id = context.node_id();
452        self.node_id = node_id;
453
454        if let Some(node_id) = node_id {
455            let callback_id = self.state.add_invalidate_callback(Box::new(move || {
456                // Schedule scoped layout repass for this node
457                crate::schedule_layout_repass(node_id);
458            }));
459            self.invalidation_callback_id = Some(callback_id);
460        } else {
461            log::debug!(
462                "ScrollNode attached without a NodeId; deferring invalidation registration."
463            );
464        }
465
466        // Initial invalidation
467        context.invalidate(cranpose_foundation::InvalidationKind::Layout);
468    }
469
470    fn on_detach(&mut self) {
471        // Remove invalidation callback
472        if let Some(id) = self.invalidation_callback_id.take() {
473            self.state.remove_invalidate_callback(id);
474        }
475    }
476
477    fn as_layout_node(&self) -> Option<&dyn LayoutModifierNode> {
478        Some(self)
479    }
480
481    fn as_layout_node_mut(&mut self) -> Option<&mut dyn LayoutModifierNode> {
482        Some(self)
483    }
484}
485
486impl LayoutModifierNode for ScrollNode {
487    fn measure(
488        &self,
489        _context: &mut dyn ModifierNodeContext,
490        measurable: &dyn Measurable,
491        constraints: Constraints,
492    ) -> LayoutModifierMeasureResult {
493        // Step 1: Give child infinite space in scroll direction
494        let scroll_constraints = if self.is_vertical {
495            Constraints {
496                min_height: 0.0,
497                max_height: f32::INFINITY,
498                ..constraints
499            }
500        } else {
501            Constraints {
502                min_width: 0.0,
503                max_width: f32::INFINITY,
504                ..constraints
505            }
506        };
507
508        // Step 2: Measure child
509        let placeable = measurable.measure(scroll_constraints);
510
511        // Step 3: Calculate viewport size (constrained size)
512        let width = placeable.width().min(constraints.max_width);
513        let height = placeable.height().min(constraints.max_height);
514
515        // Step 4: Calculate max scroll
516        let max_scroll = if self.is_vertical {
517            (placeable.height() - height).max(0.0)
518        } else {
519            (placeable.width() - width).max(0.0)
520        };
521
522        // Step 5: Update state with max scroll value
523        // Only update if the viewport is constrained (not infinite probe)
524        if (self.is_vertical && constraints.max_height.is_finite())
525            || (!self.is_vertical && constraints.max_width.is_finite())
526        {
527            self.state.set_max_value(max_scroll);
528        }
529
530        // Step 6: Read scroll value and calculate offset
531        // IMPORTANT: Use value_non_reactive() during measure to avoid triggering recomposition
532        let scroll = self.state.value_non_reactive().clamp(0.0, max_scroll);
533
534        let abs_scroll = if self.reverse_scrolling {
535            scroll - max_scroll
536        } else {
537            -scroll
538        };
539
540        let (x_offset, y_offset) = if self.is_vertical {
541            (0.0, abs_scroll)
542        } else {
543            (abs_scroll, 0.0)
544        };
545
546        // Step 7: Return result with viewport size and scroll offset as placement_offset
547        // This makes the scroll offset part of the layout modifier's placement, which will be
548        // correctly applied to children by the layout system
549        LayoutModifierMeasureResult::new(Size { width, height }, x_offset, y_offset)
550    }
551
552    fn min_intrinsic_width(&self, measurable: &dyn Measurable, height: f32) -> f32 {
553        measurable.min_intrinsic_width(height)
554    }
555
556    fn max_intrinsic_width(&self, measurable: &dyn Measurable, height: f32) -> f32 {
557        measurable.max_intrinsic_width(height)
558    }
559
560    fn min_intrinsic_height(&self, measurable: &dyn Measurable, width: f32) -> f32 {
561        measurable.min_intrinsic_height(width)
562    }
563
564    fn max_intrinsic_height(&self, measurable: &dyn Measurable, width: f32) -> f32 {
565        measurable.max_intrinsic_height(width)
566    }
567
568    fn create_measurement_proxy(&self) -> Option<Box<dyn cranpose_foundation::MeasurementProxy>> {
569        None
570    }
571}
572
573/// Creates a remembered ScrollState.
574///
575/// This is a convenience function for use in composable functions.
576#[macro_export]
577macro_rules! rememberScrollState {
578    ($initial:expr) => {
579        cranpose_core::remember(|| $crate::scroll::ScrollState::new($initial))
580            .with(|state| state.clone())
581    };
582    () => {
583        rememberScrollState!(0.0)
584    };
585}
586
587#[cfg(test)]
588#[path = "tests/scroll_tests.rs"]
589mod tests;