reifydb_sub_flow/transaction/
state.rs

1// Copyright (c) reifydb.com 2025
2// This file is licensed under the AGPL-3.0-or-later, see license.md file
3
4use reifydb_core::{
5	EncodedKey, EncodedKeyRange,
6	interface::{FlowNodeId, MultiVersionBatch},
7	key::{EncodableKey, FlowNodeStateKey},
8	value::encoded::{EncodedValues, EncodedValuesLayout},
9};
10use tracing::instrument;
11
12use super::FlowTransaction;
13
14impl FlowTransaction {
15	/// Get state for a specific flow node and key
16	#[instrument(name = "flow::state::get", level = "trace", skip(self), fields(
17		node_id = id.0,
18		key_len = key.as_bytes().len(),
19		found
20	))]
21	pub async fn state_get(&mut self, id: FlowNodeId, key: &EncodedKey) -> crate::Result<Option<EncodedValues>> {
22		self.metrics.increment_state_operations();
23		let state_key = FlowNodeStateKey::new(id, key.as_ref().to_vec());
24		let encoded_key = state_key.encode();
25		let result = self.get(&encoded_key).await?;
26		tracing::Span::current().record("found", result.is_some());
27		Ok(result)
28	}
29
30	/// Set state for a specific flow node and key
31	#[instrument(name = "flow::state::set", level = "trace", skip(self, value), fields(
32		node_id = id.0,
33		key_len = key.as_bytes().len(),
34		value_len = value.as_ref().len()
35	))]
36	pub fn state_set(&mut self, id: FlowNodeId, key: &EncodedKey, value: EncodedValues) -> crate::Result<()> {
37		self.metrics.increment_state_operations();
38		let state_key = FlowNodeStateKey::new(id, key.as_ref().to_vec());
39		let encoded_key = state_key.encode();
40		self.set(&encoded_key, value)
41	}
42
43	/// Remove state for a specific flow node and key
44	#[instrument(name = "flow::state::remove", level = "trace", skip(self), fields(
45		node_id = id.0,
46		key_len = key.as_bytes().len()
47	))]
48	pub fn state_remove(&mut self, id: FlowNodeId, key: &EncodedKey) -> crate::Result<()> {
49		self.metrics.increment_state_operations();
50		let state_key = FlowNodeStateKey::new(id, key.as_ref().to_vec());
51		let encoded_key = state_key.encode();
52		self.remove(&encoded_key)
53	}
54
55	/// Scan all state for a specific flow node
56	#[instrument(name = "flow::state::scan", level = "debug", skip(self), fields(
57		node_id = id.0
58	))]
59	pub async fn state_scan(&mut self, id: FlowNodeId) -> crate::Result<MultiVersionBatch> {
60		self.metrics.increment_state_operations();
61		let range = FlowNodeStateKey::node_range(id);
62		self.range(range).await
63	}
64
65	/// Range query on state for a specific flow node
66	#[instrument(name = "flow::state::range", level = "debug", skip(self, range), fields(
67		node_id = id.0
68	))]
69	pub async fn state_range(
70		&mut self,
71		id: FlowNodeId,
72		range: EncodedKeyRange,
73	) -> crate::Result<MultiVersionBatch> {
74		self.metrics.increment_state_operations();
75		let prefixed_range = range.with_prefix(FlowNodeStateKey::encoded(id, vec![]));
76		self.range(prefixed_range).await
77	}
78
79	/// Clear all state for a specific flow node
80	#[instrument(name = "flow::state::clear", level = "debug", skip(self), fields(
81		node_id = id.0,
82		removed_count
83	))]
84	pub async fn state_clear(&mut self, id: FlowNodeId) -> crate::Result<()> {
85		self.metrics.increment_state_operations();
86		let range = FlowNodeStateKey::node_range(id);
87		let batch = self.range(range).await?;
88		let keys_to_remove: Vec<_> = batch.items.into_iter().map(|multi| multi.key).collect();
89
90		let count = keys_to_remove.len();
91		for key in keys_to_remove {
92			self.remove(&key)?;
93		}
94
95		tracing::Span::current().record("removed_count", count);
96		Ok(())
97	}
98
99	/// Load state for a key, creating if not exists
100	#[instrument(name = "flow::state::load_or_create", level = "debug", skip(self, layout), fields(
101		node_id = id.0,
102		key_len = key.as_bytes().len(),
103		created
104	))]
105	pub async fn load_or_create_row(
106		&mut self,
107		id: FlowNodeId,
108		key: &EncodedKey,
109		layout: &EncodedValuesLayout,
110	) -> crate::Result<EncodedValues> {
111		match self.state_get(id, key).await? {
112			Some(row) => {
113				tracing::Span::current().record("created", false);
114				Ok(row)
115			}
116			None => {
117				tracing::Span::current().record("created", true);
118				Ok(layout.allocate())
119			}
120		}
121	}
122
123	/// Save state encoded
124	#[instrument(name = "flow::state::save", level = "trace", skip(self, row), fields(
125		node_id = id.0,
126		key_len = key.as_bytes().len()
127	))]
128	pub fn save_row(&mut self, id: FlowNodeId, key: &EncodedKey, row: EncodedValues) -> crate::Result<()> {
129		self.state_set(id, key, row)
130	}
131}
132
133#[cfg(test)]
134mod tests {
135	use reifydb_core::{
136		CommitVersion, CowVec, EncodedKey, EncodedKeyRange, interface::FlowNodeId,
137		value::encoded::EncodedValues,
138	};
139	use reifydb_type::Type;
140
141	use super::*;
142	use crate::operator::stateful::test_utils::test::create_test_transaction;
143
144	fn make_key(s: &str) -> EncodedKey {
145		EncodedKey::new(s.as_bytes().to_vec())
146	}
147
148	fn make_value(s: &str) -> EncodedValues {
149		EncodedValues(CowVec::new(s.as_bytes().to_vec()))
150	}
151
152	#[tokio::test]
153	async fn test_state_get_set() {
154		let parent = create_test_transaction().await;
155		let mut txn = FlowTransaction::new(&parent, CommitVersion(1)).await;
156
157		let node_id = FlowNodeId(1);
158		let key = make_key("state_key");
159		let value = make_value("state_value");
160
161		// Set state
162		txn.state_set(node_id, &key, value.clone()).unwrap();
163
164		// Get state back
165		let result = txn.state_get(node_id, &key).await.unwrap();
166		assert_eq!(result, Some(value));
167	}
168
169	#[tokio::test]
170	async fn test_state_get_nonexistent() {
171		let parent = create_test_transaction().await;
172		let mut txn = FlowTransaction::new(&parent, CommitVersion(1)).await;
173
174		let node_id = FlowNodeId(1);
175		let key = make_key("missing");
176
177		let result = txn.state_get(node_id, &key).await.unwrap();
178		assert_eq!(result, None);
179	}
180
181	#[tokio::test]
182	async fn test_state_remove() {
183		let parent = create_test_transaction().await;
184		let mut txn = FlowTransaction::new(&parent, CommitVersion(1)).await;
185
186		let node_id = FlowNodeId(1);
187		let key = make_key("state_key");
188		let value = make_value("state_value");
189
190		// Set then remove
191		txn.state_set(node_id, &key, value.clone()).unwrap();
192		assert_eq!(txn.state_get(node_id, &key).await.unwrap(), Some(value));
193
194		txn.state_remove(node_id, &key).unwrap();
195		assert_eq!(txn.state_get(node_id, &key).await.unwrap(), None);
196	}
197
198	#[tokio::test]
199	async fn test_state_isolation_between_nodes() {
200		let parent = create_test_transaction().await;
201		let mut txn = FlowTransaction::new(&parent, CommitVersion(1)).await;
202
203		let node1 = FlowNodeId(1);
204		let node2 = FlowNodeId(2);
205		let key = make_key("same_key");
206
207		txn.state_set(node1, &key, make_value("node1_value")).unwrap();
208		txn.state_set(node2, &key, make_value("node2_value")).unwrap();
209
210		// Each node should have its own value
211		assert_eq!(txn.state_get(node1, &key).await.unwrap(), Some(make_value("node1_value")));
212		assert_eq!(txn.state_get(node2, &key).await.unwrap(), Some(make_value("node2_value")));
213	}
214
215	#[tokio::test]
216	async fn test_state_scan() {
217		let parent = create_test_transaction().await;
218		let mut txn = FlowTransaction::new(&parent, CommitVersion(1)).await;
219
220		let node_id = FlowNodeId(1);
221
222		txn.state_set(node_id, &make_key("key1"), make_value("value1")).unwrap();
223		txn.state_set(node_id, &make_key("key2"), make_value("value2")).unwrap();
224		txn.state_set(node_id, &make_key("key3"), make_value("value3")).unwrap();
225
226		let iter = txn.state_scan(node_id).await.unwrap();
227		let items: Vec<_> = iter.items.into_iter().collect();
228
229		assert_eq!(items.len(), 3);
230	}
231
232	#[tokio::test]
233	async fn test_state_scan_only_own_node() {
234		let parent = create_test_transaction().await;
235		let mut txn = FlowTransaction::new(&parent, CommitVersion(1)).await;
236
237		let node1 = FlowNodeId(1);
238		let node2 = FlowNodeId(2);
239
240		txn.state_set(node1, &make_key("key1"), make_value("value1")).unwrap();
241		txn.state_set(node1, &make_key("key2"), make_value("value2")).unwrap();
242		txn.state_set(node2, &make_key("key3"), make_value("value3")).unwrap();
243
244		// Scan node1 should only return node1's state
245		let items: Vec<_> = txn.state_scan(node1).await.unwrap().items.into_iter().collect();
246		assert_eq!(items.len(), 2);
247
248		// Scan node2 should only return node2's state
249		let items: Vec<_> = txn.state_scan(node2).await.unwrap().items.into_iter().collect();
250		assert_eq!(items.len(), 1);
251	}
252
253	#[tokio::test]
254	async fn test_state_scan_empty() {
255		let parent = create_test_transaction().await;
256		let mut txn = FlowTransaction::new(&parent, CommitVersion(1)).await;
257
258		let node_id = FlowNodeId(1);
259
260		let iter = txn.state_scan(node_id).await.unwrap();
261		assert!(iter.items.into_iter().next().is_none());
262	}
263
264	#[tokio::test]
265	async fn test_state_range() {
266		let parent = create_test_transaction().await;
267		let mut txn = FlowTransaction::new(&parent, CommitVersion(1)).await;
268
269		let node_id = FlowNodeId(1);
270
271		txn.state_set(node_id, &make_key("a"), make_value("1")).unwrap();
272		txn.state_set(node_id, &make_key("b"), make_value("2")).unwrap();
273		txn.state_set(node_id, &make_key("c"), make_value("3")).unwrap();
274		txn.state_set(node_id, &make_key("d"), make_value("4")).unwrap();
275
276		// Range query from "b" to "d" (exclusive)
277		use std::collections::Bound;
278		let range = EncodedKeyRange::new(Bound::Included(make_key("b")), Bound::Excluded(make_key("d")));
279		let iter = txn.state_range(node_id, range).await.unwrap();
280		let items: Vec<_> = iter.items.into_iter().collect();
281
282		// Should only include "b" and "c"
283		assert_eq!(items.len(), 2);
284	}
285
286	#[tokio::test]
287	async fn test_state_clear() {
288		let parent = create_test_transaction().await;
289		let mut txn = FlowTransaction::new(&parent, CommitVersion(1)).await;
290
291		let node_id = FlowNodeId(1);
292
293		txn.state_set(node_id, &make_key("key1"), make_value("value1")).unwrap();
294		txn.state_set(node_id, &make_key("key2"), make_value("value2")).unwrap();
295		txn.state_set(node_id, &make_key("key3"), make_value("value3")).unwrap();
296
297		// Verify state exists
298		assert_eq!(txn.state_scan(node_id).await.unwrap().items.into_iter().count(), 3);
299
300		// Clear all state
301		txn.state_clear(node_id).await.unwrap();
302
303		// Verify state is empty
304		assert_eq!(txn.state_scan(node_id).await.unwrap().items.into_iter().count(), 0);
305	}
306
307	#[tokio::test]
308	async fn test_state_clear_only_own_node() {
309		let parent = create_test_transaction().await;
310		let mut txn = FlowTransaction::new(&parent, CommitVersion(1)).await;
311
312		let node1 = FlowNodeId(1);
313		let node2 = FlowNodeId(2);
314
315		txn.state_set(node1, &make_key("key1"), make_value("value1")).unwrap();
316		txn.state_set(node1, &make_key("key2"), make_value("value2")).unwrap();
317		txn.state_set(node2, &make_key("key3"), make_value("value3")).unwrap();
318
319		// Clear node1
320		txn.state_clear(node1).await.unwrap();
321
322		// Node1 should be empty
323		assert_eq!(txn.state_scan(node1).await.unwrap().items.into_iter().count(), 0);
324
325		// Node2 should still have state
326		assert_eq!(txn.state_scan(node2).await.unwrap().items.into_iter().count(), 1);
327	}
328
329	#[tokio::test]
330	async fn test_state_clear_empty_node() {
331		let parent = create_test_transaction().await;
332		let mut txn = FlowTransaction::new(&parent, CommitVersion(1)).await;
333
334		let node_id = FlowNodeId(1);
335
336		// Clear on empty node should not error
337		txn.state_clear(node_id).await.unwrap();
338	}
339
340	#[tokio::test]
341	async fn test_load_or_create_existing() {
342		let parent = create_test_transaction().await;
343		let mut txn = FlowTransaction::new(&parent, CommitVersion(1)).await;
344
345		let node_id = FlowNodeId(1);
346		let key = make_key("key1");
347		let value = make_value("existing");
348		let layout = EncodedValuesLayout::new(&[Type::Int8, Type::Float8]);
349
350		// Set existing state
351		txn.state_set(node_id, &key, value.clone()).unwrap();
352
353		// load_or_create should return existing value
354		let result = txn.load_or_create_row(node_id, &key, &layout).await.unwrap();
355		assert_eq!(result, value);
356	}
357
358	#[tokio::test]
359	async fn test_load_or_create_new() {
360		let parent = create_test_transaction().await;
361		let mut txn = FlowTransaction::new(&parent, CommitVersion(1)).await;
362
363		let node_id = FlowNodeId(1);
364		let key = make_key("key1");
365		let layout = EncodedValuesLayout::new(&[Type::Int8, Type::Float8]);
366
367		// load_or_create should allocate new row
368		let result = txn.load_or_create_row(node_id, &key, &layout).await.unwrap();
369
370		// Result should be a newly allocated row (layout.allocate())
371		assert!(!result.as_ref().is_empty());
372	}
373
374	#[tokio::test]
375	async fn test_save_row() {
376		let parent = create_test_transaction().await;
377		let mut txn = FlowTransaction::new(&parent, CommitVersion(1)).await;
378
379		let node_id = FlowNodeId(1);
380		let key = make_key("key1");
381		let row = make_value("row_data");
382
383		txn.save_row(node_id, &key, row.clone()).unwrap();
384
385		// Verify saved
386		let result = txn.state_get(node_id, &key).await.unwrap();
387		assert_eq!(result, Some(row));
388	}
389
390	#[tokio::test]
391	async fn test_state_operations_increment_metrics() {
392		let parent = create_test_transaction().await;
393		let mut txn = FlowTransaction::new(&parent, CommitVersion(1)).await;
394
395		let node_id = FlowNodeId(1);
396		let key = make_key("key1");
397
398		assert_eq!(txn.metrics().state_operations, 0);
399
400		txn.state_set(node_id, &key, make_value("value")).unwrap();
401		assert_eq!(txn.metrics().state_operations, 1);
402
403		txn.state_get(node_id, &key).await.unwrap();
404		assert_eq!(txn.metrics().state_operations, 2);
405
406		txn.state_remove(node_id, &key).unwrap();
407		assert_eq!(txn.metrics().state_operations, 3);
408
409		let _ = txn.state_scan(node_id).await.unwrap();
410		assert_eq!(txn.metrics().state_operations, 4);
411
412		let range = EncodedKeyRange::start_end(Some(make_key("a")), Some(make_key("z")));
413		let _ = txn.state_range(node_id, range).await.unwrap();
414		assert_eq!(txn.metrics().state_operations, 5);
415
416		txn.state_clear(node_id).await.unwrap();
417		// state_clear calls state_scan internally, so it increments by 2 (one for clear, one for the range
418		// scan)
419		assert!(txn.metrics().state_operations >= 6);
420	}
421
422	#[tokio::test]
423	async fn test_state_multiple_nodes() {
424		let parent = create_test_transaction().await;
425		let mut txn = FlowTransaction::new(&parent, CommitVersion(1)).await;
426
427		let node1 = FlowNodeId(1);
428		let node2 = FlowNodeId(2);
429		let node3 = FlowNodeId(3);
430
431		txn.state_set(node1, &make_key("a"), make_value("n1_a")).unwrap();
432		txn.state_set(node1, &make_key("b"), make_value("n1_b")).unwrap();
433		txn.state_set(node2, &make_key("a"), make_value("n2_a")).unwrap();
434		txn.state_set(node3, &make_key("c"), make_value("n3_c")).unwrap();
435
436		// Verify each node has correct state
437		assert_eq!(txn.state_get(node1, &make_key("a")).await.unwrap(), Some(make_value("n1_a")));
438		assert_eq!(txn.state_get(node1, &make_key("b")).await.unwrap(), Some(make_value("n1_b")));
439		assert_eq!(txn.state_get(node2, &make_key("a")).await.unwrap(), Some(make_value("n2_a")));
440		assert_eq!(txn.state_get(node3, &make_key("c")).await.unwrap(), Some(make_value("n3_c")));
441
442		// Cross-node keys should not exist
443		assert_eq!(txn.state_get(node2, &make_key("b")).await.unwrap(), None);
444		assert_eq!(txn.state_get(node3, &make_key("a")).await.unwrap(), None);
445	}
446
447	#[tokio::test]
448	async fn test_load_or_create_increments_state_operations() {
449		let parent = create_test_transaction().await;
450		let mut txn = FlowTransaction::new(&parent, CommitVersion(1)).await;
451
452		let node_id = FlowNodeId(1);
453		let key = make_key("key1");
454		let layout = EncodedValuesLayout::new(&[Type::Int8]);
455
456		let initial_count = txn.metrics().state_operations;
457
458		txn.load_or_create_row(node_id, &key, &layout).await.unwrap();
459
460		// load_or_create calls state_get internally
461		assert!(txn.metrics().state_operations > initial_count);
462	}
463
464	#[tokio::test]
465	async fn test_save_row_increments_state_operations() {
466		let parent = create_test_transaction().await;
467		let mut txn = FlowTransaction::new(&parent, CommitVersion(1)).await;
468
469		let node_id = FlowNodeId(1);
470		let key = make_key("key1");
471
472		let initial_count = txn.metrics().state_operations;
473
474		txn.save_row(node_id, &key, make_value("data")).unwrap();
475
476		// save_row calls state_set internally
477		assert!(txn.metrics().state_operations > initial_count);
478	}
479}