Skip to main content

cranpose_ui/
scroll.rs

1//! Scroll state and node implementation for cranpose.
2//!
3//! This module provides the core scrolling components:
4//! - `ScrollState`: Holds scroll position and provides scroll control methods
5//! - `ScrollNode`: Layout modifier that applies scroll offset to content
6//! - `ScrollElement`: Element for creating ScrollNode instances
7//!
8//! The actual `Modifier.horizontal_scroll()` and `Modifier.vertical_scroll()`
9//! extension methods are defined in `modifier/scroll.rs`.
10
11use cranpose_core::{current_runtime_handle, ownedMutableStateOf, NodeId, OwnedMutableState};
12use cranpose_foundation::{
13    Constraints, DelegatableNode, LayoutModifierNode, Measurable, ModifierNode,
14    ModifierNodeContext, ModifierNodeElement, NodeCapabilities, NodeState,
15};
16use cranpose_ui_graphics::Size;
17use cranpose_ui_layout::LayoutModifierMeasureResult;
18use std::cell::{Cell, RefCell};
19use std::hash::{DefaultHasher, Hash, Hasher};
20use std::rc::Rc;
21use std::sync::atomic::{AtomicU64, Ordering};
22
23static NEXT_SCROLL_STATE_ID: AtomicU64 = AtomicU64::new(1);
24
25/// State object for scroll position tracking.
26///
27/// Holds the current scroll offset and provides methods to programmatically
28/// control scrolling. Can be created with `rememberScrollState()`.
29///
30/// This is a pure scroll model - it does NOT store ephemeral gesture/pointer state.
31/// Gesture state is managed locally in the scroll modifier.
32#[derive(Clone)]
33pub struct ScrollState {
34    inner: Rc<ScrollStateInner>,
35}
36
37pub(crate) struct ScrollStateInner {
38    /// Unique ID for debugging
39    id: u64,
40    /// Current scroll offset in pixels.
41    /// Uses MutableState<f32> for reactivity - Composables can observe this value.
42    /// Layout reads use get_non_reactive() to avoid triggering recomposition.
43    value: OwnedMutableState<f32>,
44    /// Maximum scroll value (content_size - viewport_size)
45    /// Using RefCell instead of MutableState to avoid snapshot isolation issues
46    max_value: RefCell<f32>,
47    /// Callbacks to invalidate layout when scroll value changes
48    /// Using HashMap to allow multiple listeners (e.g. real node + clones)
49    invalidate_callbacks: RefCell<std::collections::HashMap<u64, Box<dyn Fn()>>>,
50    /// Tracks whether we need to invalidate once a callback is registered.
51    pending_invalidation: Cell<bool>,
52}
53
54impl ScrollState {
55    /// Creates a new ScrollState with the given initial scroll position.
56    pub fn new(initial: f32) -> Self {
57        let id = NEXT_SCROLL_STATE_ID.fetch_add(1, Ordering::Relaxed);
58
59        Self {
60            inner: Rc::new(ScrollStateInner {
61                id,
62                value: ownedMutableStateOf(initial),
63                max_value: RefCell::new(0.0),
64                invalidate_callbacks: RefCell::new(std::collections::HashMap::new()),
65                pending_invalidation: Cell::new(false),
66            }),
67        }
68    }
69
70    /// Get the unique ID of this ScrollState
71    pub fn id(&self) -> u64 {
72        self.inner.id
73    }
74
75    /// Gets the current scroll position in pixels (reactive - triggers recomposition).
76    ///
77    /// Use this in Composable functions when you want UI to update on scroll.
78    /// Example: `Text("Scroll position: ${scrollState.value()}")`
79    pub fn value(&self) -> f32 {
80        self.inner.value.with(|v| *v)
81    }
82
83    /// Gets the current scroll position in pixels (non-reactive).
84    ///
85    /// Use this in layout/measure phase to avoid triggering recomposition.
86    /// This is called internally by ScrollNode::measure().
87    pub fn value_non_reactive(&self) -> f32 {
88        self.inner.value.get_non_reactive()
89    }
90
91    /// Gets the maximum scroll value.
92    pub fn max_value(&self) -> f32 {
93        *self.inner.max_value.borrow()
94    }
95
96    /// Scrolls by the given delta, clamping to valid range [0, max_value].
97    /// Returns the actual amount scrolled.
98    pub fn dispatch_raw_delta(&self, delta: f32) -> f32 {
99        let current = self.value();
100        let max = self.max_value();
101        let new_value = (current + delta).clamp(0.0, max);
102        let actual_delta = new_value - current;
103
104        if actual_delta.abs() > 0.001 {
105            // Use MutableState::set which triggers snapshot observers for reactive updates
106            self.inner.value.set(new_value);
107
108            // Trigger layout invalidation callbacks
109            let callbacks = self.inner.invalidate_callbacks.borrow();
110            if callbacks.is_empty() {
111                // Defer invalidation until a node registers a callback.
112                self.inner.pending_invalidation.set(true);
113            } else {
114                for callback in callbacks.values() {
115                    callback();
116                }
117            }
118        }
119
120        actual_delta
121    }
122
123    /// Sets the maximum scroll value (internal use by ScrollNode).
124    pub(crate) fn set_max_value(&self, max: f32) {
125        *self.inner.max_value.borrow_mut() = max;
126    }
127
128    /// Scrolls to the given position immediately.
129    pub fn scroll_to(&self, position: f32) {
130        let max = self.max_value();
131        let clamped = position.clamp(0.0, max);
132
133        self.inner.value.set(clamped);
134
135        // Trigger layout invalidation callbacks
136        let callbacks = self.inner.invalidate_callbacks.borrow();
137        if callbacks.is_empty() {
138            self.inner.pending_invalidation.set(true);
139        } else {
140            for callback in callbacks.values() {
141                callback();
142            }
143        }
144    }
145
146    /// Adds an invalidation callback and returns its ID
147    pub(crate) fn add_invalidate_callback(&self, callback: Box<dyn Fn()>) -> u64 {
148        static NEXT_CALLBACK_ID: std::sync::atomic::AtomicU64 =
149            std::sync::atomic::AtomicU64::new(1);
150        let id = NEXT_CALLBACK_ID.fetch_add(1, Ordering::Relaxed);
151        self.inner
152            .invalidate_callbacks
153            .borrow_mut()
154            .insert(id, callback);
155        if self.inner.pending_invalidation.replace(false) {
156            if let Some(callback) = self.inner.invalidate_callbacks.borrow().get(&id) {
157                callback();
158            }
159        }
160        id
161    }
162
163    /// Removes an invalidation callback by ID
164    pub(crate) fn remove_invalidate_callback(&self, id: u64) {
165        self.inner.invalidate_callbacks.borrow_mut().remove(&id);
166    }
167}
168
169#[derive(Clone)]
170pub(crate) struct ScrollMotionContext {
171    inner: Rc<ScrollMotionContextInner>,
172}
173
174struct ScrollMotionContextInner {
175    active: Cell<bool>,
176    generation: Cell<u64>,
177    invalidate_callbacks: RefCell<std::collections::HashMap<u64, Box<dyn Fn()>>>,
178    pending_invalidation: Cell<bool>,
179}
180
181impl ScrollMotionContext {
182    pub(crate) fn new() -> Self {
183        Self {
184            inner: Rc::new(ScrollMotionContextInner {
185                active: Cell::new(false),
186                generation: Cell::new(0),
187                invalidate_callbacks: RefCell::new(std::collections::HashMap::new()),
188                pending_invalidation: Cell::new(false),
189            }),
190        }
191    }
192
193    pub(crate) fn is_active(&self) -> bool {
194        self.inner.active.get()
195    }
196
197    pub(crate) fn ptr_eq(&self, other: &Self) -> bool {
198        Rc::ptr_eq(&self.inner, &other.inner)
199    }
200
201    pub(crate) fn stable_key(&self) -> usize {
202        Rc::as_ptr(&self.inner) as usize
203    }
204
205    pub(crate) fn set_active(&self, active: bool) {
206        if self.inner.active.replace(active) != active {
207            self.bump_generation();
208            self.invalidate();
209        }
210    }
211
212    pub(crate) fn activate_for_next_frame(&self) {
213        let was_active = self.inner.active.replace(true);
214        let generation = self.bump_generation();
215        if !was_active {
216            self.invalidate();
217        }
218        if let Some(runtime) = current_runtime_handle() {
219            let state = self.clone();
220            let _ = runtime.register_frame_callback(move |_| {
221                state.clear_if_generation(generation);
222            });
223            runtime.schedule();
224        } else {
225            self.clear_if_generation(generation);
226        }
227    }
228
229    pub(crate) fn add_invalidate_callback(&self, callback: Box<dyn Fn()>) -> u64 {
230        static NEXT_CALLBACK_ID: AtomicU64 = AtomicU64::new(1);
231        let id = NEXT_CALLBACK_ID.fetch_add(1, Ordering::Relaxed);
232        self.inner
233            .invalidate_callbacks
234            .borrow_mut()
235            .insert(id, callback);
236        if self.inner.pending_invalidation.replace(false) {
237            if let Some(callback) = self.inner.invalidate_callbacks.borrow().get(&id) {
238                callback();
239            }
240        }
241        id
242    }
243
244    pub(crate) fn remove_invalidate_callback(&self, id: u64) {
245        self.inner.invalidate_callbacks.borrow_mut().remove(&id);
246    }
247
248    fn bump_generation(&self) -> u64 {
249        let next = self.inner.generation.get().wrapping_add(1);
250        self.inner.generation.set(next);
251        next
252    }
253
254    fn clear_if_generation(&self, generation: u64) {
255        if self.inner.generation.get() == generation {
256            self.set_active(false);
257        }
258    }
259
260    fn invalidate(&self) {
261        let callbacks = self.inner.invalidate_callbacks.borrow();
262        if callbacks.is_empty() {
263            self.inner.pending_invalidation.set(true);
264            return;
265        }
266        for callback in callbacks.values() {
267            callback();
268        }
269    }
270}
271
272/// Element for creating a ScrollNode.
273#[derive(Clone)]
274pub struct ScrollElement {
275    state: ScrollState,
276    is_vertical: bool,
277    reverse_scrolling: bool,
278}
279
280impl ScrollElement {
281    pub fn new(state: ScrollState, is_vertical: bool, reverse_scrolling: bool) -> Self {
282        Self {
283            state,
284            is_vertical,
285            reverse_scrolling,
286        }
287    }
288}
289
290impl std::fmt::Debug for ScrollElement {
291    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
292        f.debug_struct("ScrollElement")
293            .field("is_vertical", &self.is_vertical)
294            .field("reverse_scrolling", &self.reverse_scrolling)
295            .finish()
296    }
297}
298
299impl PartialEq for ScrollElement {
300    fn eq(&self, other: &Self) -> bool {
301        // ScrollStates are equal if they point to the same underlying state
302        Rc::ptr_eq(&self.state.inner, &other.state.inner)
303            && self.is_vertical == other.is_vertical
304            && self.reverse_scrolling == other.reverse_scrolling
305    }
306}
307
308impl Eq for ScrollElement {}
309
310impl Hash for ScrollElement {
311    fn hash<H: Hasher>(&self, state: &mut H) {
312        (Rc::as_ptr(&self.state.inner) as usize).hash(state);
313        self.is_vertical.hash(state);
314        self.reverse_scrolling.hash(state);
315    }
316}
317
318impl ModifierNodeElement for ScrollElement {
319    type Node = ScrollNode;
320
321    fn create(&self) -> Self::Node {
322        // println!("ScrollElement::create");
323        ScrollNode::new(self.state.clone(), self.is_vertical, self.reverse_scrolling)
324    }
325
326    fn key(&self) -> Option<u64> {
327        let mut hasher = DefaultHasher::new();
328        self.state.id().hash(&mut hasher);
329        self.reverse_scrolling.hash(&mut hasher);
330        self.is_vertical.hash(&mut hasher);
331        Some(hasher.finish())
332    }
333
334    fn update(&self, node: &mut Self::Node) {
335        let needs_invalidation = !Rc::ptr_eq(&node.state.inner, &self.state.inner)
336            || node.is_vertical != self.is_vertical
337            || node.reverse_scrolling != self.reverse_scrolling;
338
339        if needs_invalidation {
340            node.state = self.state.clone();
341            node.is_vertical = self.is_vertical;
342            node.reverse_scrolling = self.reverse_scrolling;
343        }
344    }
345
346    fn capabilities(&self) -> NodeCapabilities {
347        NodeCapabilities::LAYOUT
348    }
349}
350
351/// ScrollNode layout modifier that physically moves content based on scroll position.
352/// This is the component that actually reads ScrollState and applies the visual offset.
353pub struct ScrollNode {
354    state: ScrollState,
355    is_vertical: bool,
356    reverse_scrolling: bool,
357    node_state: NodeState,
358    /// ID of the invalidation callback registered with ScrollState
359    invalidation_callback_id: Option<u64>,
360    /// We capture the NodeId when attached to ensure correct invalidation scope
361    node_id: Option<NodeId>,
362}
363
364impl ScrollNode {
365    pub fn new(state: ScrollState, is_vertical: bool, reverse_scrolling: bool) -> Self {
366        Self {
367            state,
368            is_vertical,
369            reverse_scrolling,
370            node_state: NodeState::default(),
371            invalidation_callback_id: None,
372            node_id: None,
373        }
374    }
375
376    /// Returns a reference to the ScrollState.
377    pub fn state(&self) -> &ScrollState {
378        &self.state
379    }
380}
381
382impl DelegatableNode for ScrollNode {
383    fn node_state(&self) -> &NodeState {
384        &self.node_state
385    }
386}
387
388impl ModifierNode for ScrollNode {
389    fn on_attach(&mut self, context: &mut dyn ModifierNodeContext) {
390        // Set up the invalidation callback to trigger layout when scroll state changes.
391        // We capture the node_id directly from the context, avoiding any global registry.
392
393        let node_id = context.node_id();
394        self.node_id = node_id;
395
396        if let Some(node_id) = node_id {
397            let callback_id = self.state.add_invalidate_callback(Box::new(move || {
398                // Schedule scoped layout repass for this node
399                crate::schedule_layout_repass(node_id);
400            }));
401            self.invalidation_callback_id = Some(callback_id);
402        } else {
403            log::debug!(
404                "ScrollNode attached without a NodeId; deferring invalidation registration."
405            );
406        }
407
408        // Initial invalidation
409        context.invalidate(cranpose_foundation::InvalidationKind::Layout);
410    }
411
412    fn on_detach(&mut self) {
413        // Remove invalidation callback
414        if let Some(id) = self.invalidation_callback_id.take() {
415            self.state.remove_invalidate_callback(id);
416        }
417    }
418
419    fn as_layout_node(&self) -> Option<&dyn LayoutModifierNode> {
420        Some(self)
421    }
422
423    fn as_layout_node_mut(&mut self) -> Option<&mut dyn LayoutModifierNode> {
424        Some(self)
425    }
426}
427
428impl LayoutModifierNode for ScrollNode {
429    fn measure(
430        &self,
431        _context: &mut dyn ModifierNodeContext,
432        measurable: &dyn Measurable,
433        constraints: Constraints,
434    ) -> LayoutModifierMeasureResult {
435        // Step 1: Give child infinite space in scroll direction
436        let scroll_constraints = if self.is_vertical {
437            Constraints {
438                min_height: 0.0,
439                max_height: f32::INFINITY,
440                ..constraints
441            }
442        } else {
443            Constraints {
444                min_width: 0.0,
445                max_width: f32::INFINITY,
446                ..constraints
447            }
448        };
449
450        // Step 2: Measure child
451        let placeable = measurable.measure(scroll_constraints);
452
453        // Step 3: Calculate viewport size (constrained size)
454        let width = placeable.width().min(constraints.max_width);
455        let height = placeable.height().min(constraints.max_height);
456
457        // Step 4: Calculate max scroll
458        let max_scroll = if self.is_vertical {
459            (placeable.height() - height).max(0.0)
460        } else {
461            (placeable.width() - width).max(0.0)
462        };
463
464        // Step 5: Update state with max scroll value
465        // Only update if the viewport is constrained (not infinite probe)
466        if (self.is_vertical && constraints.max_height.is_finite())
467            || (!self.is_vertical && constraints.max_width.is_finite())
468        {
469            self.state.set_max_value(max_scroll);
470        }
471
472        // Step 6: Read scroll value and calculate offset
473        // IMPORTANT: Use value_non_reactive() during measure to avoid triggering recomposition
474        let scroll = self.state.value_non_reactive().clamp(0.0, max_scroll);
475
476        let abs_scroll = if self.reverse_scrolling {
477            scroll - max_scroll
478        } else {
479            -scroll
480        };
481
482        let (x_offset, y_offset) = if self.is_vertical {
483            (0.0, abs_scroll)
484        } else {
485            (abs_scroll, 0.0)
486        };
487
488        // Step 7: Return result with viewport size and scroll offset as placement_offset
489        // This makes the scroll offset part of the layout modifier's placement, which will be
490        // correctly applied to children by the layout system
491        LayoutModifierMeasureResult::new(Size { width, height }, x_offset, y_offset)
492    }
493
494    fn min_intrinsic_width(&self, measurable: &dyn Measurable, height: f32) -> f32 {
495        measurable.min_intrinsic_width(height)
496    }
497
498    fn max_intrinsic_width(&self, measurable: &dyn Measurable, height: f32) -> f32 {
499        measurable.max_intrinsic_width(height)
500    }
501
502    fn min_intrinsic_height(&self, measurable: &dyn Measurable, width: f32) -> f32 {
503        measurable.min_intrinsic_height(width)
504    }
505
506    fn max_intrinsic_height(&self, measurable: &dyn Measurable, width: f32) -> f32 {
507        measurable.max_intrinsic_height(width)
508    }
509
510    fn create_measurement_proxy(&self) -> Option<Box<dyn cranpose_foundation::MeasurementProxy>> {
511        None
512    }
513}
514
515/// Creates a remembered ScrollState.
516///
517/// This is a convenience function for use in composable functions.
518#[macro_export]
519macro_rules! rememberScrollState {
520    ($initial:expr) => {
521        cranpose_core::remember(|| $crate::scroll::ScrollState::new($initial))
522            .with(|state| state.clone())
523    };
524    () => {
525        rememberScrollState!(0.0)
526    };
527}
528
529#[cfg(test)]
530#[path = "tests/scroll_tests.rs"]
531mod tests;