Skip to main content

fret_ui_headless/table/
tanstack_memo.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use super::memo::Memo;
5use super::{
6    ColumnDef, FilterFnDef, FilteringFnSpec, GlobalFilterState, Row, RowIndex, RowKey, RowModel,
7    SortSpec, SortingFnDef, TableOptions, filter_row_model, sort_row_model,
8};
9
10#[derive(Debug, Clone, PartialEq)]
11pub struct TanStackUngroupedRowModelOrderDeps {
12    pub items_revision: u64,
13    pub data_len: usize,
14    pub sorting: Vec<SortSpec>,
15    pub column_filters: super::ColumnFiltersState,
16    pub global_filter: GlobalFilterState,
17    pub expanding: super::ExpandingState,
18    pub pagination: super::PaginationState,
19    pub options: TableOptions,
20    pub global_filter_fn: FilteringFnSpec,
21    pub has_get_column_can_global_filter: bool,
22}
23
24#[derive(Debug, Clone, PartialEq)]
25pub struct TanStackFilteredFlatRowOrderDeps {
26    pub items_revision: u64,
27    pub data_len: usize,
28    pub column_filters: super::ColumnFiltersState,
29    pub global_filter: GlobalFilterState,
30    pub options: TableOptions,
31    pub global_filter_fn: FilteringFnSpec,
32    pub has_get_column_can_global_filter: bool,
33}
34
35#[derive(Debug, Clone, PartialEq)]
36pub struct TanStackSortedFlatRowOrderDeps {
37    pub items_revision: u64,
38    pub data_len: usize,
39    pub sorting: Vec<SortSpec>,
40    pub column_filters: super::ColumnFiltersState,
41    pub global_filter: GlobalFilterState,
42    pub options: TableOptions,
43    pub global_filter_fn: FilteringFnSpec,
44    pub has_get_column_can_global_filter: bool,
45}
46
47#[derive(Debug, Clone, PartialEq)]
48pub struct FlatRowOrderEntry {
49    pub index: usize,
50    pub key: RowKey,
51}
52
53#[derive(Debug, Clone, PartialEq)]
54pub struct TanStackRowModelOrderSnapshot {
55    /// TanStack `getRowModel().rows` equivalent (in Fret's row model shape: `root_rows` keys).
56    pub rows: Arc<[RowKey]>,
57    /// TanStack `getRowModel().flatRows` equivalent (Fret: `flat_rows` keys).
58    pub flat_rows: Arc<[RowKey]>,
59}
60
61#[derive(Default)]
62pub struct TanStackUngroupedRowModelOrderCache {
63    memo: Memo<(u64, TanStackUngroupedRowModelOrderDeps), Arc<TanStackRowModelOrderSnapshot>>,
64    columns_signature: u64,
65    recompute_count: u64,
66}
67
68impl TanStackUngroupedRowModelOrderCache {
69    pub fn recompute_count(&self) -> u64 {
70        self.recompute_count
71    }
72
73    /// Returns a stable, memoized snapshot of the final **ungrouped** row model ordering.
74    ///
75    /// This cache is designed for "rebuild each frame" callers: keep it outside ephemeral table
76    /// instances, and drive invalidation via `deps` (plus `items_revision`).
77    ///
78    /// Notes:
79    /// - This is intentionally scoped to the **ungrouped** pipeline (`state.grouping.is_empty()`).
80    ///   Grouped row models introduce non-core rows (group headers) whose keys are not stable across
81    ///   consumer-defined rebuild strategies; that will be covered by a separate cache surface.
82    pub fn ungrouped_order(
83        &mut self,
84        columns_signature: u64,
85        deps: TanStackUngroupedRowModelOrderDeps,
86        compute: impl FnOnce() -> TanStackRowModelOrderSnapshot,
87    ) -> (&Arc<TanStackRowModelOrderSnapshot>, bool) {
88        if columns_signature != self.columns_signature {
89            self.columns_signature = columns_signature;
90            self.memo.reset();
91        }
92
93        let sig_and_deps = (columns_signature, deps);
94        let (value, recomputed) = self
95            .memo
96            .get_or_compute(sig_and_deps, || Arc::new(compute()));
97        if recomputed {
98            self.recompute_count = self.recompute_count.saturating_add(1);
99        }
100        (value, recomputed)
101    }
102}
103
104#[derive(Default)]
105pub struct TanStackSortedFlatRowOrderCache {
106    filtered_memo: Memo<(u64, TanStackFilteredFlatRowOrderDeps), Arc<[FlatRowOrderEntry]>>,
107    sorted_memo: Memo<(u64, TanStackSortedFlatRowOrderDeps, u64), Arc<[FlatRowOrderEntry]>>,
108    columns_signature: u64,
109    recompute_count: u64,
110    filtered_recompute_count: u64,
111}
112
113impl TanStackSortedFlatRowOrderCache {
114    pub fn recompute_count(&self) -> u64 {
115        self.recompute_count
116    }
117
118    pub fn filtered_recompute_count(&self) -> u64 {
119        self.filtered_recompute_count
120    }
121
122    /// Returns a stable, memoized ordering of the root row list after filtering + sorting.
123    ///
124    /// Notes:
125    /// - This cache is designed for “rebuild every frame” callers. Keep it outside the ephemeral
126    ///   table instance and feed dependency snapshots.
127    /// - Dependency tracking is explicit via `deps`. If you change any inputs that are not
128    ///   represented in `deps` (e.g. `filter_fns`, `sorting_fns`, or the closure identities),
129    ///   you must reset the cache (or bump a revision captured in `deps`).
130    pub fn sorted_order<'a, TData>(
131        &mut self,
132        data: &'a [TData],
133        columns: &[ColumnDef<TData>],
134        get_row_key: &dyn Fn(&TData, usize, Option<&RowKey>) -> RowKey,
135        filter_fns: &HashMap<Arc<str>, FilterFnDef>,
136        sorting_fns: &HashMap<Arc<str>, SortingFnDef<TData>>,
137        get_column_can_global_filter: Option<&dyn Fn(&ColumnDef<TData>, &TData) -> bool>,
138        deps: TanStackSortedFlatRowOrderDeps,
139    ) -> (&Arc<[FlatRowOrderEntry]>, bool) {
140        debug_assert_eq!(deps.data_len, data.len());
141        debug_assert_eq!(
142            deps.has_get_column_can_global_filter,
143            get_column_can_global_filter.is_some()
144        );
145
146        let signature = columns_signature(columns);
147        if signature != self.columns_signature {
148            self.columns_signature = signature;
149            self.filtered_memo.reset();
150            self.sorted_memo.reset();
151        }
152
153        let filtered_deps = TanStackFilteredFlatRowOrderDeps {
154            items_revision: deps.items_revision,
155            data_len: deps.data_len,
156            column_filters: deps.column_filters.clone(),
157            global_filter: deps.global_filter.clone(),
158            options: deps.options,
159            global_filter_fn: deps.global_filter_fn.clone(),
160            has_get_column_can_global_filter: deps.has_get_column_can_global_filter,
161        };
162
163        let (filtered_order, filtered_recomputed) =
164            self.filtered_memo
165                .get_or_compute((signature, filtered_deps.clone()), || {
166                    compute_filtered_order(
167                        data,
168                        columns,
169                        get_row_key,
170                        filter_fns,
171                        get_column_can_global_filter,
172                        &filtered_deps,
173                    )
174                });
175        if filtered_recomputed {
176            self.filtered_recompute_count = self.filtered_recompute_count.saturating_add(1);
177        }
178
179        let filtered_sig = flat_row_order_signature(filtered_order);
180        let sig_and_deps = (signature, deps.clone(), filtered_sig);
181        let filtered_for_sort = filtered_order.clone();
182
183        let (value, recomputed) = self.sorted_memo.get_or_compute(sig_and_deps, || {
184            compute_sorted_order_from_filtered(
185                data,
186                columns,
187                get_row_key,
188                sorting_fns,
189                &deps,
190                filtered_for_sort,
191            )
192        });
193        if recomputed {
194            self.recompute_count = self.recompute_count.saturating_add(1);
195        }
196        (value, recomputed)
197    }
198}
199
200pub(crate) fn columns_signature<TData>(columns: &[ColumnDef<TData>]) -> u64 {
201    use std::collections::hash_map::DefaultHasher;
202    use std::hash::{Hash, Hasher};
203
204    let mut hasher = DefaultHasher::new();
205    columns.len().hash(&mut hasher);
206    for col in columns {
207        col.id.as_ref().hash(&mut hasher);
208        col.sort_cmp.is_some().hash(&mut hasher);
209        col.sorting_fn.hash(&mut hasher);
210        col.sort_value.is_some().hash(&mut hasher);
211        col.sort_undefined.hash(&mut hasher);
212        col.sort_is_undefined.is_some().hash(&mut hasher);
213        col.filtering_fn.hash(&mut hasher);
214        col.filter_fn.is_some().hash(&mut hasher);
215        col.enable_column_filter.hash(&mut hasher);
216        col.enable_global_filter.hash(&mut hasher);
217        col.invert_sorting.hash(&mut hasher);
218        col.sort_desc_first.hash(&mut hasher);
219    }
220    hasher.finish()
221}
222
223fn flat_row_order_signature(order: &[FlatRowOrderEntry]) -> u64 {
224    use std::collections::hash_map::DefaultHasher;
225    use std::hash::{Hash, Hasher};
226
227    let mut hasher = DefaultHasher::new();
228    order.len().hash(&mut hasher);
229    for e in order {
230        e.index.hash(&mut hasher);
231        e.key.hash(&mut hasher);
232    }
233    hasher.finish()
234}
235
236fn build_flat_core_row_model<'a, TData>(
237    data: &'a [TData],
238    get_row_key: &dyn Fn(&TData, usize, Option<&RowKey>) -> RowKey,
239) -> RowModel<'a, TData> {
240    let mut root_rows: Vec<RowIndex> = Vec::with_capacity(data.len());
241    let mut flat_rows: Vec<RowIndex> = Vec::with_capacity(data.len());
242    let mut rows_by_key: HashMap<RowKey, RowIndex> = HashMap::with_capacity(data.len());
243    let mut rows_by_id: HashMap<super::RowId, RowIndex> = HashMap::with_capacity(data.len());
244    let mut arena: Vec<Row<'a, TData>> = Vec::with_capacity(data.len());
245
246    for (index, original) in data.iter().enumerate() {
247        let key = get_row_key(original, index, None);
248        let id = super::RowId(Arc::<str>::from(key.0.to_string()));
249        let row_index = arena.len();
250        arena.push(Row {
251            id: id.clone(),
252            key,
253            original,
254            index,
255            depth: 0,
256            parent: None,
257            parent_key: None,
258            sub_rows: Vec::new(),
259        });
260        root_rows.push(row_index);
261        flat_rows.push(row_index);
262        rows_by_key.insert(key, row_index);
263        rows_by_id.insert(id, row_index);
264    }
265
266    RowModel {
267        root_rows,
268        flat_rows,
269        rows_by_key,
270        rows_by_id,
271        arena,
272    }
273}
274
275fn compute_filtered_order<'a, TData>(
276    data: &'a [TData],
277    columns: &[ColumnDef<TData>],
278    get_row_key: &dyn Fn(&TData, usize, Option<&RowKey>) -> RowKey,
279    filter_fns: &HashMap<Arc<str>, FilterFnDef>,
280    get_column_can_global_filter: Option<&dyn Fn(&ColumnDef<TData>, &TData) -> bool>,
281    deps: &TanStackFilteredFlatRowOrderDeps,
282) -> Arc<[FlatRowOrderEntry]> {
283    let core = build_flat_core_row_model(data, get_row_key);
284
285    let filtered = if deps.options.manual_filtering {
286        core
287    } else {
288        filter_row_model(
289            &core,
290            columns,
291            &deps.column_filters,
292            deps.global_filter.clone(),
293            deps.options,
294            filter_fns,
295            &deps.global_filter_fn,
296            get_column_can_global_filter,
297        )
298    };
299
300    let mut out: Vec<FlatRowOrderEntry> = Vec::with_capacity(filtered.root_rows().len());
301    for &i in filtered.root_rows() {
302        let Some(r) = filtered.row(i) else {
303            continue;
304        };
305        out.push(FlatRowOrderEntry {
306            index: r.index,
307            key: r.key,
308        });
309    }
310    Arc::from(out.into_boxed_slice())
311}
312
313fn compute_sorted_order_from_filtered<'a, TData>(
314    data: &'a [TData],
315    columns: &[ColumnDef<TData>],
316    get_row_key: &dyn Fn(&TData, usize, Option<&RowKey>) -> RowKey,
317    sorting_fns: &HashMap<Arc<str>, SortingFnDef<TData>>,
318    deps: &TanStackSortedFlatRowOrderDeps,
319    filtered: Arc<[FlatRowOrderEntry]>,
320) -> Arc<[FlatRowOrderEntry]> {
321    if deps.options.manual_sorting || deps.sorting.is_empty() {
322        return filtered;
323    }
324
325    let core = build_flat_core_row_model(data, get_row_key);
326
327    // For the flat core row model, `RowIndex` equals the original data index. We can use the
328    // filtered ordering as an index view without re-evaluating filters.
329    let indices: Vec<RowIndex> = filtered.iter().map(|e| e.index).collect();
330    let mut view = core;
331    view.root_rows = indices.clone();
332    view.flat_rows = indices;
333
334    let sorted = sort_row_model(&view, columns, &deps.sorting, sorting_fns);
335
336    let mut out: Vec<FlatRowOrderEntry> = Vec::with_capacity(sorted.root_rows().len());
337    for &i in sorted.root_rows() {
338        let Some(r) = sorted.row(i) else {
339            continue;
340        };
341        out.push(FlatRowOrderEntry {
342            index: r.index,
343            key: r.key,
344        });
345    }
346    Arc::from(out.into_boxed_slice())
347}
348
349#[cfg(test)]
350mod tests {
351    use super::{
352        FlatRowOrderEntry, TanStackSortedFlatRowOrderCache, TanStackSortedFlatRowOrderDeps,
353    };
354    use crate::table::{ColumnDef, FilteringFnSpec, RowKey, TableOptions};
355    use serde_json::json;
356    use std::collections::HashMap;
357    use std::sync::Arc;
358
359    #[derive(Debug)]
360    struct Row {
361        id: u64,
362        name: &'static str,
363    }
364
365    fn col_name() -> ColumnDef<Row> {
366        ColumnDef::<Row>::new("name")
367            .sort_value_by(|row: &Row| {
368                crate::table::TanStackValue::String(Arc::<str>::from(row.name))
369            })
370            .sorting_fn_auto()
371            .filtering_fn_auto()
372    }
373
374    fn deps_for(
375        data: &[Row],
376        sorting: Vec<crate::table::SortSpec>,
377        column_filters: crate::table::ColumnFiltersState,
378        global_filter: crate::table::GlobalFilterState,
379    ) -> TanStackSortedFlatRowOrderDeps {
380        TanStackSortedFlatRowOrderDeps {
381            items_revision: 1,
382            data_len: data.len(),
383            sorting,
384            column_filters,
385            global_filter,
386            options: TableOptions::default(),
387            global_filter_fn: FilteringFnSpec::Auto,
388            has_get_column_can_global_filter: false,
389        }
390    }
391
392    #[test]
393    fn sorted_flat_row_order_cache_is_stable_when_deps_unchanged() {
394        let data = [Row { id: 2, name: "b" }, Row { id: 1, name: "a" }];
395        let columns = vec![col_name()];
396
397        let mut cache = TanStackSortedFlatRowOrderCache::default();
398        let filter_fns = HashMap::new();
399        let sorting_fns = HashMap::new();
400
401        let deps = deps_for(
402            &data,
403            vec![crate::table::SortSpec {
404                column: "name".into(),
405                desc: false,
406            }],
407            Vec::new(),
408            None,
409        );
410
411        let (order1, recomputed1) = {
412            let (order, recomputed) = cache.sorted_order(
413                &data,
414                &columns,
415                &|row: &Row, _idx, _parent| RowKey(row.id),
416                &filter_fns,
417                &sorting_fns,
418                None,
419                deps.clone(),
420            );
421            (order.clone(), recomputed)
422        };
423        assert!(recomputed1);
424        assert_eq!(cache.recompute_count(), 1);
425        assert_eq!(cache.filtered_recompute_count(), 1);
426        assert_eq!(
427            &*order1,
428            &[
429                FlatRowOrderEntry {
430                    index: 1,
431                    key: RowKey(1)
432                },
433                FlatRowOrderEntry {
434                    index: 0,
435                    key: RowKey(2)
436                },
437            ]
438        );
439
440        let (order2, recomputed2) = {
441            let (order, recomputed) = cache.sorted_order(
442                &data,
443                &columns,
444                &|row: &Row, _idx, _parent| RowKey(row.id),
445                &filter_fns,
446                &sorting_fns,
447                None,
448                deps,
449            );
450            (order.clone(), recomputed)
451        };
452        assert!(!recomputed2);
453        assert_eq!(cache.recompute_count(), 1);
454        assert_eq!(cache.filtered_recompute_count(), 1);
455        assert!(Arc::ptr_eq(&order1, &order2));
456    }
457
458    #[test]
459    fn sorted_flat_row_order_cache_reuses_filtered_step_when_only_sorting_changes() {
460        let data = [
461            Row { id: 2, name: "b" },
462            Row { id: 1, name: "a" },
463            Row { id: 3, name: "c" },
464        ];
465        let columns = vec![col_name()];
466
467        let mut cache = TanStackSortedFlatRowOrderCache::default();
468        let filter_fns = HashMap::new();
469        let sorting_fns = HashMap::new();
470
471        let deps_asc = deps_for(
472            &data,
473            vec![crate::table::SortSpec {
474                column: "name".into(),
475                desc: false,
476            }],
477            Vec::new(),
478            None,
479        );
480        let (_order1, recomputed1) = cache.sorted_order(
481            &data,
482            &columns,
483            &|row: &Row, _idx, _parent| RowKey(row.id),
484            &filter_fns,
485            &sorting_fns,
486            None,
487            deps_asc,
488        );
489        assert!(recomputed1);
490        assert_eq!(cache.filtered_recompute_count(), 1);
491        assert_eq!(cache.recompute_count(), 1);
492
493        let deps_desc = deps_for(
494            &data,
495            vec![crate::table::SortSpec {
496                column: "name".into(),
497                desc: true,
498            }],
499            Vec::new(),
500            None,
501        );
502        let (_order2, recomputed2) = cache.sorted_order(
503            &data,
504            &columns,
505            &|row: &Row, _idx, _parent| RowKey(row.id),
506            &filter_fns,
507            &sorting_fns,
508            None,
509            deps_desc,
510        );
511        assert!(recomputed2);
512
513        // Only the sorted step should recompute when sorting changes.
514        assert_eq!(cache.filtered_recompute_count(), 1);
515        assert_eq!(cache.recompute_count(), 2);
516    }
517
518    #[test]
519    fn sorted_flat_row_order_cache_recomputes_when_filters_change() {
520        let data = [
521            Row {
522                id: 1,
523                name: "alpha",
524            },
525            Row {
526                id: 2,
527                name: "beta",
528            },
529        ];
530        let columns = vec![col_name()];
531
532        let mut cache = TanStackSortedFlatRowOrderCache::default();
533        let filter_fns = HashMap::new();
534        let sorting_fns = HashMap::new();
535
536        let deps1 = deps_for(&data, Vec::new(), Vec::new(), None);
537        let (_order1, recomputed1) = cache.sorted_order(
538            &data,
539            &columns,
540            &|row: &Row, _idx, _parent| RowKey(row.id),
541            &filter_fns,
542            &sorting_fns,
543            None,
544            deps1,
545        );
546        assert!(recomputed1);
547        assert_eq!(cache.recompute_count(), 1);
548        assert_eq!(cache.filtered_recompute_count(), 1);
549
550        let deps2 = deps_for(&data, Vec::new(), Vec::new(), Some(json!("alp")));
551        let (_order2, recomputed2) = cache.sorted_order(
552            &data,
553            &columns,
554            &|row: &Row, _idx, _parent| RowKey(row.id),
555            &filter_fns,
556            &sorting_fns,
557            None,
558            deps2,
559        );
560        assert!(recomputed2);
561        assert_eq!(cache.recompute_count(), 2);
562        assert_eq!(cache.filtered_recompute_count(), 2);
563    }
564}