Skip to main content

reifydb_sub_flow/operator/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::{collections::BTreeMap, sync::Arc};
5
6use postcard::{from_bytes, to_stdvec};
7use reifydb_core::{
8	encoded::shape::RowShape,
9	interface::{
10		catalog::flow::FlowNodeId,
11		change::{Change, Diff},
12	},
13	internal,
14	value::column::columns::Columns,
15};
16use reifydb_type::{
17	Result,
18	error::Error,
19	value::{blob::Blob, row_number::RowNumber, r#type::Type},
20};
21use serde::{Deserialize, Serialize};
22
23use crate::{
24	operator::{
25		Operator, Operators,
26		stateful::{raw::RawStatefulOperator, single::SingleStateful, utils},
27	},
28	transaction::{FlowTransaction, slot::PersistFn},
29};
30
31#[derive(Debug, Clone, Serialize, Deserialize, Default)]
32struct TakeState {
33	active: BTreeMap<RowNumber, usize>,
34	candidates: BTreeMap<RowNumber, usize>,
35}
36
37pub struct TakeOperator {
38	parent: Arc<Operators>,
39	node: FlowNodeId,
40	limit: usize,
41	shape: RowShape,
42}
43
44impl TakeOperator {
45	pub fn new(parent: Arc<Operators>, node: FlowNodeId, limit: usize) -> Self {
46		Self {
47			parent,
48			node,
49			limit,
50			shape: RowShape::testing(&[Type::Blob]),
51		}
52	}
53
54	fn load_take_state(&self, txn: &mut FlowTransaction) -> Result<TakeState> {
55		let state_row = self.load_state(txn)?;
56
57		if state_row.is_empty() || !state_row.is_defined(0) {
58			return Ok(TakeState::default());
59		}
60
61		let blob = self.shape.get_blob(&state_row, 0);
62		if blob.is_empty() {
63			return Ok(TakeState::default());
64		}
65
66		from_bytes(blob.as_ref())
67			.map_err(|e| Error(Box::new(internal!("Failed to deserialize TakeState: {}", e))))
68	}
69
70	fn save_take_state(&self, txn: &mut FlowTransaction, state: &TakeState) -> Result<()> {
71		let serialized = to_stdvec(state)
72			.map_err(|e| Error(Box::new(internal!("Failed to serialize TakeState: {}", e))))?;
73		let blob = Blob::from(serialized);
74
75		self.update_state(txn, |shape, row| {
76			shape.set_blob(row, 0, &blob);
77			Ok(())
78		})?;
79		Ok(())
80	}
81
82	fn promote_candidates(&self, state: &mut TakeState, txn: &mut FlowTransaction) -> Result<Vec<Diff>> {
83		let mut output_diffs = Vec::new();
84
85		while state.active.len() < self.limit && !state.candidates.is_empty() {
86			if let Some((&candidate_row, &count)) = state.candidates.iter().next_back() {
87				state.candidates.remove(&candidate_row);
88				state.active.insert(candidate_row, count);
89
90				let cols = self.parent.pull(txn, &[candidate_row])?;
91				if !cols.is_empty() {
92					output_diffs.push(Diff::insert(cols));
93				}
94			}
95		}
96
97		Ok(output_diffs)
98	}
99
100	fn evict_to_candidates(&self, state: &mut TakeState, txn: &mut FlowTransaction) -> Result<Vec<Diff>> {
101		let mut output_diffs = Vec::new();
102		let candidate_limit = self.limit * 4;
103
104		while state.active.len() > self.limit {
105			if let Some((&evicted_row, &count)) = state.active.iter().next() {
106				state.active.remove(&evicted_row);
107				state.candidates.insert(evicted_row, count);
108
109				let cols = self.parent.pull(txn, &[evicted_row])?;
110				if !cols.is_empty() {
111					output_diffs.push(Diff::remove(cols));
112				}
113			}
114		}
115
116		while state.candidates.len() > candidate_limit {
117			if let Some((&removed_row, _)) = state.candidates.iter().next() {
118				state.candidates.remove(&removed_row);
119			}
120		}
121
122		Ok(output_diffs)
123	}
124
125	#[inline]
126	fn acquire_take_state(&self, txn: &mut FlowTransaction) -> Result<(TakeState, PersistFn)> {
127		let node_id = self.node;
128		let shape_for_persist = self.shape.clone();
129		txn.take_operator_state::<TakeState, _>(node_id, |txn| {
130			let s = self.load_take_state(txn)?;
131			let shape = shape_for_persist.clone();
132			let persist: PersistFn = Box::new(move |txn, value| {
133				let state = value.downcast::<TakeState>().expect("TakeState slot type");
134				let serialized = to_stdvec(&*state).map_err(|e| {
135					Error(Box::new(internal!("Failed to serialize TakeState: {}", e)))
136				})?;
137				let blob = Blob::from(serialized);
138				let key = utils::empty_key();
139				let mut row = utils::load_or_create_row(node_id, txn, &key, &shape)?;
140				shape.set_blob(&mut row, 0, &blob);
141				utils::save_row(node_id, txn, &key, row)?;
142				Ok(())
143			});
144			Ok((s, persist))
145		})
146	}
147
148	#[inline]
149	fn admit_or_evict_new_row(
150		&self,
151		state: &mut TakeState,
152		txn: &mut FlowTransaction,
153		row_number: RowNumber,
154		single_row: Columns,
155		output_diffs: &mut Vec<Diff>,
156	) -> Result<()> {
157		if state.active.len() < self.limit {
158			state.active.insert(row_number, 1);
159			output_diffs.push(Diff::insert(single_row));
160			return Ok(());
161		}
162
163		let Some(smallest) = state.active.keys().next().copied() else {
164			return Ok(());
165		};
166
167		if row_number > smallest {
168			if let Some(count) = state.active.remove(&smallest) {
169				state.candidates.insert(smallest, count);
170				let cols = self.parent.pull(txn, &[smallest])?;
171				if !cols.is_empty() {
172					output_diffs.push(Diff::remove(cols));
173				}
174			}
175			state.active.insert(row_number, 1);
176			output_diffs.push(Diff::insert(single_row));
177		} else {
178			state.candidates.insert(row_number, 1);
179		}
180		prune_candidates(state, self.limit);
181		Ok(())
182	}
183
184	#[inline]
185	fn apply_insert_diff(
186		&self,
187		state: &mut TakeState,
188		txn: &mut FlowTransaction,
189		post: Arc<Columns>,
190		output_diffs: &mut Vec<Diff>,
191	) -> Result<()> {
192		let row_count = post.row_count();
193		for row_idx in 0..row_count {
194			let row_number = post.row_numbers[row_idx];
195
196			if state.active.contains_key(&row_number) {
197				*state.active.get_mut(&row_number).unwrap() += 1;
198				continue;
199			}
200			if state.candidates.contains_key(&row_number) {
201				*state.candidates.get_mut(&row_number).unwrap() += 1;
202				continue;
203			}
204
205			let single = post.extract_by_indices(&[row_idx]);
206			self.admit_or_evict_new_row(state, txn, row_number, single, output_diffs)?;
207		}
208		Ok(())
209	}
210
211	#[inline]
212	fn apply_update_diff(
213		&self,
214		state: &mut TakeState,
215		txn: &mut FlowTransaction,
216		pre: Arc<Columns>,
217		post: Arc<Columns>,
218		output_diffs: &mut Vec<Diff>,
219	) -> Result<()> {
220		let row_count = post.row_count();
221		let mut update_indices: Vec<usize> = Vec::new();
222
223		for row_idx in 0..row_count {
224			let row_number = post.row_numbers[row_idx];
225
226			if state.active.contains_key(&row_number) {
227				update_indices.push(row_idx);
228				continue;
229			}
230
231			// Row suppressed by the take limit (kept as a candidate for
232			// future promotion); subscriber is intentionally not receiving
233			// it, so the Update is also suppressed.
234			if state.candidates.contains_key(&row_number) {
235				continue;
236			}
237
238			// Row is unknown to TakeState because it existed before the
239			// subscription started (subscriptions have no backfill, so
240			// TakeState begins empty even when the upstream view is
241			// populated). The subscriber is seeing this row for the first
242			// time, so emit the post-image as an Insert and run the same
243			// admission/eviction policy that the Insert branch uses for
244			// genuinely new rows. Without this, every Update against a
245			// pre-existing row would be silently dropped.
246			let single = post.extract_by_indices(&[row_idx]);
247			self.admit_or_evict_new_row(state, txn, row_number, single, output_diffs)?;
248		}
249
250		if !update_indices.is_empty() {
251			output_diffs.push(Diff::update(
252				pre.extract_by_indices(&update_indices),
253				post.extract_by_indices(&update_indices),
254			));
255		}
256		Ok(())
257	}
258
259	#[inline]
260	fn apply_remove_diff(
261		&self,
262		state: &mut TakeState,
263		txn: &mut FlowTransaction,
264		pre: Arc<Columns>,
265		output_diffs: &mut Vec<Diff>,
266	) -> Result<()> {
267		let row_count = pre.row_count();
268		for row_idx in 0..row_count {
269			let row_number = pre.row_numbers[row_idx];
270
271			if let Some(count) = state.active.get_mut(&row_number) {
272				if *count > 1 {
273					*count -= 1;
274				} else {
275					state.active.remove(&row_number);
276					output_diffs.push(Diff::remove(pre.extract_by_indices(&[row_idx])));
277					let promoted = self.promote_candidates(state, txn)?;
278					output_diffs.extend(promoted);
279				}
280			} else if let Some(count) = state.candidates.get_mut(&row_number) {
281				if *count > 1 {
282					*count -= 1;
283				} else {
284					state.candidates.remove(&row_number);
285				}
286			}
287		}
288		Ok(())
289	}
290}
291
292#[inline]
293fn prune_candidates(state: &mut TakeState, limit: usize) {
294	let candidate_limit = limit * 4;
295	while state.candidates.len() > candidate_limit {
296		if let Some((&r, _)) = state.candidates.iter().next() {
297			state.candidates.remove(&r);
298		}
299	}
300}
301
302impl RawStatefulOperator for TakeOperator {}
303
304impl SingleStateful for TakeOperator {
305	fn layout(&self) -> RowShape {
306		self.shape.clone()
307	}
308}
309
310impl Operator for TakeOperator {
311	fn id(&self) -> FlowNodeId {
312		self.node
313	}
314
315	fn apply(&self, txn: &mut FlowTransaction, change: Change) -> Result<Change> {
316		let node_id = self.node;
317		let (mut state, persist) = self.acquire_take_state(txn)?;
318
319		let mut output_diffs = Vec::new();
320		let version = change.version;
321
322		for diff in change.diffs {
323			match diff {
324				Diff::Insert {
325					post,
326				} => self.apply_insert_diff(&mut state, txn, post, &mut output_diffs)?,
327				Diff::Update {
328					pre,
329					post,
330				} => self.apply_update_diff(&mut state, txn, pre, post, &mut output_diffs)?,
331				Diff::Remove {
332					pre,
333				} => self.apply_remove_diff(&mut state, txn, pre, &mut output_diffs)?,
334			}
335		}
336
337		// Restore the cached state for the next batch in this txn; the put
338		// marks the slot dirty so flush_operator_states will persist it.
339		txn.put_operator_state(node_id, state, persist);
340
341		Ok(Change::from_flow(self.node, version, output_diffs, change.changed_at))
342	}
343
344	fn pull(&self, txn: &mut FlowTransaction, rows: &[RowNumber]) -> Result<Columns> {
345		self.parent.pull(txn, rows)
346	}
347}