Skip to main content

reifydb_sub_flow/operator/stateful/
counter.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_core::{
5	encoded::{key::EncodedKey, row::EncodedRow},
6	interface::catalog::flow::FlowNodeId,
7	util::encoding::keycode::serializer::KeySerializer,
8};
9use reifydb_type::{Result, util::cowvec::CowVec, value::row_number::RowNumber};
10
11use crate::{
12	operator::stateful::utils::{internal_state_get, internal_state_set},
13	transaction::FlowTransaction,
14};
15
16/// Direction for counter increment/decrement
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum CounterDirection {
19	/// Count upwards: 1, 2, 3, ...
20	Ascending,
21	/// Count downwards: MAX, MAX-1, MAX-2, ...
22	Descending,
23}
24
25impl Default for CounterDirection {
26	fn default() -> Self {
27		CounterDirection::Ascending
28	}
29}
30
31pub struct Counter {
32	node: FlowNodeId,
33	key: EncodedKey,
34	direction: CounterDirection,
35}
36
37impl Counter {
38	/// Create counter with single-byte prefix key
39	pub fn with_prefix(node: FlowNodeId, prefix: u8, direction: CounterDirection) -> Self {
40		let mut serializer = KeySerializer::new();
41		serializer.extend_u8(prefix);
42		let key = EncodedKey::new(serializer.finish());
43		Self {
44			node,
45			key,
46			direction,
47		}
48	}
49
50	/// Create counter with custom key (e.g., subscription ID)
51	pub fn with_key(node: FlowNodeId, key: EncodedKey, direction: CounterDirection) -> Self {
52		Self {
53			node,
54			key,
55			direction,
56		}
57	}
58
59	/// Get next counter value (atomically: returns current, then increments/decrements)
60	pub fn next(&self, txn: &mut FlowTransaction) -> Result<RowNumber> {
61		let current = self.load(txn)?;
62		let next_value = self.compute_next(current);
63		self.save(txn, next_value)?;
64		Ok(RowNumber(current))
65	}
66
67	/// Get current value without modifying
68	pub fn current(&self, txn: &mut FlowTransaction) -> Result<u64> {
69		self.load(txn)
70	}
71
72	/// Set to specific value
73	pub fn set(&self, txn: &mut FlowTransaction, value: u64) -> Result<()> {
74		self.save(txn, value)
75	}
76
77	// Internal methods
78	fn load(&self, txn: &mut FlowTransaction) -> Result<u64> {
79		match internal_state_get(self.node, txn, &self.key)? {
80			None => Ok(self.default_value()),
81			Some(encoded) => {
82				let bytes = encoded.as_slice();
83				if bytes.len() >= 8 {
84					Ok(u64::from_be_bytes([
85						bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6],
86						bytes[7],
87					]))
88				} else {
89					Ok(self.default_value())
90				}
91			}
92		}
93	}
94
95	fn save(&self, txn: &mut FlowTransaction, value: u64) -> Result<()> {
96		let bytes = value.to_be_bytes().to_vec();
97		internal_state_set(self.node, txn, &self.key, EncodedRow(CowVec::new(bytes)))?;
98		Ok(())
99	}
100
101	fn default_value(&self) -> u64 {
102		match self.direction {
103			CounterDirection::Ascending => 1,
104			CounterDirection::Descending => u64::MAX,
105		}
106	}
107
108	fn compute_next(&self, current: u64) -> u64 {
109		match self.direction {
110			CounterDirection::Ascending => current.wrapping_add(1),
111			CounterDirection::Descending => current.wrapping_sub(1),
112		}
113	}
114}
115
116#[cfg(test)]
117mod tests {
118	use reifydb_catalog::catalog::Catalog;
119	use reifydb_core::common::CommitVersion;
120	use reifydb_transaction::interceptor::interceptors::Interceptors;
121
122	use super::*;
123	use crate::operator::stateful::test_utils::test::*;
124
125	#[test]
126	fn test_counter_starts_at_one() {
127		let mut txn = create_test_transaction();
128		let mut txn =
129			FlowTransaction::deferred(&mut txn, CommitVersion(1), Catalog::testing(), Interceptors::new());
130		let counter = Counter::with_prefix(FlowNodeId(1), b'T', CounterDirection::Ascending);
131
132		let value = counter.next(&mut txn).unwrap();
133		assert_eq!(value.0, 1);
134	}
135
136	#[test]
137	fn test_counter_increments() {
138		let mut txn = create_test_transaction();
139		let mut txn =
140			FlowTransaction::deferred(&mut txn, CommitVersion(1), Catalog::testing(), Interceptors::new());
141		let counter = Counter::with_prefix(FlowNodeId(1), b'T', CounterDirection::Ascending);
142
143		let v1 = counter.next(&mut txn).unwrap();
144		let v2 = counter.next(&mut txn).unwrap();
145		let v3 = counter.next(&mut txn).unwrap();
146
147		assert_eq!(v1.0, 1);
148		assert_eq!(v2.0, 2);
149		assert_eq!(v3.0, 3);
150	}
151
152	#[test]
153	fn test_counter_persistence() {
154		let mut txn = create_test_transaction();
155		let mut txn =
156			FlowTransaction::deferred(&mut txn, CommitVersion(1), Catalog::testing(), Interceptors::new());
157		let node = FlowNodeId(1);
158
159		// First counter instance
160		{
161			let counter = Counter::with_prefix(node, b'P', CounterDirection::Ascending);
162			counter.next(&mut txn).unwrap();
163			counter.next(&mut txn).unwrap();
164		}
165
166		// Second counter instance with same node and prefix
167		{
168			let counter = Counter::with_prefix(node, b'P', CounterDirection::Ascending);
169			let value = counter.next(&mut txn).unwrap();
170			// Should continue from where we left off
171			assert_eq!(value.0, 3);
172		}
173	}
174
175	#[test]
176	fn test_counter_current() {
177		let mut txn = create_test_transaction();
178		let mut txn =
179			FlowTransaction::deferred(&mut txn, CommitVersion(1), Catalog::testing(), Interceptors::new());
180		let counter = Counter::with_prefix(FlowNodeId(1), b'T', CounterDirection::Ascending);
181
182		// First call returns default (1)
183		let current = counter.current(&mut txn).unwrap();
184		assert_eq!(current, 1);
185
186		// After next(), current should reflect the saved value
187		counter.next(&mut txn).unwrap();
188		let current = counter.current(&mut txn).unwrap();
189		assert_eq!(current, 2);
190
191		// current() should not modify the counter
192		let current_again = counter.current(&mut txn).unwrap();
193		assert_eq!(current_again, 2);
194	}
195
196	#[test]
197	fn test_counter_set() {
198		let mut txn = create_test_transaction();
199		let mut txn =
200			FlowTransaction::deferred(&mut txn, CommitVersion(1), Catalog::testing(), Interceptors::new());
201		let counter = Counter::with_prefix(FlowNodeId(1), b'T', CounterDirection::Ascending);
202
203		// Set to a specific value
204		counter.set(&mut txn, 100).unwrap();
205
206		// Next should return 100 and advance to 101
207		let value = counter.next(&mut txn).unwrap();
208		assert_eq!(value.0, 100);
209
210		let value = counter.next(&mut txn).unwrap();
211		assert_eq!(value.0, 101);
212	}
213
214	#[test]
215	fn test_counter_with_custom_key() {
216		let mut txn = create_test_transaction();
217		let mut txn =
218			FlowTransaction::deferred(&mut txn, CommitVersion(1), Catalog::testing(), Interceptors::new());
219
220		// Create a custom key
221		let custom_key = {
222			let mut serializer = KeySerializer::new();
223			serializer.extend_bytes(b"subscription-id-123");
224			EncodedKey::new(serializer.finish())
225		};
226
227		let counter = Counter::with_key(FlowNodeId(1), custom_key, CounterDirection::Ascending);
228
229		let v1 = counter.next(&mut txn).unwrap();
230		let v2 = counter.next(&mut txn).unwrap();
231
232		assert_eq!(v1.0, 1);
233		assert_eq!(v2.0, 2);
234	}
235
236	#[test]
237	fn test_multiple_counters_isolated() {
238		let mut txn = create_test_transaction();
239		let mut txn =
240			FlowTransaction::deferred(&mut txn, CommitVersion(1), Catalog::testing(), Interceptors::new());
241		let node = FlowNodeId(1);
242
243		// Different prefixes should be isolated
244		let counter1 = Counter::with_prefix(node, b'A', CounterDirection::Ascending);
245		let counter2 = Counter::with_prefix(node, b'B', CounterDirection::Ascending);
246
247		let v1a = counter1.next(&mut txn).unwrap();
248		let v2a = counter2.next(&mut txn).unwrap();
249		let v1b = counter1.next(&mut txn).unwrap();
250		let v2b = counter2.next(&mut txn).unwrap();
251
252		// Each counter should maintain its own sequence
253		assert_eq!(v1a.0, 1);
254		assert_eq!(v2a.0, 1);
255		assert_eq!(v1b.0, 2);
256		assert_eq!(v2b.0, 2);
257	}
258
259	#[test]
260	fn test_different_nodes_isolated() {
261		let mut txn = create_test_transaction();
262		let mut txn =
263			FlowTransaction::deferred(&mut txn, CommitVersion(1), Catalog::testing(), Interceptors::new());
264
265		// Same prefix, different nodes should be isolated
266		let counter1 = Counter::with_prefix(FlowNodeId(1), b'X', CounterDirection::Ascending);
267		let counter2 = Counter::with_prefix(FlowNodeId(2), b'X', CounterDirection::Ascending);
268
269		let v1 = counter1.next(&mut txn).unwrap();
270		let v2 = counter2.next(&mut txn).unwrap();
271
272		// Each node should have its own counter
273		assert_eq!(v1.0, 1);
274		assert_eq!(v2.0, 1);
275	}
276
277	#[test]
278	fn test_wrapping_behavior() {
279		let mut txn = create_test_transaction();
280		let mut txn =
281			FlowTransaction::deferred(&mut txn, CommitVersion(1), Catalog::testing(), Interceptors::new());
282
283		// Test wrapping from MAX to 0
284		let counter = Counter::with_prefix(FlowNodeId(1), b'W', CounterDirection::Ascending);
285		counter.set(&mut txn, u64::MAX).unwrap();
286		let v1 = counter.next(&mut txn).unwrap();
287		let v2 = counter.next(&mut txn).unwrap();
288		assert_eq!(v1.0, u64::MAX);
289		assert_eq!(v2.0, 0); // Wraps to 0
290	}
291
292	#[test]
293	fn test_encoded_keys_sort_descending() {
294		// Verify that when counter values are encoded as keys,
295		// they sort in descending order
296		let mut serializer1 = KeySerializer::new();
297		serializer1.extend_u64(1u64);
298		let key1 = serializer1.finish();
299
300		let mut serializer2 = KeySerializer::new();
301		serializer2.extend_u64(2u64);
302		let key2 = serializer2.finish();
303
304		// Key from value 1 should be > key from value 2
305		// (descending order in key space)
306		assert!(key1 > key2, "encode(1) > encode(2) for descending order");
307	}
308
309	#[test]
310	fn test_counter_descending_starts_at_max() {
311		let mut txn = create_test_transaction();
312		let mut txn =
313			FlowTransaction::deferred(&mut txn, CommitVersion(1), Catalog::testing(), Interceptors::new());
314		let counter = Counter::with_prefix(FlowNodeId(1), b'T', CounterDirection::Descending);
315
316		let value = counter.next(&mut txn).unwrap();
317		assert_eq!(value.0, u64::MAX);
318	}
319
320	#[test]
321	fn test_counter_descending_decrements() {
322		let mut txn = create_test_transaction();
323		let mut txn =
324			FlowTransaction::deferred(&mut txn, CommitVersion(1), Catalog::testing(), Interceptors::new());
325		let counter = Counter::with_prefix(FlowNodeId(1), b'T', CounterDirection::Descending);
326
327		let v1 = counter.next(&mut txn).unwrap();
328		let v2 = counter.next(&mut txn).unwrap();
329		let v3 = counter.next(&mut txn).unwrap();
330
331		assert_eq!(v1.0, u64::MAX);
332		assert_eq!(v2.0, u64::MAX - 1);
333		assert_eq!(v3.0, u64::MAX - 2);
334	}
335
336	#[test]
337	fn test_counter_descending_wrapping() {
338		let mut txn = create_test_transaction();
339		let mut txn =
340			FlowTransaction::deferred(&mut txn, CommitVersion(1), Catalog::testing(), Interceptors::new());
341		let counter = Counter::with_prefix(FlowNodeId(1), b'W', CounterDirection::Descending);
342
343		// Set to 1, next should give 1, then wrap to 0, then MAX
344		counter.set(&mut txn, 1).unwrap();
345		let v1 = counter.next(&mut txn).unwrap();
346		let v2 = counter.next(&mut txn).unwrap();
347		assert_eq!(v1.0, 1);
348		assert_eq!(v2.0, 0);
349		let v3 = counter.next(&mut txn).unwrap();
350		assert_eq!(v3.0, u64::MAX); // Wraps from 0 to MAX
351	}
352}