1use std::{collections::HashMap, hash::Hash, mem, sync::Arc};
5
6use reifydb_core::{encoded::key::IntoEncodedKey, util::lru::LruCache};
7use serde::{Serialize, de::DeserializeOwned};
8
9use crate::{error::Result, operator::context::OperatorContext};
10
11#[derive(Clone, Copy, Debug)]
12pub enum StateBackend {
13 Data,
14
15 Internal,
16}
17
18pub struct StateCache<K, V> {
19 cache: LruCache<K, Arc<V>>,
20 dirty: HashMap<K, Option<Arc<V>>>,
21 backend: StateBackend,
22}
23
24impl<K, V> StateCache<K, V>
25where
26 K: Hash + Eq + Clone,
27 for<'a> &'a K: IntoEncodedKey,
28 V: Clone + Serialize + DeserializeOwned,
29{
30 pub fn new(capacity: usize) -> Self {
31 Self::with_backend(capacity, StateBackend::Data)
32 }
33
34 pub fn new_internal(capacity: usize) -> Self {
35 Self::with_backend(capacity, StateBackend::Internal)
36 }
37
38 fn with_backend(capacity: usize, backend: StateBackend) -> Self {
39 Self {
40 cache: LruCache::new(capacity),
41 dirty: HashMap::new(),
42 backend,
43 }
44 }
45
46 pub fn get_arc(&mut self, ctx: &mut OperatorContext, key: &K) -> Result<Option<Arc<V>>> {
47 if let Some(cached) = self.cache.get(key) {
48 return Ok(Some(cached));
49 }
50
51 if let Some(slot) = self.dirty.get(key) {
52 return Ok(slot.clone());
53 }
54
55 let encoded_key = key.into_encoded_key();
56 let loaded = match self.backend {
57 StateBackend::Data => ctx.state().get::<V>(&encoded_key)?,
58 StateBackend::Internal => ctx.internal_state().get::<V>(&encoded_key)?,
59 };
60 match loaded {
61 Some(value) => {
62 let arc = Arc::new(value);
63 self.cache.put(key.clone(), arc.clone());
64 Ok(Some(arc))
65 }
66 None => Ok(None),
67 }
68 }
69
70 pub fn get(&mut self, ctx: &mut OperatorContext, key: &K) -> Result<Option<V>> {
71 Ok(self.get_arc(ctx, key)?.map(|arc| (*arc).clone()))
72 }
73
74 pub fn set(&mut self, _ctx: &mut OperatorContext, key: &K, value: &V) -> Result<()> {
75 let arc = Arc::new(value.clone());
76 self.cache.put(key.clone(), arc.clone());
77 self.dirty.insert(key.clone(), Some(arc));
78 Ok(())
79 }
80
81 pub fn put(&mut self, _ctx: &mut OperatorContext, key: &K, value: V) -> Result<()> {
82 let arc = Arc::new(value);
83 self.cache.put(key.clone(), arc.clone());
84 self.dirty.insert(key.clone(), Some(arc));
85 Ok(())
86 }
87
88 pub fn put_arc(&mut self, _ctx: &mut OperatorContext, key: &K, value: Arc<V>) -> Result<()> {
89 self.cache.put(key.clone(), value.clone());
90 self.dirty.insert(key.clone(), Some(value));
91 Ok(())
92 }
93
94 pub fn modify<F>(&mut self, ctx: &mut OperatorContext, key: &K, f: F) -> Result<()>
95 where
96 F: FnOnce(&mut V) -> Result<()>,
97 V: Default,
98 {
99 let mut arc = self.get_arc(ctx, key)?.unwrap_or_else(|| Arc::new(V::default()));
100 f(Arc::make_mut(&mut arc))?;
101 self.put_arc(ctx, key, arc)
102 }
103
104 pub fn remove(&mut self, _ctx: &mut OperatorContext, key: &K) -> Result<()> {
105 self.cache.remove(key);
106 self.dirty.insert(key.clone(), None);
107 Ok(())
108 }
109
110 pub fn flush(&mut self, ctx: &mut OperatorContext) -> Result<()> {
111 let dirty = mem::take(&mut self.dirty);
112 for (key, slot) in dirty {
113 let encoded_key = (&key).into_encoded_key();
114 match (slot, self.backend) {
115 (Some(value), StateBackend::Data) => ctx.state().set(&encoded_key, value.as_ref())?,
116 (Some(value), StateBackend::Internal) => {
117 ctx.internal_state().set(&encoded_key, value.as_ref())?
118 }
119 (None, StateBackend::Data) => ctx.state().remove(&encoded_key)?,
120 (None, StateBackend::Internal) => ctx.internal_state().remove(&encoded_key)?,
121 }
122 }
123 Ok(())
124 }
125
126 pub fn clear_cache(&mut self) {
127 self.cache.clear();
128 }
129
130 pub fn invalidate(&mut self, key: &K) {
131 self.cache.remove(key);
132 }
133
134 pub fn is_cached(&self, key: &K) -> bool {
135 self.cache.contains_key(key)
136 }
137
138 pub fn len(&self) -> usize {
139 self.cache.len()
140 }
141
142 pub fn is_empty(&self) -> bool {
143 self.cache.is_empty()
144 }
145
146 pub fn capacity(&self) -> usize {
147 self.cache.capacity()
148 }
149}
150
151impl<K, V> StateCache<K, V>
152where
153 K: Hash + Eq + Clone,
154 for<'a> &'a K: IntoEncodedKey,
155 V: Clone + Default + Serialize + DeserializeOwned,
156{
157 pub fn get_or_default(&mut self, ctx: &mut OperatorContext, key: &K) -> Result<V> {
158 match self.get(ctx, key)? {
159 Some(value) => Ok(value),
160 None => Ok(V::default()),
161 }
162 }
163
164 pub fn update<U>(&mut self, ctx: &mut OperatorContext, key: &K, updater: U) -> Result<V>
165 where
166 U: FnOnce(&mut V) -> Result<()>,
167 {
168 let mut value = self.get_or_default(ctx, key)?;
169 updater(&mut value)?;
170 self.set(ctx, key, &value)?;
171 Ok(value)
172 }
173}
174
175#[cfg(test)]
176pub mod tests {
177 use reifydb_core::encoded::key::IntoEncodedKey;
178
179 use super::*;
180
181 #[test]
182 fn test_cache_capacity() {
183 let cache: StateCache<String, i32> = StateCache::new(100);
184 assert_eq!(cache.capacity(), 100);
185 assert!(cache.is_empty());
186 assert_eq!(cache.len(), 0);
187 }
188
189 #[test]
190 #[should_panic(expected = "capacity must be greater than 0")]
191 fn test_zero_capacity_panics() {
192 let _cache: StateCache<String, i32> = StateCache::new(0);
193 }
194
195 #[test]
196 fn test_into_encoded_key_string() {
197 let key = "test_key".to_string();
198 let encoded = (&key).into_encoded_key();
199 assert!(!encoded.as_bytes().is_empty());
200 }
201
202 #[test]
203 fn test_into_encoded_key_str() {
204 let key = "test_key";
205 let encoded = key.into_encoded_key();
206 assert!(!encoded.as_bytes().is_empty());
207 }
208
209 #[test]
210 fn test_into_encoded_key_tuple2() {
211 let key = ("base".to_string(), "quote".to_string());
212 let encoded = (&key).into_encoded_key();
213 assert!(!encoded.as_bytes().is_empty());
214 }
215
216 #[test]
217 fn test_into_encoded_key_tuple3() {
218 let key = ("a".to_string(), "b".to_string(), "c".to_string());
219 let encoded = (&key).into_encoded_key();
220 assert!(!encoded.as_bytes().is_empty());
221 }
222
223 #[test]
224 fn test_into_encoded_key_consistency() {
225 let key1 = ("base".to_string(), "quote".to_string());
226 let key2 = ("base".to_string(), "quote".to_string());
227 assert_eq!((&key1).into_encoded_key().as_bytes(), (&key2).into_encoded_key().as_bytes());
228 }
229
230 #[test]
231 fn test_into_encoded_key_different_keys() {
232 let key1 = ("a".to_string(), "b".to_string());
233 let key2 = ("c".to_string(), "d".to_string());
234 assert_ne!((&key1).into_encoded_key().as_bytes(), (&key2).into_encoded_key().as_bytes());
235 }
236}