1use reifydb_core::{
5 encoded::{key::EncodedKey, row::EncodedRow},
6 interface::catalog::flow::FlowNodeId,
7 util::encoding::keycode::serializer::KeySerializer,
8};
9use reifydb_type::{Result, util::cowvec::CowVec, value::row_number::RowNumber};
10
11use crate::{
12 operator::stateful::utils::{internal_state_get, internal_state_set},
13 transaction::FlowTransaction,
14};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
18pub enum CounterDirection {
19 #[default]
21 Ascending,
22 Descending,
24}
25
26pub struct Counter {
27 node: FlowNodeId,
28 key: EncodedKey,
29 direction: CounterDirection,
30}
31
32impl Counter {
33 pub fn with_prefix(node: FlowNodeId, prefix: u8, direction: CounterDirection) -> Self {
35 let mut serializer = KeySerializer::new();
36 serializer.extend_u8(prefix);
37 let key = EncodedKey::new(serializer.finish());
38 Self {
39 node,
40 key,
41 direction,
42 }
43 }
44
45 pub fn with_key(node: FlowNodeId, key: EncodedKey, direction: CounterDirection) -> Self {
47 Self {
48 node,
49 key,
50 direction,
51 }
52 }
53
54 pub fn next(&self, txn: &mut FlowTransaction) -> Result<RowNumber> {
56 let current = self.load(txn)?;
57 let next_value = self.compute_next(current);
58 self.save(txn, next_value)?;
59 Ok(RowNumber(current))
60 }
61
62 pub fn current(&self, txn: &mut FlowTransaction) -> Result<u64> {
64 self.load(txn)
65 }
66
67 pub fn set(&self, txn: &mut FlowTransaction, value: u64) -> Result<()> {
69 self.save(txn, value)
70 }
71
72 fn load(&self, txn: &mut FlowTransaction) -> Result<u64> {
74 match internal_state_get(self.node, txn, &self.key)? {
75 None => Ok(self.default_value()),
76 Some(encoded) => {
77 let bytes = encoded.as_slice();
78 if bytes.len() >= 8 {
79 Ok(u64::from_be_bytes([
80 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6],
81 bytes[7],
82 ]))
83 } else {
84 Ok(self.default_value())
85 }
86 }
87 }
88 }
89
90 fn save(&self, txn: &mut FlowTransaction, value: u64) -> Result<()> {
91 let bytes = value.to_be_bytes().to_vec();
92 internal_state_set(self.node, txn, &self.key, EncodedRow(CowVec::new(bytes)))?;
93 Ok(())
94 }
95
96 fn default_value(&self) -> u64 {
97 match self.direction {
98 CounterDirection::Ascending => 1,
99 CounterDirection::Descending => u64::MAX,
100 }
101 }
102
103 fn compute_next(&self, current: u64) -> u64 {
104 match self.direction {
105 CounterDirection::Ascending => current.wrapping_add(1),
106 CounterDirection::Descending => current.wrapping_sub(1),
107 }
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use reifydb_catalog::catalog::Catalog;
114 use reifydb_core::common::CommitVersion;
115 use reifydb_runtime::context::clock::{Clock, MockClock};
116 use reifydb_transaction::interceptor::interceptors::Interceptors;
117
118 use super::*;
119 use crate::operator::stateful::test_utils::test::*;
120
121 #[test]
122 fn test_counter_starts_at_one() {
123 let mut txn = create_test_transaction();
124 let mut txn = FlowTransaction::deferred(
125 &mut txn,
126 CommitVersion(1),
127 Catalog::testing(),
128 Interceptors::new(),
129 Clock::Mock(MockClock::from_millis(1000)),
130 );
131 let counter = Counter::with_prefix(FlowNodeId(1), b'T', CounterDirection::Ascending);
132
133 let value = counter.next(&mut txn).unwrap();
134 assert_eq!(value.0, 1);
135 }
136
137 #[test]
138 fn test_counter_increments() {
139 let mut txn = create_test_transaction();
140 let mut txn = FlowTransaction::deferred(
141 &mut txn,
142 CommitVersion(1),
143 Catalog::testing(),
144 Interceptors::new(),
145 Clock::Mock(MockClock::from_millis(1000)),
146 );
147 let counter = Counter::with_prefix(FlowNodeId(1), b'T', CounterDirection::Ascending);
148
149 let v1 = counter.next(&mut txn).unwrap();
150 let v2 = counter.next(&mut txn).unwrap();
151 let v3 = counter.next(&mut txn).unwrap();
152
153 assert_eq!(v1.0, 1);
154 assert_eq!(v2.0, 2);
155 assert_eq!(v3.0, 3);
156 }
157
158 #[test]
159 fn test_counter_persistence() {
160 let mut txn = create_test_transaction();
161 let mut txn = FlowTransaction::deferred(
162 &mut txn,
163 CommitVersion(1),
164 Catalog::testing(),
165 Interceptors::new(),
166 Clock::Mock(MockClock::from_millis(1000)),
167 );
168 let node = FlowNodeId(1);
169
170 {
172 let counter = Counter::with_prefix(node, b'P', CounterDirection::Ascending);
173 counter.next(&mut txn).unwrap();
174 counter.next(&mut txn).unwrap();
175 }
176
177 {
179 let counter = Counter::with_prefix(node, b'P', CounterDirection::Ascending);
180 let value = counter.next(&mut txn).unwrap();
181 assert_eq!(value.0, 3);
183 }
184 }
185
186 #[test]
187 fn test_counter_current() {
188 let mut txn = create_test_transaction();
189 let mut txn = FlowTransaction::deferred(
190 &mut txn,
191 CommitVersion(1),
192 Catalog::testing(),
193 Interceptors::new(),
194 Clock::Mock(MockClock::from_millis(1000)),
195 );
196 let counter = Counter::with_prefix(FlowNodeId(1), b'T', CounterDirection::Ascending);
197
198 let current = counter.current(&mut txn).unwrap();
200 assert_eq!(current, 1);
201
202 counter.next(&mut txn).unwrap();
204 let current = counter.current(&mut txn).unwrap();
205 assert_eq!(current, 2);
206
207 let current_again = counter.current(&mut txn).unwrap();
209 assert_eq!(current_again, 2);
210 }
211
212 #[test]
213 fn test_counter_set() {
214 let mut txn = create_test_transaction();
215 let mut txn = FlowTransaction::deferred(
216 &mut txn,
217 CommitVersion(1),
218 Catalog::testing(),
219 Interceptors::new(),
220 Clock::Mock(MockClock::from_millis(1000)),
221 );
222 let counter = Counter::with_prefix(FlowNodeId(1), b'T', CounterDirection::Ascending);
223
224 counter.set(&mut txn, 100).unwrap();
226
227 let value = counter.next(&mut txn).unwrap();
229 assert_eq!(value.0, 100);
230
231 let value = counter.next(&mut txn).unwrap();
232 assert_eq!(value.0, 101);
233 }
234
235 #[test]
236 fn test_counter_with_custom_key() {
237 let mut txn = create_test_transaction();
238 let mut txn = FlowTransaction::deferred(
239 &mut txn,
240 CommitVersion(1),
241 Catalog::testing(),
242 Interceptors::new(),
243 Clock::Mock(MockClock::from_millis(1000)),
244 );
245
246 let custom_key = {
248 let mut serializer = KeySerializer::new();
249 serializer.extend_bytes(b"subscription-id-123");
250 EncodedKey::new(serializer.finish())
251 };
252
253 let counter = Counter::with_key(FlowNodeId(1), custom_key, CounterDirection::Ascending);
254
255 let v1 = counter.next(&mut txn).unwrap();
256 let v2 = counter.next(&mut txn).unwrap();
257
258 assert_eq!(v1.0, 1);
259 assert_eq!(v2.0, 2);
260 }
261
262 #[test]
263 fn test_multiple_counters_isolated() {
264 let mut txn = create_test_transaction();
265 let mut txn = FlowTransaction::deferred(
266 &mut txn,
267 CommitVersion(1),
268 Catalog::testing(),
269 Interceptors::new(),
270 Clock::Mock(MockClock::from_millis(1000)),
271 );
272 let node = FlowNodeId(1);
273
274 let counter1 = Counter::with_prefix(node, b'A', CounterDirection::Ascending);
276 let counter2 = Counter::with_prefix(node, b'B', CounterDirection::Ascending);
277
278 let v1a = counter1.next(&mut txn).unwrap();
279 let v2a = counter2.next(&mut txn).unwrap();
280 let v1b = counter1.next(&mut txn).unwrap();
281 let v2b = counter2.next(&mut txn).unwrap();
282
283 assert_eq!(v1a.0, 1);
285 assert_eq!(v2a.0, 1);
286 assert_eq!(v1b.0, 2);
287 assert_eq!(v2b.0, 2);
288 }
289
290 #[test]
291 fn test_different_nodes_isolated() {
292 let mut txn = create_test_transaction();
293 let mut txn = FlowTransaction::deferred(
294 &mut txn,
295 CommitVersion(1),
296 Catalog::testing(),
297 Interceptors::new(),
298 Clock::Mock(MockClock::from_millis(1000)),
299 );
300
301 let counter1 = Counter::with_prefix(FlowNodeId(1), b'X', CounterDirection::Ascending);
303 let counter2 = Counter::with_prefix(FlowNodeId(2), b'X', CounterDirection::Ascending);
304
305 let v1 = counter1.next(&mut txn).unwrap();
306 let v2 = counter2.next(&mut txn).unwrap();
307
308 assert_eq!(v1.0, 1);
310 assert_eq!(v2.0, 1);
311 }
312
313 #[test]
314 fn test_wrapping_behavior() {
315 let mut txn = create_test_transaction();
316 let mut txn = FlowTransaction::deferred(
317 &mut txn,
318 CommitVersion(1),
319 Catalog::testing(),
320 Interceptors::new(),
321 Clock::Mock(MockClock::from_millis(1000)),
322 );
323
324 let counter = Counter::with_prefix(FlowNodeId(1), b'W', CounterDirection::Ascending);
326 counter.set(&mut txn, u64::MAX).unwrap();
327 let v1 = counter.next(&mut txn).unwrap();
328 let v2 = counter.next(&mut txn).unwrap();
329 assert_eq!(v1.0, u64::MAX);
330 assert_eq!(v2.0, 0); }
332
333 #[test]
334 fn test_encoded_keys_sort_descending() {
335 let mut serializer1 = KeySerializer::new();
338 serializer1.extend_u64(1u64);
339 let key1 = serializer1.finish();
340
341 let mut serializer2 = KeySerializer::new();
342 serializer2.extend_u64(2u64);
343 let key2 = serializer2.finish();
344
345 assert!(key1 > key2, "encode(1) > encode(2) for descending order");
348 }
349
350 #[test]
351 fn test_counter_descending_starts_at_max() {
352 let mut txn = create_test_transaction();
353 let mut txn = FlowTransaction::deferred(
354 &mut txn,
355 CommitVersion(1),
356 Catalog::testing(),
357 Interceptors::new(),
358 Clock::Mock(MockClock::from_millis(1000)),
359 );
360 let counter = Counter::with_prefix(FlowNodeId(1), b'T', CounterDirection::Descending);
361
362 let value = counter.next(&mut txn).unwrap();
363 assert_eq!(value.0, u64::MAX);
364 }
365
366 #[test]
367 fn test_counter_descending_decrements() {
368 let mut txn = create_test_transaction();
369 let mut txn = FlowTransaction::deferred(
370 &mut txn,
371 CommitVersion(1),
372 Catalog::testing(),
373 Interceptors::new(),
374 Clock::Mock(MockClock::from_millis(1000)),
375 );
376 let counter = Counter::with_prefix(FlowNodeId(1), b'T', CounterDirection::Descending);
377
378 let v1 = counter.next(&mut txn).unwrap();
379 let v2 = counter.next(&mut txn).unwrap();
380 let v3 = counter.next(&mut txn).unwrap();
381
382 assert_eq!(v1.0, u64::MAX);
383 assert_eq!(v2.0, u64::MAX - 1);
384 assert_eq!(v3.0, u64::MAX - 2);
385 }
386
387 #[test]
388 fn test_counter_descending_wrapping() {
389 let mut txn = create_test_transaction();
390 let mut txn = FlowTransaction::deferred(
391 &mut txn,
392 CommitVersion(1),
393 Catalog::testing(),
394 Interceptors::new(),
395 Clock::Mock(MockClock::from_millis(1000)),
396 );
397 let counter = Counter::with_prefix(FlowNodeId(1), b'W', CounterDirection::Descending);
398
399 counter.set(&mut txn, 1).unwrap();
401 let v1 = counter.next(&mut txn).unwrap();
402 let v2 = counter.next(&mut txn).unwrap();
403 assert_eq!(v1.0, 1);
404 assert_eq!(v2.0, 0);
405 let v3 = counter.next(&mut txn).unwrap();
406 assert_eq!(v3.0, u64::MAX); }
408}