Skip to main content

mpp_br/
store.rs

1//! Pluggable key-value store abstraction.
2//!
3//! Modeled after the TypeScript SDK's Store interface (Cloudflare KV API style).
4//! Implementations handle serialization internally.
5
6use std::future::Future;
7use std::pin::Pin;
8
9/// Async key-value store interface.
10///
11/// Simple `get`/`put`/`delete` API compatible with various backends:
12/// - In-memory (for development/testing)
13/// - File-system (for simple persistence)
14/// - Redis, SQLite, etc. (for production)
15pub trait Store: Send + Sync {
16    /// Get a value by key. Returns None if not found.
17    fn get(
18        &self,
19        key: &str,
20    ) -> Pin<Box<dyn Future<Output = Result<Option<serde_json::Value>, StoreError>> + Send + '_>>;
21
22    /// Put a value by key.
23    fn put(
24        &self,
25        key: &str,
26        value: serde_json::Value,
27    ) -> Pin<Box<dyn Future<Output = Result<(), StoreError>> + Send + '_>>;
28
29    /// Delete a value by key.
30    fn delete(
31        &self,
32        key: &str,
33    ) -> Pin<Box<dyn Future<Output = Result<(), StoreError>> + Send + '_>>;
34}
35
36#[derive(Debug, thiserror::Error)]
37pub enum StoreError {
38    #[error("Store error: {0}")]
39    Internal(String),
40    #[error("Serialization error: {0}")]
41    Serialization(String),
42}
43
44// ==================== MemoryStore ====================
45
46/// In-memory store backed by a HashMap. JSON-roundtrips values to match production behavior.
47pub struct MemoryStore {
48    data: std::sync::Mutex<std::collections::HashMap<String, String>>,
49}
50
51impl Default for MemoryStore {
52    fn default() -> Self {
53        Self {
54            data: std::sync::Mutex::new(std::collections::HashMap::new()),
55        }
56    }
57}
58
59impl MemoryStore {
60    pub fn new() -> Self {
61        Self::default()
62    }
63}
64
65impl Store for MemoryStore {
66    fn get(
67        &self,
68        key: &str,
69    ) -> Pin<Box<dyn Future<Output = Result<Option<serde_json::Value>, StoreError>> + Send + '_>>
70    {
71        let result = self.data.lock().unwrap().get(key).cloned();
72        Box::pin(async move {
73            match result {
74                Some(raw) => {
75                    let value = serde_json::from_str(&raw)
76                        .map_err(|e| StoreError::Serialization(e.to_string()))?;
77                    Ok(Some(value))
78                }
79                None => Ok(None),
80            }
81        })
82    }
83
84    fn put(
85        &self,
86        key: &str,
87        value: serde_json::Value,
88    ) -> Pin<Box<dyn Future<Output = Result<(), StoreError>> + Send + '_>> {
89        let key = key.to_string();
90        let serialized =
91            serde_json::to_string(&value).map_err(|e| StoreError::Serialization(e.to_string()));
92        Box::pin(async move {
93            let serialized = serialized?;
94            self.data.lock().unwrap().insert(key, serialized);
95            Ok(())
96        })
97    }
98
99    fn delete(
100        &self,
101        key: &str,
102    ) -> Pin<Box<dyn Future<Output = Result<(), StoreError>> + Send + '_>> {
103        self.data.lock().unwrap().remove(key);
104        Box::pin(async { Ok(()) })
105    }
106}
107
108// ==================== FileStore ====================
109
110/// File-system backed store. Each key is stored as a JSON file.
111///
112/// Useful for development and simple deployments where a database is overkill.
113pub struct FileStore {
114    dir: std::path::PathBuf,
115}
116
117impl FileStore {
118    /// Create a new FileStore that persists data in the given directory.
119    ///
120    /// Creates the directory if it does not exist.
121    pub fn new(dir: impl Into<std::path::PathBuf>) -> Result<Self, StoreError> {
122        let dir = dir.into();
123        std::fs::create_dir_all(&dir)
124            .map_err(|e| StoreError::Internal(format!("Failed to create store dir: {}", e)))?;
125        Ok(Self { dir })
126    }
127
128    fn key_path(&self, key: &str) -> std::path::PathBuf {
129        // Sanitize key: replace path separators and special chars
130        let safe_key: String = key
131            .chars()
132            .map(|c| {
133                if c.is_alphanumeric() || c == '-' || c == '_' {
134                    c
135                } else {
136                    '_'
137                }
138            })
139            .collect();
140        self.dir.join(format!("{}.json", safe_key))
141    }
142}
143
144impl Store for FileStore {
145    fn get(
146        &self,
147        key: &str,
148    ) -> Pin<Box<dyn Future<Output = Result<Option<serde_json::Value>, StoreError>> + Send + '_>>
149    {
150        let path = self.key_path(key);
151        Box::pin(async move {
152            match std::fs::read_to_string(&path) {
153                Ok(raw) => {
154                    let value = serde_json::from_str(&raw)
155                        .map_err(|e| StoreError::Serialization(e.to_string()))?;
156                    Ok(Some(value))
157                }
158                Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
159                Err(e) => Err(StoreError::Internal(e.to_string())),
160            }
161        })
162    }
163
164    fn put(
165        &self,
166        key: &str,
167        value: serde_json::Value,
168    ) -> Pin<Box<dyn Future<Output = Result<(), StoreError>> + Send + '_>> {
169        let path = self.key_path(key);
170        Box::pin(async move {
171            let serialized = serde_json::to_string_pretty(&value)
172                .map_err(|e| StoreError::Serialization(e.to_string()))?;
173            std::fs::write(&path, serialized).map_err(|e| StoreError::Internal(e.to_string()))?;
174            Ok(())
175        })
176    }
177
178    fn delete(
179        &self,
180        key: &str,
181    ) -> Pin<Box<dyn Future<Output = Result<(), StoreError>> + Send + '_>> {
182        let path = self.key_path(key);
183        Box::pin(async move {
184            match std::fs::remove_file(&path) {
185                Ok(()) => Ok(()),
186                Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
187                Err(e) => Err(StoreError::Internal(e.to_string())),
188            }
189        })
190    }
191}
192
193// ==================== ChannelStoreAdapter ====================
194
195/// Adapter that implements `ChannelStore` using a generic `Store` backend.
196///
197/// This adapter serializes `update_channel` calls for the same channel within a
198/// single Rust process so it can satisfy `ChannelStore`'s atomic
199/// read-modify-write contract over the weaker `Store` trait.
200///
201/// It does not provide distributed locking. If multiple processes share the
202/// same backend, use a store implementation with native compare-and-swap or
203/// transactional update semantics instead of relying on this adapter alone.
204#[cfg(all(feature = "server", feature = "tempo"))]
205pub struct ChannelStoreAdapter {
206    store: std::sync::Arc<dyn Store>,
207    prefix: String,
208    channel_locks:
209        std::sync::Mutex<std::collections::HashMap<String, std::sync::Arc<tokio::sync::Mutex<()>>>>,
210}
211
212#[cfg(all(feature = "server", feature = "tempo"))]
213impl ChannelStoreAdapter {
214    /// Create a new adapter with the given store and key prefix.
215    pub fn new(store: std::sync::Arc<dyn Store>, prefix: impl Into<String>) -> Self {
216        Self {
217            store,
218            prefix: prefix.into(),
219            channel_locks: std::sync::Mutex::new(std::collections::HashMap::new()),
220        }
221    }
222
223    fn channel_key(&self, channel_id: &str) -> String {
224        format!("{}{}", self.prefix, channel_id)
225    }
226
227    fn channel_lock(&self, key: &str) -> std::sync::Arc<tokio::sync::Mutex<()>> {
228        self.channel_locks
229            .lock()
230            .unwrap()
231            .entry(key.to_string())
232            .or_insert_with(|| std::sync::Arc::new(tokio::sync::Mutex::new(())))
233            .clone()
234    }
235}
236
237#[cfg(all(feature = "server", feature = "tempo"))]
238impl crate::protocol::methods::tempo::session_method::ChannelStore for ChannelStoreAdapter {
239    fn get_channel(
240        &self,
241        channel_id: &str,
242    ) -> Pin<
243        Box<
244            dyn Future<
245                    Output = Result<
246                        Option<crate::protocol::methods::tempo::session_method::ChannelState>,
247                        crate::protocol::traits::VerificationError,
248                    >,
249                > + Send
250                + '_,
251        >,
252    > {
253        let key = self.channel_key(channel_id);
254        Box::pin(async move {
255            let value = self
256                .store
257                .get(&key)
258                .await
259                .map_err(|e| crate::protocol::traits::VerificationError::new(e.to_string()))?;
260            match value {
261                Some(v) => {
262                    let state = serde_json::from_value(v).map_err(|e| {
263                        crate::protocol::traits::VerificationError::new(format!(
264                            "Failed to deserialize channel state: {}",
265                            e
266                        ))
267                    })?;
268                    Ok(Some(state))
269                }
270                None => Ok(None),
271            }
272        })
273    }
274
275    fn update_channel(
276        &self,
277        channel_id: &str,
278        updater: Box<
279            dyn FnOnce(
280                    Option<crate::protocol::methods::tempo::session_method::ChannelState>,
281                ) -> Result<
282                    Option<crate::protocol::methods::tempo::session_method::ChannelState>,
283                    crate::protocol::traits::VerificationError,
284                > + Send,
285        >,
286    ) -> Pin<
287        Box<
288            dyn Future<
289                    Output = Result<
290                        Option<crate::protocol::methods::tempo::session_method::ChannelState>,
291                        crate::protocol::traits::VerificationError,
292                    >,
293                > + Send
294                + '_,
295        >,
296    > {
297        let key = self.channel_key(channel_id);
298        let channel_lock = self.channel_lock(&key);
299        Box::pin(async move {
300            let _guard = channel_lock.lock().await;
301            let current_value = self
302                .store
303                .get(&key)
304                .await
305                .map_err(|e| crate::protocol::traits::VerificationError::new(e.to_string()))?;
306            let current_state: Option<
307                crate::protocol::methods::tempo::session_method::ChannelState,
308            > = match current_value {
309                Some(v) => Some(serde_json::from_value(v).map_err(|e| {
310                    crate::protocol::traits::VerificationError::new(format!(
311                        "Failed to deserialize channel state: {}",
312                        e
313                    ))
314                })?),
315                None => None,
316            };
317
318            let result = updater(current_state)?;
319
320            match &result {
321                Some(state) => {
322                    let value = serde_json::to_value(state).map_err(|e| {
323                        crate::protocol::traits::VerificationError::new(format!(
324                            "Failed to serialize channel state: {}",
325                            e
326                        ))
327                    })?;
328                    self.store.put(&key, value).await.map_err(|e| {
329                        crate::protocol::traits::VerificationError::new(e.to_string())
330                    })?;
331                }
332                None => {
333                    self.store.delete(&key).await.map_err(|e| {
334                        crate::protocol::traits::VerificationError::new(e.to_string())
335                    })?;
336                    // Evict the per-channel lock so the map doesn't grow unboundedly.
337                    if let Ok(mut locks) = self.channel_locks.lock() {
338                        locks.remove(&key);
339                    }
340                }
341            }
342
343            Ok(result)
344        })
345    }
346}
347
348// ==================== Tests ====================
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[tokio::test]
355    async fn memory_store_get_put_delete() {
356        let store = MemoryStore::new();
357
358        // Missing key returns None
359        assert!(store.get("missing").await.unwrap().is_none());
360
361        // Put and get
362        let value = serde_json::json!({"name": "alice", "balance": 100});
363        store.put("user:1", value.clone()).await.unwrap();
364        assert_eq!(store.get("user:1").await.unwrap(), Some(value));
365
366        // Delete
367        store.delete("user:1").await.unwrap();
368        assert!(store.get("user:1").await.unwrap().is_none());
369
370        // Delete missing key is a no-op
371        store.delete("nonexistent").await.unwrap();
372    }
373
374    #[tokio::test]
375    async fn memory_store_overwrite() {
376        let store = MemoryStore::new();
377        store.put("k", serde_json::json!("first")).await.unwrap();
378        store.put("k", serde_json::json!("second")).await.unwrap();
379        assert_eq!(
380            store.get("k").await.unwrap(),
381            Some(serde_json::json!("second"))
382        );
383    }
384
385    #[tokio::test]
386    async fn file_store_get_put_delete() {
387        let tmp = std::env::temp_dir().join(format!("mpp_file_store_test_{}", std::process::id()));
388        let _ = std::fs::remove_dir_all(&tmp);
389        let store = FileStore::new(&tmp).unwrap();
390
391        // Missing key returns None
392        assert!(store.get("missing").await.unwrap().is_none());
393
394        // Put and get
395        let value = serde_json::json!({"name": "bob", "items": [1, 2, 3]});
396        store.put("data:1", value.clone()).await.unwrap();
397        assert_eq!(store.get("data:1").await.unwrap(), Some(value));
398
399        // Delete
400        store.delete("data:1").await.unwrap();
401        assert!(store.get("data:1").await.unwrap().is_none());
402
403        // Delete missing key is a no-op
404        store.delete("nonexistent").await.unwrap();
405
406        // Cleanup
407        let _ = std::fs::remove_dir_all(&tmp);
408    }
409
410    #[tokio::test]
411    async fn file_store_overwrite() {
412        let tmp = std::env::temp_dir().join(format!(
413            "mpp_file_store_overwrite_test_{}",
414            std::process::id()
415        ));
416        let _ = std::fs::remove_dir_all(&tmp);
417        let store = FileStore::new(&tmp).unwrap();
418
419        store.put("k", serde_json::json!("first")).await.unwrap();
420        store.put("k", serde_json::json!("second")).await.unwrap();
421        assert_eq!(
422            store.get("k").await.unwrap(),
423            Some(serde_json::json!("second"))
424        );
425
426        let _ = std::fs::remove_dir_all(&tmp);
427    }
428}
429
430#[cfg(all(test, feature = "server", feature = "tempo"))]
431mod adapter_tests {
432    use super::*;
433    use crate::protocol::methods::tempo::session_method::deduct_from_channel;
434    use crate::protocol::methods::tempo::session_method::{ChannelState, ChannelStore};
435    use alloy::primitives::Address;
436    use std::sync::Arc;
437    use std::time::Duration;
438
439    struct SlowMemoryStore {
440        inner: MemoryStore,
441        delay: Duration,
442    }
443
444    impl SlowMemoryStore {
445        fn new(delay: Duration) -> Self {
446            Self {
447                inner: MemoryStore::new(),
448                delay,
449            }
450        }
451    }
452
453    impl Store for SlowMemoryStore {
454        fn get(
455            &self,
456            key: &str,
457        ) -> Pin<Box<dyn Future<Output = Result<Option<serde_json::Value>, StoreError>> + Send + '_>>
458        {
459            let key = key.to_string();
460            let delay = self.delay;
461            Box::pin(async move {
462                tokio::time::sleep(delay).await;
463                self.inner.get(&key).await
464            })
465        }
466
467        fn put(
468            &self,
469            key: &str,
470            value: serde_json::Value,
471        ) -> Pin<Box<dyn Future<Output = Result<(), StoreError>> + Send + '_>> {
472            let key = key.to_string();
473            let delay = self.delay;
474            Box::pin(async move {
475                tokio::time::sleep(delay).await;
476                self.inner.put(&key, value).await
477            })
478        }
479
480        fn delete(
481            &self,
482            key: &str,
483        ) -> Pin<Box<dyn Future<Output = Result<(), StoreError>> + Send + '_>> {
484            let key = key.to_string();
485            let delay = self.delay;
486            Box::pin(async move {
487                tokio::time::sleep(delay).await;
488                self.inner.delete(&key).await
489            })
490        }
491    }
492
493    fn test_channel_state(channel_id: &str) -> ChannelState {
494        ChannelState {
495            channel_id: channel_id.to_string(),
496            chain_id: 42431,
497            escrow_contract: Address::ZERO,
498            payer: Address::ZERO,
499            payee: Address::ZERO,
500            token: Address::ZERO,
501            authorized_signer: Address::ZERO,
502            deposit: 1000,
503            settled_on_chain: 0,
504            highest_voucher_amount: 0,
505            highest_voucher_signature: None,
506            spent: 0,
507            units: 0,
508            finalized: false,
509            close_requested_at: 0,
510            created_at: "2025-01-01T00:00:00Z".to_string(),
511        }
512    }
513
514    #[tokio::test]
515    async fn channel_store_adapter_get_and_update() {
516        let store = Arc::new(MemoryStore::new());
517        let adapter = ChannelStoreAdapter::new(store, "channels:");
518
519        // Get missing channel
520        assert!(adapter.get_channel("ch1").await.unwrap().is_none());
521
522        // Update (insert) a channel
523        let state = test_channel_state("ch1");
524        let result = adapter
525            .update_channel("ch1", Box::new(move |_current| Ok(Some(state))))
526            .await
527            .unwrap();
528        assert!(result.is_some());
529        assert_eq!(result.unwrap().channel_id, "ch1");
530
531        // Get the channel back
532        let fetched = adapter.get_channel("ch1").await.unwrap().unwrap();
533        assert_eq!(fetched.channel_id, "ch1");
534        assert_eq!(fetched.deposit, 1000);
535
536        // Update existing channel (increment spent)
537        let result = adapter
538            .update_channel(
539                "ch1",
540                Box::new(|current| {
541                    let mut s = current.unwrap();
542                    s.spent = 500;
543                    s.units = 10;
544                    Ok(Some(s))
545                }),
546            )
547            .await
548            .unwrap();
549        let updated = result.unwrap();
550        assert_eq!(updated.spent, 500);
551        assert_eq!(updated.units, 10);
552
553        // Delete via update returning None
554        let result = adapter
555            .update_channel("ch1", Box::new(|_| Ok(None)))
556            .await
557            .unwrap();
558        assert!(result.is_none());
559        assert!(adapter.get_channel("ch1").await.unwrap().is_none());
560    }
561
562    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
563    async fn channel_store_adapter_same_channel_deduction_race() {
564        let store = Arc::new(SlowMemoryStore::new(Duration::from_millis(25)));
565        let adapter = Arc::new(ChannelStoreAdapter::new(store, "channels:"));
566
567        let mut state = test_channel_state("ch1");
568        state.highest_voucher_amount = 10_000;
569        adapter
570            .update_channel("ch1", Box::new(move |_| Ok(Some(state))))
571            .await
572            .unwrap();
573
574        let start = Arc::new(tokio::sync::Barrier::new(3));
575
576        let adapter1 = adapter.clone();
577        let start1 = start.clone();
578        let task1 = tokio::spawn(async move {
579            start1.wait().await;
580            deduct_from_channel(&*adapter1, "ch1", 7_000).await
581        });
582
583        let adapter2 = adapter.clone();
584        let start2 = start.clone();
585        let task2 = tokio::spawn(async move {
586            start2.wait().await;
587            deduct_from_channel(&*adapter2, "ch1", 7_000).await
588        });
589
590        // This is the original repro path: two concurrent deductions against the
591        // same adapter-backed channel. Before the fix, both calls could report
592        // success even though only one update was persisted.
593        start.wait().await;
594
595        let result1 = task1.await.unwrap();
596        let result2 = task2.await.unwrap();
597        let successes = [result1.is_ok(), result2.is_ok()]
598            .into_iter()
599            .filter(|ok| *ok)
600            .count();
601        assert_eq!(
602            successes, 1,
603            "the repro must not allow both concurrent deductions to succeed"
604        );
605
606        let error = result1.err().or_else(|| result2.err()).unwrap();
607        assert!(
608            error.to_string().contains("available 3000"),
609            "expected insufficient balance after the first deduction, got: {error}"
610        );
611
612        let stored = adapter.get_channel("ch1").await.unwrap().unwrap();
613        assert_eq!(stored.spent, 7_000);
614        assert_eq!(stored.units, 1);
615    }
616
617    #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
618    async fn channel_store_adapter_serializes_same_channel_update_channel_calls() {
619        let store = Arc::new(SlowMemoryStore::new(Duration::from_millis(25)));
620        let adapter = Arc::new(ChannelStoreAdapter::new(store, "channels:"));
621
622        let state = test_channel_state("ch1");
623        adapter
624            .update_channel("ch1", Box::new(move |_| Ok(Some(state))))
625            .await
626            .unwrap();
627
628        let start = Arc::new(tokio::sync::Barrier::new(3));
629
630        let adapter1 = adapter.clone();
631        let start1 = start.clone();
632        let task1 = tokio::spawn(async move {
633            start1.wait().await;
634            adapter1
635                .update_channel(
636                    "ch1",
637                    Box::new(|current| {
638                        let mut state = current.unwrap();
639                        state.spent += 1;
640                        state.units += 1;
641                        Ok(Some(state))
642                    }),
643                )
644                .await
645        });
646
647        let adapter2 = adapter.clone();
648        let start2 = start.clone();
649        let task2 = tokio::spawn(async move {
650            start2.wait().await;
651            adapter2
652                .update_channel(
653                    "ch1",
654                    Box::new(|current| {
655                        let mut state = current.unwrap();
656                        state.spent += 1;
657                        state.units += 1;
658                        Ok(Some(state))
659                    }),
660                )
661                .await
662        });
663
664        start.wait().await;
665
666        let result1 = task1.await.unwrap().unwrap().unwrap();
667        let result2 = task2.await.unwrap().unwrap().unwrap();
668        let mut returned_spent = [result1.spent, result2.spent];
669        returned_spent.sort_unstable();
670        assert_eq!(returned_spent, [1, 2]);
671
672        let stored = adapter.get_channel("ch1").await.unwrap().unwrap();
673        assert_eq!(stored.spent, 2);
674        assert_eq!(stored.units, 2);
675    }
676
677    #[tokio::test]
678    async fn channel_store_adapter_evicts_lock_on_channel_delete() {
679        let store = Arc::new(MemoryStore::new());
680        let adapter = ChannelStoreAdapter::new(store, "channels:");
681
682        // Create channels
683        for id in ["ch1", "ch2", "ch3"] {
684            let state = test_channel_state(id);
685            adapter
686                .update_channel(id, Box::new(move |_| Ok(Some(state))))
687                .await
688                .unwrap();
689        }
690        assert_eq!(adapter.channel_locks.lock().unwrap().len(), 3);
691
692        // Delete ch2 via update returning None
693        adapter
694            .update_channel("ch2", Box::new(|_| Ok(None)))
695            .await
696            .unwrap();
697
698        // Lock entry for ch2 should be evicted
699        let locks = adapter.channel_locks.lock().unwrap();
700        assert_eq!(locks.len(), 2);
701        assert!(!locks.contains_key("channels:ch2"));
702        assert!(locks.contains_key("channels:ch1"));
703        assert!(locks.contains_key("channels:ch3"));
704    }
705}