Skip to main content

reifydb_routine/procedure/testing/
changed.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_catalog::catalog::Catalog;
5use reifydb_core::{
6	interface::{catalog::shape::ShapeId, change::Diff},
7	internal_error,
8	value::column::{ColumnWithName, buffer::ColumnBuffer, columns::Columns},
9};
10use reifydb_transaction::transaction::Transaction;
11use reifydb_type::{
12	error::Error,
13	params::Params,
14	value::{Value, r#type::Type},
15};
16
17use crate::routine::{Routine, RoutineInfo, context::ProcedureContext, error::RoutineError};
18
19pub struct TestingChanged {
20	pub shape_type: &'static str,
21	info: RoutineInfo,
22}
23
24impl TestingChanged {
25	pub fn new(shape_type: &'static str) -> Self {
26		Self {
27			shape_type,
28			info: RoutineInfo::new(&format!("testing::{}::changed", shape_type)),
29		}
30	}
31}
32
33impl<'a, 'tx> Routine<ProcedureContext<'a, 'tx>> for TestingChanged {
34	fn info(&self) -> &RoutineInfo {
35		&self.info
36	}
37
38	fn return_type(&self, _input_types: &[Type]) -> Type {
39		Type::Any
40	}
41
42	fn execute(&self, ctx: &mut ProcedureContext<'a, 'tx>, _args: &Columns) -> Result<Columns, RoutineError> {
43		let t = match ctx.tx {
44			Transaction::Test(t) => t,
45			_ => {
46				return Err(internal_error!("testing::*::changed() requires a test transaction").into());
47			}
48		};
49
50		let filter_arg = extract_optional_string_param(ctx.params);
51
52		if self.shape_type == "views" {
53			let _ = t.capture_testing_pre_commit();
54		}
55
56		let entries: Vec<_> =
57			t.accumulator_entries_from().iter().map(|(id, diff)| (*id, diff.clone())).collect();
58
59		let mut mutations: Vec<MutationEntry> = Vec::new();
60
61		for (shape_id, diff) in &entries {
62			let type_matches = matches!(
63				(&shape_id, self.shape_type),
64				(ShapeId::Table(_), "tables")
65					| (ShapeId::View(_), "views") | (ShapeId::RingBuffer(_), "ringbuffers")
66					| (ShapeId::Series(_), "series") | (ShapeId::Dictionary(_), "dictionaries")
67			);
68			if !type_matches {
69				continue;
70			}
71
72			let catalog: &Catalog = ctx.catalog;
73			let name = match resolve_shape_name(
74				catalog,
75				&mut Transaction::Test(Box::new(t.reborrow())),
76				shape_id,
77			) {
78				Ok(n) => n,
79				Err(_) => continue,
80			};
81
82			if let Some(filter) = filter_arg.as_deref()
83				&& name != filter
84			{
85				continue;
86			}
87
88			mutations.push(MutationEntry {
89				target: name,
90				diff: diff.clone(),
91			});
92		}
93
94		mutations.sort_by(|a, b| a.target.cmp(&b.target));
95		Ok(build_output_columns(&mutations)?)
96	}
97}
98
99fn extract_optional_string_param(params: &Params) -> Option<String> {
100	match params {
101		Params::Positional(args) if !args.is_empty() => match &args[0] {
102			Value::Utf8(s) => Some(s.clone()),
103			_ => None,
104		},
105		_ => None,
106	}
107}
108
109struct MutationEntry {
110	target: String,
111	diff: Diff,
112}
113
114fn resolve_shape_name(catalog: &Catalog, txn: &mut Transaction<'_>, id: &ShapeId) -> Result<String, Error> {
115	match id {
116		ShapeId::Table(table_id) => {
117			let table = catalog
118				.find_table(txn, *table_id)?
119				.ok_or_else(|| internal_error!("table not found for id {:?}", table_id))?;
120			let ns = catalog
121				.find_namespace(txn, table.namespace)?
122				.ok_or_else(|| internal_error!("namespace not found"))?;
123			Ok(format!("{}::{}", ns.name(), table.name))
124		}
125		ShapeId::View(view_id) => {
126			let view = catalog
127				.find_view(txn, *view_id)?
128				.ok_or_else(|| internal_error!("view not found for id {:?}", view_id))?;
129			let ns = catalog
130				.find_namespace(txn, view.namespace())?
131				.ok_or_else(|| internal_error!("namespace not found"))?;
132			Ok(format!("{}::{}", ns.name(), view.name()))
133		}
134		ShapeId::RingBuffer(rb_id) => {
135			let rb = catalog
136				.find_ringbuffer(txn, *rb_id)?
137				.ok_or_else(|| internal_error!("ringbuffer not found for id {:?}", rb_id))?;
138			let ns = catalog
139				.find_namespace(txn, rb.namespace)?
140				.ok_or_else(|| internal_error!("namespace not found"))?;
141			Ok(format!("{}::{}", ns.name(), rb.name))
142		}
143		ShapeId::Series(series_id) => {
144			let series = catalog
145				.find_series(txn, *series_id)?
146				.ok_or_else(|| internal_error!("series not found for id {:?}", series_id))?;
147			let ns = catalog
148				.find_namespace(txn, series.namespace)?
149				.ok_or_else(|| internal_error!("namespace not found"))?;
150			Ok(format!("{}::{}", ns.name(), series.name))
151		}
152		ShapeId::Dictionary(dict_id) => {
153			let dict = catalog
154				.find_dictionary(txn, *dict_id)?
155				.ok_or_else(|| internal_error!("dictionary not found for id {:?}", dict_id))?;
156			let ns = catalog
157				.find_namespace(txn, dict.namespace)?
158				.ok_or_else(|| internal_error!("namespace not found"))?;
159			Ok(format!("{}::{}", ns.name(), dict.name))
160		}
161		_ => Err(internal_error!("unsupported primitive type {:?}", id)),
162	}
163}
164
165fn build_output_columns(entries: &[MutationEntry]) -> Result<Columns, Error> {
166	if entries.is_empty() {
167		return Ok(Columns::empty());
168	}
169
170	let mut op_data = ColumnBuffer::utf8_with_capacity(entries.len());
171	let mut target_data = ColumnBuffer::utf8_with_capacity(entries.len());
172
173	let mut field_names: Vec<String> = Vec::new();
174	for entry in entries {
175		match &entry.diff {
176			Diff::Insert {
177				post,
178				..
179			}
180			| Diff::Remove {
181				pre: post,
182				..
183			} => {
184				for col in post.iter() {
185					let name = col.name().text().to_string();
186					if !field_names.contains(&name) {
187						field_names.push(name);
188					}
189				}
190			}
191			Diff::Update {
192				pre,
193				post,
194				..
195			} => {
196				for col in pre.iter() {
197					let name = col.name().text().to_string();
198					if !field_names.contains(&name) {
199						field_names.push(name);
200					}
201				}
202				for col in post.iter() {
203					let name = col.name().text().to_string();
204					if !field_names.contains(&name) {
205						field_names.push(name);
206					}
207				}
208			}
209		}
210	}
211
212	let mut old_columns: Vec<Vec<Value>> = vec![Vec::with_capacity(entries.len()); field_names.len()];
213	let mut new_columns: Vec<Vec<Value>> = vec![Vec::with_capacity(entries.len()); field_names.len()];
214
215	for entry in entries {
216		let empty = Columns::empty();
217		let (op, old_cols, new_cols): (&str, &Columns, &Columns) = match &entry.diff {
218			Diff::Insert {
219				post,
220				..
221			} => ("insert", &empty, post.as_ref()),
222			Diff::Update {
223				pre,
224				post,
225				..
226			} => ("update", pre.as_ref(), post.as_ref()),
227			Diff::Remove {
228				pre,
229				..
230			} => ("delete", pre.as_ref(), &empty),
231		};
232
233		let row_count = match &entry.diff {
234			Diff::Insert {
235				post,
236				..
237			} => post.row_count(),
238			Diff::Update {
239				post,
240				..
241			} => post.row_count(),
242			Diff::Remove {
243				pre,
244				..
245			} => pre.row_count(),
246		};
247
248		for row_idx in 0..row_count {
249			op_data.push(op);
250			target_data.push(entry.target.as_str());
251
252			for (i, field_name) in field_names.iter().enumerate() {
253				let old_val = old_cols
254					.column(field_name)
255					.map(|col| col.data().get_value(row_idx))
256					.unwrap_or(Value::none());
257				old_columns[i].push(old_val);
258
259				let new_val = new_cols
260					.column(field_name)
261					.map(|col| col.data().get_value(row_idx))
262					.unwrap_or(Value::none());
263				new_columns[i].push(new_val);
264			}
265		}
266	}
267
268	let mut columns = vec![ColumnWithName::new("op", op_data), ColumnWithName::new("target", target_data)];
269
270	for (i, name) in field_names.iter().enumerate() {
271		let mut old_data = column_for_values(&old_columns[i]);
272		for val in &old_columns[i] {
273			old_data.push_value(val.clone());
274		}
275		columns.push(ColumnWithName::new(format!("old_{}", name), old_data));
276
277		let mut new_data = column_for_values(&new_columns[i]);
278		for val in &new_columns[i] {
279			new_data.push_value(val.clone());
280		}
281		columns.push(ColumnWithName::new(format!("new_{}", name), new_data));
282	}
283
284	Ok(Columns::new(columns))
285}
286
287fn column_for_values(values: &[Value]) -> ColumnBuffer {
288	let first_type = values.iter().find_map(|v| {
289		if matches!(v, Value::None { .. }) {
290			None
291		} else {
292			Some(v.get_type())
293		}
294	});
295	match first_type {
296		Some(ty) => ColumnBuffer::with_capacity(ty, values.len()),
297		None => ColumnBuffer::none_typed(Type::Boolean, 0),
298	}
299}