1use std::collections::HashMap;
14use std::sync::Arc;
15
16use async_trait::async_trait;
17use dashmap::DashMap;
18
19use crate::error::StorageError;
20use crate::key::StateKey;
21
22#[async_trait]
30pub trait StateStorage: Send + Sync + 'static {
31 async fn get_state(&self, key: StateKey) -> Result<Option<String>, StorageError>;
33
34 async fn set_state(&self, key: StateKey, state: String) -> Result<(), StorageError>;
36
37 async fn clear_state(&self, key: StateKey) -> Result<(), StorageError>;
39
40 async fn get_data(
42 &self,
43 key: StateKey,
44 field: &str,
45 ) -> Result<Option<serde_json::Value>, StorageError>;
46
47 async fn set_data(
49 &self,
50 key: StateKey,
51 field: &str,
52 value: serde_json::Value,
53 ) -> Result<(), StorageError>;
54
55 async fn get_all_data(
57 &self,
58 key: StateKey,
59 ) -> Result<HashMap<String, serde_json::Value>, StorageError>;
60
61 async fn clear_data(&self, key: StateKey) -> Result<(), StorageError>;
63
64 async fn clear_all(&self, key: StateKey) -> Result<(), StorageError>;
66}
67
68#[derive(Clone, Default)]
76pub struct MemoryStorage {
77 entries: Arc<DashMap<StateKey, StorageEntry>>,
78}
79
80#[derive(Clone, Default)]
81struct StorageEntry {
82 state: Option<String>,
83 data: HashMap<String, serde_json::Value>,
84}
85
86impl MemoryStorage {
87 pub fn new() -> Self {
89 Self::default()
90 }
91
92 pub fn len(&self) -> usize {
94 self.entries.len()
95 }
96
97 pub fn is_empty(&self) -> bool {
99 self.entries.is_empty()
100 }
101}
102
103#[async_trait]
104impl StateStorage for MemoryStorage {
105 async fn get_state(&self, key: StateKey) -> Result<Option<String>, StorageError> {
106 Ok(self.entries.get(&key).and_then(|e| e.state.clone()))
107 }
108
109 async fn set_state(&self, key: StateKey, state: String) -> Result<(), StorageError> {
110 self.entries.entry(key).or_default().state = Some(state);
111 Ok(())
112 }
113
114 async fn clear_state(&self, key: StateKey) -> Result<(), StorageError> {
115 if let Some(mut entry) = self.entries.get_mut(&key) {
116 entry.state = None;
117 if entry.data.is_empty() {
118 drop(entry);
119 self.entries.remove(&key);
120 }
121 }
122 Ok(())
123 }
124
125 async fn get_data(
126 &self,
127 key: StateKey,
128 field: &str,
129 ) -> Result<Option<serde_json::Value>, StorageError> {
130 Ok(self
131 .entries
132 .get(&key)
133 .and_then(|e| e.data.get(field).cloned()))
134 }
135
136 async fn set_data(
137 &self,
138 key: StateKey,
139 field: &str,
140 value: serde_json::Value,
141 ) -> Result<(), StorageError> {
142 self.entries
143 .entry(key)
144 .or_default()
145 .data
146 .insert(field.to_string(), value);
147 Ok(())
148 }
149
150 async fn get_all_data(
151 &self,
152 key: StateKey,
153 ) -> Result<HashMap<String, serde_json::Value>, StorageError> {
154 Ok(self
155 .entries
156 .get(&key)
157 .map(|e| e.data.clone())
158 .unwrap_or_default())
159 }
160
161 async fn clear_data(&self, key: StateKey) -> Result<(), StorageError> {
162 if let Some(mut entry) = self.entries.get_mut(&key) {
163 entry.data.clear();
164 if entry.state.is_none() {
165 drop(entry);
166 self.entries.remove(&key);
167 }
168 }
169 Ok(())
170 }
171
172 async fn clear_all(&self, key: StateKey) -> Result<(), StorageError> {
173 self.entries.remove(&key);
174 Ok(())
175 }
176}