Skip to main content

reifydb_sdk/state/
cache.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use std::{collections::HashMap, hash::Hash, mem, sync::Arc};
5
6use reifydb_core::{encoded::key::IntoEncodedKey, util::lru::LruCache};
7use serde::{Serialize, de::DeserializeOwned};
8
9use crate::{error::Result, operator::context::OperatorContext};
10
11#[derive(Clone, Copy, Debug)]
12pub enum StateBackend {
13	Data,
14
15	Internal,
16}
17
18pub struct StateCache<K, V> {
19	cache: LruCache<K, Arc<V>>,
20	dirty: HashMap<K, Option<Arc<V>>>,
21	backend: StateBackend,
22}
23
24impl<K, V> StateCache<K, V>
25where
26	K: Hash + Eq + Clone,
27	for<'a> &'a K: IntoEncodedKey,
28	V: Clone + Serialize + DeserializeOwned,
29{
30	pub fn new(capacity: usize) -> Self {
31		Self::with_backend(capacity, StateBackend::Data)
32	}
33
34	pub fn new_internal(capacity: usize) -> Self {
35		Self::with_backend(capacity, StateBackend::Internal)
36	}
37
38	fn with_backend(capacity: usize, backend: StateBackend) -> Self {
39		Self {
40			cache: LruCache::new(capacity),
41			dirty: HashMap::new(),
42			backend,
43		}
44	}
45
46	pub fn get_arc(&mut self, ctx: &mut OperatorContext, key: &K) -> Result<Option<Arc<V>>> {
47		if let Some(cached) = self.cache.get(key) {
48			return Ok(Some(cached));
49		}
50
51		if let Some(slot) = self.dirty.get(key) {
52			return Ok(slot.clone());
53		}
54
55		let encoded_key = key.into_encoded_key();
56		let loaded = match self.backend {
57			StateBackend::Data => ctx.state().get::<V>(&encoded_key)?,
58			StateBackend::Internal => ctx.internal_state().get::<V>(&encoded_key)?,
59		};
60		match loaded {
61			Some(value) => {
62				let arc = Arc::new(value);
63				self.cache.put(key.clone(), arc.clone());
64				Ok(Some(arc))
65			}
66			None => Ok(None),
67		}
68	}
69
70	pub fn get(&mut self, ctx: &mut OperatorContext, key: &K) -> Result<Option<V>> {
71		Ok(self.get_arc(ctx, key)?.map(|arc| (*arc).clone()))
72	}
73
74	pub fn set(&mut self, _ctx: &mut OperatorContext, key: &K, value: &V) -> Result<()> {
75		let arc = Arc::new(value.clone());
76		self.cache.put(key.clone(), arc.clone());
77		self.dirty.insert(key.clone(), Some(arc));
78		Ok(())
79	}
80
81	pub fn put(&mut self, _ctx: &mut OperatorContext, key: &K, value: V) -> Result<()> {
82		let arc = Arc::new(value);
83		self.cache.put(key.clone(), arc.clone());
84		self.dirty.insert(key.clone(), Some(arc));
85		Ok(())
86	}
87
88	pub fn put_arc(&mut self, _ctx: &mut OperatorContext, key: &K, value: Arc<V>) -> Result<()> {
89		self.cache.put(key.clone(), value.clone());
90		self.dirty.insert(key.clone(), Some(value));
91		Ok(())
92	}
93
94	pub fn modify<F>(&mut self, ctx: &mut OperatorContext, key: &K, f: F) -> Result<()>
95	where
96		F: FnOnce(&mut V) -> Result<()>,
97		V: Default,
98	{
99		let mut arc = self.get_arc(ctx, key)?.unwrap_or_else(|| Arc::new(V::default()));
100		f(Arc::make_mut(&mut arc))?;
101		self.put_arc(ctx, key, arc)
102	}
103
104	pub fn remove(&mut self, _ctx: &mut OperatorContext, key: &K) -> Result<()> {
105		self.cache.remove(key);
106		self.dirty.insert(key.clone(), None);
107		Ok(())
108	}
109
110	pub fn flush(&mut self, ctx: &mut OperatorContext) -> Result<()> {
111		let dirty = mem::take(&mut self.dirty);
112		for (key, slot) in dirty {
113			let encoded_key = (&key).into_encoded_key();
114			match (slot, self.backend) {
115				(Some(value), StateBackend::Data) => ctx.state().set(&encoded_key, value.as_ref())?,
116				(Some(value), StateBackend::Internal) => {
117					ctx.internal_state().set(&encoded_key, value.as_ref())?
118				}
119				(None, StateBackend::Data) => ctx.state().remove(&encoded_key)?,
120				(None, StateBackend::Internal) => ctx.internal_state().remove(&encoded_key)?,
121			}
122		}
123		Ok(())
124	}
125
126	pub fn clear_cache(&mut self) {
127		self.cache.clear();
128	}
129
130	pub fn invalidate(&mut self, key: &K) {
131		self.cache.remove(key);
132	}
133
134	pub fn is_cached(&self, key: &K) -> bool {
135		self.cache.contains_key(key)
136	}
137
138	pub fn len(&self) -> usize {
139		self.cache.len()
140	}
141
142	pub fn is_empty(&self) -> bool {
143		self.cache.is_empty()
144	}
145
146	pub fn capacity(&self) -> usize {
147		self.cache.capacity()
148	}
149}
150
151impl<K, V> StateCache<K, V>
152where
153	K: Hash + Eq + Clone,
154	for<'a> &'a K: IntoEncodedKey,
155	V: Clone + Default + Serialize + DeserializeOwned,
156{
157	pub fn get_or_default(&mut self, ctx: &mut OperatorContext, key: &K) -> Result<V> {
158		match self.get(ctx, key)? {
159			Some(value) => Ok(value),
160			None => Ok(V::default()),
161		}
162	}
163
164	pub fn update<U>(&mut self, ctx: &mut OperatorContext, key: &K, updater: U) -> Result<V>
165	where
166		U: FnOnce(&mut V) -> Result<()>,
167	{
168		let mut value = self.get_or_default(ctx, key)?;
169		updater(&mut value)?;
170		self.set(ctx, key, &value)?;
171		Ok(value)
172	}
173}
174
175#[cfg(test)]
176pub mod tests {
177	use reifydb_core::encoded::key::IntoEncodedKey;
178
179	use super::*;
180
181	#[test]
182	fn test_cache_capacity() {
183		let cache: StateCache<String, i32> = StateCache::new(100);
184		assert_eq!(cache.capacity(), 100);
185		assert!(cache.is_empty());
186		assert_eq!(cache.len(), 0);
187	}
188
189	#[test]
190	#[should_panic(expected = "capacity must be greater than 0")]
191	fn test_zero_capacity_panics() {
192		let _cache: StateCache<String, i32> = StateCache::new(0);
193	}
194
195	#[test]
196	fn test_into_encoded_key_string() {
197		let key = "test_key".to_string();
198		let encoded = (&key).into_encoded_key();
199		assert!(!encoded.as_bytes().is_empty());
200	}
201
202	#[test]
203	fn test_into_encoded_key_str() {
204		let key = "test_key";
205		let encoded = key.into_encoded_key();
206		assert!(!encoded.as_bytes().is_empty());
207	}
208
209	#[test]
210	fn test_into_encoded_key_tuple2() {
211		let key = ("base".to_string(), "quote".to_string());
212		let encoded = (&key).into_encoded_key();
213		assert!(!encoded.as_bytes().is_empty());
214	}
215
216	#[test]
217	fn test_into_encoded_key_tuple3() {
218		let key = ("a".to_string(), "b".to_string(), "c".to_string());
219		let encoded = (&key).into_encoded_key();
220		assert!(!encoded.as_bytes().is_empty());
221	}
222
223	#[test]
224	fn test_into_encoded_key_consistency() {
225		let key1 = ("base".to_string(), "quote".to_string());
226		let key2 = ("base".to_string(), "quote".to_string());
227		assert_eq!((&key1).into_encoded_key().as_bytes(), (&key2).into_encoded_key().as_bytes());
228	}
229
230	#[test]
231	fn test_into_encoded_key_different_keys() {
232		let key1 = ("a".to_string(), "b".to_string());
233		let key2 = ("c".to_string(), "d".to_string());
234		assert_ne!((&key1).into_encoded_key().as_bytes(), (&key2).into_encoded_key().as_bytes());
235	}
236}