1use std::iter::once;
4
5use reifydb_core::{
6 EncodedKey,
7 interface::FlowNodeId,
8 key::{EncodableKey, FlowNodeInternalStateKey},
9 util::{CowVec, encoding::keycode::KeySerializer},
10 value::encoded::{EncodedKeyRange, EncodedValues},
11};
12use reifydb_type::RowNumber;
13
14use crate::{
15 operator::stateful::utils::{internal_state_get, internal_state_set},
16 transaction::FlowTransaction,
17};
18
19pub struct RowNumberProvider {
29 node: FlowNodeId,
30}
31
32impl RowNumberProvider {
33 pub fn new(node: FlowNodeId) -> Self {
35 Self {
36 node,
37 }
38 }
39
40 pub async fn get_or_create_row_numbers_batch<'a, I>(
44 &self,
45 txn: &mut FlowTransaction,
46 keys: I,
47 ) -> crate::Result<Vec<(RowNumber, bool)>>
48 where
49 I: IntoIterator<Item = &'a EncodedKey>,
50 {
51 let mut results = Vec::new();
52 let mut counter = self.load_counter(txn).await?;
53 let initial_counter = counter;
54
55 for key in keys {
56 let map_key = self.make_map_key(key);
57
58 if let Some(existing_row) = internal_state_get(self.node, txn, &map_key).await? {
59 let bytes = existing_row.as_ref();
60 if bytes.len() >= 8 {
61 let row_num = u64::from_be_bytes([
62 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6],
63 bytes[7],
64 ]);
65 results.push((RowNumber(row_num), false));
66 continue;
67 }
68 }
69
70 let new_row_number = RowNumber(counter);
71
72 let row_num_bytes = counter.to_be_bytes().to_vec();
74 internal_state_set(self.node, txn, &map_key, EncodedValues(CowVec::new(row_num_bytes)))?;
75
76 let reverse_key = self.make_reverse_map_key(new_row_number);
78 internal_state_set(
79 self.node,
80 txn,
81 &reverse_key,
82 EncodedValues(CowVec::new(key.as_ref().to_vec())),
83 )?;
84
85 results.push((new_row_number, true));
86 counter += 1;
87 }
88
89 if counter != initial_counter {
91 self.save_counter(txn, counter)?;
92 }
93
94 Ok(results)
95 }
96
97 pub async fn get_or_create_row_number(
101 &self,
102 txn: &mut FlowTransaction,
103 key: &EncodedKey,
104 ) -> crate::Result<(RowNumber, bool)> {
105 Ok(self.get_or_create_row_numbers_batch(txn, once(key)).await?.into_iter().next().unwrap())
106 }
107
108 pub async fn get_key_for_row_number(
110 &self,
111 txn: &mut FlowTransaction,
112 row_number: RowNumber,
113 ) -> crate::Result<Option<EncodedKey>> {
114 let reverse_key = self.make_reverse_map_key(row_number);
115 if let Some(key_bytes) = internal_state_get(self.node, txn, &reverse_key).await? {
116 Ok(Some(EncodedKey::new(key_bytes.as_ref().to_vec())))
117 } else {
118 Ok(None)
119 }
120 }
121
122 async fn load_counter(&self, txn: &mut FlowTransaction) -> crate::Result<u64> {
124 let key = self.make_counter_key();
125 match internal_state_get(self.node, txn, &key).await? {
126 None => Ok(1), Some(state_row) => {
128 let bytes = state_row.as_ref();
130 if bytes.len() >= 8 {
131 Ok(u64::from_be_bytes([
132 bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6],
133 bytes[7],
134 ]))
135 } else {
136 Ok(1)
137 }
138 }
139 }
140 }
141
142 fn save_counter(&self, txn: &mut FlowTransaction, counter: u64) -> crate::Result<()> {
144 let key = self.make_counter_key();
145 let value = EncodedValues(CowVec::new(counter.to_be_bytes().to_vec()));
146 internal_state_set(self.node, txn, &key, value)?;
147 Ok(())
148 }
149
150 fn make_counter_key(&self) -> EncodedKey {
152 let mut serializer = KeySerializer::new();
153 serializer.extend_u8(b'C'); EncodedKey::new(serializer.finish())
155 }
156
157 fn make_map_key(&self, key: &EncodedKey) -> EncodedKey {
159 let mut serializer = KeySerializer::new();
160 serializer.extend_u8(b'M'); serializer.extend_bytes(key.as_ref());
162 EncodedKey::new(serializer.finish())
163 }
164
165 fn make_reverse_map_key(&self, row_number: RowNumber) -> EncodedKey {
167 let mut serializer = KeySerializer::new();
168 serializer.extend_u8(b'R'); serializer.extend_u64(row_number.0);
170 EncodedKey::new(serializer.finish())
171 }
172
173 pub async fn remove_by_prefix(&self, txn: &mut FlowTransaction, key_prefix: &[u8]) -> crate::Result<()> {
176 let mut prefix = Vec::new();
178 let mut serializer = KeySerializer::new();
179 serializer.extend_u8(b'M'); prefix.extend_from_slice(&serializer.finish());
181 prefix.extend_from_slice(key_prefix);
182
183 let state_prefix = FlowNodeInternalStateKey::new(self.node, prefix.clone());
184 let full_range = EncodedKeyRange::prefix(&state_prefix.encode());
185
186 let batch = txn.range(full_range).await?;
187 let keys_to_remove: Vec<_> = batch.items.into_iter().map(|multi| multi.key).collect();
188
189 for key in keys_to_remove {
190 txn.remove(&key)?;
191 }
192
193 Ok(())
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use reifydb_core::CommitVersion;
200
201 use super::*;
202 use crate::operator::stateful::test_utils::test::*;
203
204 #[tokio::test]
205 async fn test_first_row_number() {
206 let mut txn = create_test_transaction().await;
207 let mut txn = FlowTransaction::new(&mut txn, CommitVersion(1)).await;
208 let provider = RowNumberProvider::new(FlowNodeId(1));
209
210 let key = test_key("first");
211 let (row_num, is_new) = provider.get_or_create_row_number(&mut txn, &key).await.unwrap();
212
213 assert_eq!(row_num.0, 1);
214 assert!(is_new);
215 }
216
217 #[tokio::test]
218 async fn test_duplicate_key_same_row_number() {
219 let mut txn = create_test_transaction().await;
220 let mut txn = FlowTransaction::new(&mut txn, CommitVersion(1)).await;
221 let provider = RowNumberProvider::new(FlowNodeId(1));
222
223 let key = test_key("duplicate");
224
225 let (row_num1, is_new1) = provider.get_or_create_row_number(&mut txn, &key).await.unwrap();
227 assert_eq!(row_num1.0, 1);
228 assert!(is_new1);
229
230 let (row_num2, is_new2) = provider.get_or_create_row_number(&mut txn, &key).await.unwrap();
232 assert_eq!(row_num2.0, 1);
233 assert!(!is_new2);
234
235 assert_eq!(row_num1, row_num2);
237 }
238
239 #[tokio::test]
240 async fn test_sequential_row_numbers() {
241 let mut txn = create_test_transaction().await;
242 let mut txn = FlowTransaction::new(&mut txn, CommitVersion(1)).await;
243 let provider = RowNumberProvider::new(FlowNodeId(1));
244
245 for i in 1..=5 {
247 let key = test_key(&format!("key_{}", i));
248 let (row_num, is_new) = provider.get_or_create_row_number(&mut txn, &key).await.unwrap();
249
250 assert_eq!(row_num.0, i as u64);
251 assert!(is_new);
252 }
253 }
254
255 #[tokio::test]
256 async fn test_mixed_new_and_existing() {
257 let mut txn = create_test_transaction().await;
258 let mut txn = FlowTransaction::new(&mut txn, CommitVersion(1)).await;
259 let provider = RowNumberProvider::new(FlowNodeId(1));
260
261 let key1 = test_key("mixed_1");
263 let key2 = test_key("mixed_2");
264 let key3 = test_key("mixed_3");
265
266 let (rn1, new1) = provider.get_or_create_row_number(&mut txn, &key1).await.unwrap();
268 let (rn2, new2) = provider.get_or_create_row_number(&mut txn, &key2).await.unwrap();
269 let (rn3, new3) = provider.get_or_create_row_number(&mut txn, &key3).await.unwrap();
270
271 assert_eq!(rn1.0, 1);
272 assert!(new1);
273 assert_eq!(rn2.0, 2);
274 assert!(new2);
275 assert_eq!(rn3.0, 3);
276 assert!(new3);
277
278 let key4 = test_key("mixed_4");
280 let (rn2_again, new2_again) = provider.get_or_create_row_number(&mut txn, &key2).await.unwrap();
281 let (rn4, new4) = provider.get_or_create_row_number(&mut txn, &key4).await.unwrap();
282 let (rn1_again, new1_again) = provider.get_or_create_row_number(&mut txn, &key1).await.unwrap();
283
284 assert_eq!(rn2_again.0, 2);
285 assert!(!new2_again);
286 assert_eq!(rn4.0, 4); assert!(new4);
288 assert_eq!(rn1_again.0, 1);
289 assert!(!new1_again);
290 }
291
292 #[tokio::test]
293 async fn test_multiple_providers_isolated() {
294 let mut txn = create_test_transaction().await;
295 let mut txn = FlowTransaction::new(&mut txn, CommitVersion(1)).await;
296 let provider1 = RowNumberProvider::new(FlowNodeId(1));
297 let provider2 = RowNumberProvider::new(FlowNodeId(2));
298
299 let key = test_key("shared_key");
300
301 let (rn1, _) = provider1.get_or_create_row_number(&mut txn, &key).await.unwrap();
303 let (rn2, _) = provider2.get_or_create_row_number(&mut txn, &key).await.unwrap();
304
305 assert_eq!(rn1.0, 1);
306 assert_eq!(rn2.0, 1);
307
308 let key2 = test_key("key2");
310 let (rn1_2, _) = provider1.get_or_create_row_number(&mut txn, &key2).await.unwrap();
311 assert_eq!(rn1_2.0, 2);
312
313 let (rn2_2, _) = provider2.get_or_create_row_number(&mut txn, &key2).await.unwrap();
315 assert_eq!(rn2_2.0, 2);
316 }
317
318 #[tokio::test]
319 async fn test_counter_persistence() {
320 let mut txn = create_test_transaction().await;
321 let mut txn = FlowTransaction::new(&mut txn, CommitVersion(1)).await;
322 let operator = TestOperator::simple(FlowNodeId(1));
323 let provider = RowNumberProvider::new(FlowNodeId(1));
324
325 for i in 1..=3 {
327 let key = test_key(&format!("persist_{}", i));
328 let (rn, _) = provider.get_or_create_row_number(&mut txn, &key).await.unwrap();
329 assert_eq!(rn.0, i as u64);
330 }
331
332 let new_key = test_key("persist_new");
334 let (rn, is_new) = provider.get_or_create_row_number(&mut txn, &new_key).await.unwrap();
335
336 assert_eq!(rn.0, 4);
338 assert!(is_new);
339 }
340
341 #[tokio::test]
342 async fn test_large_row_numbers() {
343 let mut txn = create_test_transaction().await;
344 let mut txn = FlowTransaction::new(&mut txn, CommitVersion(1)).await;
345 let operator = TestOperator::simple(FlowNodeId(1));
346 let provider = RowNumberProvider::new(FlowNodeId(1));
347
348 for i in 1..=1000 {
350 let key = test_key(&format!("large_{}", i));
351 let (rn, is_new) = provider.get_or_create_row_number(&mut txn, &key).await.unwrap();
352 assert_eq!(rn.0, i as u64);
353 assert!(is_new);
354 }
355
356 let key = test_key("large_1");
358 let (rn, is_new) = provider.get_or_create_row_number(&mut txn, &key).await.unwrap();
359 assert_eq!(rn.0, 1);
360 assert!(!is_new);
361
362 let key = test_key("large_1001");
364 let (rn, is_new) = provider.get_or_create_row_number(&mut txn, &key).await.unwrap();
365 assert_eq!(rn.0, 1001);
366 assert!(is_new);
367 }
368
369 #[tokio::test]
370 async fn test_batch_mixed_existing_and_new_keys() {
371 let mut txn = create_test_transaction().await;
372 let mut txn = FlowTransaction::new(&mut txn, CommitVersion(1)).await;
373 let operator = TestOperator::simple(FlowNodeId(1));
374 let provider = RowNumberProvider::new(FlowNodeId(1));
375
376 let key1 = test_key("batch_key_1");
378 let key2 = test_key("batch_key_2");
379 let key3 = test_key("batch_key_3");
380
381 let (rn1, _) = provider.get_or_create_row_number(&mut txn, &key1).await.unwrap();
382 assert_eq!(rn1.0, 1);
383
384 let (rn2, _) = provider.get_or_create_row_number(&mut txn, &key2).await.unwrap();
385 assert_eq!(rn2.0, 2);
386
387 let (rn3, _) = provider.get_or_create_row_number(&mut txn, &key3).await.unwrap();
388 assert_eq!(rn3.0, 3);
389
390 let key4 = test_key("batch_key_4");
392 let key5 = test_key("batch_key_5");
393
394 let batch_keys = vec![&key2, &key4, &key1, &key5, &key3];
396
397 let results = provider.get_or_create_row_numbers_batch(&mut txn, batch_keys.into_iter()).await.unwrap();
398
399 assert_eq!(results.len(), 5);
401
402 assert_eq!(results[0].0.0, 2);
404 assert!(!results[0].1);
405
406 assert_eq!(results[1].0.0, 4);
408 assert!(results[1].1);
409
410 assert_eq!(results[2].0.0, 1);
412 assert!(!results[2].1);
413
414 assert_eq!(results[3].0.0, 5);
416 assert!(results[3].1);
417
418 assert_eq!(results[4].0.0, 3);
420 assert!(!results[4].1);
421
422 let key6 = test_key("batch_key_6");
425 let (rn6, is_new6) = provider.get_or_create_row_number(&mut txn, &key6).await.unwrap();
426 assert_eq!(rn6.0, 6);
427 assert!(is_new6);
428
429 let (check_rn4, is_new4) = provider.get_or_create_row_number(&mut txn, &key4).await.unwrap();
431 assert_eq!(check_rn4.0, 4);
432 assert!(!is_new4);
433
434 let (check_rn5, is_new5) = provider.get_or_create_row_number(&mut txn, &key5).await.unwrap();
435 assert_eq!(check_rn5.0, 5);
436 assert!(!is_new5);
437
438 let reverse_key4 = provider.get_key_for_row_number(&mut txn, RowNumber(4)).await.unwrap();
440 assert_eq!(reverse_key4, Some(key4));
441
442 let reverse_key5 = provider.get_key_for_row_number(&mut txn, RowNumber(5)).await.unwrap();
443 assert_eq!(reverse_key5, Some(key5));
444
445 let reverse_key1 = provider.get_key_for_row_number(&mut txn, RowNumber(1)).await.unwrap();
447 assert_eq!(reverse_key1, Some(key1));
448
449 let reverse_key2 = provider.get_key_for_row_number(&mut txn, RowNumber(2)).await.unwrap();
450 assert_eq!(reverse_key2, Some(key2));
451 }
452}