Skip to main content

reifydb_routine/procedure/testing/
changed.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_catalog::catalog::Catalog;
5use reifydb_core::{
6	interface::{catalog::shape::ShapeId, change::Diff},
7	internal_error,
8	value::column::{ColumnWithName, buffer::ColumnBuffer, columns::Columns},
9};
10use reifydb_transaction::transaction::Transaction;
11use reifydb_type::{
12	error::Error,
13	params::Params,
14	value::{Value, r#type::Type},
15};
16
17use crate::routine::{Routine, RoutineInfo, context::ProcedureContext, error::RoutineError};
18
19/// Identifies the primitive type category for a `testing::*::changed()` procedure.
20pub struct TestingChanged {
21	pub shape_type: &'static str,
22	info: RoutineInfo,
23}
24
25impl TestingChanged {
26	pub fn new(shape_type: &'static str) -> Self {
27		Self {
28			shape_type,
29			info: RoutineInfo::new(&format!("testing::{}::changed", shape_type)),
30		}
31	}
32}
33
34impl<'a, 'tx> Routine<ProcedureContext<'a, 'tx>> for TestingChanged {
35	fn info(&self) -> &RoutineInfo {
36		&self.info
37	}
38
39	fn return_type(&self, _input_types: &[Type]) -> Type {
40		Type::Any
41	}
42
43	fn execute(&self, ctx: &mut ProcedureContext<'a, 'tx>, _args: &Columns) -> Result<Columns, RoutineError> {
44		let t = match ctx.tx {
45			Transaction::Test(t) => t,
46			_ => {
47				return Err(internal_error!("testing::*::changed() requires a test transaction").into());
48			}
49		};
50
51		let filter_arg = extract_optional_string_param(ctx.params);
52
53		// Materialize view rows from pending source changes so that
54		// changed() sees transactional view mutations.
55		if self.shape_type == "views" {
56			let _ = t.capture_testing_pre_commit();
57		}
58
59		// Read individual mutations from the accumulator
60		let entries: Vec<_> =
61			t.accumulator_entries_from().iter().map(|(id, diff)| (*id, diff.clone())).collect();
62
63		let mut mutations: Vec<MutationEntry> = Vec::new();
64
65		for (shape_id, diff) in &entries {
66			let type_matches = matches!(
67				(&shape_id, self.shape_type),
68				(ShapeId::Table(_), "tables")
69					| (ShapeId::View(_), "views") | (ShapeId::RingBuffer(_), "ringbuffers")
70					| (ShapeId::Series(_), "series") | (ShapeId::Dictionary(_), "dictionaries")
71			);
72			if !type_matches {
73				continue;
74			}
75
76			let catalog: &Catalog = ctx.catalog;
77			let name = match resolve_shape_name(
78				catalog,
79				&mut Transaction::Test(Box::new(t.reborrow())),
80				shape_id,
81			) {
82				Ok(n) => n,
83				Err(_) => continue,
84			};
85
86			if let Some(filter) = filter_arg.as_deref()
87				&& name != filter
88			{
89				continue;
90			}
91
92			mutations.push(MutationEntry {
93				target: name,
94				diff: diff.clone(),
95			});
96		}
97
98		mutations.sort_by(|a, b| a.target.cmp(&b.target));
99		Ok(build_output_columns(&mutations)?)
100	}
101}
102
103fn extract_optional_string_param(params: &Params) -> Option<String> {
104	match params {
105		Params::Positional(args) if !args.is_empty() => match &args[0] {
106			Value::Utf8(s) => Some(s.clone()),
107			_ => None,
108		},
109		_ => None,
110	}
111}
112
113struct MutationEntry {
114	target: String,
115	diff: Diff,
116}
117
118fn resolve_shape_name(catalog: &Catalog, txn: &mut Transaction<'_>, id: &ShapeId) -> Result<String, Error> {
119	match id {
120		ShapeId::Table(table_id) => {
121			let table = catalog
122				.find_table(txn, *table_id)?
123				.ok_or_else(|| internal_error!("table not found for id {:?}", table_id))?;
124			let ns = catalog
125				.find_namespace(txn, table.namespace)?
126				.ok_or_else(|| internal_error!("namespace not found"))?;
127			Ok(format!("{}::{}", ns.name(), table.name))
128		}
129		ShapeId::View(view_id) => {
130			let view = catalog
131				.find_view(txn, *view_id)?
132				.ok_or_else(|| internal_error!("view not found for id {:?}", view_id))?;
133			let ns = catalog
134				.find_namespace(txn, view.namespace())?
135				.ok_or_else(|| internal_error!("namespace not found"))?;
136			Ok(format!("{}::{}", ns.name(), view.name()))
137		}
138		ShapeId::RingBuffer(rb_id) => {
139			let rb = catalog
140				.find_ringbuffer(txn, *rb_id)?
141				.ok_or_else(|| internal_error!("ringbuffer not found for id {:?}", rb_id))?;
142			let ns = catalog
143				.find_namespace(txn, rb.namespace)?
144				.ok_or_else(|| internal_error!("namespace not found"))?;
145			Ok(format!("{}::{}", ns.name(), rb.name))
146		}
147		ShapeId::Series(series_id) => {
148			let series = catalog
149				.find_series(txn, *series_id)?
150				.ok_or_else(|| internal_error!("series not found for id {:?}", series_id))?;
151			let ns = catalog
152				.find_namespace(txn, series.namespace)?
153				.ok_or_else(|| internal_error!("namespace not found"))?;
154			Ok(format!("{}::{}", ns.name(), series.name))
155		}
156		ShapeId::Dictionary(dict_id) => {
157			let dict = catalog
158				.find_dictionary(txn, *dict_id)?
159				.ok_or_else(|| internal_error!("dictionary not found for id {:?}", dict_id))?;
160			let ns = catalog
161				.find_namespace(txn, dict.namespace)?
162				.ok_or_else(|| internal_error!("namespace not found"))?;
163			Ok(format!("{}::{}", ns.name(), dict.name))
164		}
165		_ => Err(internal_error!("unsupported primitive type {:?}", id)),
166	}
167}
168
169fn build_output_columns(entries: &[MutationEntry]) -> Result<Columns, Error> {
170	if entries.is_empty() {
171		return Ok(Columns::empty());
172	}
173
174	let mut op_data = ColumnBuffer::utf8_with_capacity(entries.len());
175	let mut target_data = ColumnBuffer::utf8_with_capacity(entries.len());
176
177	let mut field_names: Vec<String> = Vec::new();
178	for entry in entries {
179		match &entry.diff {
180			Diff::Insert {
181				post,
182			}
183			| Diff::Remove {
184				pre: post,
185			} => {
186				for col in post.iter() {
187					let name = col.name().text().to_string();
188					if !field_names.contains(&name) {
189						field_names.push(name);
190					}
191				}
192			}
193			Diff::Update {
194				pre,
195				post,
196			} => {
197				for col in pre.iter() {
198					let name = col.name().text().to_string();
199					if !field_names.contains(&name) {
200						field_names.push(name);
201					}
202				}
203				for col in post.iter() {
204					let name = col.name().text().to_string();
205					if !field_names.contains(&name) {
206						field_names.push(name);
207					}
208				}
209			}
210		}
211	}
212
213	let mut old_columns: Vec<Vec<Value>> = vec![Vec::with_capacity(entries.len()); field_names.len()];
214	let mut new_columns: Vec<Vec<Value>> = vec![Vec::with_capacity(entries.len()); field_names.len()];
215
216	for entry in entries {
217		let empty = Columns::empty();
218		let (op, old_cols, new_cols): (&str, &Columns, &Columns) = match &entry.diff {
219			Diff::Insert {
220				post,
221			} => ("insert", &empty, post.as_ref()),
222			Diff::Update {
223				pre,
224				post,
225			} => ("update", pre.as_ref(), post.as_ref()),
226			Diff::Remove {
227				pre,
228			} => ("delete", pre.as_ref(), &empty),
229		};
230
231		op_data.push(op);
232		target_data.push(entry.target.as_str());
233
234		for (i, field_name) in field_names.iter().enumerate() {
235			let old_val =
236				old_cols.column(field_name).map(|col| col.data().get_value(0)).unwrap_or(Value::none());
237			old_columns[i].push(old_val);
238
239			let new_val =
240				new_cols.column(field_name).map(|col| col.data().get_value(0)).unwrap_or(Value::none());
241			new_columns[i].push(new_val);
242		}
243	}
244
245	let mut columns = vec![ColumnWithName::new("op", op_data), ColumnWithName::new("target", target_data)];
246
247	for (i, name) in field_names.iter().enumerate() {
248		let mut old_data = column_for_values(&old_columns[i]);
249		for val in &old_columns[i] {
250			old_data.push_value(val.clone());
251		}
252		columns.push(ColumnWithName::new(format!("old_{}", name), old_data));
253
254		let mut new_data = column_for_values(&new_columns[i]);
255		for val in &new_columns[i] {
256			new_data.push_value(val.clone());
257		}
258		columns.push(ColumnWithName::new(format!("new_{}", name), new_data));
259	}
260
261	Ok(Columns::new(columns))
262}
263
264fn column_for_values(values: &[Value]) -> ColumnBuffer {
265	let first_type = values.iter().find_map(|v| {
266		if matches!(v, Value::None { .. }) {
267			None
268		} else {
269			Some(v.get_type())
270		}
271	});
272	match first_type {
273		Some(ty) => ColumnBuffer::with_capacity(ty, values.len()),
274		None => ColumnBuffer::none_typed(Type::Boolean, 0),
275	}
276}