Skip to main content

cranpose_ui/
scroll.rs

1//! Scroll state and node implementation for cranpose.
2//!
3//! This module provides the core scrolling components:
4//! - `ScrollState`: Holds scroll position and provides scroll control methods
5//! - `ScrollNode`: Layout modifier that applies scroll offset to content
6//! - `ScrollElement`: Element for creating ScrollNode instances
7//!
8//! The actual `Modifier.horizontal_scroll()` and `Modifier.vertical_scroll()`
9//! extension methods are defined in `modifier/scroll.rs`.
10
11use cranpose_core::{ownedMutableStateOf, NodeId, OwnedMutableState};
12use cranpose_foundation::{
13    Constraints, DelegatableNode, LayoutModifierNode, Measurable, ModifierNode,
14    ModifierNodeContext, ModifierNodeElement, NodeCapabilities, NodeState,
15};
16use cranpose_ui_graphics::Size;
17use cranpose_ui_layout::LayoutModifierMeasureResult;
18use std::cell::{Cell, RefCell};
19use std::collections::HashMap;
20use std::hash::{DefaultHasher, Hash, Hasher};
21use std::rc::{Rc, Weak};
22
23/// State object for scroll position tracking.
24///
25/// Holds the current scroll offset and provides methods to programmatically
26/// control scrolling. Can be created with `rememberScrollState()`.
27///
28/// This is a pure scroll model - it does NOT store ephemeral gesture/pointer state.
29/// Gesture state is managed locally in the scroll modifier.
30#[derive(Clone)]
31pub struct ScrollState {
32    inner: Rc<ScrollStateInner>,
33}
34
35pub(crate) struct ScrollStateInner {
36    /// Current scroll offset in pixels.
37    /// Uses MutableState<f32> for reactivity - Composables can observe this value.
38    /// Layout reads use get_non_reactive() to avoid triggering recomposition.
39    value: OwnedMutableState<f32>,
40    /// Maximum scroll value (content_size - viewport_size)
41    /// Using RefCell instead of MutableState to avoid snapshot isolation issues
42    max_value: RefCell<f32>,
43    /// Callbacks to invalidate layout when scroll value changes
44    /// Using HashMap to allow multiple listeners (e.g. real node + clones)
45    invalidate_callbacks: RefCell<HashMap<u64, Rc<dyn Fn()>>>,
46    next_invalidate_callback_id: Cell<u64>,
47    /// Tracks whether we need to invalidate once a callback is registered.
48    pending_invalidation: Cell<bool>,
49}
50
51impl ScrollState {
52    /// Creates a new ScrollState with the given initial scroll position.
53    pub fn new(initial: f32) -> Self {
54        Self {
55            inner: Rc::new(ScrollStateInner {
56                value: ownedMutableStateOf(initial),
57                max_value: RefCell::new(0.0),
58                invalidate_callbacks: RefCell::new(HashMap::new()),
59                next_invalidate_callback_id: Cell::new(1),
60                pending_invalidation: Cell::new(false),
61            }),
62        }
63    }
64
65    /// Get the unique ID of this ScrollState
66    pub fn id(&self) -> u64 {
67        Rc::as_ptr(&self.inner) as usize as u64
68    }
69
70    /// Gets the current scroll position in pixels (reactive - triggers recomposition).
71    ///
72    /// Use this in Composable functions when you want UI to update on scroll.
73    /// Example: `Text("Scroll position: ${scrollState.value()}")`
74    pub fn value(&self) -> f32 {
75        self.inner.value.with(|v| *v)
76    }
77
78    /// Gets the current scroll position in pixels (non-reactive).
79    ///
80    /// Use this in layout/measure phase to avoid triggering recomposition.
81    /// This is called internally by ScrollNode::measure().
82    pub fn value_non_reactive(&self) -> f32 {
83        self.inner.value.get_non_reactive()
84    }
85
86    /// Gets the maximum scroll value.
87    pub fn max_value(&self) -> f32 {
88        *self.inner.max_value.borrow()
89    }
90
91    /// Scrolls by the given delta, clamping to valid range [0, max_value].
92    /// Returns the actual amount scrolled.
93    pub fn dispatch_raw_delta(&self, delta: f32) -> f32 {
94        let current = self.value();
95        let max = self.max_value();
96        let new_value = (current + delta).clamp(0.0, max);
97        let actual_delta = new_value - current;
98
99        if actual_delta.abs() > 0.001 {
100            // Use MutableState::set which triggers snapshot observers for reactive updates
101            self.inner.value.set(new_value);
102
103            self.invalidate();
104        }
105
106        actual_delta
107    }
108
109    /// Sets the maximum scroll value (internal use by ScrollNode).
110    pub(crate) fn set_max_value(&self, max: f32) {
111        *self.inner.max_value.borrow_mut() = max;
112    }
113
114    /// Scrolls to the given position immediately.
115    pub fn scroll_to(&self, position: f32) {
116        let max = self.max_value();
117        let clamped = position.clamp(0.0, max);
118
119        self.inner.value.set(clamped);
120
121        self.invalidate();
122    }
123
124    /// Adds an invalidation callback and returns its ID
125    pub(crate) fn add_invalidate_callback(&self, callback: Box<dyn Fn()>) -> u64 {
126        let id = self.inner.next_invalidate_callback_id.get();
127        self.inner
128            .next_invalidate_callback_id
129            .set(id.saturating_add(1));
130        let callback: Rc<dyn Fn()> = Rc::from(callback);
131        self.inner
132            .invalidate_callbacks
133            .borrow_mut()
134            .insert(id, Rc::clone(&callback));
135        if self.inner.pending_invalidation.replace(false) {
136            callback();
137        }
138        id
139    }
140
141    /// Removes an invalidation callback by ID
142    pub(crate) fn remove_invalidate_callback(&self, id: u64) {
143        self.inner.invalidate_callbacks.borrow_mut().remove(&id);
144    }
145
146    fn invalidate(&self) {
147        let callbacks: Vec<Rc<dyn Fn()>> = {
148            let callbacks = self.inner.invalidate_callbacks.borrow();
149            if callbacks.is_empty() {
150                self.inner.pending_invalidation.set(true);
151                return;
152            }
153            callbacks.values().cloned().collect()
154        };
155        for callback in callbacks {
156            callback();
157        }
158    }
159}
160
161#[derive(Clone)]
162pub(crate) struct ScrollMotionContext {
163    inner: Rc<ScrollMotionContextInner>,
164}
165
166#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
167pub(crate) enum ScrollMotionContextKey {
168    ScrollState {
169        state_id: u64,
170        is_vertical: bool,
171        reverse_scrolling: bool,
172    },
173    LazyList {
174        state_identity: usize,
175        is_vertical: bool,
176        reverse_scrolling: bool,
177    },
178}
179
180struct ScrollMotionContextInner {
181    active: Cell<bool>,
182    transient_active: Cell<bool>,
183    generation: Cell<u64>,
184    invalidate_callbacks: RefCell<HashMap<u64, Rc<dyn Fn()>>>,
185    next_invalidate_callback_id: Cell<u64>,
186    pending_invalidation: Cell<bool>,
187}
188
189pub(crate) struct ScrollMotionContextStore {
190    contexts: RefCell<HashMap<ScrollMotionContextKey, Weak<ScrollMotionContextInner>>>,
191}
192
193impl ScrollMotionContextStore {
194    pub(crate) fn new() -> Self {
195        Self {
196            contexts: RefCell::new(HashMap::new()),
197        }
198    }
199
200    fn context_for_key(&self, key: ScrollMotionContextKey) -> ScrollMotionContext {
201        let mut contexts = self.contexts.borrow_mut();
202        if let Some(inner) = contexts.get(&key).and_then(Weak::upgrade) {
203            return ScrollMotionContext { inner };
204        }
205
206        let context = ScrollMotionContext::new();
207        contexts.insert(key, Rc::downgrade(&context.inner));
208        contexts.retain(|_, weak| weak.strong_count() > 0);
209        context
210    }
211
212    pub(crate) fn clear_transient_after_frame(&self) {
213        let contexts = {
214            let mut contexts = self.contexts.borrow_mut();
215            let live = contexts
216                .values()
217                .filter_map(Weak::upgrade)
218                .collect::<Vec<_>>();
219            contexts.retain(|_, weak| weak.strong_count() > 0);
220            live
221        };
222        for inner in contexts {
223            ScrollMotionContext { inner }.clear_transient_after_frame();
224        }
225    }
226}
227
228pub(crate) fn scroll_motion_context_for_key(key: ScrollMotionContextKey) -> ScrollMotionContext {
229    crate::render_state::with_scroll_motion_context_store(|store| store.context_for_key(key))
230}
231
232impl ScrollMotionContext {
233    pub(crate) fn new() -> Self {
234        Self {
235            inner: Rc::new(ScrollMotionContextInner {
236                active: Cell::new(false),
237                transient_active: Cell::new(false),
238                generation: Cell::new(0),
239                invalidate_callbacks: RefCell::new(HashMap::new()),
240                next_invalidate_callback_id: Cell::new(1),
241                pending_invalidation: Cell::new(false),
242            }),
243        }
244    }
245
246    pub(crate) fn is_active(&self) -> bool {
247        self.inner.active.get() || self.inner.transient_active.get()
248    }
249
250    pub(crate) fn ptr_eq(&self, other: &Self) -> bool {
251        Rc::ptr_eq(&self.inner, &other.inner)
252    }
253
254    pub(crate) fn stable_key(&self) -> usize {
255        Rc::as_ptr(&self.inner) as usize
256    }
257
258    pub(crate) fn set_active(&self, active: bool) {
259        let was_active = self.is_active();
260        self.inner.active.set(active);
261        if !active {
262            self.inner.transient_active.set(false);
263        }
264        if was_active != self.is_active() {
265            self.bump_generation();
266            self.invalidate();
267        }
268    }
269
270    pub(crate) fn activate_for_current_frame(&self) {
271        let was_active = self.is_active();
272        self.inner.transient_active.set(true);
273        self.bump_generation();
274        if !was_active {
275            self.invalidate();
276        }
277    }
278
279    pub(crate) fn add_invalidate_callback(&self, callback: Box<dyn Fn()>) -> u64 {
280        let id = self.inner.next_invalidate_callback_id.get();
281        self.inner
282            .next_invalidate_callback_id
283            .set(id.saturating_add(1));
284        let callback: Rc<dyn Fn()> = Rc::from(callback);
285        self.inner
286            .invalidate_callbacks
287            .borrow_mut()
288            .insert(id, Rc::clone(&callback));
289        if self.inner.pending_invalidation.replace(false) {
290            callback();
291        }
292        id
293    }
294
295    pub(crate) fn remove_invalidate_callback(&self, id: u64) {
296        self.inner.invalidate_callbacks.borrow_mut().remove(&id);
297    }
298
299    fn bump_generation(&self) -> u64 {
300        let next = self.inner.generation.get().wrapping_add(1);
301        self.inner.generation.set(next);
302        next
303    }
304
305    fn clear_transient_after_frame(&self) {
306        let was_active = self.is_active();
307        if self.inner.transient_active.replace(false) {
308            self.bump_generation();
309            if was_active != self.is_active() {
310                self.invalidate();
311            }
312        }
313    }
314
315    fn invalidate(&self) {
316        let callbacks: Vec<Rc<dyn Fn()>> = {
317            let callbacks = self.inner.invalidate_callbacks.borrow();
318            if callbacks.is_empty() {
319                self.inner.pending_invalidation.set(true);
320                return;
321            }
322            callbacks.values().cloned().collect()
323        };
324        for callback in callbacks {
325            callback();
326        }
327    }
328}
329
330/// Element for creating a ScrollNode.
331#[derive(Clone)]
332pub struct ScrollElement {
333    state: ScrollState,
334    is_vertical: bool,
335    reverse_scrolling: bool,
336}
337
338impl ScrollElement {
339    pub fn new(state: ScrollState, is_vertical: bool, reverse_scrolling: bool) -> Self {
340        Self {
341            state,
342            is_vertical,
343            reverse_scrolling,
344        }
345    }
346}
347
348impl std::fmt::Debug for ScrollElement {
349    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350        f.debug_struct("ScrollElement")
351            .field("is_vertical", &self.is_vertical)
352            .field("reverse_scrolling", &self.reverse_scrolling)
353            .finish()
354    }
355}
356
357impl PartialEq for ScrollElement {
358    fn eq(&self, other: &Self) -> bool {
359        // ScrollStates are equal if they point to the same underlying state
360        Rc::ptr_eq(&self.state.inner, &other.state.inner)
361            && self.is_vertical == other.is_vertical
362            && self.reverse_scrolling == other.reverse_scrolling
363    }
364}
365
366impl Eq for ScrollElement {}
367
368impl Hash for ScrollElement {
369    fn hash<H: Hasher>(&self, state: &mut H) {
370        (Rc::as_ptr(&self.state.inner) as usize).hash(state);
371        self.is_vertical.hash(state);
372        self.reverse_scrolling.hash(state);
373    }
374}
375
376impl ModifierNodeElement for ScrollElement {
377    type Node = ScrollNode;
378
379    fn create(&self) -> Self::Node {
380        // println!("ScrollElement::create");
381        ScrollNode::new(self.state.clone(), self.is_vertical, self.reverse_scrolling)
382    }
383
384    fn key(&self) -> Option<u64> {
385        let mut hasher = DefaultHasher::new();
386        self.state.id().hash(&mut hasher);
387        self.reverse_scrolling.hash(&mut hasher);
388        self.is_vertical.hash(&mut hasher);
389        Some(hasher.finish())
390    }
391
392    fn update(&self, node: &mut Self::Node) {
393        let needs_invalidation = !Rc::ptr_eq(&node.state.inner, &self.state.inner)
394            || node.is_vertical != self.is_vertical
395            || node.reverse_scrolling != self.reverse_scrolling;
396
397        if needs_invalidation {
398            node.state = self.state.clone();
399            node.is_vertical = self.is_vertical;
400            node.reverse_scrolling = self.reverse_scrolling;
401        }
402    }
403
404    fn capabilities(&self) -> NodeCapabilities {
405        NodeCapabilities::LAYOUT
406    }
407}
408
409/// ScrollNode layout modifier that physically moves content based on scroll position.
410/// This is the component that actually reads ScrollState and applies the visual offset.
411pub struct ScrollNode {
412    state: ScrollState,
413    is_vertical: bool,
414    reverse_scrolling: bool,
415    node_state: NodeState,
416    /// ID of the invalidation callback registered with ScrollState
417    invalidation_callback_id: Option<u64>,
418    /// We capture the NodeId when attached to ensure correct invalidation scope
419    node_id: Option<NodeId>,
420}
421
422impl ScrollNode {
423    pub fn new(state: ScrollState, is_vertical: bool, reverse_scrolling: bool) -> Self {
424        Self {
425            state,
426            is_vertical,
427            reverse_scrolling,
428            node_state: NodeState::default(),
429            invalidation_callback_id: None,
430            node_id: None,
431        }
432    }
433
434    /// Returns a reference to the ScrollState.
435    pub fn state(&self) -> &ScrollState {
436        &self.state
437    }
438}
439
440impl DelegatableNode for ScrollNode {
441    fn node_state(&self) -> &NodeState {
442        &self.node_state
443    }
444}
445
446impl ModifierNode for ScrollNode {
447    fn on_attach(&mut self, context: &mut dyn ModifierNodeContext) {
448        // Set up the invalidation callback to trigger layout when scroll state changes.
449        // We capture the node_id directly from the context, avoiding any global registry.
450
451        let node_id = context.node_id();
452        self.node_id = node_id;
453
454        if let Some(node_id) = node_id {
455            let callback_id = self.state.add_invalidate_callback(Box::new(move || {
456                // Schedule scoped layout repass for this node
457                crate::schedule_layout_repass(node_id);
458            }));
459            self.invalidation_callback_id = Some(callback_id);
460        } else {
461            log::debug!(
462                "ScrollNode attached without a NodeId; deferring invalidation registration."
463            );
464        }
465
466        // Initial invalidation
467        context.invalidate(cranpose_foundation::InvalidationKind::Layout);
468    }
469
470    fn on_detach(&mut self) {
471        // Remove invalidation callback
472        if let Some(id) = self.invalidation_callback_id.take() {
473            self.state.remove_invalidate_callback(id);
474        }
475    }
476
477    fn as_layout_node(&self) -> Option<&dyn LayoutModifierNode> {
478        Some(self)
479    }
480
481    fn as_layout_node_mut(&mut self) -> Option<&mut dyn LayoutModifierNode> {
482        Some(self)
483    }
484}
485
486impl LayoutModifierNode for ScrollNode {
487    fn measure(
488        &self,
489        _context: &mut dyn ModifierNodeContext,
490        measurable: &dyn Measurable,
491        constraints: Constraints,
492    ) -> LayoutModifierMeasureResult {
493        // Step 1: Give child infinite space in scroll direction
494        let scroll_constraints = if self.is_vertical {
495            Constraints {
496                min_height: 0.0,
497                max_height: f32::INFINITY,
498                ..constraints
499            }
500        } else {
501            Constraints {
502                min_width: 0.0,
503                max_width: f32::INFINITY,
504                ..constraints
505            }
506        };
507
508        // Step 2: Measure child
509        let placeable = measurable.measure(scroll_constraints);
510
511        // Step 3: Calculate viewport size (constrained size)
512        let width = placeable.width().min(constraints.max_width);
513        let height = placeable.height().min(constraints.max_height);
514
515        // Step 4: Calculate max scroll
516        let max_scroll = if self.is_vertical {
517            (placeable.height() - height).max(0.0)
518        } else {
519            (placeable.width() - width).max(0.0)
520        };
521
522        // Step 5: Update state with max scroll value
523        // Only update if the viewport is constrained (not infinite probe)
524        if (self.is_vertical && constraints.max_height.is_finite())
525            || (!self.is_vertical && constraints.max_width.is_finite())
526        {
527            self.state.set_max_value(max_scroll);
528        }
529
530        // Step 6: Read scroll value and calculate offset
531        // IMPORTANT: Use value_non_reactive() during measure to avoid triggering recomposition
532        let scroll = self.state.value_non_reactive().clamp(0.0, max_scroll);
533
534        let abs_scroll = if self.reverse_scrolling {
535            scroll - max_scroll
536        } else {
537            -scroll
538        };
539
540        let (x_offset, y_offset) = if self.is_vertical {
541            (0.0, abs_scroll)
542        } else {
543            (abs_scroll, 0.0)
544        };
545
546        // Step 7: Return result with viewport size and scroll offset as placement_offset
547        // This makes the scroll offset part of the layout modifier's placement, which will be
548        // correctly applied to children by the layout system
549        LayoutModifierMeasureResult::new(Size { width, height }, x_offset, y_offset)
550    }
551
552    fn min_intrinsic_width(&self, measurable: &dyn Measurable, height: f32) -> f32 {
553        measurable.min_intrinsic_width(height)
554    }
555
556    fn max_intrinsic_width(&self, measurable: &dyn Measurable, height: f32) -> f32 {
557        measurable.max_intrinsic_width(height)
558    }
559
560    fn min_intrinsic_height(&self, measurable: &dyn Measurable, width: f32) -> f32 {
561        measurable.min_intrinsic_height(width)
562    }
563
564    fn max_intrinsic_height(&self, measurable: &dyn Measurable, width: f32) -> f32 {
565        measurable.max_intrinsic_height(width)
566    }
567}
568
569/// Creates a remembered ScrollState.
570///
571/// This is a convenience function for use in composable functions.
572#[macro_export]
573macro_rules! rememberScrollState {
574    ($initial:expr) => {
575        cranpose_core::remember(|| $crate::scroll::ScrollState::new($initial))
576            .with(|state| state.clone())
577    };
578    () => {
579        rememberScrollState!(0.0)
580    };
581}
582
583#[cfg(test)]
584#[path = "tests/scroll_tests.rs"]
585mod tests;