Skip to main content

reifydb_sub_flow/operator/
take.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2025 ReifyDB
3
4use std::{collections::BTreeMap, sync::Arc};
5
6use postcard::{from_bytes, to_stdvec};
7use reifydb_core::{
8	encoded::schema::Schema,
9	interface::{
10		catalog::flow::FlowNodeId,
11		change::{Change, Diff},
12	},
13	internal,
14	value::column::columns::Columns,
15};
16use reifydb_type::{
17	Result,
18	error::Error,
19	value::{blob::Blob, row_number::RowNumber, r#type::Type},
20};
21use serde::{Deserialize, Serialize};
22
23use crate::{
24	operator::{
25		Operator, Operators,
26		stateful::{raw::RawStatefulOperator, single::SingleStateful},
27	},
28	transaction::FlowTransaction,
29};
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32struct TakeState {
33	active: BTreeMap<RowNumber, usize>,
34	candidates: BTreeMap<RowNumber, usize>,
35}
36
37impl Default for TakeState {
38	fn default() -> Self {
39		Self {
40			active: BTreeMap::new(),
41			candidates: BTreeMap::new(),
42		}
43	}
44}
45
46pub struct TakeOperator {
47	parent: Arc<Operators>,
48	node: FlowNodeId,
49	limit: usize,
50	schema: Schema,
51}
52
53impl TakeOperator {
54	pub fn new(parent: Arc<Operators>, node: FlowNodeId, limit: usize) -> Self {
55		Self {
56			parent,
57			node,
58			limit,
59			schema: Schema::testing(&[Type::Blob]),
60		}
61	}
62
63	fn load_take_state(&self, txn: &mut FlowTransaction) -> Result<TakeState> {
64		let state_row = self.load_state(txn)?;
65
66		if state_row.is_empty() || !state_row.is_defined(0) {
67			return Ok(TakeState::default());
68		}
69
70		let blob = self.schema.get_blob(&state_row, 0);
71		if blob.is_empty() {
72			return Ok(TakeState::default());
73		}
74
75		from_bytes(blob.as_ref()).map_err(|e| Error(internal!("Failed to deserialize TakeState: {}", e)))
76	}
77
78	fn save_take_state(&self, txn: &mut FlowTransaction, state: &TakeState) -> Result<()> {
79		let serialized =
80			to_stdvec(state).map_err(|e| Error(internal!("Failed to serialize TakeState: {}", e)))?;
81
82		let mut state_row = self.schema.allocate();
83		let blob = Blob::from(serialized);
84		self.schema.set_blob(&mut state_row, 0, &blob);
85
86		self.save_state(txn, state_row)
87	}
88
89	fn promote_candidates(&self, state: &mut TakeState, txn: &mut FlowTransaction) -> Result<Vec<Diff>> {
90		let mut output_diffs = Vec::new();
91
92		while state.active.len() < self.limit && !state.candidates.is_empty() {
93			if let Some((&candidate_row, &count)) = state.candidates.iter().next_back() {
94				state.candidates.remove(&candidate_row);
95				state.active.insert(candidate_row, count);
96
97				let cols = self.parent.pull(txn, &[candidate_row])?;
98				if !cols.is_empty() {
99					output_diffs.push(Diff::Insert {
100						post: cols,
101					});
102				}
103			}
104		}
105
106		Ok(output_diffs)
107	}
108
109	fn evict_to_candidates(&self, state: &mut TakeState, txn: &mut FlowTransaction) -> Result<Vec<Diff>> {
110		let mut output_diffs = Vec::new();
111		let candidate_limit = self.limit * 4;
112
113		while state.active.len() > self.limit {
114			if let Some((&evicted_row, &count)) = state.active.iter().next() {
115				state.active.remove(&evicted_row);
116				state.candidates.insert(evicted_row, count);
117
118				let cols = self.parent.pull(txn, &[evicted_row])?;
119				if !cols.is_empty() {
120					output_diffs.push(Diff::Remove {
121						pre: cols,
122					});
123				}
124			}
125		}
126
127		while state.candidates.len() > candidate_limit {
128			if let Some((&removed_row, _)) = state.candidates.iter().next() {
129				state.candidates.remove(&removed_row);
130			}
131		}
132
133		Ok(output_diffs)
134	}
135}
136
137impl RawStatefulOperator for TakeOperator {}
138
139impl SingleStateful for TakeOperator {
140	fn layout(&self) -> Schema {
141		self.schema.clone()
142	}
143}
144
145impl Operator for TakeOperator {
146	fn id(&self) -> FlowNodeId {
147		self.node
148	}
149
150	fn apply(&self, txn: &mut FlowTransaction, change: Change) -> Result<Change> {
151		let mut state = self.load_take_state(txn)?;
152		let mut output_diffs = Vec::new();
153		let version = change.version;
154
155		for diff in change.diffs {
156			match diff {
157				Diff::Insert {
158					post,
159				} => {
160					let row_count = post.row_count();
161					for row_idx in 0..row_count {
162						let row_number = post.row_numbers[row_idx];
163
164						if state.active.contains_key(&row_number) {
165							*state.active.get_mut(&row_number).unwrap() += 1;
166							continue;
167						}
168						if state.candidates.contains_key(&row_number) {
169							*state.candidates.get_mut(&row_number).unwrap() += 1;
170							continue;
171						}
172
173						if state.active.len() < self.limit {
174							state.active.insert(row_number, 1);
175							output_diffs.push(Diff::Insert {
176								post: post.extract_by_indices(&[row_idx]),
177							});
178						} else {
179							let smallest_active = state.active.keys().next().copied();
180							if let Some(smallest) = smallest_active {
181								if row_number > smallest {
182									if let Some(count) =
183										state.active.remove(&smallest)
184									{
185										state.candidates
186											.insert(smallest, count);
187										let cols = self
188											.parent
189											.pull(txn, &[smallest])?;
190										if !cols.is_empty() {
191											output_diffs.push(
192												Diff::Remove {
193													pre: cols,
194												},
195											);
196										}
197									}
198									state.active.insert(row_number, 1);
199									output_diffs.push(Diff::Insert {
200										post: post
201											.extract_by_indices(&[row_idx]),
202									});
203									let candidate_limit = self.limit * 4;
204									while state.candidates.len() > candidate_limit {
205										if let Some((&r, _)) =
206											state.candidates.iter().next()
207										{
208											state.candidates.remove(&r);
209										}
210									}
211								} else {
212									state.candidates.insert(row_number, 1);
213									let candidate_limit = self.limit * 4;
214									while state.candidates.len() > candidate_limit {
215										if let Some((&r, _)) =
216											state.candidates.iter().next()
217										{
218											state.candidates.remove(&r);
219										}
220									}
221								}
222							}
223						}
224					}
225				}
226				Diff::Update {
227					pre,
228					post,
229				} => {
230					let row_count = post.row_count();
231					let mut update_indices: Vec<usize> = Vec::new();
232					for row_idx in 0..row_count {
233						let row_number = post.row_numbers[row_idx];
234						if state.active.contains_key(&row_number) {
235							update_indices.push(row_idx);
236						}
237					}
238					if !update_indices.is_empty() {
239						output_diffs.push(Diff::Update {
240							pre: pre.extract_by_indices(&update_indices),
241							post: post.extract_by_indices(&update_indices),
242						});
243					}
244				}
245				Diff::Remove {
246					pre,
247				} => {
248					let row_count = pre.row_count();
249					for row_idx in 0..row_count {
250						let row_number = pre.row_numbers[row_idx];
251
252						if let Some(count) = state.active.get_mut(&row_number) {
253							if *count > 1 {
254								*count -= 1;
255							} else {
256								state.active.remove(&row_number);
257								output_diffs.push(Diff::Remove {
258									pre: pre.extract_by_indices(&[row_idx]),
259								});
260								let promoted =
261									self.promote_candidates(&mut state, txn)?;
262								output_diffs.extend(promoted);
263							}
264						} else if let Some(count) = state.candidates.get_mut(&row_number) {
265							if *count > 1 {
266								*count -= 1;
267							} else {
268								state.candidates.remove(&row_number);
269							}
270						}
271					}
272				}
273			}
274		}
275
276		self.save_take_state(txn, &state)?;
277
278		Ok(Change::from_flow(self.node, version, output_diffs))
279	}
280
281	fn pull(&self, txn: &mut FlowTransaction, rows: &[RowNumber]) -> Result<Columns> {
282		self.parent.pull(txn, rows)
283	}
284}