1use std::time::Duration;
24
25use crate::grid::protocol::GridMessage;
26
27#[derive(Debug, Clone, Default)]
33pub enum TransportMode {
34 #[default]
36 InMemory,
37 Quic {
39 bind_addr: String,
41 cert_path: String,
43 key_path: String,
45 },
46}
47
48#[derive(Debug, Clone)]
52pub struct ReplicationConfig {
53 pub mode: TransportMode,
55 pub cluster_size: usize,
57 pub enabled: bool,
59 pub quorum_write_timeout: Duration,
63}
64
65impl Default for ReplicationConfig {
66 fn default() -> Self {
67 Self {
68 mode: TransportMode::InMemory,
69 cluster_size: 1,
70 enabled: false,
71 quorum_write_timeout: Duration::from_secs(5),
72 }
73 }
74}
75
76impl ReplicationConfig {
77 pub fn in_memory(cluster_size: usize) -> Self {
79 Self {
80 mode: TransportMode::InMemory,
81 cluster_size,
82 enabled: true,
83 quorum_write_timeout: Duration::from_secs(5),
84 }
85 }
86
87 pub fn quic(
89 bind_addr: impl Into<String>,
90 cert_path: impl Into<String>,
91 key_path: impl Into<String>,
92 cluster_size: usize,
93 ) -> Self {
94 Self {
95 mode: TransportMode::Quic {
96 bind_addr: bind_addr.into(),
97 cert_path: cert_path.into(),
98 key_path: key_path.into(),
99 },
100 cluster_size,
101 enabled: true,
102 quorum_write_timeout: Duration::from_secs(5),
103 }
104 }
105
106 pub fn mode_name(&self) -> &'static str {
108 match &self.mode {
109 TransportMode::InMemory => "InMemory",
110 TransportMode::Quic { .. } => "QUIC (s2n-quic)",
111 }
112 }
113
114 pub async fn build_transport_async(
119 &self,
120 broadcast_tx: tokio::sync::broadcast::Sender<GridMessage>,
121 ) -> Result<Box<dyn Transport>, quic::QuicError> {
122 match &self.mode {
123 TransportMode::InMemory => Ok(Box::new(InMemoryTransport::new(broadcast_tx))),
124 TransportMode::Quic {
125 bind_addr,
126 cert_path,
127 key_path,
128 } => {
129 tracing::info!(
130 "QUIC Transport 초기화: bind={} cert={}",
131 bind_addr,
132 cert_path
133 );
134 let (node, handle) = quic::QuicNode::server(
135 bind_addr,
136 std::path::Path::new(cert_path),
137 std::path::Path::new(key_path),
138 )
139 .await?;
140 tokio::spawn(handle);
142 Ok(Box::new(quic::QuicTransport::new(node)))
143 }
144 }
145 }
146
147 pub fn build_inmemory(
151 broadcast_tx: tokio::sync::broadcast::Sender<GridMessage>,
152 ) -> Box<dyn Transport> {
153 Box::new(InMemoryTransport::new(broadcast_tx))
154 }
155}
156
157pub trait Transport: Send + Sync {
166 fn send(&self, target_node_id: u32, msg: GridMessage);
168
169 fn recv(&self) -> Option<GridMessage>;
171}
172
173pub struct InMemoryTransport {
182 sender: tokio::sync::broadcast::Sender<GridMessage>,
183 receiver: std::sync::Mutex<tokio::sync::broadcast::Receiver<GridMessage>>,
184}
185
186impl InMemoryTransport {
187 pub fn new(sender: tokio::sync::broadcast::Sender<GridMessage>) -> Self {
188 let receiver = sender.subscribe();
189 Self {
190 sender,
191 receiver: std::sync::Mutex::new(receiver),
192 }
193 }
194}
195
196impl Transport for InMemoryTransport {
197 fn send(&self, _target_node_id: u32, msg: GridMessage) {
198 let _ = self.sender.send(msg);
199 }
200
201 fn recv(&self) -> Option<GridMessage> {
202 self.receiver.lock().unwrap().try_recv().ok()
203 }
204}
205
206pub mod quic {
211 use s2n_quic::provider::tls;
234 use s2n_quic::{Client, Server};
235 use std::path::Path;
236 use std::sync::Arc;
237 use tokio::sync::mpsc;
238
239 use crate::grid::protocol::GridMessage;
240
241 const LEN_PREFIX_BYTES: usize = 4;
243
244 #[derive(Debug, thiserror::Error)]
246 pub enum QuicError {
247 #[error("QUIC 연결 실패: {0}")]
248 ConnectionError(String),
249 #[error("직렬화 실패: {0}")]
250 SerializeError(String),
251 #[error("IO 오류: {0}")]
252 IoError(#[from] std::io::Error),
253 }
254
255 pub struct QuicNode {
257 rx: Arc<tokio::sync::Mutex<mpsc::Receiver<GridMessage>>>,
259 tx_out: mpsc::Sender<(GridMessage, String)>,
261 }
262
263 impl QuicNode {
264 pub async fn server(
269 bind_addr: &str,
270 cert_pem: &Path,
271 key_pem: &Path,
272 ) -> Result<(Self, tokio::task::JoinHandle<()>), QuicError> {
273 let (msg_tx, msg_rx) = mpsc::channel::<GridMessage>(256);
274 let (out_tx, _out_rx) = mpsc::channel::<(GridMessage, String)>(256);
275
276 let tls = tls::default::Server::builder()
277 .with_certificate(cert_pem, key_pem)
278 .map_err(|e| QuicError::ConnectionError(e.to_string()))?
279 .build()
280 .map_err(|e| QuicError::ConnectionError(e.to_string()))?;
281
282 let mut server = Server::builder()
283 .with_tls(tls)
284 .map_err(|e| QuicError::ConnectionError(e.to_string()))?
285 .with_io(bind_addr)
286 .map_err(|e| QuicError::ConnectionError(e.to_string()))?
287 .start()
288 .map_err(|e| QuicError::ConnectionError(e.to_string()))?;
289
290 let handle = tokio::spawn(async move {
292 while let Some(mut conn) = server.accept().await {
293 let msg_tx = msg_tx.clone();
294 tokio::spawn(async move {
295 while let Ok(Some(mut stream)) = conn.accept_bidirectional_stream().await {
296 use tokio::io::AsyncReadExt;
298 let mut len_buf = [0u8; LEN_PREFIX_BYTES];
299 if stream.read_exact(&mut len_buf).await.is_err() {
300 break;
301 }
302 let len = u32::from_le_bytes(len_buf) as usize;
303 let mut buf = vec![0u8; len];
304 if stream.read_exact(&mut buf).await.is_err() {
305 break;
306 }
307 if let Ok(msg) = bincode::deserialize::<GridMessage>(&buf) {
308 let _ = msg_tx.send(msg).await;
309 }
310 }
311 });
312 }
313 });
314
315 Ok((
316 Self {
317 rx: Arc::new(tokio::sync::Mutex::new(msg_rx)),
318 tx_out: out_tx,
319 },
320 handle,
321 ))
322 }
323
324 pub async fn client(
329 peer_addr: &str,
330 ca_cert_pem: &Path,
331 ) -> Result<(Self, tokio::task::JoinHandle<()>), QuicError> {
332 let (msg_tx, msg_rx) = mpsc::channel::<GridMessage>(256);
333 let (out_tx, mut out_rx) = mpsc::channel::<(GridMessage, String)>(256);
334
335 let tls = tls::default::Client::builder()
336 .with_certificate(ca_cert_pem)
337 .map_err(|e| QuicError::ConnectionError(e.to_string()))?
338 .build()
339 .map_err(|e| QuicError::ConnectionError(e.to_string()))?;
340
341 let client = Client::builder()
342 .with_tls(tls)
343 .map_err(|e| QuicError::ConnectionError(e.to_string()))?
344 .with_io("0.0.0.0:0") .map_err(|e| QuicError::ConnectionError(e.to_string()))?
346 .start()
347 .map_err(|e| QuicError::ConnectionError(e.to_string()))?;
348
349 let peer = peer_addr.to_string();
350
351 let handle = tokio::spawn(async move {
353 use tokio::io::AsyncWriteExt;
354 let connect =
355 s2n_quic::client::Connect::new(peer.parse::<std::net::SocketAddr>().unwrap())
356 .with_server_name("dbx-node");
357
358 if let Ok(mut conn) = client.connect(connect).await {
359 conn.keep_alive(true).ok();
360 while let Some((msg, _addr)) = out_rx.recv().await {
361 let Ok(bytes) = bincode::serialize(&msg) else {
362 continue;
363 };
364 let len = bytes.len() as u32;
365 if let Ok(mut stream) = conn.open_bidirectional_stream().await {
366 let _ = stream.write_all(&len.to_le_bytes()).await;
367 let _ = stream.write_all(&bytes).await;
368 }
369 }
370 }
371 drop(msg_tx);
373 });
374
375 Ok((
376 Self {
377 rx: Arc::new(tokio::sync::Mutex::new(msg_rx)),
378 tx_out: out_tx,
379 },
380 handle,
381 ))
382 }
383
384 pub async fn send_msg(&self, msg: GridMessage, peer_addr: String) {
386 let _ = self.tx_out.send((msg, peer_addr)).await;
387 }
388
389 pub async fn try_recv(&self) -> Option<GridMessage> {
391 self.rx.lock().await.try_recv().ok()
392 }
393 }
394
395 pub struct QuicTransport {
404 node: Arc<QuicNode>,
405 rt: tokio::runtime::Handle,
407 peer_addr: String,
409 }
410
411 impl QuicTransport {
412 pub fn new(node: QuicNode) -> Self {
413 Self {
414 node: Arc::new(node),
415 rt: tokio::runtime::Handle::current(),
416 peer_addr: String::new(),
417 }
418 }
419
420 pub fn with_peer(node: QuicNode, peer_addr: impl Into<String>) -> Self {
421 Self {
422 node: Arc::new(node),
423 rt: tokio::runtime::Handle::current(),
424 peer_addr: peer_addr.into(),
425 }
426 }
427 }
428
429 impl crate::replication::transport::Transport for QuicTransport {
430 fn send(&self, _target_node_id: u32, msg: GridMessage) {
431 let node = Arc::clone(&self.node);
432 let addr = self.peer_addr.clone();
433 self.rt.spawn(async move {
434 node.send_msg(msg, addr).await;
435 });
436 }
437
438 fn recv(&self) -> Option<GridMessage> {
439 let node = Arc::clone(&self.node);
440 self.rt.block_on(async move { node.try_recv().await })
441 }
442 }
443
444 pub fn generate_self_signed_cert(
452 output_dir: &Path,
453 ) -> Result<(std::path::PathBuf, std::path::PathBuf), QuicError> {
454 let cert_path = output_dir.join("cert.pem");
455 let key_path = output_dir.join("key.pem");
456
457 let status = std::process::Command::new("openssl")
458 .args([
459 "req",
460 "-x509",
461 "-newkey",
462 "rsa:2048",
463 "-keyout",
464 key_path.to_str().unwrap(),
465 "-out",
466 cert_path.to_str().unwrap(),
467 "-days",
468 "365",
469 "-nodes",
470 "-subj",
471 "/CN=dbx-node",
472 ])
473 .status()
474 .map_err(QuicError::IoError)?;
475
476 if !status.success() {
477 return Err(QuicError::ConnectionError(
478 "openssl 실행 실패 (openssl이 설치되어 있어야 함)".to_string(),
479 ));
480 }
481
482 Ok((cert_path, key_path))
483 }
484}
485
486#[cfg(test)]
491mod tests {
492 use super::*;
493
494 #[tokio::test]
495 async fn test_in_memory_transport_roundtrip() {
496 let (tx, _) = tokio::sync::broadcast::channel(16);
497 let t1 = InMemoryTransport::new(tx.clone());
498 let t2 = InMemoryTransport::new(tx.clone());
499
500 let msg = GridMessage::Replication(
501 crate::replication::protocol::ReplicationMessage::Heartbeat {
502 node_id: 1,
503 lsn: 42,
504 },
505 );
506 t1.send(2, msg.clone());
507
508 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
510 let received = t2.recv();
511 assert!(received.is_some());
512 assert_eq!(received.unwrap(), msg);
513 }
514
515 #[test]
516 fn test_transport_trait_object() {
517 let (tx, _) = tokio::sync::broadcast::channel::<GridMessage>(16);
519 let _t: Box<dyn Transport> = Box::new(InMemoryTransport::new(tx));
520 }
521}