Skip to main content

reifydb_sub_flow/operator/stateful/
counter.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_core::{
5	encoded::{key::EncodedKey, row::EncodedRow},
6	interface::catalog::flow::FlowNodeId,
7	util::encoding::keycode::serializer::KeySerializer,
8};
9use reifydb_type::{Result, util::cowvec::CowVec, value::row_number::RowNumber};
10
11use crate::{
12	operator::stateful::utils::{internal_state_get, internal_state_set},
13	transaction::FlowTransaction,
14};
15
16/// Direction for counter increment/decrement
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
18pub enum CounterDirection {
19	/// Count upwards: 1, 2, 3, ...
20	#[default]
21	Ascending,
22	/// Count downwards: MAX, MAX-1, MAX-2, ...
23	Descending,
24}
25
26pub struct Counter {
27	node: FlowNodeId,
28	key: EncodedKey,
29	direction: CounterDirection,
30}
31
32impl Counter {
33	/// Create counter with single-byte prefix key
34	pub fn with_prefix(node: FlowNodeId, prefix: u8, direction: CounterDirection) -> Self {
35		let mut serializer = KeySerializer::new();
36		serializer.extend_u8(prefix);
37		let key = EncodedKey::new(serializer.finish());
38		Self {
39			node,
40			key,
41			direction,
42		}
43	}
44
45	/// Create counter with custom key (e.g., subscription ID)
46	pub fn with_key(node: FlowNodeId, key: EncodedKey, direction: CounterDirection) -> Self {
47		Self {
48			node,
49			key,
50			direction,
51		}
52	}
53
54	/// Get next counter value (atomically: returns current, then increments/decrements)
55	pub fn next(&self, txn: &mut FlowTransaction) -> Result<RowNumber> {
56		let current = self.load(txn)?;
57		let next_value = self.compute_next(current);
58		self.save(txn, next_value)?;
59		Ok(RowNumber(current))
60	}
61
62	/// Get current value without modifying
63	pub fn current(&self, txn: &mut FlowTransaction) -> Result<u64> {
64		self.load(txn)
65	}
66
67	/// Set to specific value
68	pub fn set(&self, txn: &mut FlowTransaction, value: u64) -> Result<()> {
69		self.save(txn, value)
70	}
71
72	// Internal methods
73	fn load(&self, txn: &mut FlowTransaction) -> Result<u64> {
74		match internal_state_get(self.node, txn, &self.key)? {
75			None => Ok(self.default_value()),
76			Some(encoded) => {
77				let bytes = encoded.as_slice();
78				if bytes.len() >= 8 {
79					Ok(u64::from_be_bytes([
80						bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6],
81						bytes[7],
82					]))
83				} else {
84					Ok(self.default_value())
85				}
86			}
87		}
88	}
89
90	fn save(&self, txn: &mut FlowTransaction, value: u64) -> Result<()> {
91		let bytes = value.to_be_bytes().to_vec();
92		internal_state_set(self.node, txn, &self.key, EncodedRow(CowVec::new(bytes)))?;
93		Ok(())
94	}
95
96	fn default_value(&self) -> u64 {
97		match self.direction {
98			CounterDirection::Ascending => 1,
99			CounterDirection::Descending => u64::MAX,
100		}
101	}
102
103	fn compute_next(&self, current: u64) -> u64 {
104		match self.direction {
105			CounterDirection::Ascending => current.wrapping_add(1),
106			CounterDirection::Descending => current.wrapping_sub(1),
107		}
108	}
109}
110
111#[cfg(test)]
112mod tests {
113	use reifydb_catalog::catalog::Catalog;
114	use reifydb_core::common::CommitVersion;
115	use reifydb_runtime::context::clock::{Clock, MockClock};
116	use reifydb_transaction::interceptor::interceptors::Interceptors;
117
118	use super::*;
119	use crate::operator::stateful::test_utils::test::*;
120
121	#[test]
122	fn test_counter_starts_at_one() {
123		let mut txn = create_test_transaction();
124		let mut txn = FlowTransaction::deferred(
125			&mut txn,
126			CommitVersion(1),
127			Catalog::testing(),
128			Interceptors::new(),
129			Clock::Mock(MockClock::from_millis(1000)),
130		);
131		let counter = Counter::with_prefix(FlowNodeId(1), b'T', CounterDirection::Ascending);
132
133		let value = counter.next(&mut txn).unwrap();
134		assert_eq!(value.0, 1);
135	}
136
137	#[test]
138	fn test_counter_increments() {
139		let mut txn = create_test_transaction();
140		let mut txn = FlowTransaction::deferred(
141			&mut txn,
142			CommitVersion(1),
143			Catalog::testing(),
144			Interceptors::new(),
145			Clock::Mock(MockClock::from_millis(1000)),
146		);
147		let counter = Counter::with_prefix(FlowNodeId(1), b'T', CounterDirection::Ascending);
148
149		let v1 = counter.next(&mut txn).unwrap();
150		let v2 = counter.next(&mut txn).unwrap();
151		let v3 = counter.next(&mut txn).unwrap();
152
153		assert_eq!(v1.0, 1);
154		assert_eq!(v2.0, 2);
155		assert_eq!(v3.0, 3);
156	}
157
158	#[test]
159	fn test_counter_persistence() {
160		let mut txn = create_test_transaction();
161		let mut txn = FlowTransaction::deferred(
162			&mut txn,
163			CommitVersion(1),
164			Catalog::testing(),
165			Interceptors::new(),
166			Clock::Mock(MockClock::from_millis(1000)),
167		);
168		let node = FlowNodeId(1);
169
170		// First counter instance
171		{
172			let counter = Counter::with_prefix(node, b'P', CounterDirection::Ascending);
173			counter.next(&mut txn).unwrap();
174			counter.next(&mut txn).unwrap();
175		}
176
177		// Second counter instance with same node and prefix
178		{
179			let counter = Counter::with_prefix(node, b'P', CounterDirection::Ascending);
180			let value = counter.next(&mut txn).unwrap();
181			// Should continue from where we left off
182			assert_eq!(value.0, 3);
183		}
184	}
185
186	#[test]
187	fn test_counter_current() {
188		let mut txn = create_test_transaction();
189		let mut txn = FlowTransaction::deferred(
190			&mut txn,
191			CommitVersion(1),
192			Catalog::testing(),
193			Interceptors::new(),
194			Clock::Mock(MockClock::from_millis(1000)),
195		);
196		let counter = Counter::with_prefix(FlowNodeId(1), b'T', CounterDirection::Ascending);
197
198		// First call returns default (1)
199		let current = counter.current(&mut txn).unwrap();
200		assert_eq!(current, 1);
201
202		// After next(), current should reflect the saved value
203		counter.next(&mut txn).unwrap();
204		let current = counter.current(&mut txn).unwrap();
205		assert_eq!(current, 2);
206
207		// current() should not modify the counter
208		let current_again = counter.current(&mut txn).unwrap();
209		assert_eq!(current_again, 2);
210	}
211
212	#[test]
213	fn test_counter_set() {
214		let mut txn = create_test_transaction();
215		let mut txn = FlowTransaction::deferred(
216			&mut txn,
217			CommitVersion(1),
218			Catalog::testing(),
219			Interceptors::new(),
220			Clock::Mock(MockClock::from_millis(1000)),
221		);
222		let counter = Counter::with_prefix(FlowNodeId(1), b'T', CounterDirection::Ascending);
223
224		// Set to a specific value
225		counter.set(&mut txn, 100).unwrap();
226
227		// Next should return 100 and advance to 101
228		let value = counter.next(&mut txn).unwrap();
229		assert_eq!(value.0, 100);
230
231		let value = counter.next(&mut txn).unwrap();
232		assert_eq!(value.0, 101);
233	}
234
235	#[test]
236	fn test_counter_with_custom_key() {
237		let mut txn = create_test_transaction();
238		let mut txn = FlowTransaction::deferred(
239			&mut txn,
240			CommitVersion(1),
241			Catalog::testing(),
242			Interceptors::new(),
243			Clock::Mock(MockClock::from_millis(1000)),
244		);
245
246		// Create a custom key
247		let custom_key = {
248			let mut serializer = KeySerializer::new();
249			serializer.extend_bytes(b"subscription-id-123");
250			EncodedKey::new(serializer.finish())
251		};
252
253		let counter = Counter::with_key(FlowNodeId(1), custom_key, CounterDirection::Ascending);
254
255		let v1 = counter.next(&mut txn).unwrap();
256		let v2 = counter.next(&mut txn).unwrap();
257
258		assert_eq!(v1.0, 1);
259		assert_eq!(v2.0, 2);
260	}
261
262	#[test]
263	fn test_multiple_counters_isolated() {
264		let mut txn = create_test_transaction();
265		let mut txn = FlowTransaction::deferred(
266			&mut txn,
267			CommitVersion(1),
268			Catalog::testing(),
269			Interceptors::new(),
270			Clock::Mock(MockClock::from_millis(1000)),
271		);
272		let node = FlowNodeId(1);
273
274		// Different prefixes should be isolated
275		let counter1 = Counter::with_prefix(node, b'A', CounterDirection::Ascending);
276		let counter2 = Counter::with_prefix(node, b'B', CounterDirection::Ascending);
277
278		let v1a = counter1.next(&mut txn).unwrap();
279		let v2a = counter2.next(&mut txn).unwrap();
280		let v1b = counter1.next(&mut txn).unwrap();
281		let v2b = counter2.next(&mut txn).unwrap();
282
283		// Each counter should maintain its own sequence
284		assert_eq!(v1a.0, 1);
285		assert_eq!(v2a.0, 1);
286		assert_eq!(v1b.0, 2);
287		assert_eq!(v2b.0, 2);
288	}
289
290	#[test]
291	fn test_different_nodes_isolated() {
292		let mut txn = create_test_transaction();
293		let mut txn = FlowTransaction::deferred(
294			&mut txn,
295			CommitVersion(1),
296			Catalog::testing(),
297			Interceptors::new(),
298			Clock::Mock(MockClock::from_millis(1000)),
299		);
300
301		// Same prefix, different nodes should be isolated
302		let counter1 = Counter::with_prefix(FlowNodeId(1), b'X', CounterDirection::Ascending);
303		let counter2 = Counter::with_prefix(FlowNodeId(2), b'X', CounterDirection::Ascending);
304
305		let v1 = counter1.next(&mut txn).unwrap();
306		let v2 = counter2.next(&mut txn).unwrap();
307
308		// Each node should have its own counter
309		assert_eq!(v1.0, 1);
310		assert_eq!(v2.0, 1);
311	}
312
313	#[test]
314	fn test_wrapping_behavior() {
315		let mut txn = create_test_transaction();
316		let mut txn = FlowTransaction::deferred(
317			&mut txn,
318			CommitVersion(1),
319			Catalog::testing(),
320			Interceptors::new(),
321			Clock::Mock(MockClock::from_millis(1000)),
322		);
323
324		// Test wrapping from MAX to 0
325		let counter = Counter::with_prefix(FlowNodeId(1), b'W', CounterDirection::Ascending);
326		counter.set(&mut txn, u64::MAX).unwrap();
327		let v1 = counter.next(&mut txn).unwrap();
328		let v2 = counter.next(&mut txn).unwrap();
329		assert_eq!(v1.0, u64::MAX);
330		assert_eq!(v2.0, 0); // Wraps to 0
331	}
332
333	#[test]
334	fn test_encoded_keys_sort_descending() {
335		// Verify that when counter values are encoded as keys,
336		// they sort in descending order
337		let mut serializer1 = KeySerializer::new();
338		serializer1.extend_u64(1u64);
339		let key1 = serializer1.finish();
340
341		let mut serializer2 = KeySerializer::new();
342		serializer2.extend_u64(2u64);
343		let key2 = serializer2.finish();
344
345		// Key from value 1 should be > key from value 2
346		// (descending order in key space)
347		assert!(key1 > key2, "encode(1) > encode(2) for descending order");
348	}
349
350	#[test]
351	fn test_counter_descending_starts_at_max() {
352		let mut txn = create_test_transaction();
353		let mut txn = FlowTransaction::deferred(
354			&mut txn,
355			CommitVersion(1),
356			Catalog::testing(),
357			Interceptors::new(),
358			Clock::Mock(MockClock::from_millis(1000)),
359		);
360		let counter = Counter::with_prefix(FlowNodeId(1), b'T', CounterDirection::Descending);
361
362		let value = counter.next(&mut txn).unwrap();
363		assert_eq!(value.0, u64::MAX);
364	}
365
366	#[test]
367	fn test_counter_descending_decrements() {
368		let mut txn = create_test_transaction();
369		let mut txn = FlowTransaction::deferred(
370			&mut txn,
371			CommitVersion(1),
372			Catalog::testing(),
373			Interceptors::new(),
374			Clock::Mock(MockClock::from_millis(1000)),
375		);
376		let counter = Counter::with_prefix(FlowNodeId(1), b'T', CounterDirection::Descending);
377
378		let v1 = counter.next(&mut txn).unwrap();
379		let v2 = counter.next(&mut txn).unwrap();
380		let v3 = counter.next(&mut txn).unwrap();
381
382		assert_eq!(v1.0, u64::MAX);
383		assert_eq!(v2.0, u64::MAX - 1);
384		assert_eq!(v3.0, u64::MAX - 2);
385	}
386
387	#[test]
388	fn test_counter_descending_wrapping() {
389		let mut txn = create_test_transaction();
390		let mut txn = FlowTransaction::deferred(
391			&mut txn,
392			CommitVersion(1),
393			Catalog::testing(),
394			Interceptors::new(),
395			Clock::Mock(MockClock::from_millis(1000)),
396		);
397		let counter = Counter::with_prefix(FlowNodeId(1), b'W', CounterDirection::Descending);
398
399		// Set to 1, next should give 1, then wrap to 0, then MAX
400		counter.set(&mut txn, 1).unwrap();
401		let v1 = counter.next(&mut txn).unwrap();
402		let v2 = counter.next(&mut txn).unwrap();
403		assert_eq!(v1.0, 1);
404		assert_eq!(v2.0, 0);
405		let v3 = counter.next(&mut txn).unwrap();
406		assert_eq!(v3.0, u64::MAX); // Wraps from 0 to MAX
407	}
408}