Skip to main content

reifydb_engine/vm/volcano/
assert.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::sync::Arc;
5
6use reifydb_core::value::column::{buffer::ColumnBuffer, columns::Columns, headers::ColumnHeaders};
7use reifydb_rql::expression::{Expression, name::display_label};
8use reifydb_transaction::transaction::Transaction;
9use tracing::instrument;
10
11use crate::{
12	Result,
13	error::EngineError,
14	expression::{context::EvalContext, eval::evaluate},
15	vm::volcano::query::{QueryContext, QueryNode},
16};
17
18pub(crate) struct AssertNode {
19	input: Box<dyn QueryNode>,
20	expressions: Vec<Expression>,
21	message: Option<String>,
22	context: Option<Arc<QueryContext>>,
23}
24
25impl AssertNode {
26	pub fn new(input: Box<dyn QueryNode>, expressions: Vec<Expression>, message: Option<String>) -> Self {
27		Self {
28			input,
29			expressions,
30			message,
31			context: None,
32		}
33	}
34}
35
36impl QueryNode for AssertNode {
37	#[instrument(level = "trace", skip_all, name = "volcano::assert::initialize")]
38	fn initialize<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
39		self.context = Some(Arc::new(ctx.clone()));
40		self.input.initialize(rx, ctx)?;
41		Ok(())
42	}
43
44	#[instrument(level = "trace", skip_all, name = "volcano::assert::next")]
45	fn next<'a>(&mut self, rx: &mut Transaction<'a>, ctx: &mut QueryContext) -> Result<Option<Columns>> {
46		debug_assert!(self.context.is_some(), "AssertNode::next() called before initialize()");
47		let stored_ctx = self.context.as_ref().unwrap();
48
49		if let Some(columns) = self.input.next(rx, ctx)? {
50			let row_count = columns.row_count();
51			let session = EvalContext::from_query(stored_ctx);
52
53			// Evaluate each assert expression
54			for assert_expr in &self.expressions {
55				let eval_ctx = session.with_eval(columns.clone(), row_count);
56
57				let result = evaluate(&eval_ctx, assert_expr)?;
58
59				let frag = assert_expr.full_fragment_owned();
60				let label = display_label(assert_expr);
61				match result.data() {
62					ColumnBuffer::Bool(container) => {
63						for i in 0..row_count {
64							let valid = container.is_defined(i);
65							let value = container.data().get(i);
66							if !valid || !value {
67								return Err(EngineError::AssertionFailed {
68									fragment: frag.clone(),
69									message: self
70										.message
71										.clone()
72										.unwrap_or_default(),
73									expression: Some(label.text().to_string()),
74								}
75								.into());
76							}
77						}
78					}
79					ColumnBuffer::Option {
80						inner,
81						bitvec,
82					} => match inner.as_ref() {
83						ColumnBuffer::Bool(container) => {
84							for i in 0..row_count {
85								let defined = i < bitvec.len() && bitvec.get(i);
86								let valid = defined && container.is_defined(i);
87								let value = valid && container.data().get(i);
88								if !value {
89									return Err(EngineError::AssertionFailed {
90										fragment: frag.clone(),
91										message: self
92											.message
93											.clone()
94											.unwrap_or_default(),
95										expression: Some(label
96											.text()
97											.to_string()),
98									}
99									.into());
100								}
101							}
102						}
103						_ => {
104							return Err(EngineError::AssertionFailed {
105								fragment: frag.clone(),
106								message: "assert expression must evaluate to a boolean"
107									.to_string(),
108								expression: Some(label.text().to_string()),
109							}
110							.into());
111						}
112					},
113					_ => {
114						return Err(EngineError::AssertionFailed {
115							fragment: frag.clone(),
116							message: "assert expression must evaluate to a boolean"
117								.to_string(),
118							expression: Some(label.text().to_string()),
119						}
120						.into());
121					}
122				}
123			}
124
125			// Passthrough: return the original columns unchanged
126			Ok(Some(columns))
127		} else {
128			Ok(None)
129		}
130	}
131
132	fn headers(&self) -> Option<ColumnHeaders> {
133		self.input.headers()
134	}
135}
136
137pub(crate) struct AssertWithoutInputNode {
138	expressions: Vec<Expression>,
139	message: Option<String>,
140	context: Option<Arc<QueryContext>>,
141	done: bool,
142}
143
144impl AssertWithoutInputNode {
145	pub fn new(expressions: Vec<Expression>, message: Option<String>) -> Self {
146		Self {
147			expressions,
148			message,
149			context: None,
150			done: false,
151		}
152	}
153}
154
155impl QueryNode for AssertWithoutInputNode {
156	#[instrument(level = "trace", skip_all, name = "volcano::assert::noinput::initialize")]
157	fn initialize<'a>(&mut self, _rx: &mut Transaction<'a>, ctx: &QueryContext) -> Result<()> {
158		self.context = Some(Arc::new(ctx.clone()));
159		Ok(())
160	}
161
162	#[instrument(level = "trace", skip_all, name = "volcano::assert::noinput::next")]
163	fn next<'a>(&mut self, _rx: &mut Transaction<'a>, _ctx: &mut QueryContext) -> Result<Option<Columns>> {
164		if self.done {
165			return Ok(None);
166		}
167		self.done = true;
168
169		debug_assert!(self.context.is_some(), "AssertWithoutInputNode::next() called before initialize()");
170		let stored_ctx = self.context.as_ref().unwrap();
171		let session = EvalContext::from_query(stored_ctx);
172
173		for assert_expr in &self.expressions {
174			let eval_ctx = session.with_eval_empty();
175
176			let result = evaluate(&eval_ctx, assert_expr)?;
177
178			let frag = assert_expr.full_fragment_owned();
179			let label = display_label(assert_expr);
180			match result.data() {
181				ColumnBuffer::Bool(container) => {
182					let valid = container.is_defined(0);
183					let value = container.data().get(0);
184					if !valid || !value {
185						return Err(EngineError::AssertionFailed {
186							fragment: frag.clone(),
187							message: self.message.clone().unwrap_or_default(),
188							expression: Some(label.text().to_string()),
189						}
190						.into());
191					}
192				}
193				ColumnBuffer::Option {
194					inner,
195					bitvec,
196				} => match inner.as_ref() {
197					ColumnBuffer::Bool(container) => {
198						let defined = !bitvec.is_empty() && bitvec.get(0);
199						let valid = defined && container.is_defined(0);
200						let value = valid && container.data().get(0);
201						if !value {
202							return Err(EngineError::AssertionFailed {
203								fragment: frag.clone(),
204								message: self.message.clone().unwrap_or_default(),
205								expression: Some(label.text().to_string()),
206							}
207							.into());
208						}
209					}
210					_ => {
211						return Err(EngineError::AssertionFailed {
212							fragment: frag.clone(),
213							message: "assert expression must evaluate to a boolean"
214								.to_string(),
215							expression: Some(label.text().to_string()),
216						}
217						.into());
218					}
219				},
220				_ => {
221					return Err(EngineError::AssertionFailed {
222						fragment: frag.clone(),
223						message: "assert expression must evaluate to a boolean".to_string(),
224						expression: Some(label.text().to_string()),
225					}
226					.into());
227				}
228			}
229		}
230
231		Ok(None)
232	}
233
234	fn headers(&self) -> Option<ColumnHeaders> {
235		None
236	}
237}