Skip to main content

reifydb_sdk/testing/
registry.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::{cell::Cell, collections::HashMap, fmt, mem, ptr, slice, str, sync::Mutex};
5
6use postcard::from_bytes as postcard_decode;
7use reifydb_abi::{
8	callbacks::builder::{ColumnBufferHandle, EmitDiffKind},
9	context::context::ContextFFI,
10	data::column::ColumnTypeCode,
11};
12use reifydb_core::{
13	interface::change::{Diff, Diffs},
14	value::column::{ColumnWithName, buffer::ColumnBuffer, columns::Columns},
15};
16use reifydb_type::{
17	fragment::Fragment,
18	util::{bitvec::BitVec, cowvec::CowVec},
19	value::{
20		Value,
21		constraint::{bytes::MaxBytes, precision::Precision, scale::Scale},
22		container::{
23			any::AnyContainer, blob::BlobContainer, bool::BoolContainer, dictionary::DictionaryContainer,
24			identity_id::IdentityIdContainer, number::NumberContainer, temporal::TemporalContainer,
25			utf8::Utf8Container, uuid::UuidContainer,
26		},
27		date::Date,
28		datetime::DateTime,
29		decimal::Decimal,
30		dictionary::DictionaryEntryId,
31		duration::Duration,
32		identity::IdentityId,
33		int::Int,
34		is::IsNumber,
35		row_number::RowNumber,
36		time::Time,
37		uint::Uint,
38		uuid::{Uuid4, Uuid7},
39	},
40};
41use serde::de::DeserializeOwned;
42
43pub struct TestBuilderRegistry {
44	inner: Mutex<RegistryInner>,
45}
46
47struct RegistryInner {
48	slots: HashMap<u64, Slot>,
49	accumulator: Vec<EmittedDiff>,
50	next_id: u64,
51}
52
53enum Slot {
54	Active(Active),
55	Committed(Committed),
56}
57
58pub struct Active {
59	pub type_code: ColumnTypeCode,
60	pub data: Vec<u8>,
61	pub offsets: Option<Vec<u64>>,
62	pub bitvec: Option<Vec<u8>>,
63	pub generation: u64,
64}
65
66pub struct Committed {
67	pub buffer: ColumnBuffer,
68	pub row_count: usize,
69}
70
71pub struct EmittedDiff {
72	pub kind: EmitDiffKind,
73	pub pre: Option<Columns>,
74	pub post: Option<Columns>,
75}
76
77impl Default for TestBuilderRegistry {
78	fn default() -> Self {
79		Self::new()
80	}
81}
82
83impl TestBuilderRegistry {
84	pub fn new() -> Self {
85		Self {
86			inner: Mutex::new(RegistryInner {
87				slots: HashMap::new(),
88				accumulator: Vec::new(),
89				next_id: 1,
90			}),
91		}
92	}
93
94	pub fn drain_diffs(&self) -> Vec<EmittedDiff> {
95		let mut inner = self.inner.lock().unwrap();
96		inner.slots.clear();
97		mem::take(&mut inner.accumulator)
98	}
99}
100
101#[derive(Clone, Copy)]
102struct Handle {
103	id: u64,
104	generation: u64,
105}
106
107impl Handle {
108	fn encode(self) -> *mut ColumnBufferHandle {
109		assert!(self.id != 0 && self.id < (1 << 48));
110		assert!(self.generation < (1 << 16));
111		(self.id | (self.generation << 48)) as *mut ColumnBufferHandle
112	}
113
114	fn decode(ptr: *mut ColumnBufferHandle) -> Self {
115		let packed = ptr as u64;
116		Self {
117			id: packed & ((1 << 48) - 1),
118			generation: packed >> 48,
119		}
120	}
121}
122
123thread_local! {
124	static REGISTRY: Cell<Option<&'static TestBuilderRegistry>> = const { Cell::new(None) };
125}
126
127pub fn with_registry<R>(registry: &TestBuilderRegistry, f: impl FnOnce() -> R) -> R {
128	// SAFETY: we only hold the pointer for the duration of `f`. The
129
130	let extended: &'static TestBuilderRegistry = unsafe { mem::transmute(registry) };
131	let prev = REGISTRY.with(|cell| cell.replace(Some(extended)));
132	let r = f();
133	REGISTRY.with(|cell| cell.set(prev));
134	r
135}
136
137fn current() -> Option<&'static TestBuilderRegistry> {
138	REGISTRY.with(|cell| cell.get())
139}
140
141fn elem_size_for(type_code: ColumnTypeCode) -> usize {
142	match type_code {
143		ColumnTypeCode::Bool => 1,
144		ColumnTypeCode::Float4 | ColumnTypeCode::Int4 | ColumnTypeCode::Uint4 | ColumnTypeCode::Date => 4,
145		ColumnTypeCode::Int1 | ColumnTypeCode::Uint1 => 1,
146		ColumnTypeCode::Int2 | ColumnTypeCode::Uint2 => 2,
147		ColumnTypeCode::Float8
148		| ColumnTypeCode::Int8
149		| ColumnTypeCode::Uint8
150		| ColumnTypeCode::DateTime
151		| ColumnTypeCode::Time => 8,
152		ColumnTypeCode::Int16 | ColumnTypeCode::Uint16 => 16,
153		ColumnTypeCode::Duration
154		| ColumnTypeCode::IdentityId
155		| ColumnTypeCode::Uuid4
156		| ColumnTypeCode::Uuid7
157		| ColumnTypeCode::DictionaryId => 16,
158		ColumnTypeCode::Utf8 | ColumnTypeCode::Blob => 1,
159		ColumnTypeCode::Int | ColumnTypeCode::Uint | ColumnTypeCode::Decimal | ColumnTypeCode::Any => 1,
160		ColumnTypeCode::Undefined => 1,
161	}
162}
163
164fn is_var_len(type_code: ColumnTypeCode) -> bool {
165	matches!(
166		type_code,
167		ColumnTypeCode::Utf8
168			| ColumnTypeCode::Blob
169			| ColumnTypeCode::Int | ColumnTypeCode::Uint
170			| ColumnTypeCode::Decimal
171			| ColumnTypeCode::Any | ColumnTypeCode::DictionaryId
172	)
173}
174
175pub(crate) unsafe extern "C" fn test_acquire(
176	_ctx: *mut ContextFFI,
177	type_code: ColumnTypeCode,
178	capacity: usize,
179) -> *mut ColumnBufferHandle {
180	let Some(registry) = current() else {
181		return ptr::null_mut();
182	};
183	let mut inner = registry.inner.lock().unwrap();
184	let id = inner.next_id;
185	inner.next_id = inner.next_id.checked_add(1).unwrap_or(1);
186
187	let elem = elem_size_for(type_code);
188	let active = Active {
189		type_code,
190		data: Vec::with_capacity(capacity.saturating_mul(elem)),
191		offsets: if is_var_len(type_code) {
192			let mut o = Vec::with_capacity(capacity + 1);
193			o.push(0);
194			Some(o)
195		} else {
196			None
197		},
198		bitvec: None,
199		generation: 1,
200	};
201	inner.slots.insert(id, Slot::Active(active));
202	Handle {
203		id,
204		generation: 1,
205	}
206	.encode()
207}
208
209pub(crate) unsafe extern "C" fn test_data_ptr(handle: *mut ColumnBufferHandle) -> *mut u8 {
210	let Some(registry) = current() else {
211		return ptr::null_mut();
212	};
213	let h = Handle::decode(handle);
214	let mut inner = registry.inner.lock().unwrap();
215	match inner.slots.get_mut(&h.id) {
216		Some(Slot::Active(a)) if a.generation == h.generation => a.data.as_mut_ptr(),
217		_ => ptr::null_mut(),
218	}
219}
220
221pub(crate) unsafe extern "C" fn test_offsets_ptr(handle: *mut ColumnBufferHandle) -> *mut u64 {
222	let Some(registry) = current() else {
223		return ptr::null_mut();
224	};
225	let h = Handle::decode(handle);
226	let mut inner = registry.inner.lock().unwrap();
227	match inner.slots.get_mut(&h.id) {
228		Some(Slot::Active(a)) if a.generation == h.generation => match &mut a.offsets {
229			Some(o) => o.as_mut_ptr(),
230			None => ptr::null_mut(),
231		},
232		_ => ptr::null_mut(),
233	}
234}
235
236pub(crate) unsafe extern "C" fn test_bitvec_ptr(handle: *mut ColumnBufferHandle) -> *mut u8 {
237	let Some(registry) = current() else {
238		return ptr::null_mut();
239	};
240	let h = Handle::decode(handle);
241	let mut inner = registry.inner.lock().unwrap();
242	match inner.slots.get_mut(&h.id) {
243		Some(Slot::Active(a)) if a.generation == h.generation => {
244			if a.bitvec.is_none() {
245				let cap = a.data.capacity() / elem_size_for(a.type_code).max(1);
246				a.bitvec = Some(vec![0u8; cap.div_ceil(8)]);
247			}
248			a.bitvec.as_mut().unwrap().as_mut_ptr()
249		}
250		_ => ptr::null_mut(),
251	}
252}
253
254pub(crate) unsafe extern "C" fn test_grow(handle: *mut ColumnBufferHandle, additional: usize) -> i32 {
255	let Some(registry) = current() else {
256		return -1;
257	};
258	let h = Handle::decode(handle);
259	let mut inner = registry.inner.lock().unwrap();
260	match inner.slots.get_mut(&h.id) {
261		Some(Slot::Active(a)) if a.generation == h.generation => {
262			let elem = elem_size_for(a.type_code);
263			let extra_bytes = additional.saturating_mul(elem);
264			let old_cap = a.data.capacity();
265
266			unsafe { a.data.set_len(old_cap) };
267			a.data.reserve(extra_bytes);
268			unsafe { a.data.set_len(0) };
269			0
270		}
271		_ => -1,
272	}
273}
274
275pub(crate) unsafe extern "C" fn test_commit(handle: *mut ColumnBufferHandle, written_count: usize) -> i32 {
276	let Some(registry) = current() else {
277		return -1;
278	};
279	let h = Handle::decode(handle);
280	let mut inner = registry.inner.lock().unwrap();
281	let slot = match inner.slots.remove(&h.id) {
282		Some(s) => s,
283		None => return -1,
284	};
285	let mut active = match slot {
286		Slot::Active(a) if a.generation == h.generation => a,
287		other => {
288			inner.slots.insert(h.id, other);
289			return -1;
290		}
291	};
292
293	let elem = elem_size_for(active.type_code);
294
295	if let Some(offsets) = active.offsets.as_mut() {
296		let offsets_len = written_count + 1;
297		if offsets_len > offsets.capacity() {
298			return -1;
299		}
300		unsafe {
301			offsets.set_len(offsets_len);
302		}
303	}
304	let data_byte_len = if is_var_len(active.type_code) {
305		match active.offsets.as_ref() {
306			Some(o) if !o.is_empty() => *o.last().unwrap() as usize,
307			_ => 0,
308		}
309	} else {
310		written_count.saturating_mul(elem)
311	};
312	if data_byte_len > active.data.capacity() {
313		return -1;
314	}
315	unsafe {
316		active.data.set_len(data_byte_len);
317	}
318	if let Some(bitvec) = active.bitvec.as_mut() {
319		let needed = written_count.div_ceil(8);
320		if needed > bitvec.capacity() {
321			return -1;
322		}
323		unsafe {
324			bitvec.set_len(needed);
325		}
326	}
327
328	let buffer = match finalize_buffer(active.type_code, active.data, active.offsets, active.bitvec, written_count)
329	{
330		Some(b) => b,
331		None => return -1,
332	};
333	inner.slots.insert(
334		h.id,
335		Slot::Committed(Committed {
336			buffer,
337			row_count: written_count,
338		}),
339	);
340	0
341}
342
343pub(crate) unsafe extern "C" fn test_release(handle: *mut ColumnBufferHandle) {
344	let Some(registry) = current() else {
345		return;
346	};
347	let h = Handle::decode(handle);
348	let mut inner = registry.inner.lock().unwrap();
349	inner.slots.remove(&h.id);
350}
351
352pub(crate) unsafe extern "C" fn test_emit_diff(
353	_ctx: *mut ContextFFI,
354	kind: EmitDiffKind,
355	pre_handles_ptr: *const *mut ColumnBufferHandle,
356	pre_name_ptrs: *const *const u8,
357	pre_name_lens: *const usize,
358	pre_count: usize,
359	pre_row_count: usize,
360	pre_row_numbers_ptr: *const u64,
361	pre_row_numbers_len: usize,
362	post_handles_ptr: *const *mut ColumnBufferHandle,
363	post_name_ptrs: *const *const u8,
364	post_name_lens: *const usize,
365	post_count: usize,
366	post_row_count: usize,
367	post_row_numbers_ptr: *const u64,
368	post_row_numbers_len: usize,
369) -> i32 {
370	let Some(registry) = current() else {
371		return -1;
372	};
373	let mut inner = registry.inner.lock().unwrap();
374	let now = DateTime::default();
375
376	let pre = if pre_count > 0 {
377		let ptrs = ColumnsPtrs {
378			handles: pre_handles_ptr,
379			names: pre_name_ptrs,
380			name_lens: pre_name_lens,
381			count: pre_count,
382		};
383		match assemble(&mut inner, ptrs, pre_row_count, pre_row_numbers_ptr, pre_row_numbers_len, now) {
384			Ok(c) => Some(c),
385			Err(code) => return code,
386		}
387	} else {
388		None
389	};
390	let post = if post_count > 0 {
391		let ptrs = ColumnsPtrs {
392			handles: post_handles_ptr,
393			names: post_name_ptrs,
394			name_lens: post_name_lens,
395			count: post_count,
396		};
397		match assemble(&mut inner, ptrs, post_row_count, post_row_numbers_ptr, post_row_numbers_len, now) {
398			Ok(c) => Some(c),
399			Err(code) => return code,
400		}
401	} else {
402		None
403	};
404
405	inner.accumulator.push(EmittedDiff {
406		kind,
407		pre,
408		post,
409	});
410	0
411}
412
413struct ColumnsPtrs {
414	handles: *const *mut ColumnBufferHandle,
415	names: *const *const u8,
416	name_lens: *const usize,
417	count: usize,
418}
419
420fn assemble(
421	inner: &mut RegistryInner,
422	ptrs: ColumnsPtrs,
423	row_count: usize,
424	row_numbers_ptr: *const u64,
425	row_numbers_len: usize,
426	now: DateTime,
427) -> Result<Columns, i32> {
428	if ptrs.handles.is_null() || ptrs.names.is_null() || ptrs.name_lens.is_null() {
429		return Err(-1);
430	}
431	if row_numbers_len != row_count {
432		return Err(-1);
433	}
434	if row_count > 0 && row_numbers_ptr.is_null() {
435		return Err(-1);
436	}
437	let count = ptrs.count;
438	let handles = unsafe { slice::from_raw_parts(ptrs.handles, count) };
439	let names = unsafe { slice::from_raw_parts(ptrs.names, count) };
440	let lens = unsafe { slice::from_raw_parts(ptrs.name_lens, count) };
441
442	let mut cols: Vec<ColumnWithName> = Vec::with_capacity(count);
443	for i in 0..count {
444		let h = Handle::decode(handles[i]);
445		let slot = inner.slots.remove(&h.id).ok_or(-1)?;
446		let committed = match slot {
447			Slot::Committed(c) => c,
448			Slot::Active(a) => {
449				inner.slots.insert(h.id, Slot::Active(a));
450				return Err(-1);
451			}
452		};
453		let name = if names[i].is_null() || lens[i] == 0 {
454			""
455		} else {
456			let s = unsafe { slice::from_raw_parts(names[i], lens[i]) };
457			str::from_utf8(s).unwrap_or("")
458		};
459		cols.push(ColumnWithName::new(Fragment::internal(name), committed.buffer));
460	}
461	let row_numbers: Vec<RowNumber> = if row_count == 0 {
462		Vec::new()
463	} else {
464		let raw = unsafe { slice::from_raw_parts(row_numbers_ptr, row_count) };
465		raw.iter().copied().map(RowNumber).collect()
466	};
467	let timestamps: Vec<DateTime> = vec![now; row_count];
468	Ok(Columns::with_system_columns(cols, row_numbers, timestamps.clone(), timestamps))
469}
470
471pub(crate) fn finalize_buffer(
472	type_code: ColumnTypeCode,
473	mut data: Vec<u8>,
474	offsets: Option<Vec<u64>>,
475	bitvec: Option<Vec<u8>>,
476	written_count: usize,
477) -> Option<ColumnBuffer> {
478	let make_option_wrapped = |inner: ColumnBuffer| match bitvec {
479		Some(bytes) => {
480			let bv = BitVec::from_raw(bytes, written_count);
481			ColumnBuffer::Option {
482				inner: Box::new(inner),
483				bitvec: bv,
484			}
485		}
486		None => inner,
487	};
488
489	let inner = match type_code {
490		ColumnTypeCode::Bool => {
491			let bv = BitVec::from_raw(data, written_count);
492			ColumnBuffer::Bool(BoolContainer::from_parts(bv))
493		}
494		ColumnTypeCode::Float4 => to_numeric::<f32>(&data, written_count, ColumnBuffer::Float4)?,
495		ColumnTypeCode::Float8 => to_numeric::<f64>(&data, written_count, ColumnBuffer::Float8)?,
496		ColumnTypeCode::Int1 => to_numeric::<i8>(&data, written_count, ColumnBuffer::Int1)?,
497		ColumnTypeCode::Int2 => to_numeric::<i16>(&data, written_count, ColumnBuffer::Int2)?,
498		ColumnTypeCode::Int4 => to_numeric::<i32>(&data, written_count, ColumnBuffer::Int4)?,
499		ColumnTypeCode::Int8 => to_numeric::<i64>(&data, written_count, ColumnBuffer::Int8)?,
500		ColumnTypeCode::Int16 => to_numeric::<i128>(&data, written_count, ColumnBuffer::Int16)?,
501		ColumnTypeCode::Uint1 => to_numeric::<u8>(&data, written_count, ColumnBuffer::Uint1)?,
502		ColumnTypeCode::Uint2 => to_numeric::<u16>(&data, written_count, ColumnBuffer::Uint2)?,
503		ColumnTypeCode::Uint4 => to_numeric::<u32>(&data, written_count, ColumnBuffer::Uint4)?,
504		ColumnTypeCode::Uint8 => to_numeric::<u64>(&data, written_count, ColumnBuffer::Uint8)?,
505		ColumnTypeCode::Uint16 => to_numeric::<u128>(&data, written_count, ColumnBuffer::Uint16)?,
506		ColumnTypeCode::Date => {
507			let v = bytes_to_vec::<Date>(&data, written_count)?;
508			ColumnBuffer::Date(TemporalContainer::from_parts(CowVec::new(v)))
509		}
510		ColumnTypeCode::DateTime => {
511			let v = bytes_to_vec::<DateTime>(&data, written_count)?;
512			ColumnBuffer::DateTime(TemporalContainer::from_parts(CowVec::new(v)))
513		}
514		ColumnTypeCode::Time => {
515			let v = bytes_to_vec::<Time>(&data, written_count)?;
516			ColumnBuffer::Time(TemporalContainer::from_parts(CowVec::new(v)))
517		}
518		ColumnTypeCode::Duration => {
519			let v = bytes_to_vec::<Duration>(&data, written_count)?;
520			ColumnBuffer::Duration(TemporalContainer::from_parts(CowVec::new(v)))
521		}
522		ColumnTypeCode::IdentityId => {
523			let v = bytes_to_vec::<IdentityId>(&data, written_count)?;
524			ColumnBuffer::IdentityId(IdentityIdContainer::from_parts(CowVec::new(v)))
525		}
526		ColumnTypeCode::Uuid4 => {
527			let v = bytes_to_vec::<Uuid4>(&data, written_count)?;
528			ColumnBuffer::Uuid4(UuidContainer::from_parts(CowVec::new(v)))
529		}
530		ColumnTypeCode::Uuid7 => {
531			let v = bytes_to_vec::<Uuid7>(&data, written_count)?;
532			ColumnBuffer::Uuid7(UuidContainer::from_parts(CowVec::new(v)))
533		}
534		ColumnTypeCode::Utf8 => {
535			let offsets = offsets.unwrap_or_else(|| vec![0u64]);
536			let payload_len = *offsets.last().unwrap_or(&0) as usize;
537			data.truncate(payload_len);
538			ColumnBuffer::Utf8 {
539				container: Utf8Container::from_bytes_offsets(data, offsets),
540				max_bytes: MaxBytes::MAX,
541			}
542		}
543		ColumnTypeCode::Blob => {
544			let offsets = offsets.unwrap_or_else(|| vec![0u64]);
545			let payload_len = *offsets.last().unwrap_or(&0) as usize;
546			data.truncate(payload_len);
547			ColumnBuffer::Blob {
548				container: BlobContainer::from_bytes_offsets(data, offsets),
549				max_bytes: MaxBytes::MAX,
550			}
551		}
552		ColumnTypeCode::Int => {
553			let v = postcard_per_element::<Int>(&data, &offsets, written_count)?;
554			ColumnBuffer::Int {
555				container: NumberContainer::from_vec(v),
556				max_bytes: MaxBytes::MAX,
557			}
558		}
559		ColumnTypeCode::Uint => {
560			let v = postcard_per_element::<Uint>(&data, &offsets, written_count)?;
561			ColumnBuffer::Uint {
562				container: NumberContainer::from_vec(v),
563				max_bytes: MaxBytes::MAX,
564			}
565		}
566		ColumnTypeCode::Decimal => {
567			let v = postcard_per_element::<Decimal>(&data, &offsets, written_count)?;
568			ColumnBuffer::Decimal {
569				container: NumberContainer::from_vec(v),
570				precision: Precision::MAX,
571				scale: Scale::MIN,
572			}
573		}
574		ColumnTypeCode::Any => {
575			let values: Vec<Value> = postcard_per_element::<Value>(&data, &offsets, written_count)?;
576			let boxed: Vec<Box<Value>> = values.into_iter().map(Box::new).collect();
577			ColumnBuffer::Any(AnyContainer::from_vec(boxed))
578		}
579		ColumnTypeCode::DictionaryId => {
580			let entries: Vec<DictionaryEntryId> =
581				postcard_per_element::<DictionaryEntryId>(&data, &offsets, written_count)?;
582			ColumnBuffer::DictionaryId(DictionaryContainer::from_vec(entries))
583		}
584		_ => return None,
585	};
586	Some(make_option_wrapped(inner))
587}
588
589fn postcard_per_element<T: DeserializeOwned>(data: &[u8], offsets: &Option<Vec<u64>>, count: usize) -> Option<Vec<T>> {
590	let offsets = offsets.as_ref()?;
591	if offsets.len() < count + 1 {
592		return None;
593	}
594	let mut out: Vec<T> = Vec::with_capacity(count);
595	for i in 0..count {
596		let start = offsets[i] as usize;
597		let end = offsets[i + 1] as usize;
598		if end > data.len() || start > end {
599			return None;
600		}
601		let value: T = postcard_decode(&data[start..end]).ok()?;
602		out.push(value);
603	}
604	Some(out)
605}
606
607fn bytes_to_vec<T: Copy>(data: &[u8], count: usize) -> Option<Vec<T>> {
608	let needed = count.checked_mul(mem::size_of::<T>())?;
609	if data.len() < needed {
610		return None;
611	}
612	let mut v: Vec<T> = Vec::with_capacity(count);
613	unsafe {
614		ptr::copy_nonoverlapping(data.as_ptr() as *const T, v.as_mut_ptr(), count);
615		v.set_len(count);
616	}
617	Some(v)
618}
619
620fn to_numeric<T: Copy + IsNumber + fmt::Debug + Default>(
621	data: &[u8],
622	count: usize,
623	wrap: fn(NumberContainer<T>) -> ColumnBuffer,
624) -> Option<ColumnBuffer> {
625	let v = bytes_to_vec::<T>(data, count)?;
626	Some(wrap(NumberContainer::from_parts(CowVec::new(v))))
627}
628
629pub fn into_diffs(emitted: Vec<EmittedDiff>) -> Diffs {
630	emitted.into_iter()
631		.map(|d| match d.kind {
632			EmitDiffKind::Insert => Diff::insert(d.post.unwrap_or_else(Columns::empty)),
633			EmitDiffKind::Update => Diff::update(
634				d.pre.unwrap_or_else(Columns::empty),
635				d.post.unwrap_or_else(Columns::empty),
636			),
637			EmitDiffKind::Remove => Diff::remove(d.pre.unwrap_or_else(Columns::empty)),
638		})
639		.collect()
640}