Skip to main content

reifydb_engine/vm/volcano/
take.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_core::value::column::{columns::Columns, headers::ColumnHeaders};
5use reifydb_extension::transform::{Transform, context::TransformContext};
6use reifydb_transaction::transaction::Transaction;
7use tracing::instrument;
8
9use crate::{
10	Result,
11	vm::volcano::query::{QueryContext, QueryNode},
12};
13
14pub(crate) struct TakeNode {
15	input: Box<dyn QueryNode>,
16	remaining: usize,
17	initialized: Option<()>,
18}
19
20impl TakeNode {
21	pub(crate) fn new(input: Box<dyn QueryNode>, take: usize) -> Self {
22		Self {
23			input,
24			remaining: take,
25			initialized: None,
26		}
27	}
28}
29
30impl QueryNode for TakeNode {
31	#[instrument(name = "volcano::take::initialize", level = "trace", skip_all)]
32	fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
33		self.input.initialize(rx, ctx)?;
34		self.initialized = Some(());
35		Ok(())
36	}
37
38	#[instrument(name = "volcano::take::next", level = "trace", skip_all)]
39	fn next<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<Option<Columns>> {
40		debug_assert!(self.initialized.is_some(), "TakeNode::next() called before initialize()");
41
42		if self.remaining == 0 {
43			return Ok(None);
44		}
45
46		while let Some(columns) = self.input.next(rx, ctx)? {
47			if columns.row_count() == 0 {
48				continue;
49			}
50			let transform_ctx = TransformContext {
51				routines: &ctx.services.routines,
52				runtime_context: &ctx.services.runtime_context,
53				params: &ctx.params,
54			};
55			let result = self.apply(&transform_ctx, columns)?;
56			self.remaining -= result.row_count();
57			return Ok(Some(result));
58		}
59		Ok(None)
60	}
61
62	fn headers(&self) -> Option<ColumnHeaders> {
63		self.input.headers()
64	}
65}
66
67impl Transform for TakeNode {
68	fn apply(&self, _ctx: &TransformContext, mut input: Columns) -> Result<Columns> {
69		let row_count = input.row_count();
70		if row_count > self.remaining {
71			input.take(self.remaining)?;
72		}
73		Ok(input)
74	}
75}