Skip to main content

mpp_br/
store.rs

1//! Pluggable key-value store abstraction.
2//!
3//! Modeled after the TypeScript SDK's Store interface (Cloudflare KV API style).
4//! Implementations handle serialization internally.
5
6use std::future::Future;
7use std::pin::Pin;
8
9/// Async key-value store interface.
10///
11/// Simple `get`/`put`/`delete` API compatible with various backends:
12/// - In-memory (for development/testing)
13/// - File-system (for simple persistence)
14/// - Redis, SQLite, etc. (for production)
15pub trait Store: Send + Sync {
16    /// Get a value by key. Returns None if not found.
17    fn get(
18        &self,
19        key: &str,
20    ) -> Pin<Box<dyn Future<Output = Result<Option<serde_json::Value>, StoreError>> + Send + '_>>;
21
22    /// Put a value by key.
23    fn put(
24        &self,
25        key: &str,
26        value: serde_json::Value,
27    ) -> Pin<Box<dyn Future<Output = Result<(), StoreError>> + Send + '_>>;
28
29    /// Delete a value by key.
30    fn delete(
31        &self,
32        key: &str,
33    ) -> Pin<Box<dyn Future<Output = Result<(), StoreError>> + Send + '_>>;
34}
35
36#[derive(Debug, thiserror::Error)]
37pub enum StoreError {
38    #[error("Store error: {0}")]
39    Internal(String),
40    #[error("Serialization error: {0}")]
41    Serialization(String),
42}
43
44// ==================== MemoryStore ====================
45
46/// In-memory store backed by a HashMap. JSON-roundtrips values to match production behavior.
47pub struct MemoryStore {
48    data: std::sync::Mutex<std::collections::HashMap<String, String>>,
49}
50
51impl Default for MemoryStore {
52    fn default() -> Self {
53        Self {
54            data: std::sync::Mutex::new(std::collections::HashMap::new()),
55        }
56    }
57}
58
59impl MemoryStore {
60    pub fn new() -> Self {
61        Self::default()
62    }
63}
64
65impl Store for MemoryStore {
66    fn get(
67        &self,
68        key: &str,
69    ) -> Pin<Box<dyn Future<Output = Result<Option<serde_json::Value>, StoreError>> + Send + '_>>
70    {
71        let result = self.data.lock().unwrap().get(key).cloned();
72        Box::pin(async move {
73            match result {
74                Some(raw) => {
75                    let value = serde_json::from_str(&raw)
76                        .map_err(|e| StoreError::Serialization(e.to_string()))?;
77                    Ok(Some(value))
78                }
79                None => Ok(None),
80            }
81        })
82    }
83
84    fn put(
85        &self,
86        key: &str,
87        value: serde_json::Value,
88    ) -> Pin<Box<dyn Future<Output = Result<(), StoreError>> + Send + '_>> {
89        let key = key.to_string();
90        let serialized =
91            serde_json::to_string(&value).map_err(|e| StoreError::Serialization(e.to_string()));
92        Box::pin(async move {
93            let serialized = serialized?;
94            self.data.lock().unwrap().insert(key, serialized);
95            Ok(())
96        })
97    }
98
99    fn delete(
100        &self,
101        key: &str,
102    ) -> Pin<Box<dyn Future<Output = Result<(), StoreError>> + Send + '_>> {
103        self.data.lock().unwrap().remove(key);
104        Box::pin(async { Ok(()) })
105    }
106}
107
108// ==================== FileStore ====================
109
110/// File-system backed store. Each key is stored as a JSON file.
111///
112/// Useful for development and simple deployments where a database is overkill.
113pub struct FileStore {
114    dir: std::path::PathBuf,
115}
116
117impl FileStore {
118    /// Create a new FileStore that persists data in the given directory.
119    ///
120    /// Creates the directory if it does not exist.
121    pub fn new(dir: impl Into<std::path::PathBuf>) -> Result<Self, StoreError> {
122        let dir = dir.into();
123        std::fs::create_dir_all(&dir)
124            .map_err(|e| StoreError::Internal(format!("Failed to create store dir: {}", e)))?;
125        Ok(Self { dir })
126    }
127
128    fn key_path(&self, key: &str) -> std::path::PathBuf {
129        // Sanitize key: replace path separators and special chars
130        let safe_key: String = key
131            .chars()
132            .map(|c| {
133                if c.is_alphanumeric() || c == '-' || c == '_' {
134                    c
135                } else {
136                    '_'
137                }
138            })
139            .collect();
140        self.dir.join(format!("{}.json", safe_key))
141    }
142}
143
144impl Store for FileStore {
145    fn get(
146        &self,
147        key: &str,
148    ) -> Pin<Box<dyn Future<Output = Result<Option<serde_json::Value>, StoreError>> + Send + '_>>
149    {
150        let path = self.key_path(key);
151        Box::pin(async move {
152            match std::fs::read_to_string(&path) {
153                Ok(raw) => {
154                    let value = serde_json::from_str(&raw)
155                        .map_err(|e| StoreError::Serialization(e.to_string()))?;
156                    Ok(Some(value))
157                }
158                Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
159                Err(e) => Err(StoreError::Internal(e.to_string())),
160            }
161        })
162    }
163
164    fn put(
165        &self,
166        key: &str,
167        value: serde_json::Value,
168    ) -> Pin<Box<dyn Future<Output = Result<(), StoreError>> + Send + '_>> {
169        let path = self.key_path(key);
170        Box::pin(async move {
171            let serialized = serde_json::to_string_pretty(&value)
172                .map_err(|e| StoreError::Serialization(e.to_string()))?;
173            std::fs::write(&path, serialized).map_err(|e| StoreError::Internal(e.to_string()))?;
174            Ok(())
175        })
176    }
177
178    fn delete(
179        &self,
180        key: &str,
181    ) -> Pin<Box<dyn Future<Output = Result<(), StoreError>> + Send + '_>> {
182        let path = self.key_path(key);
183        Box::pin(async move {
184            match std::fs::remove_file(&path) {
185                Ok(()) => Ok(()),
186                Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
187                Err(e) => Err(StoreError::Internal(e.to_string())),
188            }
189        })
190    }
191}
192
193// ==================== ChannelStoreAdapter ====================
194
195/// Adapter that implements `ChannelStore` using a generic `Store` backend.
196///
197/// This allows using any persistent store (file, Redis, etc.) for channel state.
198#[cfg(all(feature = "server", feature = "tempo"))]
199pub struct ChannelStoreAdapter {
200    store: std::sync::Arc<dyn Store>,
201    prefix: String,
202}
203
204#[cfg(all(feature = "server", feature = "tempo"))]
205impl ChannelStoreAdapter {
206    /// Create a new adapter with the given store and key prefix.
207    pub fn new(store: std::sync::Arc<dyn Store>, prefix: impl Into<String>) -> Self {
208        Self {
209            store,
210            prefix: prefix.into(),
211        }
212    }
213
214    fn channel_key(&self, channel_id: &str) -> String {
215        format!("{}{}", self.prefix, channel_id)
216    }
217}
218
219#[cfg(all(feature = "server", feature = "tempo"))]
220impl crate::protocol::methods::tempo::session_method::ChannelStore for ChannelStoreAdapter {
221    fn get_channel(
222        &self,
223        channel_id: &str,
224    ) -> Pin<
225        Box<
226            dyn Future<
227                    Output = Result<
228                        Option<crate::protocol::methods::tempo::session_method::ChannelState>,
229                        crate::protocol::traits::VerificationError,
230                    >,
231                > + Send
232                + '_,
233        >,
234    > {
235        let key = self.channel_key(channel_id);
236        Box::pin(async move {
237            let value = self
238                .store
239                .get(&key)
240                .await
241                .map_err(|e| crate::protocol::traits::VerificationError::new(e.to_string()))?;
242            match value {
243                Some(v) => {
244                    let state = serde_json::from_value(v).map_err(|e| {
245                        crate::protocol::traits::VerificationError::new(format!(
246                            "Failed to deserialize channel state: {}",
247                            e
248                        ))
249                    })?;
250                    Ok(Some(state))
251                }
252                None => Ok(None),
253            }
254        })
255    }
256
257    fn update_channel(
258        &self,
259        channel_id: &str,
260        updater: Box<
261            dyn FnOnce(
262                    Option<crate::protocol::methods::tempo::session_method::ChannelState>,
263                ) -> Result<
264                    Option<crate::protocol::methods::tempo::session_method::ChannelState>,
265                    crate::protocol::traits::VerificationError,
266                > + Send,
267        >,
268    ) -> Pin<
269        Box<
270            dyn Future<
271                    Output = Result<
272                        Option<crate::protocol::methods::tempo::session_method::ChannelState>,
273                        crate::protocol::traits::VerificationError,
274                    >,
275                > + Send
276                + '_,
277        >,
278    > {
279        let key = self.channel_key(channel_id);
280        Box::pin(async move {
281            let current_value = self
282                .store
283                .get(&key)
284                .await
285                .map_err(|e| crate::protocol::traits::VerificationError::new(e.to_string()))?;
286            let current_state: Option<
287                crate::protocol::methods::tempo::session_method::ChannelState,
288            > = match current_value {
289                Some(v) => Some(serde_json::from_value(v).map_err(|e| {
290                    crate::protocol::traits::VerificationError::new(format!(
291                        "Failed to deserialize channel state: {}",
292                        e
293                    ))
294                })?),
295                None => None,
296            };
297
298            let result = updater(current_state)?;
299
300            match &result {
301                Some(state) => {
302                    let value = serde_json::to_value(state).map_err(|e| {
303                        crate::protocol::traits::VerificationError::new(format!(
304                            "Failed to serialize channel state: {}",
305                            e
306                        ))
307                    })?;
308                    self.store.put(&key, value).await.map_err(|e| {
309                        crate::protocol::traits::VerificationError::new(e.to_string())
310                    })?;
311                }
312                None => {
313                    self.store.delete(&key).await.map_err(|e| {
314                        crate::protocol::traits::VerificationError::new(e.to_string())
315                    })?;
316                }
317            }
318
319            Ok(result)
320        })
321    }
322}
323
324// ==================== Tests ====================
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    #[tokio::test]
331    async fn memory_store_get_put_delete() {
332        let store = MemoryStore::new();
333
334        // Missing key returns None
335        assert!(store.get("missing").await.unwrap().is_none());
336
337        // Put and get
338        let value = serde_json::json!({"name": "alice", "balance": 100});
339        store.put("user:1", value.clone()).await.unwrap();
340        assert_eq!(store.get("user:1").await.unwrap(), Some(value));
341
342        // Delete
343        store.delete("user:1").await.unwrap();
344        assert!(store.get("user:1").await.unwrap().is_none());
345
346        // Delete missing key is a no-op
347        store.delete("nonexistent").await.unwrap();
348    }
349
350    #[tokio::test]
351    async fn memory_store_overwrite() {
352        let store = MemoryStore::new();
353        store.put("k", serde_json::json!("first")).await.unwrap();
354        store.put("k", serde_json::json!("second")).await.unwrap();
355        assert_eq!(
356            store.get("k").await.unwrap(),
357            Some(serde_json::json!("second"))
358        );
359    }
360
361    #[tokio::test]
362    async fn file_store_get_put_delete() {
363        let tmp = std::env::temp_dir().join(format!("mpp_file_store_test_{}", std::process::id()));
364        let _ = std::fs::remove_dir_all(&tmp);
365        let store = FileStore::new(&tmp).unwrap();
366
367        // Missing key returns None
368        assert!(store.get("missing").await.unwrap().is_none());
369
370        // Put and get
371        let value = serde_json::json!({"name": "bob", "items": [1, 2, 3]});
372        store.put("data:1", value.clone()).await.unwrap();
373        assert_eq!(store.get("data:1").await.unwrap(), Some(value));
374
375        // Delete
376        store.delete("data:1").await.unwrap();
377        assert!(store.get("data:1").await.unwrap().is_none());
378
379        // Delete missing key is a no-op
380        store.delete("nonexistent").await.unwrap();
381
382        // Cleanup
383        let _ = std::fs::remove_dir_all(&tmp);
384    }
385
386    #[tokio::test]
387    async fn file_store_overwrite() {
388        let tmp = std::env::temp_dir().join(format!(
389            "mpp_file_store_overwrite_test_{}",
390            std::process::id()
391        ));
392        let _ = std::fs::remove_dir_all(&tmp);
393        let store = FileStore::new(&tmp).unwrap();
394
395        store.put("k", serde_json::json!("first")).await.unwrap();
396        store.put("k", serde_json::json!("second")).await.unwrap();
397        assert_eq!(
398            store.get("k").await.unwrap(),
399            Some(serde_json::json!("second"))
400        );
401
402        let _ = std::fs::remove_dir_all(&tmp);
403    }
404}
405
406#[cfg(all(test, feature = "server", feature = "tempo"))]
407mod adapter_tests {
408    use super::*;
409    use crate::protocol::methods::tempo::session_method::{ChannelState, ChannelStore};
410    use alloy::primitives::Address;
411    use std::sync::Arc;
412
413    fn test_channel_state(channel_id: &str) -> ChannelState {
414        ChannelState {
415            channel_id: channel_id.to_string(),
416            chain_id: 42431,
417            escrow_contract: Address::ZERO,
418            payer: Address::ZERO,
419            payee: Address::ZERO,
420            token: Address::ZERO,
421            authorized_signer: Address::ZERO,
422            deposit: 1000,
423            settled_on_chain: 0,
424            highest_voucher_amount: 0,
425            highest_voucher_signature: None,
426            spent: 0,
427            units: 0,
428            finalized: false,
429            close_requested_at: 0,
430            created_at: "2025-01-01T00:00:00Z".to_string(),
431        }
432    }
433
434    #[tokio::test]
435    async fn channel_store_adapter_get_and_update() {
436        let store = Arc::new(MemoryStore::new());
437        let adapter = ChannelStoreAdapter::new(store, "channels:");
438
439        // Get missing channel
440        assert!(adapter.get_channel("ch1").await.unwrap().is_none());
441
442        // Update (insert) a channel
443        let state = test_channel_state("ch1");
444        let result = adapter
445            .update_channel("ch1", Box::new(move |_current| Ok(Some(state))))
446            .await
447            .unwrap();
448        assert!(result.is_some());
449        assert_eq!(result.unwrap().channel_id, "ch1");
450
451        // Get the channel back
452        let fetched = adapter.get_channel("ch1").await.unwrap().unwrap();
453        assert_eq!(fetched.channel_id, "ch1");
454        assert_eq!(fetched.deposit, 1000);
455
456        // Update existing channel (increment spent)
457        let result = adapter
458            .update_channel(
459                "ch1",
460                Box::new(|current| {
461                    let mut s = current.unwrap();
462                    s.spent = 500;
463                    s.units = 10;
464                    Ok(Some(s))
465                }),
466            )
467            .await
468            .unwrap();
469        let updated = result.unwrap();
470        assert_eq!(updated.spent, 500);
471        assert_eq!(updated.units, 10);
472
473        // Delete via update returning None
474        let result = adapter
475            .update_channel("ch1", Box::new(|_| Ok(None)))
476            .await
477            .unwrap();
478        assert!(result.is_none());
479        assert!(adapter.get_channel("ch1").await.unwrap().is_none());
480    }
481}