1use crate::error::DbxResult;
7use std::net::SocketAddr;
8use tokio::sync::mpsc;
9use tracing::info;
10
11pub struct QuicChannel {
13 pub local_addr: SocketAddr,
14 client: s2n_quic::Client,
15}
16
17impl QuicChannel {
18 pub async fn new(
21 local_addr: SocketAddr,
22 cert_pem: &str,
23 key_pem: &str,
24 tx: mpsc::Sender<GridMessageWrapper>,
25 ) -> DbxResult<Self> {
26 info!("Initializing s2n-quic channel on {}", local_addr);
27
28 let tls_builder = s2n_quic::provider::tls::default::Server::builder()
29 .with_certificate(cert_pem, key_pem)
30 .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
31 let tls = tls_builder
32 .build()
33 .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
34
35 let mut server = s2n_quic::Server::builder()
36 .with_tls(tls)
37 .map_err(|e| crate::error::DbxError::Network(e.to_string()))?
38 .with_io(local_addr)
39 .map_err(|e| crate::error::DbxError::Network(e.to_string()))?
40 .start()
41 .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
42
43 let client_tls_builder = s2n_quic::provider::tls::default::Client::builder()
45 .with_certificate(cert_pem)
46 .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
47 let client_tls = client_tls_builder
48 .build()
49 .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
50
51 let client = s2n_quic::Client::builder()
52 .with_tls(client_tls)
53 .map_err(|e| crate::error::DbxError::Network(e.to_string()))?
54 .with_io("0.0.0.0:0")
55 .map_err(|e| crate::error::DbxError::Network(e.to_string()))?
56 .start()
57 .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
58
59 tokio::spawn(async move {
61 while let Some(mut connection) = server.accept().await {
62 let tx_clone = tx.clone();
63 tokio::spawn(async move {
64 while let Ok(Some(mut stream)) = connection.accept_bidirectional_stream().await
65 {
66 let tx2 = tx_clone.clone();
67 tokio::spawn(async move {
68 use tokio::io::AsyncReadExt;
69 let mut len_buf = [0u8; 4];
70 if stream.read_exact(&mut len_buf).await.is_ok() {
71 let len = u32::from_be_bytes(len_buf) as usize;
72 let mut data_buf = vec![0u8; len];
73 if stream.read_exact(&mut data_buf).await.is_ok()
74 && let Ok(msg) =
75 crate::grid::protocol::GridMessage::deserialize(&data_buf)
76 {
77 let _ = tx2
78 .send(GridMessageWrapper {
79 msg,
80 stream: Some(stream),
81 })
82 .await;
83 }
84 }
85 });
86 }
87 });
88 }
89 });
90
91 Ok(Self { local_addr, client })
92 }
93
94 pub async fn send_message(
96 &self,
97 peer_addr: SocketAddr,
98 msg: crate::grid::protocol::GridMessage,
99 ) -> DbxResult<()> {
100 info!("Sending GridMessage to {}", peer_addr);
101
102 let connect_config =
104 s2n_quic::client::Connect::new(peer_addr).with_server_name("localhost");
105 let mut connection = match self.client.connect(connect_config).await {
106 Ok(c) => c,
107 Err(e) => return Err(crate::error::DbxError::Network(e.to_string())),
108 };
109
110 match connection.keep_alive(true) {
111 Ok(_) => {}
112 Err(e) => return Err(crate::error::DbxError::Network(e.to_string())),
113 }
114
115 let mut stream = match connection.open_bidirectional_stream().await {
116 Ok(s) => s,
117 Err(e) => return Err(crate::error::DbxError::Network(e.to_string())),
118 };
119
120 let bytes = msg.serialize()?;
121 let len = (bytes.len() as u32).to_be_bytes();
122
123 use tokio::io::AsyncWriteExt;
124 if let Err(e) = stream.write_all(&len).await {
125 return Err(crate::error::DbxError::Network(e.to_string()));
126 }
127 if let Err(e) = stream.write_all(&bytes).await {
128 return Err(crate::error::DbxError::Network(e.to_string()));
129 }
130
131 if let Err(e) = stream.flush().await {
133 return Err(crate::error::DbxError::Network(e.to_string()));
134 }
135 if let Err(e) = stream.shutdown().await {
136 return Err(crate::error::DbxError::Network(e.to_string()));
137 }
138
139 Ok(())
140 }
141
142 pub async fn send_request_and_wait(
144 &self,
145 peer_addr: SocketAddr,
146 msg: crate::grid::protocol::GridMessage,
147 ) -> DbxResult<crate::grid::protocol::GridMessage> {
148 let connect_config =
149 s2n_quic::client::Connect::new(peer_addr).with_server_name("localhost");
150 let mut connection = match self.client.connect(connect_config).await {
151 Ok(c) => c,
152 Err(e) => return Err(crate::error::DbxError::Network(e.to_string())),
153 };
154
155 match connection.keep_alive(true) {
156 Ok(_) => {}
157 Err(e) => return Err(crate::error::DbxError::Network(e.to_string())),
158 }
159
160 let mut stream = match connection.open_bidirectional_stream().await {
161 Ok(s) => s,
162 Err(e) => return Err(crate::error::DbxError::Network(e.to_string())),
163 };
164
165 let bytes = msg.serialize()?;
166 let len = (bytes.len() as u32).to_be_bytes();
167
168 match stream.send(bytes::Bytes::copy_from_slice(&len)).await {
169 Ok(_) => {}
170 Err(e) => return Err(crate::error::DbxError::Network(e.to_string())),
171 }
172 match stream.send(bytes::Bytes::from(bytes)).await {
173 Ok(_) => {}
174 Err(e) => return Err(crate::error::DbxError::Network(e.to_string())),
175 }
176
177 use tokio::io::AsyncReadExt;
179 let mut len_buf = [0u8; 4];
180 stream
181 .read_exact(&mut len_buf)
182 .await
183 .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
184
185 let reply_len = u32::from_be_bytes(len_buf) as usize;
186 let mut reply_buf = vec![0u8; reply_len];
187
188 stream
189 .read_exact(&mut reply_buf)
190 .await
191 .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
192
193 crate::grid::protocol::GridMessage::deserialize(&reply_buf)
194 }
195
196 pub async fn send_response(
198 stream: &mut s2n_quic::stream::BidirectionalStream,
199 msg: crate::grid::protocol::GridMessage,
200 ) -> DbxResult<()> {
201 let bytes = msg.serialize()?;
202 let len = (bytes.len() as u32).to_be_bytes();
203
204 use tokio::io::AsyncWriteExt;
205 stream
206 .write_all(&len)
207 .await
208 .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
209 stream
210 .write_all(&bytes)
211 .await
212 .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
213
214 stream
216 .flush()
217 .await
218 .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
219 stream
220 .shutdown()
221 .await
222 .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
223
224 Ok(())
225 }
226}
227
228pub struct GridMessageWrapper {
230 pub msg: crate::grid::protocol::GridMessage,
231 pub stream: Option<s2n_quic::stream::BidirectionalStream>,
232}
233
234impl GridMessageWrapper {
235 pub async fn send_reply(&mut self, reply: crate::grid::protocol::GridMessage) -> DbxResult<()> {
237 if let Some(stream) = &mut self.stream {
238 let bytes = reply.serialize()?;
239 let len = (bytes.len() as u32).to_be_bytes();
240
241 use tokio::io::AsyncWriteExt;
243 stream
244 .write_all(&len)
245 .await
246 .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
247 stream
248 .write_all(&bytes)
249 .await
250 .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
251
252 stream
253 .flush()
254 .await
255 .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
256 stream
257 .shutdown()
258 .await
259 .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
260
261 Ok(())
262 } else {
263 Err(crate::error::DbxError::Network(
264 "No stream available for reply".into(),
265 ))
266 }
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use crate::grid::protocol::StorageMessage;
274 use rcgen::generate_simple_self_signed;
275
276 #[tokio::test]
277 async fn test_quic_channel_send_and_receive() {
278 let subject_alt_names = vec!["localhost".to_string(), "127.0.0.1".to_string()];
279 let cert = generate_simple_self_signed(subject_alt_names).unwrap();
280 let cert_pem = cert.cert.pem();
281 let key_pem = cert.key_pair.serialize_pem();
282
283 let (tx1, mut _rx1) = mpsc::channel(100);
284 let node1_addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
285 let _channel1 = QuicChannel::new(node1_addr, &cert_pem, &key_pem, tx1)
286 .await
287 .unwrap();
288 let _local_addr1 = _channel1.local_addr; let node1_fixed_addr: SocketAddr = "127.0.0.1:15682".parse().unwrap();
292 let (tx1, mut rx1) = mpsc::channel(100);
293 let _channel1 = QuicChannel::new(node1_fixed_addr, &cert_pem, &key_pem, tx1)
294 .await
295 .unwrap();
296
297 let node2_fixed_addr: SocketAddr = "127.0.0.1:15683".parse().unwrap();
298 let (tx2, _rx2) = mpsc::channel(100);
299 let channel2 = QuicChannel::new(node2_fixed_addr, &cert_pem, &key_pem, tx2)
300 .await
301 .unwrap();
302
303 let test_msg = crate::grid::protocol::GridMessage::Storage(StorageMessage::FetchShard {
304 key: "test_key".to_string(),
305 shard_id: 42,
306 });
307
308 channel2
310 .send_message(node1_fixed_addr, test_msg.clone())
311 .await
312 .unwrap();
313
314 if let Some(wrapper) = rx1.recv().await {
316 match wrapper.msg {
317 crate::grid::protocol::GridMessage::Storage(StorageMessage::FetchShard {
318 key,
319 shard_id,
320 }) => {
321 assert_eq!(key, "test_key");
322 assert_eq!(shard_id, 42);
323 }
324 _ => panic!("Unexpected message type received"),
325 }
326 } else {
327 panic!("No message received");
328 }
329 }
330}