Skip to main content

reifydb_sub_flow/transaction/
state.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_core::{
5	encoded::{
6		encoded::EncodedValues,
7		key::{EncodedKey, EncodedKeyRange},
8		schema::Schema,
9	},
10	interface::{catalog::flow::FlowNodeId, store::MultiVersionBatch},
11	key::{EncodableKey, flow_node_state::FlowNodeStateKey},
12};
13use reifydb_type::Result;
14use tracing::{Span, field, instrument};
15
16use super::FlowTransaction;
17
18impl FlowTransaction {
19	/// Get state for a specific flow node and key
20	#[instrument(name = "flow::state::get", level = "trace", skip(self), fields(
21		node_id = id.0,
22		key_len = key.as_bytes().len(),
23		found = field::Empty
24	))]
25	pub fn state_get(&mut self, id: FlowNodeId, key: &EncodedKey) -> Result<Option<EncodedValues>> {
26		let state_key = FlowNodeStateKey::new(id, key.as_ref().to_vec());
27		let encoded_key = state_key.encode();
28		let result = self.get(&encoded_key)?;
29		Span::current().record("found", result.is_some());
30		Ok(result)
31	}
32
33	/// Set state for a specific flow node and key
34	#[instrument(name = "flow::state::set", level = "trace", skip(self, value), fields(
35		node_id = id.0,
36		key_len = key.as_bytes().len(),
37		value_len = value.as_ref().len()
38	))]
39	pub fn state_set(&mut self, id: FlowNodeId, key: &EncodedKey, value: EncodedValues) -> Result<()> {
40		let state_key = FlowNodeStateKey::new(id, key.as_ref().to_vec());
41		let encoded_key = state_key.encode();
42		self.set(&encoded_key, value)
43	}
44
45	/// Remove state for a specific flow node and key
46	#[instrument(name = "flow::state::remove", level = "trace", skip(self), fields(
47		node_id = id.0,
48		key_len = key.as_bytes().len()
49	))]
50	pub fn state_remove(&mut self, id: FlowNodeId, key: &EncodedKey) -> Result<()> {
51		let state_key = FlowNodeStateKey::new(id, key.as_ref().to_vec());
52		let encoded_key = state_key.encode();
53		self.remove(&encoded_key)
54	}
55
56	/// Scan all state for a specific flow node
57	#[instrument(name = "flow::state::scan", level = "debug", skip(self), fields(
58		node_id = id.0,
59		result_count = field::Empty
60	))]
61	pub fn state_scan(&mut self, id: FlowNodeId) -> Result<MultiVersionBatch> {
62		let range = FlowNodeStateKey::node_range(id);
63		let mut iter = self.range(range, 1024);
64		let mut items = Vec::new();
65		while let Some(result) = iter.next() {
66			items.push(result?);
67		}
68		Span::current().record("result_count", items.len());
69		Ok(MultiVersionBatch {
70			items,
71			has_more: false,
72		})
73	}
74
75	/// Range query on state for a specific flow node
76	#[instrument(name = "flow::state::range", level = "debug", skip(self, range), fields(
77		node_id = id.0
78	))]
79	pub fn state_range(&mut self, id: FlowNodeId, range: EncodedKeyRange) -> Result<MultiVersionBatch> {
80		let prefixed_range = range.with_prefix(FlowNodeStateKey::encoded(id, vec![]));
81		let mut iter = self.range(prefixed_range, 1024);
82		let mut items = Vec::new();
83		while let Some(result) = iter.next() {
84			items.push(result?);
85		}
86		Ok(MultiVersionBatch {
87			items,
88			has_more: false,
89		})
90	}
91
92	/// Clear all state for a specific flow node
93	#[instrument(name = "flow::state::clear", level = "trace", skip(self), fields(
94		node_id = id.0,
95		keys_removed = field::Empty
96	))]
97	pub fn state_clear(&mut self, id: FlowNodeId) -> Result<()> {
98		// Phase 1: Scan to collect all keys
99		let keys_to_remove = self.scan_keys_for_clear(id)?;
100
101		// Phase 2: Remove all collected keys
102		let count = keys_to_remove.len();
103		self.remove_keys(keys_to_remove)?;
104
105		Span::current().record("keys_removed", count);
106		Ok(())
107	}
108
109	/// Scan and collect all keys for a node (used by state_clear)
110	#[inline]
111	#[instrument(name = "flow::state::clear::scan", level = "trace", skip(self), fields(node_id = id.0))]
112	fn scan_keys_for_clear(&mut self, id: FlowNodeId) -> Result<Vec<EncodedKey>> {
113		let range = FlowNodeStateKey::node_range(id);
114		let mut iter = self.range(range, 1024);
115		let mut keys = Vec::new();
116		while let Some(result) = iter.next() {
117			let multi = result?;
118			keys.push(multi.key);
119		}
120		Ok(keys)
121	}
122
123	/// Remove a list of keys (used by state_clear)
124	#[inline]
125	#[instrument(name = "flow::state::clear::remove", level = "trace", skip(self, keys), fields(count = keys.len()))]
126	fn remove_keys(&mut self, keys: Vec<EncodedKey>) -> Result<()> {
127		for key in keys {
128			self.remove(&key)?;
129		}
130		Ok(())
131	}
132
133	/// Load state for a key, creating if not exists
134	#[instrument(name = "flow::state::load_or_create", level = "debug", skip(self, schema), fields(
135		node_id = id.0,
136		key_len = key.as_bytes().len(),
137		created
138	))]
139	pub fn load_or_create_row(
140		&mut self,
141		id: FlowNodeId,
142		key: &EncodedKey,
143		schema: &Schema,
144	) -> Result<EncodedValues> {
145		match self.state_get(id, key)? {
146			Some(row) => {
147				Span::current().record("created", false);
148				Ok(row)
149			}
150			None => {
151				Span::current().record("created", true);
152				Ok(schema.allocate())
153			}
154		}
155	}
156
157	/// Save state encoded
158	#[instrument(name = "flow::state::save", level = "trace", skip(self, row), fields(
159		node_id = id.0,
160		key_len = key.as_bytes().len()
161	))]
162	pub fn save_row(&mut self, id: FlowNodeId, key: &EncodedKey, row: EncodedValues) -> Result<()> {
163		self.state_set(id, key, row)
164	}
165}
166
167#[cfg(test)]
168pub mod tests {
169	use std::collections::Bound;
170
171	use reifydb_catalog::catalog::Catalog;
172	use reifydb_core::{
173		common::CommitVersion,
174		encoded::{
175			encoded::EncodedValues,
176			key::{EncodedKey, EncodedKeyRange},
177			schema::Schema,
178		},
179		interface::catalog::flow::FlowNodeId,
180	};
181	use reifydb_transaction::interceptor::interceptors::Interceptors;
182	use reifydb_type::{util::cowvec::CowVec, value::r#type::Type};
183
184	use super::*;
185	use crate::operator::stateful::test_utils::test::create_test_transaction;
186
187	fn make_key(s: &str) -> EncodedKey {
188		EncodedKey::new(s.as_bytes().to_vec())
189	}
190
191	fn make_value(s: &str) -> EncodedValues {
192		EncodedValues(CowVec::new(s.as_bytes().to_vec()))
193	}
194
195	#[test]
196	fn test_state_get_set() {
197		let parent = create_test_transaction();
198		let mut txn =
199			FlowTransaction::deferred(&parent, CommitVersion(1), Catalog::testing(), Interceptors::new());
200
201		let node_id = FlowNodeId(1);
202		let key = make_key("state_key");
203		let value = make_value("state_value");
204
205		// Set state
206		txn.state_set(node_id, &key, value.clone()).unwrap();
207
208		// Get state back
209		let result = txn.state_get(node_id, &key).unwrap();
210		assert_eq!(result, Some(value));
211	}
212
213	#[test]
214	fn test_state_get_nonexistent() {
215		let parent = create_test_transaction();
216		let mut txn =
217			FlowTransaction::deferred(&parent, CommitVersion(1), Catalog::testing(), Interceptors::new());
218
219		let node_id = FlowNodeId(1);
220		let key = make_key("missing");
221
222		let result = txn.state_get(node_id, &key).unwrap();
223		assert_eq!(result, None);
224	}
225
226	#[test]
227	fn test_state_remove() {
228		let parent = create_test_transaction();
229		let mut txn =
230			FlowTransaction::deferred(&parent, CommitVersion(1), Catalog::testing(), Interceptors::new());
231
232		let node_id = FlowNodeId(1);
233		let key = make_key("state_key");
234		let value = make_value("state_value");
235
236		// Set then remove
237		txn.state_set(node_id, &key, value.clone()).unwrap();
238		assert_eq!(txn.state_get(node_id, &key).unwrap(), Some(value));
239
240		txn.state_remove(node_id, &key).unwrap();
241		assert_eq!(txn.state_get(node_id, &key).unwrap(), None);
242	}
243
244	#[test]
245	fn test_state_isolation_between_nodes() {
246		let parent = create_test_transaction();
247		let mut txn =
248			FlowTransaction::deferred(&parent, CommitVersion(1), Catalog::testing(), Interceptors::new());
249
250		let node1 = FlowNodeId(1);
251		let node2 = FlowNodeId(2);
252		let key = make_key("same_key");
253
254		txn.state_set(node1, &key, make_value("node1_value")).unwrap();
255		txn.state_set(node2, &key, make_value("node2_value")).unwrap();
256
257		// Each node should have its own value
258		assert_eq!(txn.state_get(node1, &key).unwrap(), Some(make_value("node1_value")));
259		assert_eq!(txn.state_get(node2, &key).unwrap(), Some(make_value("node2_value")));
260	}
261
262	#[test]
263	fn test_state_scan() {
264		let parent = create_test_transaction();
265		let mut txn =
266			FlowTransaction::deferred(&parent, CommitVersion(1), Catalog::testing(), Interceptors::new());
267
268		let node_id = FlowNodeId(1);
269
270		txn.state_set(node_id, &make_key("key1"), make_value("value1")).unwrap();
271		txn.state_set(node_id, &make_key("key2"), make_value("value2")).unwrap();
272		txn.state_set(node_id, &make_key("key3"), make_value("value3")).unwrap();
273
274		let iter = txn.state_scan(node_id).unwrap();
275		let items: Vec<_> = iter.items.into_iter().collect();
276
277		assert_eq!(items.len(), 3);
278	}
279
280	#[test]
281	fn test_state_scan_only_own_node() {
282		let parent = create_test_transaction();
283		let mut txn =
284			FlowTransaction::deferred(&parent, CommitVersion(1), Catalog::testing(), Interceptors::new());
285
286		let node1 = FlowNodeId(1);
287		let node2 = FlowNodeId(2);
288
289		txn.state_set(node1, &make_key("key1"), make_value("value1")).unwrap();
290		txn.state_set(node1, &make_key("key2"), make_value("value2")).unwrap();
291		txn.state_set(node2, &make_key("key3"), make_value("value3")).unwrap();
292
293		// Scan node1 should only return node1's state
294		let items: Vec<_> = txn.state_scan(node1).unwrap().items.into_iter().collect();
295		assert_eq!(items.len(), 2);
296
297		// Scan node2 should only return node2's state
298		let items: Vec<_> = txn.state_scan(node2).unwrap().items.into_iter().collect();
299		assert_eq!(items.len(), 1);
300	}
301
302	#[test]
303	fn test_state_scan_empty() {
304		let parent = create_test_transaction();
305		let mut txn =
306			FlowTransaction::deferred(&parent, CommitVersion(1), Catalog::testing(), Interceptors::new());
307
308		let node_id = FlowNodeId(1);
309
310		let iter = txn.state_scan(node_id).unwrap();
311		assert!(iter.items.into_iter().next().is_none());
312	}
313
314	#[test]
315	fn test_state_range() {
316		let parent = create_test_transaction();
317		let mut txn =
318			FlowTransaction::deferred(&parent, CommitVersion(1), Catalog::testing(), Interceptors::new());
319
320		let node_id = FlowNodeId(1);
321
322		txn.state_set(node_id, &make_key("a"), make_value("1")).unwrap();
323		txn.state_set(node_id, &make_key("b"), make_value("2")).unwrap();
324		txn.state_set(node_id, &make_key("c"), make_value("3")).unwrap();
325		txn.state_set(node_id, &make_key("d"), make_value("4")).unwrap();
326
327		// Range query from "b" to "d" (exclusive)
328		let range = EncodedKeyRange::new(Bound::Included(make_key("b")), Bound::Excluded(make_key("d")));
329		let iter = txn.state_range(node_id, range).unwrap();
330		let items: Vec<_> = iter.items.into_iter().collect();
331
332		// Should only include "b" and "c"
333		assert_eq!(items.len(), 2);
334	}
335
336	#[test]
337	fn test_state_clear() {
338		let parent = create_test_transaction();
339		let mut txn =
340			FlowTransaction::deferred(&parent, CommitVersion(1), Catalog::testing(), Interceptors::new());
341
342		let node_id = FlowNodeId(1);
343
344		txn.state_set(node_id, &make_key("key1"), make_value("value1")).unwrap();
345		txn.state_set(node_id, &make_key("key2"), make_value("value2")).unwrap();
346		txn.state_set(node_id, &make_key("key3"), make_value("value3")).unwrap();
347
348		// Verify state exists
349		assert_eq!(txn.state_scan(node_id).unwrap().items.into_iter().count(), 3);
350
351		// Clear all state
352		txn.state_clear(node_id).unwrap();
353
354		// Verify state is empty
355		assert_eq!(txn.state_scan(node_id).unwrap().items.into_iter().count(), 0);
356	}
357
358	#[test]
359	fn test_state_clear_only_own_node() {
360		let parent = create_test_transaction();
361		let mut txn =
362			FlowTransaction::deferred(&parent, CommitVersion(1), Catalog::testing(), Interceptors::new());
363
364		let node1 = FlowNodeId(1);
365		let node2 = FlowNodeId(2);
366
367		txn.state_set(node1, &make_key("key1"), make_value("value1")).unwrap();
368		txn.state_set(node1, &make_key("key2"), make_value("value2")).unwrap();
369		txn.state_set(node2, &make_key("key3"), make_value("value3")).unwrap();
370
371		// Clear node1
372		txn.state_clear(node1).unwrap();
373
374		// Node1 should be empty
375		assert_eq!(txn.state_scan(node1).unwrap().items.into_iter().count(), 0);
376
377		// Node2 should still have state
378		assert_eq!(txn.state_scan(node2).unwrap().items.into_iter().count(), 1);
379	}
380
381	#[test]
382	fn test_state_clear_empty_node() {
383		let parent = create_test_transaction();
384		let mut txn =
385			FlowTransaction::deferred(&parent, CommitVersion(1), Catalog::testing(), Interceptors::new());
386
387		let node_id = FlowNodeId(1);
388
389		// Clear on empty node should not error
390		txn.state_clear(node_id).unwrap();
391	}
392
393	#[test]
394	fn test_load_or_create_existing() {
395		let parent = create_test_transaction();
396		let mut txn =
397			FlowTransaction::deferred(&parent, CommitVersion(1), Catalog::testing(), Interceptors::new());
398
399		let node_id = FlowNodeId(1);
400		let key = make_key("key1");
401		let value = make_value("existing");
402		let schema = Schema::testing(&[Type::Int8, Type::Float8]);
403
404		// Set existing state
405		txn.state_set(node_id, &key, value.clone()).unwrap();
406
407		// load_or_create should return existing value
408		let result = txn.load_or_create_row(node_id, &key, &schema).unwrap();
409		assert_eq!(result, value);
410	}
411
412	#[test]
413	fn test_load_or_create_new() {
414		let parent = create_test_transaction();
415		let mut txn =
416			FlowTransaction::deferred(&parent, CommitVersion(1), Catalog::testing(), Interceptors::new());
417
418		let node_id = FlowNodeId(1);
419		let key = make_key("key1");
420		let schema = Schema::testing(&[Type::Int8, Type::Float8]);
421
422		// load_or_create should allocate new row
423		let result = txn.load_or_create_row(node_id, &key, &schema).unwrap();
424
425		// Result should be a newly allocated row (schema.allocate())
426		assert!(!result.as_ref().is_empty());
427	}
428
429	#[test]
430	fn test_save_row() {
431		let parent = create_test_transaction();
432		let mut txn =
433			FlowTransaction::deferred(&parent, CommitVersion(1), Catalog::testing(), Interceptors::new());
434
435		let node_id = FlowNodeId(1);
436		let key = make_key("key1");
437		let row = make_value("row_data");
438
439		txn.save_row(node_id, &key, row.clone()).unwrap();
440
441		// Verify saved
442		let result = txn.state_get(node_id, &key).unwrap();
443		assert_eq!(result, Some(row));
444	}
445
446	#[test]
447	fn test_state_multiple_nodes() {
448		let parent = create_test_transaction();
449		let mut txn =
450			FlowTransaction::deferred(&parent, CommitVersion(1), Catalog::testing(), Interceptors::new());
451
452		let node1 = FlowNodeId(1);
453		let node2 = FlowNodeId(2);
454		let node3 = FlowNodeId(3);
455
456		txn.state_set(node1, &make_key("a"), make_value("n1_a")).unwrap();
457		txn.state_set(node1, &make_key("b"), make_value("n1_b")).unwrap();
458		txn.state_set(node2, &make_key("a"), make_value("n2_a")).unwrap();
459		txn.state_set(node3, &make_key("c"), make_value("n3_c")).unwrap();
460
461		// Verify each node has correct state
462		assert_eq!(txn.state_get(node1, &make_key("a")).unwrap(), Some(make_value("n1_a")));
463		assert_eq!(txn.state_get(node1, &make_key("b")).unwrap(), Some(make_value("n1_b")));
464		assert_eq!(txn.state_get(node2, &make_key("a")).unwrap(), Some(make_value("n2_a")));
465		assert_eq!(txn.state_get(node3, &make_key("c")).unwrap(), Some(make_value("n3_c")));
466
467		// Cross-node keys should not exist
468		assert_eq!(txn.state_get(node2, &make_key("b")).unwrap(), None);
469		assert_eq!(txn.state_get(node3, &make_key("a")).unwrap(), None);
470	}
471}