Skip to main content

dbx_core/grid/
quic.rs

1//! Grid Network (s2n-quic) 인터페이스 (Stub/Foundation)
2//!
3//! Node 간 고속 데이터 전송 및 Erasure Coding 패리티 블록 분산 목적으로
4//! s2n-quic 기반의 P2P 통신 스터브를 제공합니다.
5
6use crate::error::DbxResult;
7use std::net::SocketAddr;
8use tokio::sync::mpsc;
9use tracing::info;
10
11/// QUIC 통신을 담당하는 채널 인스턴스
12pub struct QuicChannel {
13    pub local_addr: SocketAddr,
14    client: s2n_quic::Client,
15}
16
17impl QuicChannel {
18    /// 새로운 QUIC 채널 개설
19    /// TODO: 상용화 단계에서는 cert_pem과 key_pem 경로를 Config에서 주입받아야 합니다.
20    pub async fn new(
21        local_addr: SocketAddr,
22        cert_pem: &str,
23        key_pem: &str,
24        tx: mpsc::Sender<GridMessageWrapper>,
25    ) -> DbxResult<Self> {
26        info!("Initializing s2n-quic channel on {}", local_addr);
27
28        let tls_builder = s2n_quic::provider::tls::default::Server::builder()
29            .with_certificate(cert_pem, key_pem)
30            .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
31        let tls = tls_builder
32            .build()
33            .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
34
35        let mut server = s2n_quic::Server::builder()
36            .with_tls(tls)
37            .map_err(|e| crate::error::DbxError::Network(e.to_string()))?
38            .with_io(local_addr)
39            .map_err(|e| crate::error::DbxError::Network(e.to_string()))?
40            .start()
41            .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
42
43        // 인증 기관 체크 무시 (테스트 및 동일 그리드 허용) - 실무에선 인증서 검증 필수
44        let client_tls_builder = s2n_quic::provider::tls::default::Client::builder()
45            .with_certificate(cert_pem)
46            .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
47        let client_tls = client_tls_builder
48            .build()
49            .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
50
51        let client = s2n_quic::Client::builder()
52            .with_tls(client_tls)
53            .map_err(|e| crate::error::DbxError::Network(e.to_string()))?
54            .with_io("0.0.0.0:0")
55            .map_err(|e| crate::error::DbxError::Network(e.to_string()))?
56            .start()
57            .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
58
59        // 백그라운드 서버 수신 루프
60        tokio::spawn(async move {
61            while let Some(mut connection) = server.accept().await {
62                let tx_clone = tx.clone();
63                tokio::spawn(async move {
64                    while let Ok(Some(mut stream)) = connection.accept_bidirectional_stream().await
65                    {
66                        let tx2 = tx_clone.clone();
67                        tokio::spawn(async move {
68                            use tokio::io::AsyncReadExt;
69                            let mut len_buf = [0u8; 4];
70                            if stream.read_exact(&mut len_buf).await.is_ok() {
71                                let len = u32::from_be_bytes(len_buf) as usize;
72                                let mut data_buf = vec![0u8; len];
73                                if stream.read_exact(&mut data_buf).await.is_ok()
74                                    && let Ok(msg) =
75                                        crate::grid::protocol::GridMessage::deserialize(&data_buf)
76                                {
77                                    let _ = tx2
78                                        .send(GridMessageWrapper {
79                                            msg,
80                                            stream: Some(stream),
81                                        })
82                                        .await;
83                                }
84                            }
85                        });
86                    }
87                });
88            }
89        });
90
91        Ok(Self { local_addr, client })
92    }
93
94    /// 다른 Grid 노드에 데이터 단방향 전송 (방치)
95    pub async fn send_message(
96        &self,
97        peer_addr: SocketAddr,
98        msg: crate::grid::protocol::GridMessage,
99    ) -> DbxResult<()> {
100        info!("Sending GridMessage to {}", peer_addr);
101
102        // Connect
103        let connect_config =
104            s2n_quic::client::Connect::new(peer_addr).with_server_name("localhost");
105        let mut connection = match self.client.connect(connect_config).await {
106            Ok(c) => c,
107            Err(e) => return Err(crate::error::DbxError::Network(e.to_string())),
108        };
109
110        match connection.keep_alive(true) {
111            Ok(_) => {}
112            Err(e) => return Err(crate::error::DbxError::Network(e.to_string())),
113        }
114
115        let mut stream = match connection.open_bidirectional_stream().await {
116            Ok(s) => s,
117            Err(e) => return Err(crate::error::DbxError::Network(e.to_string())),
118        };
119
120        let bytes = msg.serialize()?;
121        let len = (bytes.len() as u32).to_be_bytes();
122
123        use tokio::io::AsyncWriteExt;
124        if let Err(e) = stream.write_all(&len).await {
125            return Err(crate::error::DbxError::Network(e.to_string()));
126        }
127        if let Err(e) = stream.write_all(&bytes).await {
128            return Err(crate::error::DbxError::Network(e.to_string()));
129        }
130
131        // 데이터 전송 완료 후 플러시 및 스트림 정상 종료(FIN) 대기
132        if let Err(e) = stream.flush().await {
133            return Err(crate::error::DbxError::Network(e.to_string()));
134        }
135        if let Err(e) = stream.shutdown().await {
136            return Err(crate::error::DbxError::Network(e.to_string()));
137        }
138
139        Ok(())
140    }
141
142    /// Request-Response 방식 비동기 통신 (FetchShard 등)
143    pub async fn send_request_and_wait(
144        &self,
145        peer_addr: SocketAddr,
146        msg: crate::grid::protocol::GridMessage,
147    ) -> DbxResult<crate::grid::protocol::GridMessage> {
148        let connect_config =
149            s2n_quic::client::Connect::new(peer_addr).with_server_name("localhost");
150        let mut connection = match self.client.connect(connect_config).await {
151            Ok(c) => c,
152            Err(e) => return Err(crate::error::DbxError::Network(e.to_string())),
153        };
154
155        match connection.keep_alive(true) {
156            Ok(_) => {}
157            Err(e) => return Err(crate::error::DbxError::Network(e.to_string())),
158        }
159
160        let mut stream = match connection.open_bidirectional_stream().await {
161            Ok(s) => s,
162            Err(e) => return Err(crate::error::DbxError::Network(e.to_string())),
163        };
164
165        let bytes = msg.serialize()?;
166        let len = (bytes.len() as u32).to_be_bytes();
167
168        match stream.send(bytes::Bytes::copy_from_slice(&len)).await {
169            Ok(_) => {}
170            Err(e) => return Err(crate::error::DbxError::Network(e.to_string())),
171        }
172        match stream.send(bytes::Bytes::from(bytes)).await {
173            Ok(_) => {}
174            Err(e) => return Err(crate::error::DbxError::Network(e.to_string())),
175        }
176
177        // Receive Length & Data using AsyncReadExt
178        use tokio::io::AsyncReadExt;
179        let mut len_buf = [0u8; 4];
180        stream
181            .read_exact(&mut len_buf)
182            .await
183            .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
184
185        let reply_len = u32::from_be_bytes(len_buf) as usize;
186        let mut reply_buf = vec![0u8; reply_len];
187
188        stream
189            .read_exact(&mut reply_buf)
190            .await
191            .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
192
193        crate::grid::protocol::GridMessage::deserialize(&reply_buf)
194    }
195
196    /// 정적 메서드로 스트림에 메시지 응답 전송
197    pub async fn send_response(
198        stream: &mut s2n_quic::stream::BidirectionalStream,
199        msg: crate::grid::protocol::GridMessage,
200    ) -> DbxResult<()> {
201        let bytes = msg.serialize()?;
202        let len = (bytes.len() as u32).to_be_bytes();
203
204        use tokio::io::AsyncWriteExt;
205        stream
206            .write_all(&len)
207            .await
208            .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
209        stream
210            .write_all(&bytes)
211            .await
212            .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
213
214        // 데이터 전송 완료 후 플러시 및 스트림 정상 종료(FIN) 대기
215        stream
216            .flush()
217            .await
218            .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
219        stream
220            .shutdown()
221            .await
222            .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
223
224        Ok(())
225    }
226}
227
228/// 스트림의 소유권을 유지하며 채널로 메시지를 보내기 위한 래퍼
229pub struct GridMessageWrapper {
230    pub msg: crate::grid::protocol::GridMessage,
231    pub stream: Option<s2n_quic::stream::BidirectionalStream>,
232}
233
234impl GridMessageWrapper {
235    /// 요청을 보낸 스트림을 통해 응답을 전송합니다.
236    pub async fn send_reply(&mut self, reply: crate::grid::protocol::GridMessage) -> DbxResult<()> {
237        if let Some(stream) = &mut self.stream {
238            let bytes = reply.serialize()?;
239            let len = (bytes.len() as u32).to_be_bytes();
240
241            // s2n_quic::stream::BidirectionalStream implements tokio::io::AsyncWrite
242            use tokio::io::AsyncWriteExt;
243            stream
244                .write_all(&len)
245                .await
246                .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
247            stream
248                .write_all(&bytes)
249                .await
250                .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
251
252            stream
253                .flush()
254                .await
255                .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
256            stream
257                .shutdown()
258                .await
259                .map_err(|e| crate::error::DbxError::Network(e.to_string()))?;
260
261            Ok(())
262        } else {
263            Err(crate::error::DbxError::Network(
264                "No stream available for reply".into(),
265            ))
266        }
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273    use crate::grid::protocol::StorageMessage;
274    use rcgen::generate_simple_self_signed;
275
276    #[tokio::test]
277    async fn test_quic_channel_send_and_receive() {
278        let subject_alt_names = vec!["localhost".to_string(), "127.0.0.1".to_string()];
279        let cert = generate_simple_self_signed(subject_alt_names).unwrap();
280        let cert_pem = cert.cert.pem();
281        let key_pem = cert.key_pair.serialize_pem();
282
283        let (tx1, mut rx1) = mpsc::channel(100);
284        let node1_addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
285        let channel1 = QuicChannel::new(node1_addr, &cert_pem, &key_pem, tx1)
286            .await
287            .unwrap();
288        let local_addr1 = channel1.local_addr; // need actual bound addr but binding to 0 may not expose it on `channel1.local_addr` immediately. Wait, QuicChannel::new receives 127.0.0.1:0 but doesn't update it to the bound port... So we need a fixed port or get the bound port.
289
290        // For testing we will use fixed ports
291        let node1_fixed_addr: SocketAddr = "127.0.0.1:15682".parse().unwrap();
292        let (tx1, mut rx1) = mpsc::channel(100);
293        let channel1 = QuicChannel::new(node1_fixed_addr, &cert_pem, &key_pem, tx1)
294            .await
295            .unwrap();
296
297        let node2_fixed_addr: SocketAddr = "127.0.0.1:15683".parse().unwrap();
298        let (tx2, _rx2) = mpsc::channel(100);
299        let channel2 = QuicChannel::new(node2_fixed_addr, &cert_pem, &key_pem, tx2)
300            .await
301            .unwrap();
302
303        let test_msg = crate::grid::protocol::GridMessage::Storage(StorageMessage::FetchShard {
304            key: "test_key".to_string(),
305            shard_id: 42,
306        });
307
308        // node2 sends msg to node1
309        channel2
310            .send_message(node1_fixed_addr, test_msg.clone())
311            .await
312            .unwrap();
313
314        // node1 receives msg
315        if let Some(wrapper) = rx1.recv().await {
316            match wrapper.msg {
317                crate::grid::protocol::GridMessage::Storage(StorageMessage::FetchShard {
318                    key,
319                    shard_id,
320                }) => {
321                    assert_eq!(key, "test_key");
322                    assert_eq!(shard_id, 42);
323                }
324                _ => panic!("Unexpected message type received"),
325            }
326        } else {
327            panic!("No message received");
328        }
329    }
330}