Skip to main content

fret_ui_headless/table/
row_pinning.rs

1use std::collections::HashSet;
2
3use super::{RowIndex, RowKey, RowModel};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
6pub enum RowPinPosition {
7    Top,
8    Bottom,
9}
10
11#[derive(Debug, Clone, PartialEq, Eq, Default)]
12pub struct RowPinningState {
13    pub top: Vec<RowKey>,
14    pub bottom: Vec<RowKey>,
15}
16
17pub fn is_row_pinned(row_key: RowKey, state: &RowPinningState) -> Option<RowPinPosition> {
18    if state.top.contains(&row_key) {
19        return Some(RowPinPosition::Top);
20    }
21    if state.bottom.contains(&row_key) {
22        return Some(RowPinPosition::Bottom);
23    }
24    None
25}
26
27pub fn is_some_rows_pinned(state: &RowPinningState, position: Option<RowPinPosition>) -> bool {
28    match position {
29        None => !(state.top.is_empty() && state.bottom.is_empty()),
30        Some(RowPinPosition::Top) => !state.top.is_empty(),
31        Some(RowPinPosition::Bottom) => !state.bottom.is_empty(),
32    }
33}
34
35pub fn pin_rows(
36    state: &mut RowPinningState,
37    position: Option<RowPinPosition>,
38    rows: impl IntoIterator<Item = RowKey>,
39) {
40    let mut row_keys: Vec<RowKey> = Vec::new();
41    let mut row_key_set: HashSet<RowKey> = HashSet::new();
42    for row_key in rows {
43        if row_key_set.insert(row_key) {
44            row_keys.push(row_key);
45        }
46    }
47
48    state.top.retain(|k| !row_key_set.contains(k));
49    state.bottom.retain(|k| !row_key_set.contains(k));
50
51    match position {
52        None => {}
53        Some(RowPinPosition::Top) => state.top.extend(row_keys),
54        Some(RowPinPosition::Bottom) => state.bottom.extend(row_keys),
55    }
56}
57
58/// TanStack-compatible helper: pin one row and optionally include its leaf and/or parent rows.
59pub fn pin_row<'a, TData>(
60    state: &mut RowPinningState,
61    position: Option<RowPinPosition>,
62    row_model: &RowModel<'a, TData>,
63    row_key: RowKey,
64    include_leaf_rows: bool,
65    include_parent_rows: bool,
66) {
67    let keys = pin_row_keys(row_model, row_key, include_leaf_rows, include_parent_rows);
68    pin_rows(state, position, keys);
69}
70
71pub fn pin_row_keys<'a, TData>(
72    row_model: &RowModel<'a, TData>,
73    row_key: RowKey,
74    include_leaf_rows: bool,
75    include_parent_rows: bool,
76) -> Vec<RowKey> {
77    let Some(row_index) = row_model.row_by_key(row_key) else {
78        return vec![row_key];
79    };
80
81    let mut keys: Vec<RowKey> = Vec::new();
82
83    if include_parent_rows {
84        let mut parents_rev: Vec<RowKey> = Vec::new();
85        let mut current = row_model.row(row_index);
86        while let Some(row) = current {
87            let Some(parent) = row.parent else {
88                break;
89            };
90            let Some(parent_row) = row_model.row(parent) else {
91                break;
92            };
93            parents_rev.push(parent_row.key);
94            current = Some(parent_row);
95        }
96        parents_rev.reverse();
97        keys.extend(parents_rev);
98    }
99
100    keys.push(row_key);
101
102    if include_leaf_rows {
103        fn push_descendant_keys<'a, TData>(
104            row_model: &RowModel<'a, TData>,
105            row: RowIndex,
106            out: &mut Vec<RowKey>,
107        ) {
108            let Some(r) = row_model.row(row) else {
109                return;
110            };
111            for &child in &r.sub_rows {
112                let Some(child_row) = row_model.row(child) else {
113                    continue;
114                };
115                out.push(child_row.key);
116                push_descendant_keys(row_model, child, out);
117            }
118        }
119
120        push_descendant_keys(row_model, row_index, &mut keys);
121    }
122
123    keys
124}
125
126pub fn center_row_keys<'a, TData>(
127    visible_root_rows: &[RowIndex],
128    row_model: &RowModel<'a, TData>,
129    state: &RowPinningState,
130) -> Vec<RowKey> {
131    let mut pinned = HashSet::<RowKey>::new();
132    pinned.extend(state.top.iter().copied());
133    pinned.extend(state.bottom.iter().copied());
134
135    visible_root_rows
136        .iter()
137        .filter_map(|&i| row_model.row(i))
138        .map(|r| r.key)
139        .filter(|k| !pinned.contains(k))
140        .collect()
141}
142
143#[cfg(test)]
144mod tests {
145    use super::*;
146
147    #[test]
148    fn pin_rows_preserves_input_order_and_dedupes() {
149        let mut state = RowPinningState::default();
150
151        pin_rows(
152            &mut state,
153            Some(RowPinPosition::Top),
154            [RowKey(3), RowKey(2), RowKey(3), RowKey(1)],
155        );
156        assert_eq!(state.top, vec![RowKey(3), RowKey(2), RowKey(1)]);
157        assert!(state.bottom.is_empty());
158
159        pin_rows(&mut state, Some(RowPinPosition::Bottom), [RowKey(2)]);
160        assert_eq!(state.top, vec![RowKey(3), RowKey(1)]);
161        assert_eq!(state.bottom, vec![RowKey(2)]);
162
163        pin_rows(&mut state, None, [RowKey(3)]);
164        assert_eq!(state.top, vec![RowKey(1)]);
165        assert_eq!(state.bottom, vec![RowKey(2)]);
166    }
167}