Skip to main content

reifydb_sub_flow/operator/stateful/
counter.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_core::{
5	encoded::{encoded::EncodedValues, key::EncodedKey},
6	interface::catalog::flow::FlowNodeId,
7	util::encoding::keycode::serializer::KeySerializer,
8};
9use reifydb_type::{util::cowvec::CowVec, value::row_number::RowNumber};
10
11use crate::{
12	operator::stateful::utils::{internal_state_get, internal_state_set},
13	transaction::FlowTransaction,
14};
15
16/// Direction for counter increment/decrement
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum CounterDirection {
19	/// Count upwards: 1, 2, 3, ...
20	Ascending,
21	/// Count downwards: MAX, MAX-1, MAX-2, ...
22	Descending,
23}
24
25impl Default for CounterDirection {
26	fn default() -> Self {
27		CounterDirection::Ascending
28	}
29}
30
31pub struct Counter {
32	node: FlowNodeId,
33	key: EncodedKey,
34	direction: CounterDirection,
35}
36
37impl Counter {
38	/// Create counter with single-byte prefix key
39	pub fn with_prefix(node: FlowNodeId, prefix: u8, direction: CounterDirection) -> Self {
40		let mut serializer = KeySerializer::new();
41		serializer.extend_u8(prefix);
42		let key = EncodedKey::new(serializer.finish());
43		Self {
44			node,
45			key,
46			direction,
47		}
48	}
49
50	/// Create counter with custom key (e.g., subscription ID)
51	pub fn with_key(node: FlowNodeId, key: EncodedKey, direction: CounterDirection) -> Self {
52		Self {
53			node,
54			key,
55			direction,
56		}
57	}
58
59	/// Get next counter value (atomically: returns current, then increments/decrements)
60	pub fn next(&self, txn: &mut FlowTransaction) -> reifydb_type::Result<RowNumber> {
61		let current = self.load(txn)?;
62		let next_value = self.compute_next(current);
63		self.save(txn, next_value)?;
64		Ok(RowNumber(current))
65	}
66
67	/// Get current value without modifying
68	pub fn current(&self, txn: &mut FlowTransaction) -> reifydb_type::Result<u64> {
69		self.load(txn)
70	}
71
72	/// Set to specific value
73	pub fn set(&self, txn: &mut FlowTransaction, value: u64) -> reifydb_type::Result<()> {
74		self.save(txn, value)
75	}
76
77	// Internal methods
78	fn load(&self, txn: &mut FlowTransaction) -> reifydb_type::Result<u64> {
79		match internal_state_get(self.node, txn, &self.key)? {
80			None => Ok(self.default_value()),
81			Some(encoded) => {
82				let bytes = encoded.as_ref();
83				if bytes.len() >= 8 {
84					Ok(u64::from_be_bytes([
85						bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6],
86						bytes[7],
87					]))
88				} else {
89					Ok(self.default_value())
90				}
91			}
92		}
93	}
94
95	fn save(&self, txn: &mut FlowTransaction, value: u64) -> reifydb_type::Result<()> {
96		let bytes = value.to_be_bytes().to_vec();
97		internal_state_set(self.node, txn, &self.key, EncodedValues(CowVec::new(bytes)))?;
98		Ok(())
99	}
100
101	fn default_value(&self) -> u64 {
102		match self.direction {
103			CounterDirection::Ascending => 1,
104			CounterDirection::Descending => u64::MAX,
105		}
106	}
107
108	fn compute_next(&self, current: u64) -> u64 {
109		match self.direction {
110			CounterDirection::Ascending => current.wrapping_add(1),
111			CounterDirection::Descending => current.wrapping_sub(1),
112		}
113	}
114}
115
116#[cfg(test)]
117mod tests {
118	use reifydb_catalog::catalog::Catalog;
119	use reifydb_core::common::CommitVersion;
120	use reifydb_transaction::interceptor::interceptors::Interceptors;
121
122	use super::*;
123	use crate::operator::stateful::test_utils::test::*;
124
125	#[test]
126	fn test_counter_starts_at_one() {
127		let mut txn = create_test_transaction();
128		let mut txn = FlowTransaction::new(&mut txn, CommitVersion(1), Catalog::testing(), Interceptors::new());
129		let counter = Counter::with_prefix(FlowNodeId(1), b'T', CounterDirection::Ascending);
130
131		let value = counter.next(&mut txn).unwrap();
132		assert_eq!(value.0, 1);
133	}
134
135	#[test]
136	fn test_counter_increments() {
137		let mut txn = create_test_transaction();
138		let mut txn = FlowTransaction::new(&mut txn, CommitVersion(1), Catalog::testing(), Interceptors::new());
139		let counter = Counter::with_prefix(FlowNodeId(1), b'T', CounterDirection::Ascending);
140
141		let v1 = counter.next(&mut txn).unwrap();
142		let v2 = counter.next(&mut txn).unwrap();
143		let v3 = counter.next(&mut txn).unwrap();
144
145		assert_eq!(v1.0, 1);
146		assert_eq!(v2.0, 2);
147		assert_eq!(v3.0, 3);
148	}
149
150	#[test]
151	fn test_counter_persistence() {
152		let mut txn = create_test_transaction();
153		let mut txn = FlowTransaction::new(&mut txn, CommitVersion(1), Catalog::testing(), Interceptors::new());
154		let node = FlowNodeId(1);
155
156		// First counter instance
157		{
158			let counter = Counter::with_prefix(node, b'P', CounterDirection::Ascending);
159			counter.next(&mut txn).unwrap();
160			counter.next(&mut txn).unwrap();
161		}
162
163		// Second counter instance with same node and prefix
164		{
165			let counter = Counter::with_prefix(node, b'P', CounterDirection::Ascending);
166			let value = counter.next(&mut txn).unwrap();
167			// Should continue from where we left off
168			assert_eq!(value.0, 3);
169		}
170	}
171
172	#[test]
173	fn test_counter_current() {
174		let mut txn = create_test_transaction();
175		let mut txn = FlowTransaction::new(&mut txn, CommitVersion(1), Catalog::testing(), Interceptors::new());
176		let counter = Counter::with_prefix(FlowNodeId(1), b'T', CounterDirection::Ascending);
177
178		// First call returns default (1)
179		let current = counter.current(&mut txn).unwrap();
180		assert_eq!(current, 1);
181
182		// After next(), current should reflect the saved value
183		counter.next(&mut txn).unwrap();
184		let current = counter.current(&mut txn).unwrap();
185		assert_eq!(current, 2);
186
187		// current() should not modify the counter
188		let current_again = counter.current(&mut txn).unwrap();
189		assert_eq!(current_again, 2);
190	}
191
192	#[test]
193	fn test_counter_set() {
194		let mut txn = create_test_transaction();
195		let mut txn = FlowTransaction::new(&mut txn, CommitVersion(1), Catalog::testing(), Interceptors::new());
196		let counter = Counter::with_prefix(FlowNodeId(1), b'T', CounterDirection::Ascending);
197
198		// Set to a specific value
199		counter.set(&mut txn, 100).unwrap();
200
201		// Next should return 100 and advance to 101
202		let value = counter.next(&mut txn).unwrap();
203		assert_eq!(value.0, 100);
204
205		let value = counter.next(&mut txn).unwrap();
206		assert_eq!(value.0, 101);
207	}
208
209	#[test]
210	fn test_counter_with_custom_key() {
211		let mut txn = create_test_transaction();
212		let mut txn = FlowTransaction::new(&mut txn, CommitVersion(1), Catalog::testing(), Interceptors::new());
213
214		// Create a custom key
215		let custom_key = {
216			let mut serializer = KeySerializer::new();
217			serializer.extend_bytes(b"subscription-id-123");
218			EncodedKey::new(serializer.finish())
219		};
220
221		let counter = Counter::with_key(FlowNodeId(1), custom_key, CounterDirection::Ascending);
222
223		let v1 = counter.next(&mut txn).unwrap();
224		let v2 = counter.next(&mut txn).unwrap();
225
226		assert_eq!(v1.0, 1);
227		assert_eq!(v2.0, 2);
228	}
229
230	#[test]
231	fn test_multiple_counters_isolated() {
232		let mut txn = create_test_transaction();
233		let mut txn = FlowTransaction::new(&mut txn, CommitVersion(1), Catalog::testing(), Interceptors::new());
234		let node = FlowNodeId(1);
235
236		// Different prefixes should be isolated
237		let counter1 = Counter::with_prefix(node, b'A', CounterDirection::Ascending);
238		let counter2 = Counter::with_prefix(node, b'B', CounterDirection::Ascending);
239
240		let v1a = counter1.next(&mut txn).unwrap();
241		let v2a = counter2.next(&mut txn).unwrap();
242		let v1b = counter1.next(&mut txn).unwrap();
243		let v2b = counter2.next(&mut txn).unwrap();
244
245		// Each counter should maintain its own sequence
246		assert_eq!(v1a.0, 1);
247		assert_eq!(v2a.0, 1);
248		assert_eq!(v1b.0, 2);
249		assert_eq!(v2b.0, 2);
250	}
251
252	#[test]
253	fn test_different_nodes_isolated() {
254		let mut txn = create_test_transaction();
255		let mut txn = FlowTransaction::new(&mut txn, CommitVersion(1), Catalog::testing(), Interceptors::new());
256
257		// Same prefix, different nodes should be isolated
258		let counter1 = Counter::with_prefix(FlowNodeId(1), b'X', CounterDirection::Ascending);
259		let counter2 = Counter::with_prefix(FlowNodeId(2), b'X', CounterDirection::Ascending);
260
261		let v1 = counter1.next(&mut txn).unwrap();
262		let v2 = counter2.next(&mut txn).unwrap();
263
264		// Each node should have its own counter
265		assert_eq!(v1.0, 1);
266		assert_eq!(v2.0, 1);
267	}
268
269	#[test]
270	fn test_wrapping_behavior() {
271		let mut txn = create_test_transaction();
272		let mut txn = FlowTransaction::new(&mut txn, CommitVersion(1), Catalog::testing(), Interceptors::new());
273
274		// Test wrapping from MAX to 0
275		let counter = Counter::with_prefix(FlowNodeId(1), b'W', CounterDirection::Ascending);
276		counter.set(&mut txn, u64::MAX).unwrap();
277		let v1 = counter.next(&mut txn).unwrap();
278		let v2 = counter.next(&mut txn).unwrap();
279		assert_eq!(v1.0, u64::MAX);
280		assert_eq!(v2.0, 0); // Wraps to 0
281	}
282
283	#[test]
284	fn test_encoded_keys_sort_descending() {
285		// Verify that when counter values are encoded as keys,
286		// they sort in descending order
287		let mut serializer1 = KeySerializer::new();
288		serializer1.extend_u64(1u64);
289		let key1 = serializer1.finish();
290
291		let mut serializer2 = KeySerializer::new();
292		serializer2.extend_u64(2u64);
293		let key2 = serializer2.finish();
294
295		// Key from value 1 should be > key from value 2
296		// (descending order in key space)
297		assert!(key1 > key2, "encode(1) > encode(2) for descending order");
298	}
299
300	#[test]
301	fn test_counter_descending_starts_at_max() {
302		let mut txn = create_test_transaction();
303		let mut txn = FlowTransaction::new(&mut txn, CommitVersion(1), Catalog::testing(), Interceptors::new());
304		let counter = Counter::with_prefix(FlowNodeId(1), b'T', CounterDirection::Descending);
305
306		let value = counter.next(&mut txn).unwrap();
307		assert_eq!(value.0, u64::MAX);
308	}
309
310	#[test]
311	fn test_counter_descending_decrements() {
312		let mut txn = create_test_transaction();
313		let mut txn = FlowTransaction::new(&mut txn, CommitVersion(1), Catalog::testing(), Interceptors::new());
314		let counter = Counter::with_prefix(FlowNodeId(1), b'T', CounterDirection::Descending);
315
316		let v1 = counter.next(&mut txn).unwrap();
317		let v2 = counter.next(&mut txn).unwrap();
318		let v3 = counter.next(&mut txn).unwrap();
319
320		assert_eq!(v1.0, u64::MAX);
321		assert_eq!(v2.0, u64::MAX - 1);
322		assert_eq!(v3.0, u64::MAX - 2);
323	}
324
325	#[test]
326	fn test_counter_descending_wrapping() {
327		let mut txn = create_test_transaction();
328		let mut txn = FlowTransaction::new(&mut txn, CommitVersion(1), Catalog::testing(), Interceptors::new());
329		let counter = Counter::with_prefix(FlowNodeId(1), b'W', CounterDirection::Descending);
330
331		// Set to 1, next should give 1, then wrap to 0, then MAX
332		counter.set(&mut txn, 1).unwrap();
333		let v1 = counter.next(&mut txn).unwrap();
334		let v2 = counter.next(&mut txn).unwrap();
335		assert_eq!(v1.0, 1);
336		assert_eq!(v2.0, 0);
337		let v3 = counter.next(&mut txn).unwrap();
338		assert_eq!(v3.0, u64::MAX); // Wraps from 0 to MAX
339	}
340}