Skip to main content

reifydb_sdk/testing/
state.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::{collections::HashMap, fmt::Debug};
5
6use postcard::from_bytes;
7use reifydb_core::encoded::{key::EncodedKey, row::EncodedRow, shape::RowShape};
8use reifydb_type::value::Value;
9use serde::de::DeserializeOwned;
10
11use super::helpers::get_values;
12
13#[derive(Debug, Clone, Default)]
14pub struct TestStateStore {
15	data: HashMap<EncodedKey, EncodedRow>,
16}
17
18impl TestStateStore {
19	pub fn new() -> Self {
20		Self {
21			data: HashMap::new(),
22		}
23	}
24
25	pub fn get(&self, key: &EncodedKey) -> Option<&EncodedRow> {
26		self.data.get(key)
27	}
28
29	pub fn set(&mut self, key: EncodedKey, value: EncodedRow) {
30		self.data.insert(key, value);
31	}
32
33	pub fn remove(&mut self, key: &EncodedKey) -> Option<EncodedRow> {
34		self.data.remove(key)
35	}
36
37	pub fn contains(&self, key: &EncodedKey) -> bool {
38		self.data.contains_key(key)
39	}
40
41	pub fn len(&self) -> usize {
42		self.data.len()
43	}
44
45	pub fn is_empty(&self) -> bool {
46		self.data.is_empty()
47	}
48
49	pub fn clear(&mut self) {
50		self.data.clear();
51	}
52
53	pub fn keys(&self) -> Vec<&EncodedKey> {
54		self.data.keys().collect()
55	}
56
57	pub fn entries(&self) -> Vec<(&EncodedKey, &EncodedRow)> {
58		self.data.iter().collect()
59	}
60
61	pub fn decode_value(&self, key: &EncodedKey, shape: &RowShape) -> Option<Vec<Value>> {
62		self.get(key).map(|encoded| get_values(shape, encoded))
63	}
64
65	pub fn decode_named_value(&self, key: &EncodedKey, shape: &RowShape) -> Option<HashMap<String, Value>> {
66		self.get(key).map(|encoded| {
67			let values = get_values(shape, encoded);
68			shape.field_names().map(|n| n.to_string()).zip(values).collect()
69		})
70	}
71
72	pub fn set_value(&mut self, key: EncodedKey, values: &[Value], shape: &RowShape) {
73		let mut encoded = shape.allocate();
74		shape.set_values(&mut encoded, values);
75		self.set(key, encoded);
76	}
77
78	pub fn set_named_value(&mut self, key: EncodedKey, values: &HashMap<String, Value>, shape: &RowShape) {
79		let mut encoded = shape.allocate();
80
81		let ordered_values: Vec<Value> =
82			shape.field_names().map(|name| values.get(name).cloned().unwrap_or(Value::none())).collect();
83
84		shape.set_values(&mut encoded, &ordered_values);
85		self.set(key, encoded);
86	}
87
88	pub fn snapshot(&self) -> HashMap<EncodedKey, EncodedRow> {
89		self.data.clone()
90	}
91
92	pub fn restore(&mut self, snapshot: HashMap<EncodedKey, EncodedRow>) {
93		self.data = snapshot;
94	}
95
96	pub fn assert_value(&self, key: &EncodedKey, expected: &[Value], shape: &RowShape) {
97		let actual =
98			self.decode_value(key, shape).unwrap_or_else(|| panic!("Key {:?} not found in state", key));
99		assert_eq!(actual, expected, "State value mismatch for key {:?}", key);
100	}
101
102	pub fn decode_typed<T: DeserializeOwned>(&self, key: &EncodedKey) -> Option<T> {
103		let row = self.get(key)?;
104		let shape = RowShape::operator_state();
105		let blob = shape.get_blob(row, 0);
106		from_bytes(blob.as_bytes()).ok()
107	}
108
109	pub fn assert_typed_value<T: DeserializeOwned + PartialEq + Debug>(&self, key: &EncodedKey, expected: &T) {
110		let actual = self.decode_typed::<T>(key).unwrap_or_else(|| panic!("Key {:?} not found in state", key));
111		assert_eq!(&actual, expected, "Typed state value mismatch for key {:?}", key);
112	}
113
114	pub fn assert_exists(&self, key: &EncodedKey) {
115		assert!(self.contains(key), "Expected key {:?} to exist in state", key);
116	}
117
118	pub fn assert_not_exists(&self, key: &EncodedKey) {
119		assert!(!self.contains(key), "Expected key {:?} to not exist in state", key);
120	}
121
122	pub fn assert_count(&self, expected: usize) {
123		assert_eq!(self.len(), expected, "Expected {} entries in state, found {}", expected, self.len());
124	}
125}
126
127#[cfg(test)]
128pub mod tests {
129	use reifydb_core::encoded::{
130		row::EncodedRow,
131		shape::{RowShape, RowShapeField},
132	};
133	use reifydb_type::{util::cowvec::CowVec, value::r#type::Type};
134
135	use super::*;
136	use crate::testing::helpers::encode_key;
137
138	#[test]
139	fn test_state_store_basic_operations() {
140		let mut store = TestStateStore::new();
141		let key = encode_key("test_key");
142		let value = EncodedRow(CowVec::new(vec![1, 2, 3, 4]));
143
144		assert!(store.is_empty());
145
146		store.set(key.clone(), value.clone());
147		assert_eq!(store.get(&key), Some(&value));
148		assert!(store.contains(&key));
149		assert_eq!(store.len(), 1);
150
151		let removed = store.remove(&key);
152		assert_eq!(removed, Some(value));
153		assert!(store.is_empty());
154	}
155
156	#[test]
157	fn test_state_store_with_shape() {
158		let mut store = TestStateStore::new();
159		let shape = RowShape::testing(&[Type::Int8, Type::Utf8]);
160		let key = encode_key("test_key");
161		let values = vec![Value::Int8(42i64), Value::Utf8("hello".into())];
162
163		store.set_value(key.clone(), &values, &shape);
164
165		let decoded = store.decode_value(&key, &shape).unwrap();
166		assert_eq!(decoded, values);
167	}
168
169	#[test]
170	fn test_state_store_with_named_shape() {
171		let mut store = TestStateStore::new();
172		let shape = RowShape::new(vec![
173			RowShapeField::unconstrained("count", Type::Int8),
174			RowShapeField::unconstrained("name", Type::Utf8),
175		]);
176		let key = encode_key("test_key");
177
178		let mut values = HashMap::new();
179		values.insert("count".to_string(), Value::Int8(10i64));
180		values.insert("name".to_string(), Value::Utf8("test".into()));
181
182		store.set_named_value(key.clone(), &values, &shape);
183
184		let decoded = store.decode_named_value(&key, &shape).unwrap();
185		assert_eq!(decoded, values);
186	}
187
188	#[test]
189	fn test_state_store_snapshot_and_restore() {
190		let mut store = TestStateStore::new();
191		let key1 = encode_key("key1");
192		let key2 = encode_key("key2");
193
194		store.set(key1.clone(), EncodedRow(CowVec::new(vec![1])));
195		store.set(key2.clone(), EncodedRow(CowVec::new(vec![2])));
196
197		let snapshot = store.snapshot();
198		assert_eq!(snapshot.len(), 2);
199
200		store.clear();
201		assert!(store.is_empty());
202
203		store.restore(snapshot);
204		assert_eq!(store.len(), 2);
205		assert_eq!(store.get(&key1), Some(&EncodedRow(CowVec::new(vec![1]))));
206		assert_eq!(store.get(&key2), Some(&EncodedRow(CowVec::new(vec![2]))));
207	}
208
209	#[test]
210	fn test_state_store_assertions() {
211		let mut store = TestStateStore::new();
212		let shape = RowShape::testing(&[Type::Int8]);
213		let key = encode_key("test_key");
214		let values = vec![Value::Int8(100i64)];
215
216		store.set_value(key.clone(), &values, &shape);
217
218		store.assert_exists(&key);
219		store.assert_value(&key, &values, &shape);
220		store.assert_count(1);
221
222		let missing_key = encode_key("missing");
223		store.assert_not_exists(&missing_key);
224	}
225}