use std::time::Duration;
use crate::grid::protocol::GridMessage;
#[derive(Debug, Clone, Default)]
pub enum TransportMode {
#[default]
InMemory,
Quic {
bind_addr: String,
cert_path: String,
key_path: String,
},
}
#[derive(Debug, Clone)]
pub struct ReplicationConfig {
pub mode: TransportMode,
pub cluster_size: usize,
pub enabled: bool,
pub quorum_write_timeout: Duration,
}
impl Default for ReplicationConfig {
fn default() -> Self {
Self {
mode: TransportMode::InMemory,
cluster_size: 1,
enabled: false,
quorum_write_timeout: Duration::from_secs(5),
}
}
}
impl ReplicationConfig {
pub fn in_memory(cluster_size: usize) -> Self {
Self {
mode: TransportMode::InMemory,
cluster_size,
enabled: true,
quorum_write_timeout: Duration::from_secs(5),
}
}
pub fn quic(
bind_addr: impl Into<String>,
cert_path: impl Into<String>,
key_path: impl Into<String>,
cluster_size: usize,
) -> Self {
Self {
mode: TransportMode::Quic {
bind_addr: bind_addr.into(),
cert_path: cert_path.into(),
key_path: key_path.into(),
},
cluster_size,
enabled: true,
quorum_write_timeout: Duration::from_secs(5),
}
}
pub fn mode_name(&self) -> &'static str {
match &self.mode {
TransportMode::InMemory => "InMemory",
TransportMode::Quic { .. } => "QUIC (s2n-quic)",
}
}
pub async fn build_transport_async(
&self,
broadcast_tx: tokio::sync::broadcast::Sender<GridMessage>,
) -> Result<Box<dyn Transport>, quic::QuicError> {
match &self.mode {
TransportMode::InMemory => Ok(Box::new(InMemoryTransport::new(broadcast_tx))),
TransportMode::Quic {
bind_addr,
cert_path,
key_path,
} => {
tracing::info!(
"QUIC Transport 초기화: bind={} cert={}",
bind_addr,
cert_path
);
let (node, handle) = quic::QuicNode::server(
bind_addr,
std::path::Path::new(cert_path),
std::path::Path::new(key_path),
)
.await?;
tokio::spawn(handle);
Ok(Box::new(quic::QuicTransport::new(node)))
}
}
}
pub fn build_inmemory(
broadcast_tx: tokio::sync::broadcast::Sender<GridMessage>,
) -> Box<dyn Transport> {
Box::new(InMemoryTransport::new(broadcast_tx))
}
}
pub trait Transport: Send + Sync {
fn send(&self, target_node_id: u32, msg: GridMessage);
fn recv(&self) -> Option<GridMessage>;
}
pub struct InMemoryTransport {
sender: tokio::sync::broadcast::Sender<GridMessage>,
receiver: std::sync::Mutex<tokio::sync::broadcast::Receiver<GridMessage>>,
}
impl InMemoryTransport {
pub fn new(sender: tokio::sync::broadcast::Sender<GridMessage>) -> Self {
let receiver = sender.subscribe();
Self {
sender,
receiver: std::sync::Mutex::new(receiver),
}
}
}
impl Transport for InMemoryTransport {
fn send(&self, _target_node_id: u32, msg: GridMessage) {
let _ = self.sender.send(msg);
}
fn recv(&self) -> Option<GridMessage> {
self.receiver.lock().unwrap().try_recv().ok()
}
}
pub mod quic {
use s2n_quic::provider::tls;
use s2n_quic::{Client, Server};
use std::path::Path;
use std::sync::Arc;
use tokio::sync::mpsc;
use crate::grid::protocol::GridMessage;
const LEN_PREFIX_BYTES: usize = 4;
#[derive(Debug, thiserror::Error)]
pub enum QuicError {
#[error("QUIC 연결 실패: {0}")]
ConnectionError(String),
#[error("직렬화 실패: {0}")]
SerializeError(String),
#[error("IO 오류: {0}")]
IoError(#[from] std::io::Error),
}
pub struct QuicNode {
rx: Arc<tokio::sync::Mutex<mpsc::Receiver<GridMessage>>>,
tx_out: mpsc::Sender<(GridMessage, String)>,
}
impl QuicNode {
pub async fn server(
bind_addr: &str,
cert_pem: &Path,
key_pem: &Path,
) -> Result<(Self, tokio::task::JoinHandle<()>), QuicError> {
let (msg_tx, msg_rx) = mpsc::channel::<GridMessage>(256);
let (out_tx, _out_rx) = mpsc::channel::<(GridMessage, String)>(256);
let tls = tls::default::Server::builder()
.with_certificate(cert_pem, key_pem)
.map_err(|e| QuicError::ConnectionError(e.to_string()))?
.build()
.map_err(|e| QuicError::ConnectionError(e.to_string()))?;
let mut server = Server::builder()
.with_tls(tls)
.map_err(|e| QuicError::ConnectionError(e.to_string()))?
.with_io(bind_addr)
.map_err(|e| QuicError::ConnectionError(e.to_string()))?
.start()
.map_err(|e| QuicError::ConnectionError(e.to_string()))?;
let handle = tokio::spawn(async move {
while let Some(mut conn) = server.accept().await {
let msg_tx = msg_tx.clone();
tokio::spawn(async move {
while let Ok(Some(mut stream)) = conn.accept_bidirectional_stream().await {
use tokio::io::AsyncReadExt;
let mut len_buf = [0u8; LEN_PREFIX_BYTES];
if stream.read_exact(&mut len_buf).await.is_err() {
break;
}
let len = u32::from_le_bytes(len_buf) as usize;
let mut buf = vec![0u8; len];
if stream.read_exact(&mut buf).await.is_err() {
break;
}
if let Ok(msg) = bincode::deserialize::<GridMessage>(&buf) {
let _ = msg_tx.send(msg).await;
}
}
});
}
});
Ok((
Self {
rx: Arc::new(tokio::sync::Mutex::new(msg_rx)),
tx_out: out_tx,
},
handle,
))
}
pub async fn client(
peer_addr: &str,
ca_cert_pem: &Path,
) -> Result<(Self, tokio::task::JoinHandle<()>), QuicError> {
let (msg_tx, msg_rx) = mpsc::channel::<GridMessage>(256);
let (out_tx, mut out_rx) = mpsc::channel::<(GridMessage, String)>(256);
let tls = tls::default::Client::builder()
.with_certificate(ca_cert_pem)
.map_err(|e| QuicError::ConnectionError(e.to_string()))?
.build()
.map_err(|e| QuicError::ConnectionError(e.to_string()))?;
let client = Client::builder()
.with_tls(tls)
.map_err(|e| QuicError::ConnectionError(e.to_string()))?
.with_io("0.0.0.0:0") .map_err(|e| QuicError::ConnectionError(e.to_string()))?
.start()
.map_err(|e| QuicError::ConnectionError(e.to_string()))?;
let peer = peer_addr.to_string();
let handle = tokio::spawn(async move {
use tokio::io::AsyncWriteExt;
let connect =
s2n_quic::client::Connect::new(peer.parse::<std::net::SocketAddr>().unwrap())
.with_server_name("dbx-node");
if let Ok(mut conn) = client.connect(connect).await {
conn.keep_alive(true).ok();
while let Some((msg, _addr)) = out_rx.recv().await {
let Ok(bytes) = bincode::serialize(&msg) else {
continue;
};
let len = bytes.len() as u32;
if let Ok(mut stream) = conn.open_bidirectional_stream().await {
let _ = stream.write_all(&len.to_le_bytes()).await;
let _ = stream.write_all(&bytes).await;
}
}
}
drop(msg_tx);
});
Ok((
Self {
rx: Arc::new(tokio::sync::Mutex::new(msg_rx)),
tx_out: out_tx,
},
handle,
))
}
pub async fn send_msg(&self, msg: GridMessage, peer_addr: String) {
let _ = self.tx_out.send((msg, peer_addr)).await;
}
pub async fn try_recv(&self) -> Option<GridMessage> {
self.rx.lock().await.try_recv().ok()
}
}
pub struct QuicTransport {
node: Arc<QuicNode>,
rt: tokio::runtime::Handle,
peer_addr: String,
}
impl QuicTransport {
pub fn new(node: QuicNode) -> Self {
Self {
node: Arc::new(node),
rt: tokio::runtime::Handle::current(),
peer_addr: String::new(),
}
}
pub fn with_peer(node: QuicNode, peer_addr: impl Into<String>) -> Self {
Self {
node: Arc::new(node),
rt: tokio::runtime::Handle::current(),
peer_addr: peer_addr.into(),
}
}
}
impl crate::replication::transport::Transport for QuicTransport {
fn send(&self, _target_node_id: u32, msg: GridMessage) {
let node = Arc::clone(&self.node);
let addr = self.peer_addr.clone();
self.rt.spawn(async move {
node.send_msg(msg, addr).await;
});
}
fn recv(&self) -> Option<GridMessage> {
let node = Arc::clone(&self.node);
self.rt.block_on(async move { node.try_recv().await })
}
}
pub fn generate_self_signed_cert(
output_dir: &Path,
) -> Result<(std::path::PathBuf, std::path::PathBuf), QuicError> {
let cert_path = output_dir.join("cert.pem");
let key_path = output_dir.join("key.pem");
let status = std::process::Command::new("openssl")
.args([
"req",
"-x509",
"-newkey",
"rsa:2048",
"-keyout",
key_path.to_str().unwrap(),
"-out",
cert_path.to_str().unwrap(),
"-days",
"365",
"-nodes",
"-subj",
"/CN=dbx-node",
])
.status()
.map_err(QuicError::IoError)?;
if !status.success() {
return Err(QuicError::ConnectionError(
"openssl 실행 실패 (openssl이 설치되어 있어야 함)".to_string(),
));
}
Ok((cert_path, key_path))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_in_memory_transport_roundtrip() {
let (tx, _) = tokio::sync::broadcast::channel(16);
let t1 = InMemoryTransport::new(tx.clone());
let t2 = InMemoryTransport::new(tx.clone());
let msg = GridMessage::Replication(
crate::replication::protocol::ReplicationMessage::Heartbeat {
node_id: 1,
lsn: 42,
},
);
t1.send(2, msg.clone());
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
let received = t2.recv();
assert!(received.is_some());
assert_eq!(received.unwrap(), msg);
}
#[test]
fn test_transport_trait_object() {
let (tx, _) = tokio::sync::broadcast::channel::<GridMessage>(16);
let _t: Box<dyn Transport> = Box::new(InMemoryTransport::new(tx));
}
}