dag_executor/storage/
file_storage.rs1use crate::error::StorageError;
4use async_trait::async_trait;
5use parking_lot::Mutex;
6use serde::{Deserialize, Serialize};
7use sha2::{Digest, Sha256};
8use std::collections::HashMap;
9use std::io;
10use std::path::{Path, PathBuf};
11
12pub type StorageResult<T> = std::result::Result<T, StorageError>;
14
15#[async_trait]
20pub trait Storage: Send + Sync {
21 async fn save(&self, key: &str, value: &serde_json::Value) -> StorageResult<()>;
23 async fn load(&self, key: &str) -> StorageResult<Option<serde_json::Value>>;
25 async fn delete(&self, key: &str) -> StorageResult<()>;
27 async fn list(&self) -> StorageResult<Vec<String>>;
29}
30
31#[derive(Serialize, Deserialize)]
36struct Envelope {
37 key: String,
38 checksum: String,
39 data: serde_json::Value,
40}
41
42fn checksum(value: &serde_json::Value) -> StorageResult<String> {
43 let bytes = serde_json::to_vec(value)?;
46 let mut hasher = Sha256::new();
47 hasher.update(&bytes);
48 Ok(format!("{:x}", hasher.finalize()))
49}
50
51fn safe_filename(key: &str) -> StorageResult<String> {
58 if key.is_empty() || key.len() > 1024 {
59 return Err(StorageError::InvalidKey(key.to_string()));
60 }
61 let mut hasher = Sha256::new();
62 hasher.update(key.as_bytes());
63 Ok(format!("{:x}.json", hasher.finalize()))
64}
65
66#[cfg(unix)]
70fn sync_parent_dir(path: &Path) {
71 if let Some(parent) = path.parent() {
72 if let Ok(dir) = std::fs::File::open(parent) {
73 let _ = dir.sync_all();
74 }
75 }
76}
77
78#[cfg(not(unix))]
81fn sync_parent_dir(_path: &Path) {}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub enum Durability {
86 Fast,
94 Atomic,
101 Durable,
109}
110
111pub struct FileStorage {
118 root: PathBuf,
119 durability: Durability,
120}
121
122impl FileStorage {
123 pub fn open(dir: impl AsRef<Path>) -> StorageResult<Self> {
125 Self::open_with(dir, Durability::Fast)
126 }
127
128 pub fn open_with(dir: impl AsRef<Path>, durability: Durability) -> StorageResult<Self> {
130 let root = dir.as_ref().to_path_buf();
131 std::fs::create_dir_all(&root)?;
132 Ok(FileStorage { root, durability })
133 }
134
135 pub fn durability(&self) -> Durability {
137 self.durability
138 }
139
140 fn path_for(&self, key: &str) -> StorageResult<PathBuf> {
141 Ok(self.root.join(safe_filename(key)?))
142 }
143}
144
145#[async_trait]
146impl Storage for FileStorage {
147 async fn save(&self, key: &str, value: &serde_json::Value) -> StorageResult<()> {
148 let path = self.path_for(key)?;
149 let envelope = Envelope {
150 key: key.to_string(),
151 checksum: checksum(value)?,
152 data: value.clone(),
153 };
154 let bytes = serde_json::to_vec(&envelope)?;
155 let root = self.root.clone();
156 let durability = self.durability;
157
158 tokio::task::spawn_blocking(move || -> StorageResult<()> {
160 match durability {
161 Durability::Fast => {
162 use std::io::Write;
163 let mut f = std::fs::OpenOptions::new()
165 .create(true)
166 .write(true)
167 .truncate(true)
168 .open(&path)?;
169 f.write_all(&bytes)?;
170 Ok(())
171 }
172 Durability::Atomic => {
173 let tmp = root.join(format!(".tmp-{}", uuid::Uuid::new_v4()));
174 std::fs::write(&tmp, &bytes)?;
175 std::fs::rename(&tmp, &path)?;
177 Ok(())
178 }
179 Durability::Durable => {
180 use std::io::Write;
181 let tmp = root.join(format!(".tmp-{}", uuid::Uuid::new_v4()));
182 {
183 let mut f = std::fs::OpenOptions::new()
184 .create(true)
185 .write(true)
186 .truncate(true)
187 .open(&tmp)?;
188 f.write_all(&bytes)?;
189 f.sync_all()?;
193 }
194 std::fs::rename(&tmp, &path)?;
195 sync_parent_dir(&path);
197 Ok(())
198 }
199 }
200 })
201 .await
202 .map_err(|e| StorageError::Io(io::Error::other(e)))?
203 }
204
205 async fn load(&self, key: &str) -> StorageResult<Option<serde_json::Value>> {
206 let path = self.path_for(key)?;
207 let bytes = match tokio::fs::read(&path).await {
208 Ok(b) => b,
209 Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(None),
210 Err(e) => return Err(e.into()),
211 };
212 let envelope: Envelope = serde_json::from_slice(&bytes)?;
213 if checksum(&envelope.data)? != envelope.checksum {
214 return Err(StorageError::ChecksumMismatch(key.to_string()));
215 }
216 Ok(Some(envelope.data))
217 }
218
219 async fn delete(&self, key: &str) -> StorageResult<()> {
220 let path = self.path_for(key)?;
221 match tokio::fs::remove_file(&path).await {
222 Ok(()) => Ok(()),
223 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
224 Err(e) => Err(e.into()),
225 }
226 }
227
228 async fn list(&self) -> StorageResult<Vec<String>> {
229 let mut keys = Vec::new();
230 let mut dir = tokio::fs::read_dir(&self.root).await?;
231 while let Some(entry) = dir.next_entry().await? {
232 let name = entry.file_name();
233 let name = name.to_string_lossy();
234 if name.ends_with(".json") && !name.starts_with(".tmp-") {
236 if let Ok(bytes) = tokio::fs::read(entry.path()).await {
237 if let Ok(envelope) = serde_json::from_slice::<Envelope>(&bytes) {
238 keys.push(envelope.key);
239 }
240 }
241 }
242 }
243 Ok(keys)
244 }
245}
246
247#[derive(Default)]
249pub struct MemoryStorage {
250 map: Mutex<HashMap<String, serde_json::Value>>,
251}
252
253impl MemoryStorage {
254 pub fn new() -> Self {
256 Self::default()
257 }
258}
259
260#[async_trait]
261impl Storage for MemoryStorage {
262 async fn save(&self, key: &str, value: &serde_json::Value) -> StorageResult<()> {
263 self.map.lock().insert(key.to_string(), value.clone());
264 Ok(())
265 }
266
267 async fn load(&self, key: &str) -> StorageResult<Option<serde_json::Value>> {
268 Ok(self.map.lock().get(key).cloned())
269 }
270
271 async fn delete(&self, key: &str) -> StorageResult<()> {
272 self.map.lock().remove(key);
273 Ok(())
274 }
275
276 async fn list(&self) -> StorageResult<Vec<String>> {
277 Ok(self.map.lock().keys().cloned().collect())
278 }
279}