reifydb_sub_flow/operator/
take.rs1use std::{collections::BTreeMap, sync::Arc};
5
6use postcard::{from_bytes, to_stdvec};
7use reifydb_core::{
8 encoded::shape::RowShape,
9 interface::{
10 catalog::flow::FlowNodeId,
11 change::{Change, Diff},
12 },
13 internal,
14 value::column::columns::Columns,
15};
16use reifydb_type::{
17 Result,
18 error::Error,
19 value::{blob::Blob, row_number::RowNumber, r#type::Type},
20};
21use serde::{Deserialize, Serialize};
22
23use crate::{
24 operator::{
25 Operator, Operators,
26 stateful::{raw::RawStatefulOperator, single::SingleStateful, utils},
27 },
28 transaction::{FlowTransaction, slot::PersistFn},
29};
30
31#[derive(Debug, Clone, Serialize, Deserialize, Default)]
32struct TakeState {
33 active: BTreeMap<RowNumber, usize>,
34 candidates: BTreeMap<RowNumber, usize>,
35}
36
37pub struct TakeOperator {
38 parent: Arc<Operators>,
39 node: FlowNodeId,
40 limit: usize,
41 shape: RowShape,
42}
43
44impl TakeOperator {
45 pub fn new(parent: Arc<Operators>, node: FlowNodeId, limit: usize) -> Self {
46 Self {
47 parent,
48 node,
49 limit,
50 shape: RowShape::testing(&[Type::Blob]),
51 }
52 }
53
54 fn load_take_state(&self, txn: &mut FlowTransaction) -> Result<TakeState> {
55 let state_row = self.load_state(txn)?;
56
57 if state_row.is_empty() || !state_row.is_defined(0) {
58 return Ok(TakeState::default());
59 }
60
61 let blob = self.shape.get_blob(&state_row, 0);
62 if blob.is_empty() {
63 return Ok(TakeState::default());
64 }
65
66 from_bytes(blob.as_ref())
67 .map_err(|e| Error(Box::new(internal!("Failed to deserialize TakeState: {}", e))))
68 }
69
70 fn save_take_state(&self, txn: &mut FlowTransaction, state: &TakeState) -> Result<()> {
71 let serialized = to_stdvec(state)
72 .map_err(|e| Error(Box::new(internal!("Failed to serialize TakeState: {}", e))))?;
73 let blob = Blob::from(serialized);
74
75 self.update_state(txn, |shape, row| {
76 shape.set_blob(row, 0, &blob);
77 Ok(())
78 })?;
79 Ok(())
80 }
81
82 fn promote_candidates(&self, state: &mut TakeState, txn: &mut FlowTransaction) -> Result<Vec<Diff>> {
83 let mut output_diffs = Vec::new();
84
85 while state.active.len() < self.limit && !state.candidates.is_empty() {
86 if let Some((&candidate_row, &count)) = state.candidates.iter().next_back() {
87 state.candidates.remove(&candidate_row);
88 state.active.insert(candidate_row, count);
89
90 let cols = self.parent.pull(txn, &[candidate_row])?;
91 if !cols.is_empty() {
92 output_diffs.push(Diff::insert(cols));
93 }
94 }
95 }
96
97 Ok(output_diffs)
98 }
99
100 fn evict_to_candidates(&self, state: &mut TakeState, txn: &mut FlowTransaction) -> Result<Vec<Diff>> {
101 let mut output_diffs = Vec::new();
102 let candidate_limit = self.limit * 4;
103
104 while state.active.len() > self.limit {
105 if let Some((&evicted_row, &count)) = state.active.iter().next() {
106 state.active.remove(&evicted_row);
107 state.candidates.insert(evicted_row, count);
108
109 let cols = self.parent.pull(txn, &[evicted_row])?;
110 if !cols.is_empty() {
111 output_diffs.push(Diff::remove(cols));
112 }
113 }
114 }
115
116 while state.candidates.len() > candidate_limit {
117 if let Some((&removed_row, _)) = state.candidates.iter().next() {
118 state.candidates.remove(&removed_row);
119 }
120 }
121
122 Ok(output_diffs)
123 }
124
125 #[inline]
126 fn acquire_take_state(&self, txn: &mut FlowTransaction) -> Result<(TakeState, PersistFn)> {
127 let node_id = self.node;
128 let shape_for_persist = self.shape.clone();
129 txn.take_operator_state::<TakeState, _>(node_id, |txn| {
130 let s = self.load_take_state(txn)?;
131 let shape = shape_for_persist.clone();
132 let persist: PersistFn = Box::new(move |txn, value| {
133 let state = value.downcast::<TakeState>().expect("TakeState slot type");
134 let serialized = to_stdvec(&*state).map_err(|e| {
135 Error(Box::new(internal!("Failed to serialize TakeState: {}", e)))
136 })?;
137 let blob = Blob::from(serialized);
138 let key = utils::empty_key();
139 let mut row = utils::load_or_create_row(node_id, txn, &key, &shape)?;
140 shape.set_blob(&mut row, 0, &blob);
141 utils::save_row(node_id, txn, &key, row)?;
142 Ok(())
143 });
144 Ok((s, persist))
145 })
146 }
147
148 #[inline]
149 fn admit_or_evict_new_row(
150 &self,
151 state: &mut TakeState,
152 txn: &mut FlowTransaction,
153 row_number: RowNumber,
154 single_row: Columns,
155 output_diffs: &mut Vec<Diff>,
156 ) -> Result<()> {
157 if state.active.len() < self.limit {
158 state.active.insert(row_number, 1);
159 output_diffs.push(Diff::insert(single_row));
160 return Ok(());
161 }
162
163 let Some(smallest) = state.active.keys().next().copied() else {
164 return Ok(());
165 };
166
167 if row_number > smallest {
168 if let Some(count) = state.active.remove(&smallest) {
169 state.candidates.insert(smallest, count);
170 let cols = self.parent.pull(txn, &[smallest])?;
171 if !cols.is_empty() {
172 output_diffs.push(Diff::remove(cols));
173 }
174 }
175 state.active.insert(row_number, 1);
176 output_diffs.push(Diff::insert(single_row));
177 } else {
178 state.candidates.insert(row_number, 1);
179 }
180 prune_candidates(state, self.limit);
181 Ok(())
182 }
183
184 #[inline]
185 fn apply_insert_diff(
186 &self,
187 state: &mut TakeState,
188 txn: &mut FlowTransaction,
189 post: Arc<Columns>,
190 output_diffs: &mut Vec<Diff>,
191 ) -> Result<()> {
192 let row_count = post.row_count();
193 for row_idx in 0..row_count {
194 let row_number = post.row_numbers[row_idx];
195
196 if state.active.contains_key(&row_number) {
197 *state.active.get_mut(&row_number).unwrap() += 1;
198 continue;
199 }
200 if state.candidates.contains_key(&row_number) {
201 *state.candidates.get_mut(&row_number).unwrap() += 1;
202 continue;
203 }
204
205 let single = post.extract_by_indices(&[row_idx]);
206 self.admit_or_evict_new_row(state, txn, row_number, single, output_diffs)?;
207 }
208 Ok(())
209 }
210
211 #[inline]
212 fn apply_update_diff(
213 &self,
214 state: &mut TakeState,
215 txn: &mut FlowTransaction,
216 pre: Arc<Columns>,
217 post: Arc<Columns>,
218 output_diffs: &mut Vec<Diff>,
219 ) -> Result<()> {
220 let row_count = post.row_count();
221 let mut update_indices: Vec<usize> = Vec::new();
222
223 for row_idx in 0..row_count {
224 let row_number = post.row_numbers[row_idx];
225
226 if state.active.contains_key(&row_number) {
227 update_indices.push(row_idx);
228 continue;
229 }
230
231 if state.candidates.contains_key(&row_number) {
235 continue;
236 }
237
238 let single = post.extract_by_indices(&[row_idx]);
247 self.admit_or_evict_new_row(state, txn, row_number, single, output_diffs)?;
248 }
249
250 if !update_indices.is_empty() {
251 output_diffs.push(Diff::update(
252 pre.extract_by_indices(&update_indices),
253 post.extract_by_indices(&update_indices),
254 ));
255 }
256 Ok(())
257 }
258
259 #[inline]
260 fn apply_remove_diff(
261 &self,
262 state: &mut TakeState,
263 txn: &mut FlowTransaction,
264 pre: Arc<Columns>,
265 output_diffs: &mut Vec<Diff>,
266 ) -> Result<()> {
267 let row_count = pre.row_count();
268 for row_idx in 0..row_count {
269 let row_number = pre.row_numbers[row_idx];
270
271 if let Some(count) = state.active.get_mut(&row_number) {
272 if *count > 1 {
273 *count -= 1;
274 } else {
275 state.active.remove(&row_number);
276 output_diffs.push(Diff::remove(pre.extract_by_indices(&[row_idx])));
277 let promoted = self.promote_candidates(state, txn)?;
278 output_diffs.extend(promoted);
279 }
280 } else if let Some(count) = state.candidates.get_mut(&row_number) {
281 if *count > 1 {
282 *count -= 1;
283 } else {
284 state.candidates.remove(&row_number);
285 }
286 }
287 }
288 Ok(())
289 }
290}
291
292#[inline]
293fn prune_candidates(state: &mut TakeState, limit: usize) {
294 let candidate_limit = limit * 4;
295 while state.candidates.len() > candidate_limit {
296 if let Some((&r, _)) = state.candidates.iter().next() {
297 state.candidates.remove(&r);
298 }
299 }
300}
301
302impl RawStatefulOperator for TakeOperator {}
303
304impl SingleStateful for TakeOperator {
305 fn layout(&self) -> RowShape {
306 self.shape.clone()
307 }
308}
309
310impl Operator for TakeOperator {
311 fn id(&self) -> FlowNodeId {
312 self.node
313 }
314
315 fn apply(&self, txn: &mut FlowTransaction, change: Change) -> Result<Change> {
316 let node_id = self.node;
317 let (mut state, persist) = self.acquire_take_state(txn)?;
318
319 let mut output_diffs = Vec::new();
320 let version = change.version;
321
322 for diff in change.diffs {
323 match diff {
324 Diff::Insert {
325 post,
326 } => self.apply_insert_diff(&mut state, txn, post, &mut output_diffs)?,
327 Diff::Update {
328 pre,
329 post,
330 } => self.apply_update_diff(&mut state, txn, pre, post, &mut output_diffs)?,
331 Diff::Remove {
332 pre,
333 } => self.apply_remove_diff(&mut state, txn, pre, &mut output_diffs)?,
334 }
335 }
336
337 txn.put_operator_state(node_id, state, persist);
340
341 Ok(Change::from_flow(self.node, version, output_diffs, change.changed_at))
342 }
343
344 fn pull(&self, txn: &mut FlowTransaction, rows: &[RowNumber]) -> Result<Columns> {
345 self.parent.pull(txn, rows)
346 }
347}