Skip to main content

reifydb_engine/vm/volcano/
assert.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::sync::Arc;
5
6use reifydb_core::value::column::{buffer::ColumnBuffer, columns::Columns, headers::ColumnHeaders};
7use reifydb_rql::expression::{Expression, name::display_label};
8use reifydb_transaction::transaction::Transaction;
9use tracing::instrument;
10
11use crate::{
12	Result,
13	error::EngineError,
14	expression::{context::EvalContext, eval::evaluate},
15	vm::volcano::query::{QueryContext, QueryNode},
16};
17
18pub(crate) struct AssertNode {
19	input: Box<dyn QueryNode>,
20	expressions: Vec<Expression>,
21	message: Option<String>,
22	context: Option<Arc<QueryContext>>,
23}
24
25impl AssertNode {
26	pub fn new(input: Box<dyn QueryNode>, expressions: Vec<Expression>, message: Option<String>) -> Self {
27		Self {
28			input,
29			expressions,
30			message,
31			context: None,
32		}
33	}
34}
35
36impl QueryNode for AssertNode {
37	#[instrument(level = "trace", skip_all, name = "volcano::assert::initialize")]
38	fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
39		self.context = Some(Arc::new(ctx.clone()));
40		self.input.initialize(rx, ctx)?;
41		Ok(())
42	}
43
44	#[instrument(level = "trace", skip_all, name = "volcano::assert::next")]
45	fn next<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<Option<Columns>> {
46		debug_assert!(self.context.is_some(), "AssertNode::next() called before initialize()");
47		let stored_ctx = self.context.as_ref().unwrap();
48
49		if let Some(columns) = self.input.next(rx, ctx)? {
50			let row_count = columns.row_count();
51			let session = EvalContext::from_query(stored_ctx);
52
53			for assert_expr in &self.expressions {
54				let eval_ctx = session.with_eval(columns.clone(), row_count);
55
56				let result = evaluate(&eval_ctx, assert_expr)?;
57
58				let frag = assert_expr.full_fragment_owned();
59				let label = display_label(assert_expr);
60				match result.data() {
61					ColumnBuffer::Bool(container) => {
62						for i in 0..row_count {
63							let valid = container.is_defined(i);
64							let value = container.data().get(i);
65							if !valid || !value {
66								return Err(EngineError::AssertionFailed {
67									fragment: frag.clone(),
68									message: self
69										.message
70										.clone()
71										.unwrap_or_default(),
72									expression: Some(label.text().to_string()),
73								}
74								.into());
75							}
76						}
77					}
78					ColumnBuffer::Option {
79						inner,
80						bitvec,
81					} => match inner.as_ref() {
82						ColumnBuffer::Bool(container) => {
83							for i in 0..row_count {
84								let defined = i < bitvec.len() && bitvec.get(i);
85								let valid = defined && container.is_defined(i);
86								let value = valid && container.data().get(i);
87								if !value {
88									return Err(EngineError::AssertionFailed {
89										fragment: frag.clone(),
90										message: self
91											.message
92											.clone()
93											.unwrap_or_default(),
94										expression: Some(label
95											.text()
96											.to_string()),
97									}
98									.into());
99								}
100							}
101						}
102						_ => {
103							return Err(EngineError::AssertionFailed {
104								fragment: frag.clone(),
105								message: "assert expression must evaluate to a boolean"
106									.to_string(),
107								expression: Some(label.text().to_string()),
108							}
109							.into());
110						}
111					},
112					_ => {
113						return Err(EngineError::AssertionFailed {
114							fragment: frag.clone(),
115							message: "assert expression must evaluate to a boolean"
116								.to_string(),
117							expression: Some(label.text().to_string()),
118						}
119						.into());
120					}
121				}
122			}
123
124			Ok(Some(columns))
125		} else {
126			Ok(None)
127		}
128	}
129
130	fn headers(&self) -> Option<ColumnHeaders> {
131		self.input.headers()
132	}
133}
134
135pub(crate) struct AssertWithoutInputNode {
136	expressions: Vec<Expression>,
137	message: Option<String>,
138	context: Option<Arc<QueryContext>>,
139	done: bool,
140}
141
142impl AssertWithoutInputNode {
143	pub fn new(expressions: Vec<Expression>, message: Option<String>) -> Self {
144		Self {
145			expressions,
146			message,
147			context: None,
148			done: false,
149		}
150	}
151}
152
153impl QueryNode for AssertWithoutInputNode {
154	#[instrument(level = "trace", skip_all, name = "volcano::assert::noinput::initialize")]
155	fn initialize<'a>(&mut self, _rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
156		self.context = Some(Arc::new(ctx.clone()));
157		Ok(())
158	}
159
160	#[instrument(level = "trace", skip_all, name = "volcano::assert::noinput::next")]
161	fn next<'a>(&mut self, _rx: &mut Transaction<'a>, _ctx: &mut QueryContext) -> Result<Option<Columns>> {
162		if self.done {
163			return Ok(None);
164		}
165		self.done = true;
166
167		debug_assert!(self.context.is_some(), "AssertWithoutInputNode::next() called before initialize()");
168		let stored_ctx = self.context.as_ref().unwrap();
169		let session = EvalContext::from_query(stored_ctx);
170
171		for assert_expr in &self.expressions {
172			let eval_ctx = session.with_eval_empty();
173
174			let result = evaluate(&eval_ctx, assert_expr)?;
175
176			let frag = assert_expr.full_fragment_owned();
177			let label = display_label(assert_expr);
178			match result.data() {
179				ColumnBuffer::Bool(container) => {
180					let valid = container.is_defined(0);
181					let value = container.data().get(0);
182					if !valid || !value {
183						return Err(EngineError::AssertionFailed {
184							fragment: frag.clone(),
185							message: self.message.clone().unwrap_or_default(),
186							expression: Some(label.text().to_string()),
187						}
188						.into());
189					}
190				}
191				ColumnBuffer::Option {
192					inner,
193					bitvec,
194				} => match inner.as_ref() {
195					ColumnBuffer::Bool(container) => {
196						let defined = !bitvec.is_empty() && bitvec.get(0);
197						let valid = defined && container.is_defined(0);
198						let value = valid && container.data().get(0);
199						if !value {
200							return Err(EngineError::AssertionFailed {
201								fragment: frag.clone(),
202								message: self.message.clone().unwrap_or_default(),
203								expression: Some(label.text().to_string()),
204							}
205							.into());
206						}
207					}
208					_ => {
209						return Err(EngineError::AssertionFailed {
210							fragment: frag.clone(),
211							message: "assert expression must evaluate to a boolean"
212								.to_string(),
213							expression: Some(label.text().to_string()),
214						}
215						.into());
216					}
217				},
218				_ => {
219					return Err(EngineError::AssertionFailed {
220						fragment: frag.clone(),
221						message: "assert expression must evaluate to a boolean".to_string(),
222						expression: Some(label.text().to_string()),
223					}
224					.into());
225				}
226			}
227		}
228
229		Ok(None)
230	}
231
232	fn headers(&self) -> Option<ColumnHeaders> {
233		None
234	}
235}