1use std::future::Future;
7use std::pin::Pin;
8
9pub trait Store: Send + Sync {
16 fn get(
18 &self,
19 key: &str,
20 ) -> Pin<Box<dyn Future<Output = Result<Option<serde_json::Value>, StoreError>> + Send + '_>>;
21
22 fn put(
24 &self,
25 key: &str,
26 value: serde_json::Value,
27 ) -> Pin<Box<dyn Future<Output = Result<(), StoreError>> + Send + '_>>;
28
29 fn delete(
31 &self,
32 key: &str,
33 ) -> Pin<Box<dyn Future<Output = Result<(), StoreError>> + Send + '_>>;
34}
35
36#[derive(Debug, thiserror::Error)]
37pub enum StoreError {
38 #[error("Store error: {0}")]
39 Internal(String),
40 #[error("Serialization error: {0}")]
41 Serialization(String),
42}
43
44pub struct MemoryStore {
48 data: std::sync::Mutex<std::collections::HashMap<String, String>>,
49}
50
51impl Default for MemoryStore {
52 fn default() -> Self {
53 Self {
54 data: std::sync::Mutex::new(std::collections::HashMap::new()),
55 }
56 }
57}
58
59impl MemoryStore {
60 pub fn new() -> Self {
61 Self::default()
62 }
63}
64
65impl Store for MemoryStore {
66 fn get(
67 &self,
68 key: &str,
69 ) -> Pin<Box<dyn Future<Output = Result<Option<serde_json::Value>, StoreError>> + Send + '_>>
70 {
71 let result = self.data.lock().unwrap().get(key).cloned();
72 Box::pin(async move {
73 match result {
74 Some(raw) => {
75 let value = serde_json::from_str(&raw)
76 .map_err(|e| StoreError::Serialization(e.to_string()))?;
77 Ok(Some(value))
78 }
79 None => Ok(None),
80 }
81 })
82 }
83
84 fn put(
85 &self,
86 key: &str,
87 value: serde_json::Value,
88 ) -> Pin<Box<dyn Future<Output = Result<(), StoreError>> + Send + '_>> {
89 let key = key.to_string();
90 let serialized =
91 serde_json::to_string(&value).map_err(|e| StoreError::Serialization(e.to_string()));
92 Box::pin(async move {
93 let serialized = serialized?;
94 self.data.lock().unwrap().insert(key, serialized);
95 Ok(())
96 })
97 }
98
99 fn delete(
100 &self,
101 key: &str,
102 ) -> Pin<Box<dyn Future<Output = Result<(), StoreError>> + Send + '_>> {
103 self.data.lock().unwrap().remove(key);
104 Box::pin(async { Ok(()) })
105 }
106}
107
108pub struct FileStore {
114 dir: std::path::PathBuf,
115}
116
117impl FileStore {
118 pub fn new(dir: impl Into<std::path::PathBuf>) -> Result<Self, StoreError> {
122 let dir = dir.into();
123 std::fs::create_dir_all(&dir)
124 .map_err(|e| StoreError::Internal(format!("Failed to create store dir: {}", e)))?;
125 Ok(Self { dir })
126 }
127
128 fn key_path(&self, key: &str) -> std::path::PathBuf {
129 let safe_key: String = key
131 .chars()
132 .map(|c| {
133 if c.is_alphanumeric() || c == '-' || c == '_' {
134 c
135 } else {
136 '_'
137 }
138 })
139 .collect();
140 self.dir.join(format!("{}.json", safe_key))
141 }
142}
143
144impl Store for FileStore {
145 fn get(
146 &self,
147 key: &str,
148 ) -> Pin<Box<dyn Future<Output = Result<Option<serde_json::Value>, StoreError>> + Send + '_>>
149 {
150 let path = self.key_path(key);
151 Box::pin(async move {
152 match std::fs::read_to_string(&path) {
153 Ok(raw) => {
154 let value = serde_json::from_str(&raw)
155 .map_err(|e| StoreError::Serialization(e.to_string()))?;
156 Ok(Some(value))
157 }
158 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
159 Err(e) => Err(StoreError::Internal(e.to_string())),
160 }
161 })
162 }
163
164 fn put(
165 &self,
166 key: &str,
167 value: serde_json::Value,
168 ) -> Pin<Box<dyn Future<Output = Result<(), StoreError>> + Send + '_>> {
169 let path = self.key_path(key);
170 Box::pin(async move {
171 let serialized = serde_json::to_string_pretty(&value)
172 .map_err(|e| StoreError::Serialization(e.to_string()))?;
173 std::fs::write(&path, serialized).map_err(|e| StoreError::Internal(e.to_string()))?;
174 Ok(())
175 })
176 }
177
178 fn delete(
179 &self,
180 key: &str,
181 ) -> Pin<Box<dyn Future<Output = Result<(), StoreError>> + Send + '_>> {
182 let path = self.key_path(key);
183 Box::pin(async move {
184 match std::fs::remove_file(&path) {
185 Ok(()) => Ok(()),
186 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
187 Err(e) => Err(StoreError::Internal(e.to_string())),
188 }
189 })
190 }
191}
192
193#[cfg(all(feature = "server", feature = "tempo"))]
205pub struct ChannelStoreAdapter {
206 store: std::sync::Arc<dyn Store>,
207 prefix: String,
208 channel_locks:
209 std::sync::Mutex<std::collections::HashMap<String, std::sync::Arc<tokio::sync::Mutex<()>>>>,
210}
211
212#[cfg(all(feature = "server", feature = "tempo"))]
213impl ChannelStoreAdapter {
214 pub fn new(store: std::sync::Arc<dyn Store>, prefix: impl Into<String>) -> Self {
216 Self {
217 store,
218 prefix: prefix.into(),
219 channel_locks: std::sync::Mutex::new(std::collections::HashMap::new()),
220 }
221 }
222
223 fn channel_key(&self, channel_id: &str) -> String {
224 format!("{}{}", self.prefix, channel_id)
225 }
226
227 fn channel_lock(&self, key: &str) -> std::sync::Arc<tokio::sync::Mutex<()>> {
228 self.channel_locks
229 .lock()
230 .unwrap()
231 .entry(key.to_string())
232 .or_insert_with(|| std::sync::Arc::new(tokio::sync::Mutex::new(())))
233 .clone()
234 }
235}
236
237#[cfg(all(feature = "server", feature = "tempo"))]
238impl crate::protocol::methods::tempo::session_method::ChannelStore for ChannelStoreAdapter {
239 fn get_channel(
240 &self,
241 channel_id: &str,
242 ) -> Pin<
243 Box<
244 dyn Future<
245 Output = Result<
246 Option<crate::protocol::methods::tempo::session_method::ChannelState>,
247 crate::protocol::traits::VerificationError,
248 >,
249 > + Send
250 + '_,
251 >,
252 > {
253 let key = self.channel_key(channel_id);
254 Box::pin(async move {
255 let value = self
256 .store
257 .get(&key)
258 .await
259 .map_err(|e| crate::protocol::traits::VerificationError::new(e.to_string()))?;
260 match value {
261 Some(v) => {
262 let state = serde_json::from_value(v).map_err(|e| {
263 crate::protocol::traits::VerificationError::new(format!(
264 "Failed to deserialize channel state: {}",
265 e
266 ))
267 })?;
268 Ok(Some(state))
269 }
270 None => Ok(None),
271 }
272 })
273 }
274
275 fn update_channel(
276 &self,
277 channel_id: &str,
278 updater: Box<
279 dyn FnOnce(
280 Option<crate::protocol::methods::tempo::session_method::ChannelState>,
281 ) -> Result<
282 Option<crate::protocol::methods::tempo::session_method::ChannelState>,
283 crate::protocol::traits::VerificationError,
284 > + Send,
285 >,
286 ) -> Pin<
287 Box<
288 dyn Future<
289 Output = Result<
290 Option<crate::protocol::methods::tempo::session_method::ChannelState>,
291 crate::protocol::traits::VerificationError,
292 >,
293 > + Send
294 + '_,
295 >,
296 > {
297 let key = self.channel_key(channel_id);
298 let channel_lock = self.channel_lock(&key);
299 Box::pin(async move {
300 let _guard = channel_lock.lock().await;
301 let current_value = self
302 .store
303 .get(&key)
304 .await
305 .map_err(|e| crate::protocol::traits::VerificationError::new(e.to_string()))?;
306 let current_state: Option<
307 crate::protocol::methods::tempo::session_method::ChannelState,
308 > = match current_value {
309 Some(v) => Some(serde_json::from_value(v).map_err(|e| {
310 crate::protocol::traits::VerificationError::new(format!(
311 "Failed to deserialize channel state: {}",
312 e
313 ))
314 })?),
315 None => None,
316 };
317
318 let result = updater(current_state)?;
319
320 match &result {
321 Some(state) => {
322 let value = serde_json::to_value(state).map_err(|e| {
323 crate::protocol::traits::VerificationError::new(format!(
324 "Failed to serialize channel state: {}",
325 e
326 ))
327 })?;
328 self.store.put(&key, value).await.map_err(|e| {
329 crate::protocol::traits::VerificationError::new(e.to_string())
330 })?;
331 }
332 None => {
333 self.store.delete(&key).await.map_err(|e| {
334 crate::protocol::traits::VerificationError::new(e.to_string())
335 })?;
336 if let Ok(mut locks) = self.channel_locks.lock() {
338 locks.remove(&key);
339 }
340 }
341 }
342
343 Ok(result)
344 })
345 }
346}
347
348#[cfg(test)]
351mod tests {
352 use super::*;
353
354 #[tokio::test]
355 async fn memory_store_get_put_delete() {
356 let store = MemoryStore::new();
357
358 assert!(store.get("missing").await.unwrap().is_none());
360
361 let value = serde_json::json!({"name": "alice", "balance": 100});
363 store.put("user:1", value.clone()).await.unwrap();
364 assert_eq!(store.get("user:1").await.unwrap(), Some(value));
365
366 store.delete("user:1").await.unwrap();
368 assert!(store.get("user:1").await.unwrap().is_none());
369
370 store.delete("nonexistent").await.unwrap();
372 }
373
374 #[tokio::test]
375 async fn memory_store_overwrite() {
376 let store = MemoryStore::new();
377 store.put("k", serde_json::json!("first")).await.unwrap();
378 store.put("k", serde_json::json!("second")).await.unwrap();
379 assert_eq!(
380 store.get("k").await.unwrap(),
381 Some(serde_json::json!("second"))
382 );
383 }
384
385 #[tokio::test]
386 async fn file_store_get_put_delete() {
387 let tmp = std::env::temp_dir().join(format!("mpp_file_store_test_{}", std::process::id()));
388 let _ = std::fs::remove_dir_all(&tmp);
389 let store = FileStore::new(&tmp).unwrap();
390
391 assert!(store.get("missing").await.unwrap().is_none());
393
394 let value = serde_json::json!({"name": "bob", "items": [1, 2, 3]});
396 store.put("data:1", value.clone()).await.unwrap();
397 assert_eq!(store.get("data:1").await.unwrap(), Some(value));
398
399 store.delete("data:1").await.unwrap();
401 assert!(store.get("data:1").await.unwrap().is_none());
402
403 store.delete("nonexistent").await.unwrap();
405
406 let _ = std::fs::remove_dir_all(&tmp);
408 }
409
410 #[tokio::test]
411 async fn file_store_overwrite() {
412 let tmp = std::env::temp_dir().join(format!(
413 "mpp_file_store_overwrite_test_{}",
414 std::process::id()
415 ));
416 let _ = std::fs::remove_dir_all(&tmp);
417 let store = FileStore::new(&tmp).unwrap();
418
419 store.put("k", serde_json::json!("first")).await.unwrap();
420 store.put("k", serde_json::json!("second")).await.unwrap();
421 assert_eq!(
422 store.get("k").await.unwrap(),
423 Some(serde_json::json!("second"))
424 );
425
426 let _ = std::fs::remove_dir_all(&tmp);
427 }
428}
429
430#[cfg(all(test, feature = "server", feature = "tempo"))]
431mod adapter_tests {
432 use super::*;
433 use crate::protocol::methods::tempo::session_method::deduct_from_channel;
434 use crate::protocol::methods::tempo::session_method::{ChannelState, ChannelStore};
435 use alloy::primitives::Address;
436 use std::sync::Arc;
437 use std::time::Duration;
438
439 struct SlowMemoryStore {
440 inner: MemoryStore,
441 delay: Duration,
442 }
443
444 impl SlowMemoryStore {
445 fn new(delay: Duration) -> Self {
446 Self {
447 inner: MemoryStore::new(),
448 delay,
449 }
450 }
451 }
452
453 impl Store for SlowMemoryStore {
454 fn get(
455 &self,
456 key: &str,
457 ) -> Pin<Box<dyn Future<Output = Result<Option<serde_json::Value>, StoreError>> + Send + '_>>
458 {
459 let key = key.to_string();
460 let delay = self.delay;
461 Box::pin(async move {
462 tokio::time::sleep(delay).await;
463 self.inner.get(&key).await
464 })
465 }
466
467 fn put(
468 &self,
469 key: &str,
470 value: serde_json::Value,
471 ) -> Pin<Box<dyn Future<Output = Result<(), StoreError>> + Send + '_>> {
472 let key = key.to_string();
473 let delay = self.delay;
474 Box::pin(async move {
475 tokio::time::sleep(delay).await;
476 self.inner.put(&key, value).await
477 })
478 }
479
480 fn delete(
481 &self,
482 key: &str,
483 ) -> Pin<Box<dyn Future<Output = Result<(), StoreError>> + Send + '_>> {
484 let key = key.to_string();
485 let delay = self.delay;
486 Box::pin(async move {
487 tokio::time::sleep(delay).await;
488 self.inner.delete(&key).await
489 })
490 }
491 }
492
493 fn test_channel_state(channel_id: &str) -> ChannelState {
494 ChannelState {
495 channel_id: channel_id.to_string(),
496 chain_id: 42431,
497 escrow_contract: Address::ZERO,
498 payer: Address::ZERO,
499 payee: Address::ZERO,
500 token: Address::ZERO,
501 authorized_signer: Address::ZERO,
502 deposit: 1000,
503 settled_on_chain: 0,
504 highest_voucher_amount: 0,
505 highest_voucher_signature: None,
506 spent: 0,
507 units: 0,
508 finalized: false,
509 close_requested_at: 0,
510 created_at: "2025-01-01T00:00:00Z".to_string(),
511 }
512 }
513
514 #[tokio::test]
515 async fn channel_store_adapter_get_and_update() {
516 let store = Arc::new(MemoryStore::new());
517 let adapter = ChannelStoreAdapter::new(store, "channels:");
518
519 assert!(adapter.get_channel("ch1").await.unwrap().is_none());
521
522 let state = test_channel_state("ch1");
524 let result = adapter
525 .update_channel("ch1", Box::new(move |_current| Ok(Some(state))))
526 .await
527 .unwrap();
528 assert!(result.is_some());
529 assert_eq!(result.unwrap().channel_id, "ch1");
530
531 let fetched = adapter.get_channel("ch1").await.unwrap().unwrap();
533 assert_eq!(fetched.channel_id, "ch1");
534 assert_eq!(fetched.deposit, 1000);
535
536 let result = adapter
538 .update_channel(
539 "ch1",
540 Box::new(|current| {
541 let mut s = current.unwrap();
542 s.spent = 500;
543 s.units = 10;
544 Ok(Some(s))
545 }),
546 )
547 .await
548 .unwrap();
549 let updated = result.unwrap();
550 assert_eq!(updated.spent, 500);
551 assert_eq!(updated.units, 10);
552
553 let result = adapter
555 .update_channel("ch1", Box::new(|_| Ok(None)))
556 .await
557 .unwrap();
558 assert!(result.is_none());
559 assert!(adapter.get_channel("ch1").await.unwrap().is_none());
560 }
561
562 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
563 async fn channel_store_adapter_same_channel_deduction_race() {
564 let store = Arc::new(SlowMemoryStore::new(Duration::from_millis(25)));
565 let adapter = Arc::new(ChannelStoreAdapter::new(store, "channels:"));
566
567 let mut state = test_channel_state("ch1");
568 state.highest_voucher_amount = 10_000;
569 adapter
570 .update_channel("ch1", Box::new(move |_| Ok(Some(state))))
571 .await
572 .unwrap();
573
574 let start = Arc::new(tokio::sync::Barrier::new(3));
575
576 let adapter1 = adapter.clone();
577 let start1 = start.clone();
578 let task1 = tokio::spawn(async move {
579 start1.wait().await;
580 deduct_from_channel(&*adapter1, "ch1", 7_000).await
581 });
582
583 let adapter2 = adapter.clone();
584 let start2 = start.clone();
585 let task2 = tokio::spawn(async move {
586 start2.wait().await;
587 deduct_from_channel(&*adapter2, "ch1", 7_000).await
588 });
589
590 start.wait().await;
594
595 let result1 = task1.await.unwrap();
596 let result2 = task2.await.unwrap();
597 let successes = [result1.is_ok(), result2.is_ok()]
598 .into_iter()
599 .filter(|ok| *ok)
600 .count();
601 assert_eq!(
602 successes, 1,
603 "the repro must not allow both concurrent deductions to succeed"
604 );
605
606 let error = result1.err().or_else(|| result2.err()).unwrap();
607 assert!(
608 error.to_string().contains("available 3000"),
609 "expected insufficient balance after the first deduction, got: {error}"
610 );
611
612 let stored = adapter.get_channel("ch1").await.unwrap().unwrap();
613 assert_eq!(stored.spent, 7_000);
614 assert_eq!(stored.units, 1);
615 }
616
617 #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
618 async fn channel_store_adapter_serializes_same_channel_update_channel_calls() {
619 let store = Arc::new(SlowMemoryStore::new(Duration::from_millis(25)));
620 let adapter = Arc::new(ChannelStoreAdapter::new(store, "channels:"));
621
622 let state = test_channel_state("ch1");
623 adapter
624 .update_channel("ch1", Box::new(move |_| Ok(Some(state))))
625 .await
626 .unwrap();
627
628 let start = Arc::new(tokio::sync::Barrier::new(3));
629
630 let adapter1 = adapter.clone();
631 let start1 = start.clone();
632 let task1 = tokio::spawn(async move {
633 start1.wait().await;
634 adapter1
635 .update_channel(
636 "ch1",
637 Box::new(|current| {
638 let mut state = current.unwrap();
639 state.spent += 1;
640 state.units += 1;
641 Ok(Some(state))
642 }),
643 )
644 .await
645 });
646
647 let adapter2 = adapter.clone();
648 let start2 = start.clone();
649 let task2 = tokio::spawn(async move {
650 start2.wait().await;
651 adapter2
652 .update_channel(
653 "ch1",
654 Box::new(|current| {
655 let mut state = current.unwrap();
656 state.spent += 1;
657 state.units += 1;
658 Ok(Some(state))
659 }),
660 )
661 .await
662 });
663
664 start.wait().await;
665
666 let result1 = task1.await.unwrap().unwrap().unwrap();
667 let result2 = task2.await.unwrap().unwrap().unwrap();
668 let mut returned_spent = [result1.spent, result2.spent];
669 returned_spent.sort_unstable();
670 assert_eq!(returned_spent, [1, 2]);
671
672 let stored = adapter.get_channel("ch1").await.unwrap().unwrap();
673 assert_eq!(stored.spent, 2);
674 assert_eq!(stored.units, 2);
675 }
676
677 #[tokio::test]
678 async fn channel_store_adapter_evicts_lock_on_channel_delete() {
679 let store = Arc::new(MemoryStore::new());
680 let adapter = ChannelStoreAdapter::new(store, "channels:");
681
682 for id in ["ch1", "ch2", "ch3"] {
684 let state = test_channel_state(id);
685 adapter
686 .update_channel(id, Box::new(move |_| Ok(Some(state))))
687 .await
688 .unwrap();
689 }
690 assert_eq!(adapter.channel_locks.lock().unwrap().len(), 3);
691
692 adapter
694 .update_channel("ch2", Box::new(|_| Ok(None)))
695 .await
696 .unwrap();
697
698 let locks = adapter.channel_locks.lock().unwrap();
700 assert_eq!(locks.len(), 2);
701 assert!(!locks.contains_key("channels:ch2"));
702 assert!(locks.contains_key("channels:ch1"));
703 assert!(locks.contains_key("channels:ch3"));
704 }
705}