citadel_sync/
memory_transport.rs1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::mpsc;
3use std::sync::Mutex;
4
5use crate::protocol::SyncMessage;
6use crate::transport::{SyncError, SyncTransport};
7
8pub struct MemoryTransport {
16 tx: Mutex<mpsc::Sender<Vec<u8>>>,
17 rx: Mutex<mpsc::Receiver<Vec<u8>>>,
18 closed: AtomicBool,
19}
20
21impl MemoryTransport {
22 pub fn pair() -> (Self, Self) {
26 let (tx_a, rx_b) = mpsc::channel();
27 let (tx_b, rx_a) = mpsc::channel();
28
29 let a = MemoryTransport {
30 tx: Mutex::new(tx_a),
31 rx: Mutex::new(rx_a),
32 closed: AtomicBool::new(false),
33 };
34 let b = MemoryTransport {
35 tx: Mutex::new(tx_b),
36 rx: Mutex::new(rx_b),
37 closed: AtomicBool::new(false),
38 };
39
40 (a, b)
41 }
42}
43
44impl SyncTransport for MemoryTransport {
45 fn send(&self, msg: &SyncMessage) -> std::result::Result<(), SyncError> {
46 if self.closed.load(Ordering::Relaxed) {
47 return Err(SyncError::Closed);
48 }
49 let data = msg.serialize();
50 let tx = self.tx.lock().unwrap();
51 tx.send(data).map_err(|_| SyncError::Closed)
52 }
53
54 fn recv(&self) -> std::result::Result<SyncMessage, SyncError> {
55 if self.closed.load(Ordering::Relaxed) {
56 return Err(SyncError::Closed);
57 }
58 let rx = self.rx.lock().unwrap();
59 let data = rx.recv().map_err(|_| SyncError::Closed)?;
60 Ok(SyncMessage::deserialize(&data)?)
61 }
62
63 fn close(&self) -> std::result::Result<(), SyncError> {
64 self.closed.store(true, Ordering::Relaxed);
65 Ok(())
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use crate::node_id::NodeId;
73 use citadel_core::types::PageId;
74 use citadel_core::MERKLE_HASH_SIZE;
75
76 #[test]
77 fn pair_send_recv() {
78 let (a, b) = MemoryTransport::pair();
79 let msg = SyncMessage::Hello {
80 node_id: NodeId::from_u64(1),
81 root_page: PageId(0),
82 root_hash: [0u8; MERKLE_HASH_SIZE],
83 };
84 a.send(&msg).unwrap();
85 let received = b.recv().unwrap();
86 match received {
87 SyncMessage::Hello { node_id, .. } => {
88 assert_eq!(node_id, NodeId::from_u64(1));
89 }
90 _ => panic!("wrong variant"),
91 }
92 }
93
94 #[test]
95 fn bidirectional() {
96 let (a, b) = MemoryTransport::pair();
97 a.send(&SyncMessage::Done).unwrap();
98 b.send(&SyncMessage::Done).unwrap();
99
100 assert!(matches!(b.recv().unwrap(), SyncMessage::Done));
101 assert!(matches!(a.recv().unwrap(), SyncMessage::Done));
102 }
103
104 #[test]
105 fn ordering_preserved() {
106 let (a, b) = MemoryTransport::pair();
107 for i in 0..10u64 {
108 a.send(&SyncMessage::Hello {
109 node_id: NodeId::from_u64(i),
110 root_page: PageId(0),
111 root_hash: [0u8; MERKLE_HASH_SIZE],
112 })
113 .unwrap();
114 }
115 for i in 0..10u64 {
116 match b.recv().unwrap() {
117 SyncMessage::Hello { node_id, .. } => {
118 assert_eq!(node_id, NodeId::from_u64(i));
119 }
120 _ => panic!("wrong variant"),
121 }
122 }
123 }
124
125 #[test]
126 fn close_prevents_send() {
127 let (a, _b) = MemoryTransport::pair();
128 a.close().unwrap();
129 let err = a.send(&SyncMessage::Done).unwrap_err();
130 assert!(matches!(err, SyncError::Closed));
131 }
132
133 #[test]
134 fn close_prevents_recv() {
135 let (a, _b) = MemoryTransport::pair();
136 a.close().unwrap();
137 let err = a.recv().unwrap_err();
138 assert!(matches!(err, SyncError::Closed));
139 }
140
141 #[test]
142 fn dropped_sender_causes_recv_error() {
143 let (a, b) = MemoryTransport::pair();
144 drop(a);
145 let err = b.recv().unwrap_err();
146 assert!(matches!(err, SyncError::Closed));
147 }
148}