Skip to main content

reifydb_sdk/operator/
diff.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::{collections::HashMap, thread};
5
6use postcard::to_allocvec;
7use reifydb_abi::data::column::ColumnTypeCode;
8use reifydb_type::value::{Value, decimal::Decimal, row_number::RowNumber, r#type::Type};
9
10use crate::{
11	error::FFIError,
12	operator::{
13		builder::{ColumnBuilder, ColumnsBuilder, CommittedColumn},
14		context::OperatorContext,
15	},
16};
17
18pub struct DiffStart<'a> {
19	inner: ColumnsBuilder<'a>,
20}
21
22impl<'a> DiffStart<'a> {
23	pub(crate) fn new(ctx: &'a mut OperatorContext) -> Self {
24		Self {
25			inner: ColumnsBuilder::new(ctx),
26		}
27	}
28
29	pub fn insert<S, I>(self, row_number: RowNumber, fields: I) -> InsertDiff<'a>
30	where
31		S: Into<String>,
32		I: IntoIterator<Item = (S, Value)>,
33	{
34		let mut diff = InsertDiff {
35			inner: self.inner,
36			schema: Vec::new(),
37			rows: Vec::new(),
38			disarmed: false,
39		};
40		let fields = collect_fields(fields);
41		validate_row_or_panic(&mut diff.schema, &fields, "InsertDiff::insert");
42		diff.rows.push(StagedRow {
43			row_number,
44			fields,
45		});
46		diff
47	}
48
49	pub fn update<S, I, J>(self, row_number: RowNumber, pre: I, post: J) -> UpdateDiff<'a>
50	where
51		S: Into<String>,
52		I: IntoIterator<Item = (S, Value)>,
53		J: IntoIterator<Item = (S, Value)>,
54	{
55		let mut diff = UpdateDiff {
56			inner: self.inner,
57			schema: Vec::new(),
58			rows: Vec::new(),
59			disarmed: false,
60		};
61		let pre = collect_fields(pre);
62		let post = collect_fields(post);
63		validate_row_or_panic(&mut diff.schema, &pre, "UpdateDiff::update pre");
64		validate_row_or_panic(&mut diff.schema, &post, "UpdateDiff::update post");
65		diff.rows.push(UpdateRow {
66			row_number,
67			pre,
68			post,
69		});
70		diff
71	}
72
73	pub fn remove<S, I>(self, row_number: RowNumber, fields: I) -> RemoveDiff<'a>
74	where
75		S: Into<String>,
76		I: IntoIterator<Item = (S, Value)>,
77	{
78		let mut diff = RemoveDiff {
79			inner: self.inner,
80			schema: Vec::new(),
81			rows: Vec::new(),
82			disarmed: false,
83		};
84		let fields = collect_fields(fields);
85		validate_row_or_panic(&mut diff.schema, &fields, "RemoveDiff::remove");
86		diff.rows.push(StagedRow {
87			row_number,
88			fields,
89		});
90		diff
91	}
92}
93
94struct StagedRow {
95	row_number: RowNumber,
96	fields: Vec<(String, Value)>,
97}
98
99struct UpdateRow {
100	row_number: RowNumber,
101	pre: Vec<(String, Value)>,
102	post: Vec<(String, Value)>,
103}
104
105pub struct InsertDiff<'a> {
106	inner: ColumnsBuilder<'a>,
107	schema: Vec<(String, ColumnTypeCode)>,
108	rows: Vec<StagedRow>,
109	disarmed: bool,
110}
111
112impl<'a> InsertDiff<'a> {
113	pub fn insert<S, I>(mut self, row_number: RowNumber, fields: I) -> Self
114	where
115		S: Into<String>,
116		I: IntoIterator<Item = (S, Value)>,
117	{
118		let fields = collect_fields(fields);
119		validate_row_or_panic(&mut self.schema, &fields, "InsertDiff::insert");
120		self.rows.push(StagedRow {
121			row_number,
122			fields,
123		});
124		self
125	}
126
127	pub fn try_finish(mut self) -> Result<(), FFIError> {
128		self.disarmed = true;
129		emit_insert(&mut self.inner, &self.schema, &self.rows)
130	}
131}
132
133impl<'a> Drop for InsertDiff<'a> {
134	fn drop(&mut self) {
135		if self.disarmed {
136			return;
137		}
138		if let Err(e) = emit_insert(&mut self.inner, &self.schema, &self.rows)
139			&& !thread::panicking()
140		{
141			panic!("InsertDiff drop failed: {:?}", e);
142		}
143	}
144}
145
146pub struct UpdateDiff<'a> {
147	inner: ColumnsBuilder<'a>,
148	schema: Vec<(String, ColumnTypeCode)>,
149	rows: Vec<UpdateRow>,
150	disarmed: bool,
151}
152
153impl<'a> UpdateDiff<'a> {
154	pub fn update<S, I, J>(mut self, row_number: RowNumber, pre: I, post: J) -> Self
155	where
156		S: Into<String>,
157		I: IntoIterator<Item = (S, Value)>,
158		J: IntoIterator<Item = (S, Value)>,
159	{
160		let pre = collect_fields(pre);
161		let post = collect_fields(post);
162		validate_row_or_panic(&mut self.schema, &pre, "UpdateDiff::update pre");
163		validate_row_or_panic(&mut self.schema, &post, "UpdateDiff::update post");
164		self.rows.push(UpdateRow {
165			row_number,
166			pre,
167			post,
168		});
169		self
170	}
171
172	pub fn try_finish(mut self) -> Result<(), FFIError> {
173		self.disarmed = true;
174		emit_update(&mut self.inner, &self.schema, &self.rows)
175	}
176}
177
178impl<'a> Drop for UpdateDiff<'a> {
179	fn drop(&mut self) {
180		if self.disarmed {
181			return;
182		}
183		if let Err(e) = emit_update(&mut self.inner, &self.schema, &self.rows)
184			&& !thread::panicking()
185		{
186			panic!("UpdateDiff drop failed: {:?}", e);
187		}
188	}
189}
190
191pub struct RemoveDiff<'a> {
192	inner: ColumnsBuilder<'a>,
193	schema: Vec<(String, ColumnTypeCode)>,
194	rows: Vec<StagedRow>,
195	disarmed: bool,
196}
197
198impl<'a> RemoveDiff<'a> {
199	pub fn remove<S, I>(mut self, row_number: RowNumber, fields: I) -> Self
200	where
201		S: Into<String>,
202		I: IntoIterator<Item = (S, Value)>,
203	{
204		let fields = collect_fields(fields);
205		validate_row_or_panic(&mut self.schema, &fields, "RemoveDiff::remove");
206		self.rows.push(StagedRow {
207			row_number,
208			fields,
209		});
210		self
211	}
212
213	pub fn try_finish(mut self) -> Result<(), FFIError> {
214		self.disarmed = true;
215		emit_remove(&mut self.inner, &self.schema, &self.rows)
216	}
217}
218
219impl<'a> Drop for RemoveDiff<'a> {
220	fn drop(&mut self) {
221		if self.disarmed {
222			return;
223		}
224		if let Err(e) = emit_remove(&mut self.inner, &self.schema, &self.rows)
225			&& !thread::panicking()
226		{
227			panic!("RemoveDiff drop failed: {:?}", e);
228		}
229	}
230}
231
232fn collect_fields<S, I>(fields: I) -> Vec<(String, Value)>
233where
234	S: Into<String>,
235	I: IntoIterator<Item = (S, Value)>,
236{
237	fields.into_iter().map(|(k, v)| (k.into(), v)).collect()
238}
239
240fn validate_row_or_panic(
241	schema: &mut Vec<(String, ColumnTypeCode)>,
242	fields: &[(String, Value)],
243	context: &'static str,
244) {
245	if schema.is_empty() {
246		schema.reserve(fields.len());
247		for (name, value) in fields {
248			let type_code = match value_to_type_code(value) {
249				Some(c) => c,
250				None => panic!("{}: column {:?} has unsupported value type {:?}", context, name, value),
251			};
252			if schema.iter().any(|(n, _)| n == name) {
253				panic!("{}: duplicate column name {:?}", context, name);
254			}
255			schema.push((name.clone(), type_code));
256		}
257		return;
258	}
259
260	if fields.len() != schema.len() {
261		panic!(
262			"{}: row has {} fields, schema expects {} (schema: {:?})",
263			context,
264			fields.len(),
265			schema.len(),
266			schema.iter().map(|(n, _)| n.as_str()).collect::<Vec<_>>()
267		);
268	}
269	let names: HashMap<&str, &Value> = fields.iter().map(|(n, v)| (n.as_str(), v)).collect();
270	if names.len() != fields.len() {
271		panic!("{}: duplicate column name within row", context);
272	}
273	for (name, expected) in schema.iter() {
274		match names.get(name.as_str()) {
275			None => panic!("{}: row missing column {:?}", context, name),
276			Some(value) => {
277				let observed = match value_to_type_code(value) {
278					Some(c) => c,
279					None => panic!(
280						"{}: column {:?} has unsupported value type {:?}",
281						context, name, value
282					),
283				};
284				if observed != *expected && !matches!(value, Value::None { .. }) {
285					panic!(
286						"{}: column {:?} type mismatch (expected {:?}, got {:?})",
287						context, name, expected, observed
288					);
289				}
290			}
291		}
292	}
293}
294
295fn emit_insert(
296	inner: &mut ColumnsBuilder<'_>,
297	schema: &[(String, ColumnTypeCode)],
298	rows: &[StagedRow],
299) -> Result<(), FFIError> {
300	if rows.is_empty() {
301		return Ok(());
302	}
303	let row_count = rows.len();
304	let row_numbers: Vec<RowNumber> = rows.iter().map(|r| r.row_number).collect();
305	let names: Vec<String> = schema.iter().map(|(n, _)| n.clone()).collect();
306	let names_ref: Vec<&str> = names.iter().map(|s| s.as_str()).collect();
307
308	let columns = transpose(schema, &rows.iter().map(|r| &r.fields).collect::<Vec<_>>())?;
309	let mut committed: Vec<CommittedColumn> = Vec::with_capacity(schema.len());
310	for (i, (_, type_code)) in schema.iter().enumerate() {
311		let col = inner.acquire(*type_code, row_count.max(1))?;
312		committed.push(write_column(col, *type_code, &columns[i])?);
313	}
314	inner.emit_insert(&committed, &names_ref, &row_numbers)
315}
316
317fn emit_update(
318	inner: &mut ColumnsBuilder<'_>,
319	schema: &[(String, ColumnTypeCode)],
320	rows: &[UpdateRow],
321) -> Result<(), FFIError> {
322	if rows.is_empty() {
323		return Ok(());
324	}
325	let row_count = rows.len();
326	let row_numbers: Vec<RowNumber> = rows.iter().map(|r| r.row_number).collect();
327	let names: Vec<String> = schema.iter().map(|(n, _)| n.clone()).collect();
328	let names_ref: Vec<&str> = names.iter().map(|s| s.as_str()).collect();
329
330	let pre_cols = transpose(schema, &rows.iter().map(|r| &r.pre).collect::<Vec<_>>())?;
331	let post_cols = transpose(schema, &rows.iter().map(|r| &r.post).collect::<Vec<_>>())?;
332	let mut pre_committed: Vec<CommittedColumn> = Vec::with_capacity(schema.len());
333	let mut post_committed: Vec<CommittedColumn> = Vec::with_capacity(schema.len());
334	for (i, (_, type_code)) in schema.iter().enumerate() {
335		let pre_col = inner.acquire(*type_code, row_count.max(1))?;
336		pre_committed.push(write_column(pre_col, *type_code, &pre_cols[i])?);
337		let post_col = inner.acquire(*type_code, row_count.max(1))?;
338		post_committed.push(write_column(post_col, *type_code, &post_cols[i])?);
339	}
340	inner.emit_update(
341		&pre_committed,
342		&names_ref,
343		row_count,
344		&row_numbers,
345		&post_committed,
346		&names_ref,
347		row_count,
348		&row_numbers,
349	)
350}
351
352fn emit_remove(
353	inner: &mut ColumnsBuilder<'_>,
354	schema: &[(String, ColumnTypeCode)],
355	rows: &[StagedRow],
356) -> Result<(), FFIError> {
357	if rows.is_empty() {
358		return Ok(());
359	}
360	let row_count = rows.len();
361	let row_numbers: Vec<RowNumber> = rows.iter().map(|r| r.row_number).collect();
362	let names: Vec<String> = schema.iter().map(|(n, _)| n.clone()).collect();
363	let names_ref: Vec<&str> = names.iter().map(|s| s.as_str()).collect();
364
365	let columns = transpose(schema, &rows.iter().map(|r| &r.fields).collect::<Vec<_>>())?;
366	let mut committed: Vec<CommittedColumn> = Vec::with_capacity(schema.len());
367	for (i, (_, type_code)) in schema.iter().enumerate() {
368		let col = inner.acquire(*type_code, row_count.max(1))?;
369		committed.push(write_column(col, *type_code, &columns[i])?);
370	}
371	inner.emit_remove(&committed, &names_ref, &row_numbers)
372}
373
374fn transpose(schema: &[(String, ColumnTypeCode)], rows: &[&Vec<(String, Value)>]) -> Result<Vec<Vec<Value>>, FFIError> {
375	let mut columns: Vec<Vec<Value>> = (0..schema.len()).map(|_| Vec::with_capacity(rows.len())).collect();
376	for row in rows {
377		let lookup: HashMap<&str, &Value> = row.iter().map(|(n, v)| (n.as_str(), v)).collect();
378		for (i, (name, _)) in schema.iter().enumerate() {
379			match lookup.get(name.as_str()) {
380				Some(value) => columns[i].push((*value).clone()),
381				None => {
382					return Err(FFIError::InvalidInput(format!(
383						"transpose: row missing column {:?}",
384						name
385					)));
386				}
387			}
388		}
389	}
390	Ok(columns)
391}
392
393fn write_column(
394	col: ColumnBuilder<'_>,
395	type_code: ColumnTypeCode,
396	values: &[Value],
397) -> Result<CommittedColumn, FFIError> {
398	let defined: Vec<bool> = values.iter().map(|v| !matches!(v, Value::None { .. })).collect();
399	let has_nulls = defined.iter().any(|d| !*d);
400	if has_nulls {
401		col.set_defined(&defined);
402	}
403	match type_code {
404		ColumnTypeCode::Bool => {
405			let buf: Vec<bool> = values.iter().map(value_to_bool).collect::<Result<_, _>>()?;
406			col.write_bool(&buf)
407		}
408		ColumnTypeCode::Uint1 => {
409			let buf: Vec<u8> = values.iter().map(value_to_u8).collect::<Result<_, _>>()?;
410			col.write_u8(&buf)
411		}
412		ColumnTypeCode::Uint2 => {
413			let buf: Vec<u16> = values.iter().map(value_to_u16).collect::<Result<_, _>>()?;
414			col.write_u16(&buf)
415		}
416		ColumnTypeCode::Uint4 => {
417			let buf: Vec<u32> = values.iter().map(value_to_u32).collect::<Result<_, _>>()?;
418			col.write_u32(&buf)
419		}
420		ColumnTypeCode::Uint8 => {
421			let buf: Vec<u64> = values.iter().map(value_to_u64).collect::<Result<_, _>>()?;
422			col.write_u64(&buf)
423		}
424		ColumnTypeCode::Uint16 => {
425			let buf: Vec<u128> = values.iter().map(value_to_u128).collect::<Result<_, _>>()?;
426			col.write_u128(&buf)
427		}
428		ColumnTypeCode::Int1 => {
429			let buf: Vec<i8> = values.iter().map(value_to_i8).collect::<Result<_, _>>()?;
430			col.write_i8(&buf)
431		}
432		ColumnTypeCode::Int2 => {
433			let buf: Vec<i16> = values.iter().map(value_to_i16).collect::<Result<_, _>>()?;
434			col.write_i16(&buf)
435		}
436		ColumnTypeCode::Int4 => {
437			let buf: Vec<i32> = values.iter().map(value_to_i32).collect::<Result<_, _>>()?;
438			col.write_i32(&buf)
439		}
440		ColumnTypeCode::Int8 => {
441			let buf: Vec<i64> = values.iter().map(value_to_i64).collect::<Result<_, _>>()?;
442			col.write_i64(&buf)
443		}
444		ColumnTypeCode::Int16 => {
445			let buf: Vec<i128> = values.iter().map(value_to_i128).collect::<Result<_, _>>()?;
446			col.write_i128(&buf)
447		}
448		ColumnTypeCode::Float4 => {
449			let buf: Vec<f32> = values.iter().map(value_to_f32).collect::<Result<_, _>>()?;
450			col.write_f32(&buf)
451		}
452		ColumnTypeCode::Float8 => {
453			let buf: Vec<f64> = values.iter().map(value_to_f64).collect::<Result<_, _>>()?;
454			col.write_f64(&buf)
455		}
456		ColumnTypeCode::Utf8 => {
457			let buf: Vec<String> = values.iter().map(value_to_utf8).collect::<Result<_, _>>()?;
458			col.write_utf8(&buf)
459		}
460		ColumnTypeCode::Blob => {
461			let buf: Vec<Vec<u8>> = values.iter().map(value_to_blob).collect::<Result<_, _>>()?;
462			col.write_blob(&buf)
463		}
464		ColumnTypeCode::Decimal => write_decimal_column(col, values),
465		other => Err(FFIError::NotImplemented(format!("emit: unsupported column type {:?}", other))),
466	}
467}
468
469fn write_decimal_column(col: ColumnBuilder<'_>, values: &[Value]) -> Result<CommittedColumn, FFIError> {
470	let mut serialized: Vec<Vec<u8>> = Vec::with_capacity(values.len());
471	for v in values {
472		let dec: Decimal = match v {
473			Value::Decimal(d) => d.clone(),
474			Value::Float4(f) => Decimal::from(f64::from(f32::from(*f))),
475			Value::Float8(f) => Decimal::from(f64::from(*f)),
476			Value::None {
477				..
478			} => Decimal::from_i64(0),
479			_ => {
480				return Err(FFIError::InvalidInput(format!(
481					"emit decimal: expected Decimal, got {:?}",
482					v
483				)));
484			}
485		};
486		let bytes =
487			to_allocvec(&dec).map_err(|e| FFIError::Serialization(format!("decimal serialize: {}", e)))?;
488		serialized.push(bytes);
489	}
490	col.write_blob(&serialized)
491}
492
493fn value_to_type_code(value: &Value) -> Option<ColumnTypeCode> {
494	let code = match value {
495		Value::Boolean(_) => ColumnTypeCode::Bool,
496		Value::Float4(_) => ColumnTypeCode::Float4,
497		Value::Float8(_) => ColumnTypeCode::Float8,
498		Value::Int1(_) => ColumnTypeCode::Int1,
499		Value::Int2(_) => ColumnTypeCode::Int2,
500		Value::Int4(_) => ColumnTypeCode::Int4,
501		Value::Int8(_) => ColumnTypeCode::Int8,
502		Value::Int16(_) => ColumnTypeCode::Int16,
503		Value::Uint1(_) => ColumnTypeCode::Uint1,
504		Value::Uint2(_) => ColumnTypeCode::Uint2,
505		Value::Uint4(_) => ColumnTypeCode::Uint4,
506		Value::Uint8(_) => ColumnTypeCode::Uint8,
507		Value::Uint16(_) => ColumnTypeCode::Uint16,
508		Value::Utf8(_) => ColumnTypeCode::Utf8,
509		Value::Decimal(_) => ColumnTypeCode::Decimal,
510		Value::Blob(_) => ColumnTypeCode::Blob,
511		Value::None {
512			inner,
513		} => return type_to_column_code(inner.clone()),
514		_ => return None,
515	};
516	Some(code)
517}
518
519fn type_to_column_code(ty: Type) -> Option<ColumnTypeCode> {
520	let code = match ty {
521		Type::Boolean => ColumnTypeCode::Bool,
522		Type::Float4 => ColumnTypeCode::Float4,
523		Type::Float8 => ColumnTypeCode::Float8,
524		Type::Int1 => ColumnTypeCode::Int1,
525		Type::Int2 => ColumnTypeCode::Int2,
526		Type::Int4 => ColumnTypeCode::Int4,
527		Type::Int8 => ColumnTypeCode::Int8,
528		Type::Int16 => ColumnTypeCode::Int16,
529		Type::Uint1 => ColumnTypeCode::Uint1,
530		Type::Uint2 => ColumnTypeCode::Uint2,
531		Type::Uint4 => ColumnTypeCode::Uint4,
532		Type::Uint8 => ColumnTypeCode::Uint8,
533		Type::Uint16 => ColumnTypeCode::Uint16,
534		Type::Utf8 => ColumnTypeCode::Utf8,
535		Type::Decimal => ColumnTypeCode::Decimal,
536		Type::Blob => ColumnTypeCode::Blob,
537		_ => return Option::None,
538	};
539	Some(code)
540}
541
542fn type_mismatch_err(name: &str, value: &Value) -> FFIError {
543	FFIError::InvalidInput(format!("emit: column {} type mismatch (got {:?})", name, value))
544}
545
546fn value_to_bool(v: &Value) -> Result<bool, FFIError> {
547	match v {
548		Value::Boolean(b) => Ok(*b),
549		Value::None {
550			..
551		} => Ok(false),
552		_ => Err(type_mismatch_err("bool", v)),
553	}
554}
555
556macro_rules! value_to_int {
557	($name:ident, $variant:ident, $ty:ty) => {
558		fn $name(v: &Value) -> Result<$ty, FFIError> {
559			match v {
560				Value::$variant(x) => Ok(*x),
561				Value::None {
562					..
563				} => Ok(<$ty as Default>::default()),
564				_ => Err(type_mismatch_err(stringify!($variant), v)),
565			}
566		}
567	};
568}
569
570value_to_int!(value_to_u8, Uint1, u8);
571value_to_int!(value_to_u16, Uint2, u16);
572value_to_int!(value_to_u32, Uint4, u32);
573value_to_int!(value_to_u64, Uint8, u64);
574value_to_int!(value_to_u128, Uint16, u128);
575value_to_int!(value_to_i8, Int1, i8);
576value_to_int!(value_to_i16, Int2, i16);
577value_to_int!(value_to_i32, Int4, i32);
578value_to_int!(value_to_i64, Int8, i64);
579value_to_int!(value_to_i128, Int16, i128);
580
581fn value_to_f32(v: &Value) -> Result<f32, FFIError> {
582	match v {
583		Value::Float4(f) => Ok(f32::from(*f)),
584		Value::None {
585			..
586		} => Ok(0.0),
587		_ => Err(type_mismatch_err("Float4", v)),
588	}
589}
590
591fn value_to_f64(v: &Value) -> Result<f64, FFIError> {
592	match v {
593		Value::Float8(f) => Ok(f64::from(*f)),
594		Value::None {
595			..
596		} => Ok(0.0),
597		_ => Err(type_mismatch_err("Float8", v)),
598	}
599}
600
601fn value_to_utf8(v: &Value) -> Result<String, FFIError> {
602	match v {
603		Value::Utf8(s) => Ok(s.clone()),
604		Value::None {
605			..
606		} => Ok(String::new()),
607		_ => Err(type_mismatch_err("Utf8", v)),
608	}
609}
610
611fn value_to_blob(v: &Value) -> Result<Vec<u8>, FFIError> {
612	match v {
613		Value::Blob(b) => Ok(b.as_ref().to_vec()),
614		Value::None {
615			..
616		} => Ok(Vec::new()),
617		_ => Err(type_mismatch_err("Blob", v)),
618	}
619}