Skip to main content

reifydb_engine/vm/volcano/
sort.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::cmp::Ordering::Equal;
5
6use reifydb_core::{
7	error::diagnostic::query,
8	sort::{
9		SortDirection::{Asc, Desc},
10		SortKey,
11	},
12	value::column::{buffer::ColumnBuffer, columns::Columns, headers::ColumnHeaders},
13};
14use reifydb_extension::transform::{Transform, context::TransformContext};
15use reifydb_transaction::transaction::Transaction;
16use reifydb_type::{
17	error,
18	error::Error,
19	util::cowvec::CowVec,
20	value::{
21		datetime::{CREATED_AT_COLUMN_NAME, UPDATED_AT_COLUMN_NAME},
22		row_number::ROW_NUMBER_COLUMN_NAME,
23	},
24};
25use tracing::instrument;
26
27use crate::{
28	Result,
29	vm::volcano::query::{QueryContext, QueryNode},
30};
31
32pub(crate) struct SortNode {
33	input: Box<dyn QueryNode>,
34	by: Vec<SortKey>,
35	initialized: Option<()>,
36}
37
38impl SortNode {
39	pub(crate) fn new(input: Box<dyn QueryNode>, by: Vec<SortKey>) -> Self {
40		Self {
41			input,
42			by,
43			initialized: None,
44		}
45	}
46}
47
48impl QueryNode for SortNode {
49	#[instrument(level = "trace", skip_all, name = "volcano::sort::initialize")]
50	fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
51		self.input.initialize(rx, ctx)?;
52		self.initialized = Some(());
53		Ok(())
54	}
55
56	#[instrument(level = "trace", skip_all, name = "volcano::sort::next")]
57	fn next<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<Option<Columns>> {
58		debug_assert!(self.initialized.is_some(), "SortNode::next() called before initialize()");
59
60		let mut columns_opt: Option<Columns> = None;
61
62		while let Some(columns) = self.input.next(rx, ctx)? {
63			if let Some(existing_columns) = &mut columns_opt {
64				existing_columns.row_numbers.make_mut().extend(columns.row_numbers.iter().copied());
65				existing_columns.created_at.make_mut().extend(columns.created_at.iter().copied());
66				existing_columns.updated_at.make_mut().extend(columns.updated_at.iter().copied());
67				for (i, col) in columns.columns.iter().enumerate() {
68					existing_columns[i].extend(col.clone())?;
69				}
70			} else {
71				columns_opt = Some(columns);
72			}
73		}
74
75		let columns = match columns_opt {
76			Some(f) => f,
77			None => return Ok(None),
78		};
79
80		let transform_ctx = TransformContext {
81			routines: &ctx.services.routines,
82			runtime_context: &ctx.services.runtime_context,
83			params: &ctx.params,
84		};
85		Ok(Some(self.apply(&transform_ctx, columns)?))
86	}
87
88	fn headers(&self) -> Option<ColumnHeaders> {
89		self.input.headers()
90	}
91}
92
93impl Transform for SortNode {
94	fn apply(&self, _ctx: &TransformContext, mut columns: Columns) -> Result<Columns> {
95		let key_refs = self
96			.by
97			.iter()
98			.map(|key| {
99				let name = key.column.fragment();
100				let stripped = name.strip_prefix('#').unwrap_or(name);
101
102				if stripped == ROW_NUMBER_COLUMN_NAME && !columns.row_numbers.is_empty() {
103					let data: Vec<u64> = columns.row_numbers.iter().map(|r| r.value()).collect();
104					return Ok::<_, Error>((ColumnBuffer::uint8(data), key.direction.clone()));
105				}
106				if stripped == CREATED_AT_COLUMN_NAME && !columns.created_at.is_empty() {
107					return Ok((
108						ColumnBuffer::datetime(columns.created_at.to_vec()),
109						key.direction.clone(),
110					));
111				}
112				if stripped == UPDATED_AT_COLUMN_NAME && !columns.updated_at.is_empty() {
113					return Ok((
114						ColumnBuffer::datetime(columns.updated_at.to_vec()),
115						key.direction.clone(),
116					));
117				}
118
119				let col = columns
120					.iter()
121					.find(|c| c.name() == name)
122					.ok_or_else(|| error!(query::column_not_found(key.column.clone())))?;
123				Ok((col.data().clone(), key.direction.clone()))
124			})
125			.collect::<Result<Vec<_>>>()?;
126
127		let row_count = columns.row_count();
128		let mut indices: Vec<usize> = (0..row_count).collect();
129
130		indices.sort_unstable_by(|&l, &r| {
131			for (col, dir) in &key_refs {
132				let vl = col.get_value(l);
133				let vr = col.get_value(r);
134				let ord = vl.partial_cmp(&vr).unwrap_or(Equal);
135				let ord = match dir {
136					Asc => ord,
137					Desc => ord.reverse(),
138				};
139				if ord != Equal {
140					return ord;
141				}
142			}
143			Equal
144		});
145
146		if !columns.row_numbers.is_empty() {
147			let reordered: Vec<_> = indices.iter().map(|&i| columns.row_numbers[i]).collect();
148			columns.row_numbers = CowVec::new(reordered);
149		}
150		if !columns.created_at.is_empty() {
151			let reordered: Vec<_> = indices.iter().map(|&i| columns.created_at[i]).collect();
152			columns.created_at = CowVec::new(reordered);
153		}
154		if !columns.updated_at.is_empty() {
155			let reordered: Vec<_> = indices.iter().map(|&i| columns.updated_at[i]).collect();
156			columns.updated_at = CowVec::new(reordered);
157		}
158
159		let cols = columns.columns.make_mut();
160		for col in cols.iter_mut() {
161			col.reorder(&indices);
162		}
163
164		Ok(columns)
165	}
166}