Skip to main content

reifydb_sub_flow/operator/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::{collections::BTreeMap, sync::Arc};
5
6use postcard::{from_bytes, to_stdvec};
7use reifydb_core::{
8	encoded::schema::RowSchema,
9	interface::{
10		catalog::flow::FlowNodeId,
11		change::{Change, Diff},
12	},
13	internal,
14	value::column::columns::Columns,
15};
16use reifydb_type::{
17	Result,
18	error::Error,
19	value::{blob::Blob, row_number::RowNumber, r#type::Type},
20};
21use serde::{Deserialize, Serialize};
22
23use crate::{
24	operator::{
25		Operator, Operators,
26		stateful::{raw::RawStatefulOperator, single::SingleStateful},
27	},
28	transaction::FlowTransaction,
29};
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32struct TakeState {
33	active: BTreeMap<RowNumber, usize>,
34	candidates: BTreeMap<RowNumber, usize>,
35}
36
37impl Default for TakeState {
38	fn default() -> Self {
39		Self {
40			active: BTreeMap::new(),
41			candidates: BTreeMap::new(),
42		}
43	}
44}
45
46pub struct TakeOperator {
47	parent: Arc<Operators>,
48	node: FlowNodeId,
49	limit: usize,
50	schema: RowSchema,
51}
52
53impl TakeOperator {
54	pub fn new(parent: Arc<Operators>, node: FlowNodeId, limit: usize) -> Self {
55		Self {
56			parent,
57			node,
58			limit,
59			schema: RowSchema::testing(&[Type::Blob]),
60		}
61	}
62
63	fn load_take_state(&self, txn: &mut FlowTransaction) -> Result<TakeState> {
64		let state_row = self.load_state(txn)?;
65
66		if state_row.is_empty() || !state_row.is_defined(0) {
67			return Ok(TakeState::default());
68		}
69
70		let blob = self.schema.get_blob(&state_row, 0);
71		if blob.is_empty() {
72			return Ok(TakeState::default());
73		}
74
75		from_bytes(blob.as_ref()).map_err(|e| Error(internal!("Failed to deserialize TakeState: {}", e)))
76	}
77
78	fn save_take_state(&self, txn: &mut FlowTransaction, state: &TakeState) -> Result<()> {
79		let serialized =
80			to_stdvec(state).map_err(|e| Error(internal!("Failed to serialize TakeState: {}", e)))?;
81		let blob = Blob::from(serialized);
82
83		self.update_state(txn, |schema, row| {
84			schema.set_blob(row, 0, &blob);
85			Ok(())
86		})?;
87		Ok(())
88	}
89
90	fn promote_candidates(&self, state: &mut TakeState, txn: &mut FlowTransaction) -> Result<Vec<Diff>> {
91		let mut output_diffs = Vec::new();
92
93		while state.active.len() < self.limit && !state.candidates.is_empty() {
94			if let Some((&candidate_row, &count)) = state.candidates.iter().next_back() {
95				state.candidates.remove(&candidate_row);
96				state.active.insert(candidate_row, count);
97
98				let cols = self.parent.pull(txn, &[candidate_row])?;
99				if !cols.is_empty() {
100					output_diffs.push(Diff::Insert {
101						post: cols,
102					});
103				}
104			}
105		}
106
107		Ok(output_diffs)
108	}
109
110	fn evict_to_candidates(&self, state: &mut TakeState, txn: &mut FlowTransaction) -> Result<Vec<Diff>> {
111		let mut output_diffs = Vec::new();
112		let candidate_limit = self.limit * 4;
113
114		while state.active.len() > self.limit {
115			if let Some((&evicted_row, &count)) = state.active.iter().next() {
116				state.active.remove(&evicted_row);
117				state.candidates.insert(evicted_row, count);
118
119				let cols = self.parent.pull(txn, &[evicted_row])?;
120				if !cols.is_empty() {
121					output_diffs.push(Diff::Remove {
122						pre: cols,
123					});
124				}
125			}
126		}
127
128		while state.candidates.len() > candidate_limit {
129			if let Some((&removed_row, _)) = state.candidates.iter().next() {
130				state.candidates.remove(&removed_row);
131			}
132		}
133
134		Ok(output_diffs)
135	}
136}
137
138impl RawStatefulOperator for TakeOperator {}
139
140impl SingleStateful for TakeOperator {
141	fn layout(&self) -> RowSchema {
142		self.schema.clone()
143	}
144}
145
146impl Operator for TakeOperator {
147	fn id(&self) -> FlowNodeId {
148		self.node
149	}
150
151	fn apply(&self, txn: &mut FlowTransaction, change: Change) -> Result<Change> {
152		let mut state = self.load_take_state(txn)?;
153		let mut output_diffs = Vec::new();
154		let version = change.version;
155
156		for diff in change.diffs {
157			match diff {
158				Diff::Insert {
159					post,
160				} => {
161					let row_count = post.row_count();
162					for row_idx in 0..row_count {
163						let row_number = post.row_numbers[row_idx];
164
165						if state.active.contains_key(&row_number) {
166							*state.active.get_mut(&row_number).unwrap() += 1;
167							continue;
168						}
169						if state.candidates.contains_key(&row_number) {
170							*state.candidates.get_mut(&row_number).unwrap() += 1;
171							continue;
172						}
173
174						if state.active.len() < self.limit {
175							state.active.insert(row_number, 1);
176							output_diffs.push(Diff::Insert {
177								post: post.extract_by_indices(&[row_idx]),
178							});
179						} else {
180							let smallest_active = state.active.keys().next().copied();
181							if let Some(smallest) = smallest_active {
182								if row_number > smallest {
183									if let Some(count) =
184										state.active.remove(&smallest)
185									{
186										state.candidates
187											.insert(smallest, count);
188										let cols = self
189											.parent
190											.pull(txn, &[smallest])?;
191										if !cols.is_empty() {
192											output_diffs.push(
193												Diff::Remove {
194													pre: cols,
195												},
196											);
197										}
198									}
199									state.active.insert(row_number, 1);
200									output_diffs.push(Diff::Insert {
201										post: post
202											.extract_by_indices(&[row_idx]),
203									});
204									let candidate_limit = self.limit * 4;
205									while state.candidates.len() > candidate_limit {
206										if let Some((&r, _)) =
207											state.candidates.iter().next()
208										{
209											state.candidates.remove(&r);
210										}
211									}
212								} else {
213									state.candidates.insert(row_number, 1);
214									let candidate_limit = self.limit * 4;
215									while state.candidates.len() > candidate_limit {
216										if let Some((&r, _)) =
217											state.candidates.iter().next()
218										{
219											state.candidates.remove(&r);
220										}
221									}
222								}
223							}
224						}
225					}
226				}
227				Diff::Update {
228					pre,
229					post,
230				} => {
231					let row_count = post.row_count();
232					let mut update_indices: Vec<usize> = Vec::new();
233					for row_idx in 0..row_count {
234						let row_number = post.row_numbers[row_idx];
235						if state.active.contains_key(&row_number) {
236							update_indices.push(row_idx);
237						}
238					}
239					if !update_indices.is_empty() {
240						output_diffs.push(Diff::Update {
241							pre: pre.extract_by_indices(&update_indices),
242							post: post.extract_by_indices(&update_indices),
243						});
244					}
245				}
246				Diff::Remove {
247					pre,
248				} => {
249					let row_count = pre.row_count();
250					for row_idx in 0..row_count {
251						let row_number = pre.row_numbers[row_idx];
252
253						if let Some(count) = state.active.get_mut(&row_number) {
254							if *count > 1 {
255								*count -= 1;
256							} else {
257								state.active.remove(&row_number);
258								output_diffs.push(Diff::Remove {
259									pre: pre.extract_by_indices(&[row_idx]),
260								});
261								let promoted =
262									self.promote_candidates(&mut state, txn)?;
263								output_diffs.extend(promoted);
264							}
265						} else if let Some(count) = state.candidates.get_mut(&row_number) {
266							if *count > 1 {
267								*count -= 1;
268							} else {
269								state.candidates.remove(&row_number);
270							}
271						}
272					}
273				}
274			}
275		}
276
277		self.save_take_state(txn, &state)?;
278
279		Ok(Change::from_flow(self.node, version, output_diffs))
280	}
281
282	fn pull(&self, txn: &mut FlowTransaction, rows: &[RowNumber]) -> Result<Columns> {
283		self.parent.pull(txn, rows)
284	}
285}