1use std::collections::HashMap;
5
6use reifydb_core::encoded::{
7 key::{EncodedKey, IntoEncodedKey},
8 row::EncodedRow,
9 shape::RowShape,
10};
11use reifydb_type::{
12 util::cowvec::CowVec,
13 value::{Value, r#type::Type},
14};
15
16use super::helpers::get_values;
17
18pub struct SingleStatefulTestHelper {
19 shape: RowShape,
20 state: Option<Vec<u8>>,
21}
22
23impl SingleStatefulTestHelper {
24 pub fn new(shape: RowShape) -> Self {
25 Self {
26 shape,
27 state: None,
28 }
29 }
30
31 pub fn counter() -> Self {
32 Self::new(RowShape::testing(&[Type::Int8]))
33 }
34
35 pub fn set_state(&mut self, values: &[Value]) {
36 let mut encoded = self.shape.allocate();
37 self.shape.set_values(&mut encoded, values);
38 self.state = Some(encoded.0.to_vec());
39 }
40
41 pub fn get_state(&self) -> Option<Vec<Value>> {
42 self.state.as_ref().map(|bytes| {
43 let encoded = EncodedRow(CowVec::new(bytes.clone()));
44 get_values(&self.shape, &encoded)
45 })
46 }
47
48 pub fn assert_state(&self, expected: &[Value]) {
49 let actual = self.get_state().expect("No state set");
50 assert_eq!(actual, expected, "State mismatch");
51 }
52
53 pub fn clear(&mut self) {
54 self.state = None;
55 }
56
57 pub fn has_state(&self) -> bool {
58 self.state.is_some()
59 }
60}
61
62pub struct KeyedStatefulTestHelper {
63 shape: RowShape,
64 states: HashMap<EncodedKey, EncodedRow>,
65}
66
67impl KeyedStatefulTestHelper {
68 pub fn new(shape: RowShape) -> Self {
69 Self {
70 shape,
71 states: HashMap::new(),
72 }
73 }
74
75 pub fn counter() -> Self {
76 Self::new(RowShape::testing(&[Type::Int8]))
77 }
78
79 pub fn sum() -> Self {
80 Self::new(RowShape::testing(&[Type::Int4]))
81 }
82
83 pub fn set_state<K>(&mut self, key: K, values: &[Value])
84 where
85 K: IntoEncodedKey,
86 {
87 let mut encoded = self.shape.allocate();
88 self.shape.set_values(&mut encoded, values);
89 self.states.insert(key.into_encoded_key(), encoded);
90 }
91
92 pub fn get_state<K>(&self, key: K) -> Option<Vec<Value>>
93 where
94 K: IntoEncodedKey,
95 {
96 self.states.get(&key.into_encoded_key()).map(|encoded| get_values(&self.shape, encoded))
97 }
98
99 pub fn assert_state<K>(&self, key: K, expected: &[Value])
100 where
101 K: IntoEncodedKey,
102 {
103 let key_encoded = key.into_encoded_key();
104 let actual = self
105 .states
106 .get(&key_encoded)
107 .map(|encoded| get_values(&self.shape, encoded))
108 .expect("No state for key");
109 assert_eq!(actual, expected, "State mismatch for key");
110 }
111
112 pub fn remove_state<K>(&mut self, key: K) -> Option<Vec<Value>>
113 where
114 K: IntoEncodedKey,
115 {
116 self.states.remove(&key.into_encoded_key()).map(|encoded| get_values(&self.shape, &encoded))
117 }
118
119 pub fn has_state<K>(&self, key: K) -> bool
120 where
121 K: IntoEncodedKey,
122 {
123 self.states.contains_key(&key.into_encoded_key())
124 }
125
126 pub fn state_count(&self) -> usize {
127 self.states.len()
128 }
129
130 pub fn clear(&mut self) {
131 self.states.clear();
132 }
133
134 pub fn keys(&self) -> Vec<&EncodedKey> {
135 self.states.keys().collect()
136 }
137
138 pub fn assert_count(&self, expected: usize) {
139 assert_eq!(self.state_count(), expected, "Expected {} states, found {}", expected, self.state_count());
140 }
141}
142
143pub struct WindowStatefulTestHelper {
144 shape: RowShape,
145 windows: HashMap<i64, HashMap<EncodedKey, EncodedRow>>, window_size: i64,
147}
148
149impl WindowStatefulTestHelper {
150 pub fn new(shape: RowShape, window_size: i64) -> Self {
151 Self {
152 shape,
153 windows: HashMap::new(),
154 window_size,
155 }
156 }
157
158 pub fn time_window_counter(window_size_seconds: i64) -> Self {
159 Self::new(RowShape::testing(&[Type::Int8]), window_size_seconds)
160 }
161
162 pub fn count_window_sum(window_size_count: i64) -> Self {
163 Self::new(RowShape::testing(&[Type::Int4]), window_size_count)
164 }
165
166 pub fn set_window_state<K>(&mut self, window_id: i64, key: K, values: &[Value])
167 where
168 K: IntoEncodedKey,
169 {
170 let mut encoded = self.shape.allocate();
171 self.shape.set_values(&mut encoded, values);
172
173 self.windows.entry(window_id).or_default().insert(key.into_encoded_key(), encoded);
174 }
175
176 pub fn get_window_state<K>(&self, window_id: i64, key: K) -> Option<Vec<Value>>
177 where
178 K: IntoEncodedKey,
179 {
180 self.windows
181 .get(&window_id)
182 .and_then(|window| window.get(&key.into_encoded_key()))
183 .map(|encoded| get_values(&self.shape, encoded))
184 }
185
186 pub fn assert_window_state<K>(&self, window_id: i64, key: K, expected: &[Value])
187 where
188 K: IntoEncodedKey,
189 {
190 let key_encoded = key.into_encoded_key();
191 let actual = self
192 .windows
193 .get(&window_id)
194 .and_then(|window| window.get(&key_encoded))
195 .map(|encoded| get_values(&self.shape, encoded))
196 .expect("No state for window and key");
197 assert_eq!(actual, expected, "State mismatch for window {} and key", window_id);
198 }
199
200 pub fn get_window(&self, window_id: i64) -> Option<&HashMap<EncodedKey, EncodedRow>> {
201 self.windows.get(&window_id)
202 }
203
204 pub fn remove_window(&mut self, window_id: i64) -> Option<HashMap<EncodedKey, EncodedRow>> {
205 self.windows.remove(&window_id)
206 }
207
208 pub fn has_window(&self, window_id: i64) -> bool {
209 self.windows.contains_key(&window_id)
210 }
211
212 pub fn window_count(&self) -> usize {
213 self.windows.len()
214 }
215
216 pub fn window_key_count(&self, window_id: i64) -> usize {
217 self.windows.get(&window_id).map(|w| w.len()).unwrap_or(0)
218 }
219
220 pub fn clear(&mut self) {
221 self.windows.clear();
222 }
223
224 pub fn window_ids(&self) -> Vec<i64> {
225 let mut ids: Vec<_> = self.windows.keys().copied().collect();
226 ids.sort();
227 ids
228 }
229
230 pub fn assert_window_count(&self, expected: usize) {
231 assert_eq!(
232 self.window_count(),
233 expected,
234 "Expected {} windows, found {}",
235 expected,
236 self.window_count()
237 );
238 }
239
240 pub fn window_for_timestamp(&self, timestamp: i64) -> i64 {
241 timestamp / self.window_size
242 }
243}
244
245pub mod scenarios {
246 use reifydb_core::interface::change::Change;
247 use reifydb_type::value::row_number::RowNumber;
248
249 use super::*;
250 use crate::testing::builders::TestChangeBuilder;
251
252 pub fn counter_inserts(count: usize) -> Vec<Change> {
253 (0..count)
254 .map(|i| {
255 TestChangeBuilder::new()
256 .insert_row(RowNumber(i as u64), vec![Value::Int8(1i64)])
257 .build()
258 })
259 .collect()
260 }
261
262 pub fn grouped_inserts(groups: &[(&str, i32)]) -> Change {
263 let mut builder = TestChangeBuilder::new();
264 for (i, (key, value)) in groups.iter().enumerate() {
265 builder = builder
266 .insert_row(RowNumber(i as u64), vec![Value::Utf8((*key).into()), Value::Int4(*value)]);
267 }
268 builder.build()
269 }
270
271 pub fn state_updates(row_number: i64, old_value: i8, new_value: i8) -> Change {
272 TestChangeBuilder::new()
273 .update_row(
274 RowNumber(row_number as u64),
275 vec![Value::Int8(old_value as i64)],
276 vec![Value::Int8(new_value as i64)],
277 )
278 .build()
279 }
280
281 pub fn windowed_events(window_size: i64, events_per_window: usize, windows: usize) -> Vec<(i64, Change)> {
282 let mut result = Vec::new();
283
284 for window in 0..windows {
285 let base_time = window as i64 * window_size;
286
287 for event in 0..events_per_window {
288 let timestamp = base_time + (event as i64 * (window_size / events_per_window as i64));
289 let change = TestChangeBuilder::new()
290 .insert_row(
291 RowNumber(timestamp as u64),
292 vec![Value::Int8(1i64), Value::Int8(timestamp)],
293 )
294 .build();
295 result.push((timestamp, change));
296 }
297 }
298
299 result
300 }
301}
302
303#[cfg(test)]
304pub mod tests {
305 use super::{scenarios::*, *};
306
307 #[test]
308 fn test_single_stateful_helper() {
309 let mut helper = SingleStatefulTestHelper::counter();
310
311 assert!(!helper.has_state());
312
313 helper.set_state(&[Value::Int8(42i64)]);
314 assert!(helper.has_state());
315 helper.assert_state(&[Value::Int8(42i64)]);
316
317 helper.clear();
318 assert!(!helper.has_state());
319 }
320
321 #[test]
322 fn test_keyed_stateful_helper() {
323 let mut helper = KeyedStatefulTestHelper::sum();
324
325 helper.set_state("key1", &[Value::Int4(100)]);
326 helper.set_state("key2", &[Value::Int4(200)]);
327
328 helper.assert_count(2);
329 helper.assert_state("key1", &[Value::Int4(100)]);
330 helper.assert_state("key2", &[Value::Int4(200)]);
331
332 assert!(helper.has_state("key1"));
333 assert!(!helper.has_state("key3"));
334
335 let removed = helper.remove_state("key1");
336 assert_eq!(removed, Some(vec![Value::Int4(100)]));
337 helper.assert_count(1);
338 }
339
340 #[test]
341 fn test_window_stateful_helper() {
342 let mut helper = WindowStatefulTestHelper::time_window_counter(60);
343
344 let window1 = helper.window_for_timestamp(30);
345 let window2 = helper.window_for_timestamp(90);
346
347 helper.set_window_state(window1, "key1", &[Value::Int8(10i64)]);
348 helper.set_window_state(window2, "key1", &[Value::Int8(20i64)]);
349
350 helper.assert_window_count(2);
351 helper.assert_window_state(window1, "key1", &[Value::Int8(10i64)]);
352 helper.assert_window_state(window2, "key1", &[Value::Int8(20i64)]);
353
354 assert_eq!(helper.window_ids(), vec![window1, window2]);
355 assert_eq!(helper.window_key_count(window1), 1);
356 }
357
358 #[test]
359 fn test_scenarios() {
360 let changes = counter_inserts(3);
362 assert_eq!(changes.len(), 3);
363
364 let grouped = grouped_inserts(&[("a", 10), ("b", 20), ("a", 30)]);
366 assert_eq!(grouped.diffs.len(), 3);
367
368 let update = state_updates(1, 10, 20);
370 assert_eq!(update.diffs.len(), 1);
371
372 let windowed = windowed_events(60, 2, 2);
374 assert_eq!(windowed.len(), 4); }
376
377 #[test]
378 fn test_into_encoded_key_with_strings() {
379 let mut helper = KeyedStatefulTestHelper::sum();
381
382 helper.set_state("string_key_1", &[Value::Int4(42)]);
384 helper.set_state("string_key_2", &[Value::Int4(100)]);
385
386 let key = String::from("dynamic_key");
388 helper.set_state(key.clone(), &[Value::Int4(200)]);
389
390 helper.set_state(123u32, &[Value::Int4(300)]);
392 helper.set_state(456u64, &[Value::Int4(400)]);
393
394 assert_eq!(helper.get_state("string_key_1"), Some(vec![Value::Int4(42)]));
396 assert_eq!(helper.get_state("string_key_2"), Some(vec![Value::Int4(100)]));
397 assert_eq!(helper.get_state(key), Some(vec![Value::Int4(200)]));
398 assert_eq!(helper.get_state(123u32), Some(vec![Value::Int4(300)]));
399 assert_eq!(helper.get_state(456u64), Some(vec![Value::Int4(400)]));
400
401 assert_eq!(helper.state_count(), 5);
402 }
403}