Skip to main content

fret_ui_headless/table/
row_selection.rs

1use std::collections::{HashMap, HashSet};
2
3use super::{RowIndex, RowKey, RowModel};
4
5/// Selected rows keyed by [`RowKey`].
6pub type RowSelectionState = HashSet<RowKey>;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum SubRowSelection {
10    None,
11    Some,
12    All,
13}
14
15pub fn is_row_selected(row_key: RowKey, selection: &RowSelectionState) -> bool {
16    selection.contains(&row_key)
17}
18
19pub fn is_sub_row_selected<'a, TData>(
20    row_model: &RowModel<'a, TData>,
21    selection: &RowSelectionState,
22    row: RowIndex,
23    row_can_select: &impl Fn(RowKey, &super::Row<'a, TData>) -> bool,
24) -> SubRowSelection {
25    let Some(row) = row_model.row(row) else {
26        return SubRowSelection::None;
27    };
28    if row.sub_rows.is_empty() {
29        return SubRowSelection::None;
30    }
31
32    let mut all_children_selected = true;
33    let mut some_selected = false;
34
35    for &child in &row.sub_rows {
36        if some_selected && !all_children_selected {
37            break;
38        }
39
40        let Some(child_row) = row_model.row(child) else {
41            all_children_selected = false;
42            continue;
43        };
44
45        if row_can_select(child_row.key, child_row) {
46            if is_row_selected(child_row.key, selection) {
47                some_selected = true;
48            } else {
49                all_children_selected = false;
50            }
51        }
52
53        if !child_row.sub_rows.is_empty() {
54            match is_sub_row_selected(row_model, selection, child, row_can_select) {
55                SubRowSelection::All => {
56                    some_selected = true;
57                }
58                SubRowSelection::Some => {
59                    some_selected = true;
60                    all_children_selected = false;
61                }
62                SubRowSelection::None => {
63                    all_children_selected = false;
64                }
65            }
66        }
67    }
68
69    if all_children_selected {
70        SubRowSelection::All
71    } else if some_selected {
72        SubRowSelection::Some
73    } else {
74        SubRowSelection::None
75    }
76}
77
78pub fn row_is_some_selected<'a, TData>(
79    row_model: &RowModel<'a, TData>,
80    selection: &RowSelectionState,
81    row: RowIndex,
82    row_can_select: &impl Fn(RowKey, &super::Row<'a, TData>) -> bool,
83) -> bool {
84    is_sub_row_selected(row_model, selection, row, row_can_select) == SubRowSelection::Some
85}
86
87pub fn row_is_all_sub_rows_selected<'a, TData>(
88    row_model: &RowModel<'a, TData>,
89    selection: &RowSelectionState,
90    row: RowIndex,
91    row_can_select: &impl Fn(RowKey, &super::Row<'a, TData>) -> bool,
92) -> bool {
93    is_sub_row_selected(row_model, selection, row, row_can_select) == SubRowSelection::All
94}
95
96pub fn selected_flat_row_count<'a, TData>(
97    row_model: &RowModel<'a, TData>,
98    selection: &RowSelectionState,
99) -> usize {
100    row_model
101        .flat_rows()
102        .iter()
103        .filter_map(|&i| row_model.row(i).map(|r| r.key))
104        .filter(|k| is_row_selected(*k, selection))
105        .count()
106}
107
108pub fn selected_root_row_count<'a, TData>(
109    row_model: &RowModel<'a, TData>,
110    selection: &RowSelectionState,
111) -> usize {
112    row_model
113        .root_rows()
114        .iter()
115        .filter_map(|&i| row_model.row(i).map(|r| r.key))
116        .filter(|k| is_row_selected(*k, selection))
117        .count()
118}
119
120pub fn is_all_rows_selected<'a, TData>(
121    row_model: &RowModel<'a, TData>,
122    selection: &RowSelectionState,
123    row_can_select: &impl Fn(RowKey, &super::Row<'a, TData>) -> bool,
124) -> bool {
125    if row_model.flat_rows().is_empty() {
126        return false;
127    }
128    if selection.is_empty() {
129        return false;
130    }
131
132    row_model.flat_rows().iter().all(|&i| {
133        let Some(row) = row_model.row(i) else {
134            return true;
135        };
136        !row_can_select(row.key, row) || is_row_selected(row.key, selection)
137    })
138}
139
140pub fn is_some_rows_selected<'a, TData>(
141    row_model: &RowModel<'a, TData>,
142    selection: &RowSelectionState,
143) -> bool {
144    let total = row_model.flat_rows().len();
145    let selected = selection.len();
146    selected > 0 && selected < total
147}
148
149pub fn toggle_all_rows_selected<'a, TData>(
150    row_model: &RowModel<'a, TData>,
151    selection: &RowSelectionState,
152    value: Option<bool>,
153    row_can_select: &impl Fn(RowKey, &super::Row<'a, TData>) -> bool,
154) -> RowSelectionState {
155    let mut next = selection.clone();
156    let value =
157        value.unwrap_or_else(|| !is_all_rows_selected(row_model, selection, row_can_select));
158
159    if value {
160        for &i in row_model.flat_rows() {
161            let Some(row) = row_model.row(i) else {
162                continue;
163            };
164            if row_can_select(row.key, row) {
165                next.insert(row.key);
166            }
167        }
168    } else {
169        for &i in row_model.flat_rows() {
170            let Some(row) = row_model.row(i) else {
171                continue;
172            };
173            next.remove(&row.key);
174        }
175    }
176
177    next
178}
179
180pub fn toggle_all_page_rows_selected<'a, TData>(
181    page_row_model: &RowModel<'a, TData>,
182    core_row_model: &RowModel<'a, TData>,
183    selection: &RowSelectionState,
184    value: Option<bool>,
185    row_can_select: &impl Fn(RowKey, &super::Row<'a, TData>) -> bool,
186    row_can_multi_select: &impl Fn(RowKey, &super::Row<'a, TData>) -> bool,
187    row_can_select_sub_rows: &impl Fn(RowKey, &super::Row<'a, TData>) -> bool,
188) -> RowSelectionState {
189    let value = value
190        .unwrap_or_else(|| !is_all_page_rows_selected(page_row_model, selection, row_can_select));
191
192    let mut next = selection.clone();
193    for &i in page_row_model.root_rows() {
194        let Some(row) = page_row_model.row(i) else {
195            continue;
196        };
197        mutate_row_is_selected(
198            core_row_model,
199            &mut next,
200            row.key,
201            value,
202            true,
203            row_can_select,
204            row_can_multi_select,
205            row_can_select_sub_rows,
206        );
207    }
208    next
209}
210
211pub fn toggle_row_selected<'a, TData>(
212    row_model: &RowModel<'a, TData>,
213    selection: &RowSelectionState,
214    row_key: RowKey,
215    value: Option<bool>,
216    select_children: bool,
217    row_can_select: &impl Fn(RowKey, &super::Row<'a, TData>) -> bool,
218    row_can_multi_select: &impl Fn(RowKey, &super::Row<'a, TData>) -> bool,
219    row_can_select_sub_rows: &impl Fn(RowKey, &super::Row<'a, TData>) -> bool,
220) -> RowSelectionState {
221    let current = is_row_selected(row_key, selection);
222    let value = value.unwrap_or(!current);
223
224    let mut next = selection.clone();
225    if let Some(i) = row_model
226        .row_by_key(row_key)
227        .and_then(|i| row_model.row(i).map(|_| i))
228        && let Some(row) = row_model.row(i)
229        && row_can_select(row_key, row)
230        && current == value
231    {
232        return next;
233    }
234
235    mutate_row_is_selected(
236        row_model,
237        &mut next,
238        row_key,
239        value,
240        select_children,
241        row_can_select,
242        row_can_multi_select,
243        row_can_select_sub_rows,
244    );
245    next
246}
247
248/// TanStack `getIsAllPageRowsSelected` semantics.
249///
250/// Notes:
251/// - Only rows that can be selected are considered.
252/// - Unlike `getIsAllRowsSelected`, this does not require `rowSelection` to be non-empty.
253pub fn is_all_page_rows_selected<'a, TData>(
254    page_row_model: &RowModel<'a, TData>,
255    selection: &RowSelectionState,
256    row_can_select: &impl Fn(RowKey, &super::Row<'a, TData>) -> bool,
257) -> bool {
258    let mut any_selectable = false;
259    for &i in page_row_model.flat_rows() {
260        let Some(row) = page_row_model.row(i) else {
261            continue;
262        };
263        if !row_can_select(row.key, row) {
264            continue;
265        }
266        any_selectable = true;
267        if !is_row_selected(row.key, selection) {
268            return false;
269        }
270    }
271    any_selectable
272}
273
274pub fn is_some_page_rows_selected<'a, TData>(
275    page_row_model: &RowModel<'a, TData>,
276    core_row_model: &RowModel<'a, TData>,
277    selection: &RowSelectionState,
278    row_can_select: &impl Fn(RowKey, &super::Row<'a, TData>) -> bool,
279) -> bool {
280    if is_all_page_rows_selected(page_row_model, selection, row_can_select) {
281        return false;
282    }
283
284    page_row_model.flat_rows().iter().any(|&i| {
285        let Some(row) = page_row_model.row(i) else {
286            return false;
287        };
288        if !row_can_select(row.key, row) {
289            return false;
290        }
291        if is_row_selected(row.key, selection) {
292            return true;
293        }
294        core_row_model.row_by_key(row.key).is_some_and(|core_i| {
295            row_is_some_selected(core_row_model, selection, core_i, row_can_select)
296        })
297    })
298}
299
300fn mutate_row_is_selected<'a, TData>(
301    row_model: &RowModel<'a, TData>,
302    selection: &mut RowSelectionState,
303    row_key: RowKey,
304    value: bool,
305    include_children: bool,
306    row_can_select: &impl Fn(RowKey, &super::Row<'a, TData>) -> bool,
307    row_can_multi_select: &impl Fn(RowKey, &super::Row<'a, TData>) -> bool,
308    row_can_select_sub_rows: &impl Fn(RowKey, &super::Row<'a, TData>) -> bool,
309) {
310    let Some(row_index) = row_model.row_by_key(row_key) else {
311        return;
312    };
313    mutate_row_is_selected_by_index(
314        row_model,
315        selection,
316        row_index,
317        value,
318        include_children,
319        row_can_select,
320        row_can_multi_select,
321        row_can_select_sub_rows,
322    );
323}
324
325fn mutate_row_is_selected_by_index<'a, TData>(
326    row_model: &RowModel<'a, TData>,
327    selection: &mut RowSelectionState,
328    row_index: RowIndex,
329    value: bool,
330    include_children: bool,
331    row_can_select: &impl Fn(RowKey, &super::Row<'a, TData>) -> bool,
332    row_can_multi_select: &impl Fn(RowKey, &super::Row<'a, TData>) -> bool,
333    row_can_select_sub_rows: &impl Fn(RowKey, &super::Row<'a, TData>) -> bool,
334) {
335    let Some(row) = row_model.row(row_index) else {
336        return;
337    };
338
339    if value {
340        if !row_can_multi_select(row.key, row) {
341            selection.clear();
342        }
343        if row_can_select(row.key, row) {
344            selection.insert(row.key);
345        }
346    } else {
347        selection.remove(&row.key);
348    }
349
350    if include_children && !row.sub_rows.is_empty() && row_can_select_sub_rows(row.key, row) {
351        for &child in &row.sub_rows {
352            mutate_row_is_selected_by_index(
353                row_model,
354                selection,
355                child,
356                value,
357                include_children,
358                row_can_select,
359                row_can_multi_select,
360                row_can_select_sub_rows,
361            );
362        }
363    }
364}
365
366/// TanStack-compatible `selectRowsFn`: returns a [`RowModel`] containing only selected rows in the
367/// `rows` tree, while keeping `flat_rows` and `rows_by_key` for all selected rows discovered during
368/// traversal (including selected sub-rows whose parents are not selected).
369pub fn select_rows_fn<'a, TData>(
370    row_model: &RowModel<'a, TData>,
371    selection: &RowSelectionState,
372) -> RowModel<'a, TData> {
373    let mut out_root_rows: Vec<RowIndex> = Vec::new();
374    let mut out_flat_rows: Vec<RowIndex> = Vec::new();
375    let mut out_rows_by_key: HashMap<RowKey, RowIndex> = HashMap::new();
376    let mut out_rows_by_id: HashMap<super::RowId, RowIndex> = HashMap::new();
377    let mut out_arena: Vec<super::Row<'a, TData>> = Vec::new();
378
379    fn recurse<'a, TData>(
380        source: &RowModel<'a, TData>,
381        selection: &RowSelectionState,
382        original: RowIndex,
383        out_root_rows: &mut Vec<RowIndex>,
384        out_flat_rows: &mut Vec<RowIndex>,
385        out_rows_by_key: &mut HashMap<RowKey, RowIndex>,
386        out_rows_by_id: &mut HashMap<super::RowId, RowIndex>,
387        out_arena: &mut Vec<super::Row<'a, TData>>,
388        parent_new: Option<RowIndex>,
389        is_root: bool,
390    ) -> Option<RowIndex> {
391        let row = source.row(original)?;
392        let selected = is_row_selected(row.key, selection);
393
394        if selected {
395            let new_index = out_arena.len();
396            out_arena.push(super::Row {
397                id: row.id.clone(),
398                key: row.key,
399                original: row.original,
400                index: row.index,
401                depth: row.depth,
402                parent: parent_new,
403                parent_key: row.parent_key,
404                sub_rows: Vec::new(),
405            });
406            out_flat_rows.push(new_index);
407            out_rows_by_key.insert(row.key, new_index);
408            out_rows_by_id.insert(row.id.clone(), new_index);
409            if is_root {
410                out_root_rows.push(new_index);
411            }
412
413            let mut selected_children: Vec<RowIndex> = Vec::new();
414            for child in &row.sub_rows {
415                if let Some(child_new) = recurse(
416                    source,
417                    selection,
418                    *child,
419                    out_root_rows,
420                    out_flat_rows,
421                    out_rows_by_key,
422                    out_rows_by_id,
423                    out_arena,
424                    Some(new_index),
425                    false,
426                ) {
427                    selected_children.push(child_new);
428                }
429            }
430            if let Some(new_row) = out_arena.get_mut(new_index) {
431                new_row.sub_rows = selected_children;
432            }
433            Some(new_index)
434        } else {
435            for child in &row.sub_rows {
436                let _ = recurse(
437                    source,
438                    selection,
439                    *child,
440                    out_root_rows,
441                    out_flat_rows,
442                    out_rows_by_key,
443                    out_rows_by_id,
444                    out_arena,
445                    None,
446                    false,
447                );
448            }
449            None
450        }
451    }
452
453    for &root in row_model.root_rows() {
454        let _ = recurse(
455            row_model,
456            selection,
457            root,
458            &mut out_root_rows,
459            &mut out_flat_rows,
460            &mut out_rows_by_key,
461            &mut out_rows_by_id,
462            &mut out_arena,
463            None,
464            true,
465        );
466    }
467
468    RowModel {
469        root_rows: out_root_rows,
470        flat_rows: out_flat_rows,
471        rows_by_key: out_rows_by_key,
472        rows_by_id: out_rows_by_id,
473        arena: out_arena,
474    }
475}
476
477#[cfg(test)]
478mod tests {
479    use super::super::Table;
480    use super::*;
481
482    #[derive(Debug, Clone)]
483    struct Person {
484        #[allow(dead_code)]
485        name: String,
486        sub_rows: Option<Vec<Person>>,
487    }
488
489    fn make_people(rows: usize, sub_rows: usize) -> Vec<Person> {
490        (0..rows)
491            .map(|i| Person {
492                name: format!("Person {i}"),
493                sub_rows: (sub_rows > 0).then(|| {
494                    (0..sub_rows)
495                        .map(|j| Person {
496                            name: format!("Person {i}.{j}"),
497                            sub_rows: None,
498                        })
499                        .collect()
500                }),
501            })
502            .collect()
503    }
504
505    #[test]
506    fn select_rows_fn_returns_only_selected_rows_in_tree() {
507        let data = make_people(5, 0);
508        let table = Table::builder(&data).build();
509        let model = table.core_row_model();
510
511        let selection: RowSelectionState = [RowKey::from_index(0), RowKey::from_index(2)]
512            .into_iter()
513            .collect();
514
515        let selected = select_rows_fn(model, &selection);
516
517        assert_eq!(selected.root_rows().len(), 2);
518        assert_eq!(selected.flat_rows().len(), 2);
519        assert!(selected.row_by_key(RowKey::from_index(0)).is_some());
520        assert!(selected.row_by_key(RowKey::from_index(2)).is_some());
521    }
522
523    #[test]
524    fn select_rows_fn_recurses_and_filters_sub_rows() {
525        let data = make_people(3, 2);
526        let table = Table::builder(&data)
527            .get_sub_rows(|p, _| p.sub_rows.as_deref())
528            .build();
529        let model = table.core_row_model();
530
531        let root_0 = model.row(model.root_rows()[0]).expect("root row 0");
532        let child_0_key = model
533            .row(root_0.sub_rows[0])
534            .expect("root row 0 child 0")
535            .key;
536        let selection: RowSelectionState = [root_0.key, child_0_key].into_iter().collect();
537
538        let selected = select_rows_fn(model, &selection);
539
540        let root_0 = selected.row(selected.root_rows()[0]).expect("root row 0");
541        assert_eq!(root_0.sub_rows.len(), 1);
542        assert_eq!(selected.flat_rows().len(), 2);
543        assert!(selected.row_by_key(root_0.key).is_some());
544        assert!(selected.row_by_key(child_0_key).is_some());
545    }
546
547    #[test]
548    fn select_rows_fn_returns_empty_when_no_rows_selected() {
549        let data = make_people(5, 0);
550        let table = Table::builder(&data).build();
551        let model = table.core_row_model();
552
553        let selection: RowSelectionState = RowSelectionState::default();
554        let selected = select_rows_fn(model, &selection);
555
556        assert_eq!(selected.root_rows().len(), 0);
557        assert_eq!(selected.flat_rows().len(), 0);
558        assert_eq!(selected.arena().len(), 0);
559        assert!(selected.rows_by_key().is_empty());
560    }
561
562    #[test]
563    fn toggle_all_rows_selected_selects_and_deselects_flat_rows() {
564        let data = make_people(5, 0);
565        let table = Table::builder(&data).build();
566        let model = table.core_row_model();
567
568        let selection = RowSelectionState::default();
569        let can_select = |_: RowKey, _row: &super::super::Row<'_, Person>| true;
570        let selection = toggle_all_rows_selected(model, &selection, Some(true), &can_select);
571        assert!(is_all_rows_selected(model, &selection, &can_select));
572
573        let selection = toggle_all_rows_selected(model, &selection, Some(false), &can_select);
574        assert!(selection.is_empty());
575        assert!(!is_some_rows_selected(model, &selection));
576    }
577
578    #[test]
579    fn sub_row_selection_reports_some_and_all() {
580        let data = make_people(2, 2);
581        let table = Table::builder(&data)
582            .get_sub_rows(|p, _| p.sub_rows.as_deref())
583            .build();
584        let model = table.core_row_model();
585
586        let root = model.root_rows()[0];
587        let root_key = model.row(root).unwrap().key;
588        let child0 = model.row(root).unwrap().sub_rows[0];
589        let child1 = model.row(root).unwrap().sub_rows[1];
590        let child0_key = model.row(child0).unwrap().key;
591        let child1_key = model.row(child1).unwrap().key;
592
593        let selection: RowSelectionState = [child0_key].into_iter().collect();
594        let can_select = |_: RowKey, _row: &super::super::Row<'_, Person>| true;
595        assert_eq!(
596            is_sub_row_selected(model, &selection, root, &can_select),
597            SubRowSelection::Some
598        );
599        assert!(row_is_some_selected(model, &selection, root, &can_select));
600        assert!(!row_is_all_sub_rows_selected(
601            model,
602            &selection,
603            root,
604            &can_select
605        ));
606        assert!(!is_row_selected(root_key, &selection));
607
608        let selection: RowSelectionState = [child0_key, child1_key].into_iter().collect();
609        assert_eq!(
610            is_sub_row_selected(model, &selection, root, &can_select),
611            SubRowSelection::All
612        );
613        assert!(!row_is_some_selected(model, &selection, root, &can_select));
614        assert!(row_is_all_sub_rows_selected(
615            model,
616            &selection,
617            root,
618            &can_select
619        ));
620    }
621
622    #[test]
623    fn toggle_row_selected_can_select_children() {
624        let data = make_people(1, 2);
625        let table = Table::builder(&data)
626            .get_sub_rows(|p, _| p.sub_rows.as_deref())
627            .build();
628        let model = table.core_row_model();
629
630        let root = model.root_rows()[0];
631        let root_key = model.row(root).unwrap().key;
632        let child0 = model.row(root).unwrap().sub_rows[0];
633        let child1 = model.row(root).unwrap().sub_rows[1];
634        let child0_key = model.row(child0).unwrap().key;
635        let child1_key = model.row(child1).unwrap().key;
636
637        let selection = RowSelectionState::default();
638        let can_select = |_: RowKey, _row: &super::super::Row<'_, Person>| true;
639        let selection = toggle_row_selected(
640            model,
641            &selection,
642            root_key,
643            Some(true),
644            true,
645            &can_select,
646            &can_select,
647            &can_select,
648        );
649        assert!(is_row_selected(root_key, &selection));
650        assert!(is_row_selected(child0_key, &selection));
651        assert!(is_row_selected(child1_key, &selection));
652
653        let selection = toggle_row_selected(
654            model,
655            &selection,
656            root_key,
657            Some(false),
658            true,
659            &can_select,
660            &can_select,
661            &can_select,
662        );
663        assert!(!is_row_selected(root_key, &selection));
664        assert!(!is_row_selected(child0_key, &selection));
665        assert!(!is_row_selected(child1_key, &selection));
666    }
667}