forge_orchestration/
storage.rs1use crate::error::{ForgeError, Result};
13use async_trait::async_trait;
14use serde::{de::DeserializeOwned, Serialize};
15use std::collections::HashMap;
16use std::path::{Path, PathBuf};
17use std::sync::Arc;
18use tokio::sync::RwLock;
19use tracing::{debug, info};
20
21#[async_trait]
23pub trait StateStore: Send + Sync {
24 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>>;
26
27 async fn set(&self, key: &str, value: Vec<u8>) -> Result<()>;
29
30 async fn delete(&self, key: &str) -> Result<()>;
32
33 async fn list_prefix(&self, prefix: &str) -> Result<Vec<String>>;
35
36 fn name(&self) -> &str;
38}
39
40pub async fn store_exists(store: &dyn StateStore, key: &str) -> Result<bool> {
44 Ok(store.get(key).await?.is_some())
45}
46
47pub async fn store_get_json<T: DeserializeOwned>(store: &dyn StateStore, key: &str) -> Result<Option<T>> {
49 match store.get(key).await? {
50 Some(bytes) => {
51 let value = serde_json::from_slice(&bytes)?;
52 Ok(Some(value))
53 }
54 None => Ok(None),
55 }
56}
57
58pub async fn store_set_json<T: Serialize>(store: &dyn StateStore, key: &str, value: &T) -> Result<()> {
60 let bytes = serde_json::to_vec(value)?;
61 store.set(key, bytes).await
62}
63
64#[derive(Debug, Default)]
66pub struct MemoryStore {
67 data: RwLock<HashMap<String, Vec<u8>>>,
68}
69
70impl MemoryStore {
71 pub fn new() -> Self {
73 Self {
74 data: RwLock::new(HashMap::new()),
75 }
76 }
77}
78
79#[async_trait]
80impl StateStore for MemoryStore {
81 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
82 let data = self.data.read().await;
83 Ok(data.get(key).cloned())
84 }
85
86 async fn set(&self, key: &str, value: Vec<u8>) -> Result<()> {
87 let mut data = self.data.write().await;
88 data.insert(key.to_string(), value);
89 Ok(())
90 }
91
92 async fn delete(&self, key: &str) -> Result<()> {
93 let mut data = self.data.write().await;
94 data.remove(key);
95 Ok(())
96 }
97
98 async fn list_prefix(&self, prefix: &str) -> Result<Vec<String>> {
99 let data = self.data.read().await;
100 Ok(data
101 .keys()
102 .filter(|k| k.starts_with(prefix))
103 .cloned()
104 .collect())
105 }
106
107 fn name(&self) -> &str {
108 "memory"
109 }
110}
111
112pub struct FileStore {
116 path: PathBuf,
117 data: RwLock<HashMap<String, Vec<u8>>>,
118}
119
120impl FileStore {
121 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
123 let path = path.as_ref().to_path_buf();
124
125 let data = if path.exists() {
126 let contents = std::fs::read_to_string(&path)
127 .map_err(|e| ForgeError::storage(format!("Failed to read store: {}", e)))?;
128 serde_json::from_str(&contents).unwrap_or_default()
129 } else {
130 HashMap::new()
131 };
132
133 info!(path = %path.display(), "File store opened");
134
135 Ok(Self {
136 path,
137 data: RwLock::new(data),
138 })
139 }
140
141 pub async fn flush(&self) -> Result<()> {
143 let data = self.data.read().await;
144 let contents = serde_json::to_string_pretty(&*data)?;
145
146 if let Some(parent) = self.path.parent() {
147 std::fs::create_dir_all(parent)
148 .map_err(|e| ForgeError::storage(format!("Failed to create dir: {}", e)))?;
149 }
150
151 std::fs::write(&self.path, contents)
152 .map_err(|e| ForgeError::storage(format!("Failed to write store: {}", e)))?;
153
154 debug!(path = %self.path.display(), "File store flushed");
155 Ok(())
156 }
157}
158
159#[async_trait]
160impl StateStore for FileStore {
161 async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> {
162 let data = self.data.read().await;
163 Ok(data.get(key).cloned())
164 }
165
166 async fn set(&self, key: &str, value: Vec<u8>) -> Result<()> {
167 let mut data = self.data.write().await;
168 data.insert(key.to_string(), value);
169 Ok(())
170 }
171
172 async fn delete(&self, key: &str) -> Result<()> {
173 let mut data = self.data.write().await;
174 data.remove(key);
175 Ok(())
176 }
177
178 async fn list_prefix(&self, prefix: &str) -> Result<Vec<String>> {
179 let data = self.data.read().await;
180 Ok(data
181 .keys()
182 .filter(|k| k.starts_with(prefix))
183 .cloned()
184 .collect())
185 }
186
187 fn name(&self) -> &str {
188 "file"
189 }
190}
191
192pub type BoxedStateStore = Arc<dyn StateStore>;
194
195pub fn memory_store() -> BoxedStateStore {
197 Arc::new(MemoryStore::new()) as BoxedStateStore
198}
199
200pub mod keys {
202 pub const JOBS: &str = "forge/jobs";
204 pub const SHARDS: &str = "forge/shards";
206 pub const EXPERTS: &str = "forge/experts";
208 pub const NODES: &str = "forge/nodes";
210 pub const CONFIG: &str = "forge/config";
212 pub const METRICS: &str = "forge/metrics";
214
215 pub fn job(id: &str) -> String {
217 format!("{}/{}", JOBS, id)
218 }
219
220 pub fn shard(id: u64) -> String {
222 format!("{}/{}", SHARDS, id)
223 }
224
225 pub fn expert(index: usize) -> String {
227 format!("{}/{}", EXPERTS, index)
228 }
229
230 pub fn node(id: &str) -> String {
232 format!("{}/{}", NODES, id)
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[tokio::test]
241 async fn test_memory_store_basic() {
242 let store = MemoryStore::new();
243
244 store.set("key1", b"value1".to_vec()).await.unwrap();
246 let value = store.get("key1").await.unwrap();
247 assert_eq!(value, Some(b"value1".to_vec()));
248
249 store.delete("key1").await.unwrap();
251 let value = store.get("key1").await.unwrap();
252 assert!(value.is_none());
253 }
254
255 #[tokio::test]
256 async fn test_memory_store_prefix() {
257 let store = MemoryStore::new();
258
259 store.set("prefix/a", b"1".to_vec()).await.unwrap();
260 store.set("prefix/b", b"2".to_vec()).await.unwrap();
261 store.set("other/c", b"3".to_vec()).await.unwrap();
262
263 let keys = store.list_prefix("prefix/").await.unwrap();
264 assert_eq!(keys.len(), 2);
265 assert!(keys.contains(&"prefix/a".to_string()));
266 assert!(keys.contains(&"prefix/b".to_string()));
267 }
268
269 #[tokio::test]
270 async fn test_memory_store_json() {
271 let store = MemoryStore::new();
272
273 #[derive(Debug, PartialEq, Serialize, serde::Deserialize)]
274 struct TestData {
275 name: String,
276 value: i32,
277 }
278
279 let data = TestData {
280 name: "test".to_string(),
281 value: 42,
282 };
283
284 store_set_json(&store, "json_key", &data).await.unwrap();
285 let loaded: Option<TestData> = store_get_json(&store, "json_key").await.unwrap();
286 assert_eq!(loaded, Some(data));
287 }
288
289 #[test]
290 fn test_key_builders() {
291 assert_eq!(keys::job("my-job"), "forge/jobs/my-job");
292 assert_eq!(keys::shard(42), "forge/shards/42");
293 assert_eq!(keys::expert(0), "forge/experts/0");
294 }
295}