Skip to main content

cranpose_ui/
scroll.rs

1//! Scroll state and node implementation for cranpose.
2//!
3//! This module provides the core scrolling components:
4//! - `ScrollState`: Holds scroll position and provides scroll control methods
5//! - `ScrollNode`: Layout modifier that applies scroll offset to content
6//! - `ScrollElement`: Element for creating ScrollNode instances
7//!
8//! The actual `Modifier.horizontal_scroll()` and `Modifier.vertical_scroll()`
9//! extension methods are defined in `modifier/scroll.rs`.
10
11use cranpose_core::{
12    current_runtime_handle, ownedMutableStateOf, NodeId, OwnedMutableState, RuntimeHandle,
13};
14use cranpose_foundation::{
15    Constraints, DelegatableNode, LayoutModifierNode, Measurable, ModifierNode,
16    ModifierNodeContext, ModifierNodeElement, NodeCapabilities, NodeState,
17};
18use cranpose_ui_graphics::Size;
19use cranpose_ui_layout::LayoutModifierMeasureResult;
20use std::cell::{Cell, RefCell};
21use std::collections::HashMap;
22use std::hash::{DefaultHasher, Hash, Hasher};
23use std::rc::{Rc, Weak};
24
25const SCROLL_MOTION_ACTIVE_FRAME_COUNT: u8 = 6;
26
27/// State object for scroll position tracking.
28///
29/// Holds the current scroll offset and provides methods to programmatically
30/// control scrolling. Can be created with `rememberScrollState()`.
31///
32/// This is a pure scroll model - it does NOT store ephemeral gesture/pointer state.
33/// Gesture state is managed locally in the scroll modifier.
34#[derive(Clone)]
35pub struct ScrollState {
36    inner: Rc<ScrollStateInner>,
37}
38
39pub(crate) struct ScrollStateInner {
40    /// Current scroll offset in pixels.
41    /// Uses MutableState<f32> for reactivity - Composables can observe this value.
42    /// Layout reads use get_non_reactive() to avoid triggering recomposition.
43    value: OwnedMutableState<f32>,
44    /// Maximum scroll value (content_size - viewport_size)
45    /// Using RefCell instead of MutableState to avoid snapshot isolation issues
46    max_value: RefCell<f32>,
47    /// Callbacks to invalidate layout when scroll value changes
48    /// Using HashMap to allow multiple listeners (e.g. real node + clones)
49    invalidate_callbacks: RefCell<HashMap<u64, Rc<dyn Fn()>>>,
50    next_invalidate_callback_id: Cell<u64>,
51    /// Tracks whether we need to invalidate once a callback is registered.
52    pending_invalidation: Cell<bool>,
53}
54
55impl ScrollState {
56    /// Creates a new ScrollState with the given initial scroll position.
57    pub fn new(initial: f32) -> Self {
58        Self {
59            inner: Rc::new(ScrollStateInner {
60                value: ownedMutableStateOf(initial),
61                max_value: RefCell::new(0.0),
62                invalidate_callbacks: RefCell::new(HashMap::new()),
63                next_invalidate_callback_id: Cell::new(1),
64                pending_invalidation: Cell::new(false),
65            }),
66        }
67    }
68
69    /// Get the unique ID of this ScrollState
70    pub fn id(&self) -> u64 {
71        Rc::as_ptr(&self.inner) as usize as u64
72    }
73
74    /// Gets the current scroll position in pixels (reactive - triggers recomposition).
75    ///
76    /// Use this in Composable functions when you want UI to update on scroll.
77    /// Example: `Text("Scroll position: ${scrollState.value()}")`
78    pub fn value(&self) -> f32 {
79        self.inner.value.with(|v| *v)
80    }
81
82    /// Gets the current scroll position in pixels (non-reactive).
83    ///
84    /// Use this in layout/measure phase to avoid triggering recomposition.
85    /// This is called internally by ScrollNode::measure().
86    pub fn value_non_reactive(&self) -> f32 {
87        self.inner.value.get_non_reactive()
88    }
89
90    /// Gets the maximum scroll value.
91    pub fn max_value(&self) -> f32 {
92        *self.inner.max_value.borrow()
93    }
94
95    /// Scrolls by the given delta, clamping to valid range [0, max_value].
96    /// Returns the actual amount scrolled.
97    pub fn dispatch_raw_delta(&self, delta: f32) -> f32 {
98        let current = self.value();
99        let max = self.max_value();
100        let new_value = (current + delta).clamp(0.0, max);
101        let actual_delta = new_value - current;
102
103        if actual_delta.abs() > 0.001 {
104            // Use MutableState::set which triggers snapshot observers for reactive updates
105            self.inner.value.set(new_value);
106
107            self.invalidate();
108        }
109
110        actual_delta
111    }
112
113    /// Sets the maximum scroll value (internal use by ScrollNode).
114    pub(crate) fn set_max_value(&self, max: f32) {
115        *self.inner.max_value.borrow_mut() = max;
116    }
117
118    /// Scrolls to the given position immediately.
119    pub fn scroll_to(&self, position: f32) {
120        let max = self.max_value();
121        let clamped = position.clamp(0.0, max);
122
123        self.inner.value.set(clamped);
124
125        self.invalidate();
126    }
127
128    /// Adds an invalidation callback and returns its ID
129    pub(crate) fn add_invalidate_callback(&self, callback: Box<dyn Fn()>) -> u64 {
130        let id = self.inner.next_invalidate_callback_id.get();
131        self.inner
132            .next_invalidate_callback_id
133            .set(id.saturating_add(1));
134        let callback: Rc<dyn Fn()> = Rc::from(callback);
135        self.inner
136            .invalidate_callbacks
137            .borrow_mut()
138            .insert(id, Rc::clone(&callback));
139        if self.inner.pending_invalidation.replace(false) {
140            callback();
141        }
142        id
143    }
144
145    /// Removes an invalidation callback by ID
146    pub(crate) fn remove_invalidate_callback(&self, id: u64) {
147        self.inner.invalidate_callbacks.borrow_mut().remove(&id);
148    }
149
150    fn invalidate(&self) {
151        let callbacks: Vec<Rc<dyn Fn()>> = {
152            let callbacks = self.inner.invalidate_callbacks.borrow();
153            if callbacks.is_empty() {
154                self.inner.pending_invalidation.set(true);
155                return;
156            }
157            callbacks.values().cloned().collect()
158        };
159        for callback in callbacks {
160            callback();
161        }
162    }
163}
164
165#[derive(Clone)]
166pub(crate) struct ScrollMotionContext {
167    inner: Rc<ScrollMotionContextInner>,
168}
169
170#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
171pub(crate) enum ScrollMotionContextKey {
172    ScrollState {
173        state_id: u64,
174        is_vertical: bool,
175        reverse_scrolling: bool,
176    },
177    LazyList {
178        state_identity: usize,
179        is_vertical: bool,
180        reverse_scrolling: bool,
181    },
182}
183
184struct ScrollMotionContextInner {
185    active: Cell<bool>,
186    generation: Cell<u64>,
187    invalidate_callbacks: RefCell<HashMap<u64, Rc<dyn Fn()>>>,
188    next_invalidate_callback_id: Cell<u64>,
189    pending_invalidation: Cell<bool>,
190}
191
192pub(crate) struct ScrollMotionContextStore {
193    contexts: RefCell<HashMap<ScrollMotionContextKey, Weak<ScrollMotionContextInner>>>,
194}
195
196impl ScrollMotionContextStore {
197    pub(crate) fn new() -> Self {
198        Self {
199            contexts: RefCell::new(HashMap::new()),
200        }
201    }
202
203    fn context_for_key(&self, key: ScrollMotionContextKey) -> ScrollMotionContext {
204        let mut contexts = self.contexts.borrow_mut();
205        if let Some(inner) = contexts.get(&key).and_then(Weak::upgrade) {
206            return ScrollMotionContext { inner };
207        }
208
209        let context = ScrollMotionContext::new();
210        contexts.insert(key, Rc::downgrade(&context.inner));
211        contexts.retain(|_, weak| weak.strong_count() > 0);
212        context
213    }
214}
215
216pub(crate) fn scroll_motion_context_for_key(key: ScrollMotionContextKey) -> ScrollMotionContext {
217    crate::render_state::with_scroll_motion_context_store(|store| store.context_for_key(key))
218}
219
220impl ScrollMotionContext {
221    pub(crate) fn new() -> Self {
222        Self {
223            inner: Rc::new(ScrollMotionContextInner {
224                active: Cell::new(false),
225                generation: Cell::new(0),
226                invalidate_callbacks: RefCell::new(HashMap::new()),
227                next_invalidate_callback_id: Cell::new(1),
228                pending_invalidation: Cell::new(false),
229            }),
230        }
231    }
232
233    pub(crate) fn is_active(&self) -> bool {
234        self.inner.active.get()
235    }
236
237    pub(crate) fn ptr_eq(&self, other: &Self) -> bool {
238        Rc::ptr_eq(&self.inner, &other.inner)
239    }
240
241    pub(crate) fn stable_key(&self) -> usize {
242        Rc::as_ptr(&self.inner) as usize
243    }
244
245    pub(crate) fn set_active(&self, active: bool) {
246        if self.inner.active.replace(active) != active {
247            self.bump_generation();
248            self.invalidate();
249        }
250    }
251
252    pub(crate) fn activate_for_next_frame(&self) {
253        let was_active = self.inner.active.replace(true);
254        let generation = self.bump_generation();
255        if !was_active {
256            self.invalidate();
257        }
258        if let Some(runtime) = current_runtime_handle() {
259            self.schedule_clear_after_frames(runtime, generation, SCROLL_MOTION_ACTIVE_FRAME_COUNT);
260        } else {
261            self.clear_if_generation(generation);
262        }
263    }
264
265    pub(crate) fn add_invalidate_callback(&self, callback: Box<dyn Fn()>) -> u64 {
266        let id = self.inner.next_invalidate_callback_id.get();
267        self.inner
268            .next_invalidate_callback_id
269            .set(id.saturating_add(1));
270        let callback: Rc<dyn Fn()> = Rc::from(callback);
271        self.inner
272            .invalidate_callbacks
273            .borrow_mut()
274            .insert(id, Rc::clone(&callback));
275        if self.inner.pending_invalidation.replace(false) {
276            callback();
277        }
278        id
279    }
280
281    pub(crate) fn remove_invalidate_callback(&self, id: u64) {
282        self.inner.invalidate_callbacks.borrow_mut().remove(&id);
283    }
284
285    fn bump_generation(&self) -> u64 {
286        let next = self.inner.generation.get().wrapping_add(1);
287        self.inner.generation.set(next);
288        next
289    }
290
291    fn clear_if_generation(&self, generation: u64) {
292        if self.inner.generation.get() == generation {
293            self.set_active(false);
294        }
295    }
296
297    fn schedule_clear_after_frames(
298        &self,
299        runtime: RuntimeHandle,
300        generation: u64,
301        frames_remaining: u8,
302    ) {
303        let state = self.clone();
304        let runtime_for_next = runtime.clone();
305        let _ = runtime.register_frame_callback(move |_| {
306            if state.inner.generation.get() != generation {
307                return;
308            }
309            if frames_remaining <= 1 {
310                state.clear_if_generation(generation);
311            } else {
312                state.schedule_clear_after_frames(
313                    runtime_for_next,
314                    generation,
315                    frames_remaining - 1,
316                );
317            }
318        });
319        runtime.schedule();
320    }
321
322    fn invalidate(&self) {
323        let callbacks: Vec<Rc<dyn Fn()>> = {
324            let callbacks = self.inner.invalidate_callbacks.borrow();
325            if callbacks.is_empty() {
326                self.inner.pending_invalidation.set(true);
327                return;
328            }
329            callbacks.values().cloned().collect()
330        };
331        for callback in callbacks {
332            callback();
333        }
334    }
335}
336
337/// Element for creating a ScrollNode.
338#[derive(Clone)]
339pub struct ScrollElement {
340    state: ScrollState,
341    is_vertical: bool,
342    reverse_scrolling: bool,
343}
344
345impl ScrollElement {
346    pub fn new(state: ScrollState, is_vertical: bool, reverse_scrolling: bool) -> Self {
347        Self {
348            state,
349            is_vertical,
350            reverse_scrolling,
351        }
352    }
353}
354
355impl std::fmt::Debug for ScrollElement {
356    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
357        f.debug_struct("ScrollElement")
358            .field("is_vertical", &self.is_vertical)
359            .field("reverse_scrolling", &self.reverse_scrolling)
360            .finish()
361    }
362}
363
364impl PartialEq for ScrollElement {
365    fn eq(&self, other: &Self) -> bool {
366        // ScrollStates are equal if they point to the same underlying state
367        Rc::ptr_eq(&self.state.inner, &other.state.inner)
368            && self.is_vertical == other.is_vertical
369            && self.reverse_scrolling == other.reverse_scrolling
370    }
371}
372
373impl Eq for ScrollElement {}
374
375impl Hash for ScrollElement {
376    fn hash<H: Hasher>(&self, state: &mut H) {
377        (Rc::as_ptr(&self.state.inner) as usize).hash(state);
378        self.is_vertical.hash(state);
379        self.reverse_scrolling.hash(state);
380    }
381}
382
383impl ModifierNodeElement for ScrollElement {
384    type Node = ScrollNode;
385
386    fn create(&self) -> Self::Node {
387        // println!("ScrollElement::create");
388        ScrollNode::new(self.state.clone(), self.is_vertical, self.reverse_scrolling)
389    }
390
391    fn key(&self) -> Option<u64> {
392        let mut hasher = DefaultHasher::new();
393        self.state.id().hash(&mut hasher);
394        self.reverse_scrolling.hash(&mut hasher);
395        self.is_vertical.hash(&mut hasher);
396        Some(hasher.finish())
397    }
398
399    fn update(&self, node: &mut Self::Node) {
400        let needs_invalidation = !Rc::ptr_eq(&node.state.inner, &self.state.inner)
401            || node.is_vertical != self.is_vertical
402            || node.reverse_scrolling != self.reverse_scrolling;
403
404        if needs_invalidation {
405            node.state = self.state.clone();
406            node.is_vertical = self.is_vertical;
407            node.reverse_scrolling = self.reverse_scrolling;
408        }
409    }
410
411    fn capabilities(&self) -> NodeCapabilities {
412        NodeCapabilities::LAYOUT
413    }
414}
415
416/// ScrollNode layout modifier that physically moves content based on scroll position.
417/// This is the component that actually reads ScrollState and applies the visual offset.
418pub struct ScrollNode {
419    state: ScrollState,
420    is_vertical: bool,
421    reverse_scrolling: bool,
422    node_state: NodeState,
423    /// ID of the invalidation callback registered with ScrollState
424    invalidation_callback_id: Option<u64>,
425    /// We capture the NodeId when attached to ensure correct invalidation scope
426    node_id: Option<NodeId>,
427}
428
429impl ScrollNode {
430    pub fn new(state: ScrollState, is_vertical: bool, reverse_scrolling: bool) -> Self {
431        Self {
432            state,
433            is_vertical,
434            reverse_scrolling,
435            node_state: NodeState::default(),
436            invalidation_callback_id: None,
437            node_id: None,
438        }
439    }
440
441    /// Returns a reference to the ScrollState.
442    pub fn state(&self) -> &ScrollState {
443        &self.state
444    }
445}
446
447impl DelegatableNode for ScrollNode {
448    fn node_state(&self) -> &NodeState {
449        &self.node_state
450    }
451}
452
453impl ModifierNode for ScrollNode {
454    fn on_attach(&mut self, context: &mut dyn ModifierNodeContext) {
455        // Set up the invalidation callback to trigger layout when scroll state changes.
456        // We capture the node_id directly from the context, avoiding any global registry.
457
458        let node_id = context.node_id();
459        self.node_id = node_id;
460
461        if let Some(node_id) = node_id {
462            let callback_id = self.state.add_invalidate_callback(Box::new(move || {
463                // Schedule scoped layout repass for this node
464                crate::schedule_layout_repass(node_id);
465            }));
466            self.invalidation_callback_id = Some(callback_id);
467        } else {
468            log::debug!(
469                "ScrollNode attached without a NodeId; deferring invalidation registration."
470            );
471        }
472
473        // Initial invalidation
474        context.invalidate(cranpose_foundation::InvalidationKind::Layout);
475    }
476
477    fn on_detach(&mut self) {
478        // Remove invalidation callback
479        if let Some(id) = self.invalidation_callback_id.take() {
480            self.state.remove_invalidate_callback(id);
481        }
482    }
483
484    fn as_layout_node(&self) -> Option<&dyn LayoutModifierNode> {
485        Some(self)
486    }
487
488    fn as_layout_node_mut(&mut self) -> Option<&mut dyn LayoutModifierNode> {
489        Some(self)
490    }
491}
492
493impl LayoutModifierNode for ScrollNode {
494    fn measure(
495        &self,
496        _context: &mut dyn ModifierNodeContext,
497        measurable: &dyn Measurable,
498        constraints: Constraints,
499    ) -> LayoutModifierMeasureResult {
500        // Step 1: Give child infinite space in scroll direction
501        let scroll_constraints = if self.is_vertical {
502            Constraints {
503                min_height: 0.0,
504                max_height: f32::INFINITY,
505                ..constraints
506            }
507        } else {
508            Constraints {
509                min_width: 0.0,
510                max_width: f32::INFINITY,
511                ..constraints
512            }
513        };
514
515        // Step 2: Measure child
516        let placeable = measurable.measure(scroll_constraints);
517
518        // Step 3: Calculate viewport size (constrained size)
519        let width = placeable.width().min(constraints.max_width);
520        let height = placeable.height().min(constraints.max_height);
521
522        // Step 4: Calculate max scroll
523        let max_scroll = if self.is_vertical {
524            (placeable.height() - height).max(0.0)
525        } else {
526            (placeable.width() - width).max(0.0)
527        };
528
529        // Step 5: Update state with max scroll value
530        // Only update if the viewport is constrained (not infinite probe)
531        if (self.is_vertical && constraints.max_height.is_finite())
532            || (!self.is_vertical && constraints.max_width.is_finite())
533        {
534            self.state.set_max_value(max_scroll);
535        }
536
537        // Step 6: Read scroll value and calculate offset
538        // IMPORTANT: Use value_non_reactive() during measure to avoid triggering recomposition
539        let scroll = self.state.value_non_reactive().clamp(0.0, max_scroll);
540
541        let abs_scroll = if self.reverse_scrolling {
542            scroll - max_scroll
543        } else {
544            -scroll
545        };
546
547        let (x_offset, y_offset) = if self.is_vertical {
548            (0.0, abs_scroll)
549        } else {
550            (abs_scroll, 0.0)
551        };
552
553        // Step 7: Return result with viewport size and scroll offset as placement_offset
554        // This makes the scroll offset part of the layout modifier's placement, which will be
555        // correctly applied to children by the layout system
556        LayoutModifierMeasureResult::new(Size { width, height }, x_offset, y_offset)
557    }
558
559    fn min_intrinsic_width(&self, measurable: &dyn Measurable, height: f32) -> f32 {
560        measurable.min_intrinsic_width(height)
561    }
562
563    fn max_intrinsic_width(&self, measurable: &dyn Measurable, height: f32) -> f32 {
564        measurable.max_intrinsic_width(height)
565    }
566
567    fn min_intrinsic_height(&self, measurable: &dyn Measurable, width: f32) -> f32 {
568        measurable.min_intrinsic_height(width)
569    }
570
571    fn max_intrinsic_height(&self, measurable: &dyn Measurable, width: f32) -> f32 {
572        measurable.max_intrinsic_height(width)
573    }
574}
575
576/// Creates a remembered ScrollState.
577///
578/// This is a convenience function for use in composable functions.
579#[macro_export]
580macro_rules! rememberScrollState {
581    ($initial:expr) => {
582        cranpose_core::remember(|| $crate::scroll::ScrollState::new($initial))
583            .with(|state| state.clone())
584    };
585    () => {
586        rememberScrollState!(0.0)
587    };
588}
589
590#[cfg(test)]
591#[path = "tests/scroll_tests.rs"]
592mod tests;