Skip to main content

reifydb_routine/procedure/testing/
changed.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_catalog::catalog::Catalog;
5use reifydb_core::{
6	interface::{catalog::shape::ShapeId, change::Diff},
7	internal_error,
8	value::column::{ColumnWithName, buffer::ColumnBuffer, columns::Columns},
9};
10use reifydb_transaction::transaction::Transaction;
11use reifydb_type::{
12	error::Error,
13	params::Params,
14	value::{Value, r#type::Type},
15};
16
17use crate::routine::{Routine, RoutineInfo, context::ProcedureContext, error::RoutineError};
18
19pub struct TestingChanged {
20	pub shape_type: &'static str,
21	info: RoutineInfo,
22}
23
24impl TestingChanged {
25	pub fn new(shape_type: &'static str) -> Self {
26		Self {
27			shape_type,
28			info: RoutineInfo::new(&format!("testing::{}::changed", shape_type)),
29		}
30	}
31}
32
33impl<'a, 'tx> Routine<ProcedureContext<'a, 'tx>> for TestingChanged {
34	fn info(&self) -> &RoutineInfo {
35		&self.info
36	}
37
38	fn return_type(&self, _input_types: &[Type]) -> Type {
39		Type::Any
40	}
41
42	fn execute(&self, ctx: &mut ProcedureContext<'a, 'tx>, _args: &Columns) -> Result<Columns, RoutineError> {
43		let t = match ctx.tx {
44			Transaction::Test(t) => t,
45			_ => {
46				return Err(internal_error!("testing::*::changed() requires a test transaction").into());
47			}
48		};
49
50		let filter_arg = extract_optional_string_param(ctx.params);
51
52		if self.shape_type == "views" {
53			let _ = t.capture_testing_pre_commit();
54		}
55
56		let entries: Vec<_> =
57			t.accumulator_entries_from().iter().map(|(id, diff)| (*id, diff.clone())).collect();
58
59		let mut mutations: Vec<MutationEntry> = Vec::new();
60
61		for (shape_id, diff) in &entries {
62			let type_matches = matches!(
63				(&shape_id, self.shape_type),
64				(ShapeId::Table(_), "tables")
65					| (ShapeId::View(_), "views") | (ShapeId::RingBuffer(_), "ringbuffers")
66					| (ShapeId::Series(_), "series") | (ShapeId::Dictionary(_), "dictionaries")
67			);
68			if !type_matches {
69				continue;
70			}
71
72			let catalog: &Catalog = ctx.catalog;
73			let name = match resolve_shape_name(
74				catalog,
75				&mut Transaction::Test(Box::new(t.reborrow())),
76				shape_id,
77			) {
78				Ok(n) => n,
79				Err(_) => continue,
80			};
81
82			if let Some(filter) = filter_arg.as_deref()
83				&& name != filter
84			{
85				continue;
86			}
87
88			mutations.push(MutationEntry {
89				target: name,
90				diff: diff.clone(),
91			});
92		}
93
94		mutations.sort_by(|a, b| a.target.cmp(&b.target));
95		Ok(build_output_columns(&mutations)?)
96	}
97}
98
99fn extract_optional_string_param(params: &Params) -> Option<String> {
100	match params {
101		Params::Positional(args) if !args.is_empty() => match &args[0] {
102			Value::Utf8(s) => Some(s.clone()),
103			_ => None,
104		},
105		_ => None,
106	}
107}
108
109struct MutationEntry {
110	target: String,
111	diff: Diff,
112}
113
114fn resolve_shape_name(catalog: &Catalog, txn: &mut Transaction<'_>, id: &ShapeId) -> Result<String, Error> {
115	match id {
116		ShapeId::Table(table_id) => {
117			let table = catalog
118				.find_table(txn, *table_id)?
119				.ok_or_else(|| internal_error!("table not found for id {:?}", table_id))?;
120			let ns = catalog
121				.find_namespace(txn, table.namespace)?
122				.ok_or_else(|| internal_error!("namespace not found"))?;
123			Ok(format!("{}::{}", ns.name(), table.name))
124		}
125		ShapeId::View(view_id) => {
126			let view = catalog
127				.find_view(txn, *view_id)?
128				.ok_or_else(|| internal_error!("view not found for id {:?}", view_id))?;
129			let ns = catalog
130				.find_namespace(txn, view.namespace())?
131				.ok_or_else(|| internal_error!("namespace not found"))?;
132			Ok(format!("{}::{}", ns.name(), view.name()))
133		}
134		ShapeId::RingBuffer(rb_id) => {
135			let rb = catalog
136				.find_ringbuffer(txn, *rb_id)?
137				.ok_or_else(|| internal_error!("ringbuffer not found for id {:?}", rb_id))?;
138			let ns = catalog
139				.find_namespace(txn, rb.namespace)?
140				.ok_or_else(|| internal_error!("namespace not found"))?;
141			Ok(format!("{}::{}", ns.name(), rb.name))
142		}
143		ShapeId::Series(series_id) => {
144			let series = catalog
145				.find_series(txn, *series_id)?
146				.ok_or_else(|| internal_error!("series not found for id {:?}", series_id))?;
147			let ns = catalog
148				.find_namespace(txn, series.namespace)?
149				.ok_or_else(|| internal_error!("namespace not found"))?;
150			Ok(format!("{}::{}", ns.name(), series.name))
151		}
152		ShapeId::Dictionary(dict_id) => {
153			let dict = catalog
154				.find_dictionary(txn, *dict_id)?
155				.ok_or_else(|| internal_error!("dictionary not found for id {:?}", dict_id))?;
156			let ns = catalog
157				.find_namespace(txn, dict.namespace)?
158				.ok_or_else(|| internal_error!("namespace not found"))?;
159			Ok(format!("{}::{}", ns.name(), dict.name))
160		}
161		_ => Err(internal_error!("unsupported primitive type {:?}", id)),
162	}
163}
164
165fn build_output_columns(entries: &[MutationEntry]) -> Result<Columns, Error> {
166	if entries.is_empty() {
167		return Ok(Columns::empty());
168	}
169
170	let mut op_data = ColumnBuffer::utf8_with_capacity(entries.len());
171	let mut target_data = ColumnBuffer::utf8_with_capacity(entries.len());
172
173	let mut field_names: Vec<String> = Vec::new();
174	for entry in entries {
175		match &entry.diff {
176			Diff::Insert {
177				post,
178			}
179			| Diff::Remove {
180				pre: post,
181			} => {
182				for col in post.iter() {
183					let name = col.name().text().to_string();
184					if !field_names.contains(&name) {
185						field_names.push(name);
186					}
187				}
188			}
189			Diff::Update {
190				pre,
191				post,
192			} => {
193				for col in pre.iter() {
194					let name = col.name().text().to_string();
195					if !field_names.contains(&name) {
196						field_names.push(name);
197					}
198				}
199				for col in post.iter() {
200					let name = col.name().text().to_string();
201					if !field_names.contains(&name) {
202						field_names.push(name);
203					}
204				}
205			}
206		}
207	}
208
209	let mut old_columns: Vec<Vec<Value>> = vec![Vec::with_capacity(entries.len()); field_names.len()];
210	let mut new_columns: Vec<Vec<Value>> = vec![Vec::with_capacity(entries.len()); field_names.len()];
211
212	for entry in entries {
213		let empty = Columns::empty();
214		let (op, old_cols, new_cols): (&str, &Columns, &Columns) = match &entry.diff {
215			Diff::Insert {
216				post,
217			} => ("insert", &empty, post.as_ref()),
218			Diff::Update {
219				pre,
220				post,
221			} => ("update", pre.as_ref(), post.as_ref()),
222			Diff::Remove {
223				pre,
224			} => ("delete", pre.as_ref(), &empty),
225		};
226
227		let row_count = match &entry.diff {
228			Diff::Insert {
229				post,
230			} => post.row_count(),
231			Diff::Update {
232				post,
233				..
234			} => post.row_count(),
235			Diff::Remove {
236				pre,
237			} => pre.row_count(),
238		};
239
240		for row_idx in 0..row_count {
241			op_data.push(op);
242			target_data.push(entry.target.as_str());
243
244			for (i, field_name) in field_names.iter().enumerate() {
245				let old_val = old_cols
246					.column(field_name)
247					.map(|col| col.data().get_value(row_idx))
248					.unwrap_or(Value::none());
249				old_columns[i].push(old_val);
250
251				let new_val = new_cols
252					.column(field_name)
253					.map(|col| col.data().get_value(row_idx))
254					.unwrap_or(Value::none());
255				new_columns[i].push(new_val);
256			}
257		}
258	}
259
260	let mut columns = vec![ColumnWithName::new("op", op_data), ColumnWithName::new("target", target_data)];
261
262	for (i, name) in field_names.iter().enumerate() {
263		let mut old_data = column_for_values(&old_columns[i]);
264		for val in &old_columns[i] {
265			old_data.push_value(val.clone());
266		}
267		columns.push(ColumnWithName::new(format!("old_{}", name), old_data));
268
269		let mut new_data = column_for_values(&new_columns[i]);
270		for val in &new_columns[i] {
271			new_data.push_value(val.clone());
272		}
273		columns.push(ColumnWithName::new(format!("new_{}", name), new_data));
274	}
275
276	Ok(Columns::new(columns))
277}
278
279fn column_for_values(values: &[Value]) -> ColumnBuffer {
280	let first_type = values.iter().find_map(|v| {
281		if matches!(v, Value::None { .. }) {
282			None
283		} else {
284			Some(v.get_type())
285		}
286	});
287	match first_type {
288		Some(ty) => ColumnBuffer::with_capacity(ty, values.len()),
289		None => ColumnBuffer::none_typed(Type::Boolean, 0),
290	}
291}