Skip to main content

citadel_sync/
memory_transport.rs

1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::mpsc;
3use std::sync::Mutex;
4
5use crate::protocol::SyncMessage;
6use crate::transport::{SyncError, SyncTransport};
7
8/// In-memory transport for testing sync sessions.
9///
10/// Uses `mpsc` channels for bidirectional communication.
11/// Thread-safe (`Send + Sync`) so each side can be shared with a
12/// scoped thread via `&self`.
13///
14/// Create a connected pair with `MemoryTransport::pair()`.
15pub struct MemoryTransport {
16    tx: Mutex<mpsc::Sender<Vec<u8>>>,
17    rx: Mutex<mpsc::Receiver<Vec<u8>>>,
18    closed: AtomicBool,
19}
20
21impl MemoryTransport {
22    /// Create a connected pair of transports.
23    ///
24    /// Messages sent on one side are received by the other.
25    pub fn pair() -> (Self, Self) {
26        let (tx_a, rx_b) = mpsc::channel();
27        let (tx_b, rx_a) = mpsc::channel();
28
29        let a = MemoryTransport {
30            tx: Mutex::new(tx_a),
31            rx: Mutex::new(rx_a),
32            closed: AtomicBool::new(false),
33        };
34        let b = MemoryTransport {
35            tx: Mutex::new(tx_b),
36            rx: Mutex::new(rx_b),
37            closed: AtomicBool::new(false),
38        };
39
40        (a, b)
41    }
42}
43
44impl SyncTransport for MemoryTransport {
45    fn send(&self, msg: &SyncMessage) -> std::result::Result<(), SyncError> {
46        if self.closed.load(Ordering::Relaxed) {
47            return Err(SyncError::Closed);
48        }
49        let data = msg.serialize();
50        let tx = self.tx.lock().unwrap();
51        tx.send(data).map_err(|_| SyncError::Closed)
52    }
53
54    fn recv(&self) -> std::result::Result<SyncMessage, SyncError> {
55        if self.closed.load(Ordering::Relaxed) {
56            return Err(SyncError::Closed);
57        }
58        let rx = self.rx.lock().unwrap();
59        let data = rx.recv().map_err(|_| SyncError::Closed)?;
60        Ok(SyncMessage::deserialize(&data)?)
61    }
62
63    fn close(&self) -> std::result::Result<(), SyncError> {
64        self.closed.store(true, Ordering::Relaxed);
65        Ok(())
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72    use crate::node_id::NodeId;
73    use citadel_core::types::PageId;
74    use citadel_core::MERKLE_HASH_SIZE;
75
76    #[test]
77    fn pair_send_recv() {
78        let (a, b) = MemoryTransport::pair();
79        let msg = SyncMessage::Hello {
80            node_id: NodeId::from_u64(1),
81            root_page: PageId(0),
82            root_hash: [0u8; MERKLE_HASH_SIZE],
83        };
84        a.send(&msg).unwrap();
85        let received = b.recv().unwrap();
86        match received {
87            SyncMessage::Hello { node_id, .. } => {
88                assert_eq!(node_id, NodeId::from_u64(1));
89            }
90            _ => panic!("wrong variant"),
91        }
92    }
93
94    #[test]
95    fn bidirectional() {
96        let (a, b) = MemoryTransport::pair();
97        a.send(&SyncMessage::Done).unwrap();
98        b.send(&SyncMessage::Done).unwrap();
99
100        assert!(matches!(b.recv().unwrap(), SyncMessage::Done));
101        assert!(matches!(a.recv().unwrap(), SyncMessage::Done));
102    }
103
104    #[test]
105    fn ordering_preserved() {
106        let (a, b) = MemoryTransport::pair();
107        for i in 0..10u64 {
108            a.send(&SyncMessage::Hello {
109                node_id: NodeId::from_u64(i),
110                root_page: PageId(0),
111                root_hash: [0u8; MERKLE_HASH_SIZE],
112            })
113            .unwrap();
114        }
115        for i in 0..10u64 {
116            match b.recv().unwrap() {
117                SyncMessage::Hello { node_id, .. } => {
118                    assert_eq!(node_id, NodeId::from_u64(i));
119                }
120                _ => panic!("wrong variant"),
121            }
122        }
123    }
124
125    #[test]
126    fn close_prevents_send() {
127        let (a, _b) = MemoryTransport::pair();
128        a.close().unwrap();
129        let err = a.send(&SyncMessage::Done).unwrap_err();
130        assert!(matches!(err, SyncError::Closed));
131    }
132
133    #[test]
134    fn close_prevents_recv() {
135        let (a, _b) = MemoryTransport::pair();
136        a.close().unwrap();
137        let err = a.recv().unwrap_err();
138        assert!(matches!(err, SyncError::Closed));
139    }
140
141    #[test]
142    fn dropped_sender_causes_recv_error() {
143        let (a, b) = MemoryTransport::pair();
144        drop(a);
145        let err = b.recv().unwrap_err();
146        assert!(matches!(err, SyncError::Closed));
147    }
148}