1use std::{collections::HashMap, fmt::Debug};
5
6use postcard::from_bytes;
7use reifydb_core::encoded::{key::EncodedKey, row::EncodedRow, shape::RowShape};
8use reifydb_type::value::Value;
9use serde::de::DeserializeOwned;
10
11use super::helpers::get_values;
12
13#[derive(Debug, Clone, Default)]
14pub struct TestStateStore {
15 data: HashMap<EncodedKey, EncodedRow>,
16}
17
18impl TestStateStore {
19 pub fn new() -> Self {
20 Self {
21 data: HashMap::new(),
22 }
23 }
24
25 pub fn get(&self, key: &EncodedKey) -> Option<&EncodedRow> {
26 self.data.get(key)
27 }
28
29 pub fn set(&mut self, key: EncodedKey, value: EncodedRow) {
30 self.data.insert(key, value);
31 }
32
33 pub fn remove(&mut self, key: &EncodedKey) -> Option<EncodedRow> {
34 self.data.remove(key)
35 }
36
37 pub fn contains(&self, key: &EncodedKey) -> bool {
38 self.data.contains_key(key)
39 }
40
41 pub fn len(&self) -> usize {
42 self.data.len()
43 }
44
45 pub fn is_empty(&self) -> bool {
46 self.data.is_empty()
47 }
48
49 pub fn clear(&mut self) {
50 self.data.clear();
51 }
52
53 pub fn keys(&self) -> Vec<&EncodedKey> {
54 self.data.keys().collect()
55 }
56
57 pub fn entries(&self) -> Vec<(&EncodedKey, &EncodedRow)> {
58 self.data.iter().collect()
59 }
60
61 pub fn decode_value(&self, key: &EncodedKey, shape: &RowShape) -> Option<Vec<Value>> {
62 self.get(key).map(|encoded| get_values(shape, encoded))
63 }
64
65 pub fn decode_named_value(&self, key: &EncodedKey, shape: &RowShape) -> Option<HashMap<String, Value>> {
66 self.get(key).map(|encoded| {
67 let values = get_values(shape, encoded);
68 shape.field_names().map(|n| n.to_string()).zip(values).collect()
69 })
70 }
71
72 pub fn set_value(&mut self, key: EncodedKey, values: &[Value], shape: &RowShape) {
73 let mut encoded = shape.allocate();
74 shape.set_values(&mut encoded, values);
75 self.set(key, encoded);
76 }
77
78 pub fn set_named_value(&mut self, key: EncodedKey, values: &HashMap<String, Value>, shape: &RowShape) {
79 let mut encoded = shape.allocate();
80
81 let ordered_values: Vec<Value> =
82 shape.field_names().map(|name| values.get(name).cloned().unwrap_or(Value::none())).collect();
83
84 shape.set_values(&mut encoded, &ordered_values);
85 self.set(key, encoded);
86 }
87
88 pub fn snapshot(&self) -> HashMap<EncodedKey, EncodedRow> {
89 self.data.clone()
90 }
91
92 pub fn restore(&mut self, snapshot: HashMap<EncodedKey, EncodedRow>) {
93 self.data = snapshot;
94 }
95
96 pub fn assert_value(&self, key: &EncodedKey, expected: &[Value], shape: &RowShape) {
97 let actual =
98 self.decode_value(key, shape).unwrap_or_else(|| panic!("Key {:?} not found in state", key));
99 assert_eq!(actual, expected, "State value mismatch for key {:?}", key);
100 }
101
102 pub fn decode_typed<T: DeserializeOwned>(&self, key: &EncodedKey) -> Option<T> {
103 let row = self.get(key)?;
104 let shape = RowShape::operator_state();
105 let blob = shape.get_blob(row, 0);
106 from_bytes(blob.as_bytes()).ok()
107 }
108
109 pub fn assert_typed_value<T: DeserializeOwned + PartialEq + Debug>(&self, key: &EncodedKey, expected: &T) {
110 let actual = self.decode_typed::<T>(key).unwrap_or_else(|| panic!("Key {:?} not found in state", key));
111 assert_eq!(&actual, expected, "Typed state value mismatch for key {:?}", key);
112 }
113
114 pub fn assert_exists(&self, key: &EncodedKey) {
115 assert!(self.contains(key), "Expected key {:?} to exist in state", key);
116 }
117
118 pub fn assert_not_exists(&self, key: &EncodedKey) {
119 assert!(!self.contains(key), "Expected key {:?} to not exist in state", key);
120 }
121
122 pub fn assert_count(&self, expected: usize) {
123 assert_eq!(self.len(), expected, "Expected {} entries in state, found {}", expected, self.len());
124 }
125}
126
127#[cfg(test)]
128pub mod tests {
129 use reifydb_core::encoded::{
130 row::EncodedRow,
131 shape::{RowShape, RowShapeField},
132 };
133 use reifydb_type::{util::cowvec::CowVec, value::r#type::Type};
134
135 use super::*;
136 use crate::testing::helpers::encode_key;
137
138 #[test]
139 fn test_state_store_basic_operations() {
140 let mut store = TestStateStore::new();
141 let key = encode_key("test_key");
142 let value = EncodedRow(CowVec::new(vec![1, 2, 3, 4]));
143
144 assert!(store.is_empty());
145
146 store.set(key.clone(), value.clone());
147 assert_eq!(store.get(&key), Some(&value));
148 assert!(store.contains(&key));
149 assert_eq!(store.len(), 1);
150
151 let removed = store.remove(&key);
152 assert_eq!(removed, Some(value));
153 assert!(store.is_empty());
154 }
155
156 #[test]
157 fn test_state_store_with_shape() {
158 let mut store = TestStateStore::new();
159 let shape = RowShape::testing(&[Type::Int8, Type::Utf8]);
160 let key = encode_key("test_key");
161 let values = vec![Value::Int8(42i64), Value::Utf8("hello".into())];
162
163 store.set_value(key.clone(), &values, &shape);
164
165 let decoded = store.decode_value(&key, &shape).unwrap();
166 assert_eq!(decoded, values);
167 }
168
169 #[test]
170 fn test_state_store_with_named_shape() {
171 let mut store = TestStateStore::new();
172 let shape = RowShape::new(vec![
173 RowShapeField::unconstrained("count", Type::Int8),
174 RowShapeField::unconstrained("name", Type::Utf8),
175 ]);
176 let key = encode_key("test_key");
177
178 let mut values = HashMap::new();
179 values.insert("count".to_string(), Value::Int8(10i64));
180 values.insert("name".to_string(), Value::Utf8("test".into()));
181
182 store.set_named_value(key.clone(), &values, &shape);
183
184 let decoded = store.decode_named_value(&key, &shape).unwrap();
185 assert_eq!(decoded, values);
186 }
187
188 #[test]
189 fn test_state_store_snapshot_and_restore() {
190 let mut store = TestStateStore::new();
191 let key1 = encode_key("key1");
192 let key2 = encode_key("key2");
193
194 store.set(key1.clone(), EncodedRow(CowVec::new(vec![1])));
195 store.set(key2.clone(), EncodedRow(CowVec::new(vec![2])));
196
197 let snapshot = store.snapshot();
198 assert_eq!(snapshot.len(), 2);
199
200 store.clear();
201 assert!(store.is_empty());
202
203 store.restore(snapshot);
204 assert_eq!(store.len(), 2);
205 assert_eq!(store.get(&key1), Some(&EncodedRow(CowVec::new(vec![1]))));
206 assert_eq!(store.get(&key2), Some(&EncodedRow(CowVec::new(vec![2]))));
207 }
208
209 #[test]
210 fn test_state_store_assertions() {
211 let mut store = TestStateStore::new();
212 let shape = RowShape::testing(&[Type::Int8]);
213 let key = encode_key("test_key");
214 let values = vec![Value::Int8(100i64)];
215
216 store.set_value(key.clone(), &values, &shape);
217
218 store.assert_exists(&key);
219 store.assert_value(&key, &values, &shape);
220 store.assert_count(1);
221
222 let missing_key = encode_key("missing");
223 store.assert_not_exists(&missing_key);
224 }
225}