1use super::protocol::{
7 ErrorCode, ErrorMessage, Heartbeat, HelloMessage, ReplicationMessage, SyncResponse,
8 PROTOCOL_VERSION,
9};
10use super::tls::{ReplicationTlsConfig, ReplicationTlsStream};
11use crate::state::integrity::{DecisionRecord, StateIntegrity};
12use std::collections::VecDeque;
13use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
14use std::sync::{Arc, RwLock};
15use std::time::Duration;
16use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
17use tokio::net::{TcpListener, TcpStream};
18use tokio::sync::{broadcast, mpsc, watch};
19use tokio_rustls::TlsAcceptor;
20use tracing::{debug, error, info, warn};
21
22#[derive(Debug, Clone)]
24pub struct ReplicatorConfig {
25 pub listen_addr: String,
27
28 pub heartbeat_interval_ms: u64,
30
31 pub max_buffer_size: usize,
33
34 pub node_id: String,
36
37 pub tls_config: Option<ReplicationTlsConfig>,
39}
40
41impl Default for ReplicatorConfig {
42 fn default() -> Self {
43 Self {
44 listen_addr: "127.0.0.1:26660".to_string(),
45 heartbeat_interval_ms: 1000,
46 max_buffer_size: 10000,
47 node_id: "primary".to_string(),
48 tls_config: None,
49 }
50 }
51}
52
53pub struct StateReplicator {
55 config: ReplicatorConfig,
56
57 fencing_token: AtomicU64,
59
60 decision_rx: mpsc::Receiver<DecisionRecord>,
62
63 broadcast_tx: broadcast::Sender<ReplicationMessage>,
65
66 record_buffer: Arc<RwLock<VecDeque<DecisionRecord>>>,
68
69 integrity: Arc<RwLock<StateIntegrity>>,
71
72 _shutdown: AtomicBool,
74}
75
76pub struct ReplicatorHandle {
78 decision_tx: mpsc::Sender<DecisionRecord>,
80
81 shutdown_tx: watch::Sender<bool>,
83
84 integrity: Arc<RwLock<StateIntegrity>>,
86}
87
88impl ReplicatorHandle {
89 pub async fn replicate(&self, record: DecisionRecord) -> Result<(), ReplicationError> {
91 self.decision_tx
92 .send(record)
93 .await
94 .map_err(|_| ReplicationError::ChannelClosed)
95 }
96
97 pub fn update_integrity(&self, integrity: StateIntegrity) {
99 if let Ok(mut guard) = self.integrity.write() {
100 *guard = integrity;
101 }
102 }
103
104 pub fn shutdown(self) {
106 let _ = self.shutdown_tx.send(true);
107 }
108}
109
110#[derive(Debug, thiserror::Error)]
112pub enum ReplicationError {
113 #[error("Channel closed")]
114 ChannelClosed,
115
116 #[error("IO error: {0}")]
117 Io(#[from] std::io::Error),
118
119 #[error("Serialization error: {0}")]
120 Serialization(String),
121
122 #[error("Protocol error: {0}")]
123 Protocol(String),
124
125 #[error("Fencing token rejected")]
126 FencingRejected,
127}
128
129impl StateReplicator {
130 pub fn new(
132 config: ReplicatorConfig,
133 initial_integrity: StateIntegrity,
134 ) -> (Self, ReplicatorHandle) {
135 let (decision_tx, decision_rx) = mpsc::channel(1000);
136 let (broadcast_tx, _) = broadcast::channel(1000);
137 let (shutdown_tx, _shutdown_rx) = watch::channel(false);
138
139 let integrity = Arc::new(RwLock::new(initial_integrity));
140
141 let replicator = Self {
142 config: config.clone(),
143 fencing_token: AtomicU64::new(1),
144 decision_rx,
145 broadcast_tx,
146 record_buffer: Arc::new(RwLock::new(VecDeque::with_capacity(config.max_buffer_size))),
147 integrity: integrity.clone(),
148 _shutdown: AtomicBool::new(false),
149 };
150
151 let handle = ReplicatorHandle {
152 decision_tx,
153 shutdown_tx,
154 integrity,
155 };
156
157 (replicator, handle)
158 }
159
160 pub async fn run(mut self, mut shutdown_rx: watch::Receiver<bool>) -> Result<(), ReplicationError> {
162 let listener = TcpListener::bind(&self.config.listen_addr).await?;
163
164 let tls_acceptor: Option<TlsAcceptor> = if let Some(ref tls_config) = self.config.tls_config {
166 match tls_config.build_acceptor() {
167 Ok(acceptor) => {
168 info!(addr = %self.config.listen_addr, "State replicator listening with mTLS");
169 Some(acceptor)
170 }
171 Err(e) => {
172 error!(error = %e, "Failed to build TLS acceptor");
173 return Err(ReplicationError::Protocol(format!("TLS setup failed: {}", e)));
174 }
175 }
176 } else {
177 info!(addr = %self.config.listen_addr, "State replicator listening (plaintext - NOT RECOMMENDED FOR PRODUCTION)");
178 None
179 };
180
181 let heartbeat_interval = Duration::from_millis(self.config.heartbeat_interval_ms);
182 let mut heartbeat_timer = tokio::time::interval(heartbeat_interval);
183
184 loop {
185 tokio::select! {
186 accept_result = listener.accept() => {
188 match accept_result {
189 Ok((stream, addr)) => {
190 info!(addr = %addr, "Passive node connected");
191 let broadcast_rx = self.broadcast_tx.subscribe();
192 let buffer = self.record_buffer.clone();
193 let integrity = self.integrity.clone();
194 let node_id = self.config.node_id.clone();
195 let _fencing_token = self.fencing_token.load(Ordering::Acquire);
196 let tls_acceptor_clone = tls_acceptor.clone();
197
198 tokio::spawn(async move {
199 let result = if let Some(acceptor) = tls_acceptor_clone {
201 match ReplicationTlsStream::accept(&acceptor, stream).await {
202 Ok(tls_stream) => {
203 info!(addr = %addr, "TLS handshake successful");
204 handle_passive_connection_generic(
205 tls_stream,
206 broadcast_rx,
207 buffer,
208 integrity,
209 node_id,
210 _fencing_token,
211 ).await
212 }
213 Err(e) => {
214 warn!(addr = %addr, error = %e, "TLS handshake failed");
215 return;
216 }
217 }
218 } else {
219 handle_passive_connection_generic(
220 ReplicationTlsStream::plain(stream),
221 broadcast_rx,
222 buffer,
223 integrity,
224 node_id,
225 _fencing_token,
226 ).await
227 };
228
229 if let Err(e) = result {
230 warn!(error = %e, "Passive connection error");
231 }
232 });
233 }
234 Err(e) => {
235 error!(error = %e, "Accept error");
236 }
237 }
238 }
239
240 Some(record) = self.decision_rx.recv() => {
242 self.buffer_record(record.clone());
243
244 let msg = ReplicationMessage::Decision(record);
245 let _ = self.broadcast_tx.send(msg);
246 }
247
248 _ = heartbeat_timer.tick() => {
250 if let Ok(integrity) = self.integrity.read() {
251 let heartbeat = Heartbeat::new(
252 integrity.sequence_number,
253 integrity.current_hash,
254 self.fencing_token.load(Ordering::Acquire),
255 );
256 let msg = ReplicationMessage::Heartbeat(heartbeat);
257 let _ = self.broadcast_tx.send(msg);
258 }
259 }
260
261 _ = shutdown_rx.changed() => {
263 if *shutdown_rx.borrow() {
264 info!("State replicator shutting down");
265 break;
266 }
267 }
268 }
269 }
270
271 Ok(())
272 }
273
274 fn buffer_record(&self, record: DecisionRecord) {
276 if let Ok(mut buffer) = self.record_buffer.write() {
277 if buffer.len() >= self.config.max_buffer_size {
278 buffer.pop_front();
279 }
280 buffer.push_back(record);
281 }
282 }
283
284 pub fn next_fencing_token(&self) -> u64 {
286 self.fencing_token.fetch_add(1, Ordering::AcqRel) + 1
287 }
288}
289
290#[allow(dead_code)]
292async fn handle_passive_connection(
293 stream: TcpStream,
294 broadcast_rx: broadcast::Receiver<ReplicationMessage>,
295 record_buffer: Arc<RwLock<VecDeque<DecisionRecord>>>,
296 integrity: Arc<RwLock<StateIntegrity>>,
297 node_id: String,
298 fencing_token: u64,
299) -> Result<(), ReplicationError> {
300 handle_passive_connection_generic(
301 ReplicationTlsStream::plain(stream),
302 broadcast_rx,
303 record_buffer,
304 integrity,
305 node_id,
306 fencing_token,
307 )
308 .await
309}
310
311async fn handle_passive_connection_generic<S>(
313 mut stream: S,
314 mut broadcast_rx: broadcast::Receiver<ReplicationMessage>,
315 record_buffer: Arc<RwLock<VecDeque<DecisionRecord>>>,
316 integrity: Arc<RwLock<StateIntegrity>>,
317 node_id: String,
318 _fencing_token: u64,
319) -> Result<(), ReplicationError>
320where
321 S: AsyncRead + AsyncWrite + Unpin,
322{
323 let (sequence, state_hash, genesis_root) = {
325 let guard = integrity
326 .read()
327 .map_err(|_| ReplicationError::Protocol("Lock poisoned".to_string()))?;
328 (
329 guard.sequence_number,
330 guard.current_hash,
331 guard.genesis_validators_root,
332 )
333 };
334
335 let hello = HelloMessage {
336 version: PROTOCOL_VERSION,
337 node_id: node_id.clone(),
338 role: "Primary".to_string(),
339 sequence,
340 state_hash,
341 genesis_root,
342 };
343
344 send_message_generic(&mut stream, &ReplicationMessage::Hello(hello)).await?;
345
346 let passive_hello = match read_message_generic(&mut stream).await? {
348 Some(ReplicationMessage::Hello(h)) => h,
349 Some(other) => {
350 return Err(ReplicationError::Protocol(format!(
351 "Expected Hello, got {:?}",
352 std::mem::discriminant(&other)
353 )));
354 }
355 None => {
356 return Err(ReplicationError::Protocol(
357 "Connection closed during handshake".to_string(),
358 ));
359 }
360 };
361
362 if passive_hello.version != PROTOCOL_VERSION {
364 send_message_generic(
365 &mut stream,
366 &ReplicationMessage::Error(ErrorMessage {
367 code: ErrorCode::VersionMismatch,
368 description: format!(
369 "Expected version {}, got {}",
370 PROTOCOL_VERSION, passive_hello.version
371 ),
372 sequence: None,
373 }),
374 )
375 .await?;
376 return Err(ReplicationError::Protocol("Version mismatch".to_string()));
377 }
378
379 info!(
380 passive_id = %passive_hello.node_id,
381 passive_sequence = passive_hello.sequence,
382 "Passive node handshake complete"
383 );
384
385 if passive_hello.sequence < sequence {
387 handle_sync_generic(&mut stream, &record_buffer, passive_hello.sequence, sequence).await?;
388 }
389
390 loop {
392 tokio::select! {
393 msg_result = broadcast_rx.recv() => {
394 match msg_result {
395 Ok(msg) => {
396 if let Err(e) = send_message_generic(&mut stream, &msg).await {
397 warn!(error = %e, "Failed to send to passive");
398 break;
399 }
400 }
401 Err(broadcast::error::RecvError::Lagged(n)) => {
402 warn!(lagged = n, "Passive fell behind, may need resync");
403 }
404 Err(broadcast::error::RecvError::Closed) => {
405 break;
406 }
407 }
408 }
409
410 read_result = read_message_generic(&mut stream) => {
412 match read_result {
413 Ok(Some(ReplicationMessage::Ack(ack))) => {
414 debug!(sequence = ack.sequence, "Received ACK from passive");
415 }
416 Ok(Some(ReplicationMessage::SyncRequest(req))) => {
417 let current_seq = integrity.read()
418 .map(|i| i.sequence_number)
419 .unwrap_or(0);
420 handle_sync_generic(&mut stream, &record_buffer, req.from_sequence, current_seq).await?;
421 }
422 Ok(Some(_)) => {
423 }
425 Ok(None) => {
426 info!("Passive disconnected");
427 break;
428 }
429 Err(e) => {
430 warn!(error = %e, "Error reading from passive");
431 break;
432 }
433 }
434 }
435 }
436 }
437
438 Ok(())
439}
440
441#[allow(dead_code)]
443async fn handle_sync(
444 stream: &mut TcpStream,
445 record_buffer: &Arc<RwLock<VecDeque<DecisionRecord>>>,
446 from_sequence: u64,
447 current_sequence: u64,
448) -> Result<(), ReplicationError> {
449 handle_sync_generic(stream, record_buffer, from_sequence, current_sequence).await
450}
451
452async fn handle_sync_generic<S>(
454 stream: &mut S,
455 record_buffer: &Arc<RwLock<VecDeque<DecisionRecord>>>,
456 from_sequence: u64,
457 current_sequence: u64,
458) -> Result<(), ReplicationError>
459where
460 S: AsyncRead + AsyncWrite + Unpin,
461{
462 let records = {
463 let buffer = record_buffer
464 .read()
465 .map_err(|_| ReplicationError::Protocol("Lock poisoned".to_string()))?;
466
467 buffer
468 .iter()
469 .filter(|r| r.sequence > from_sequence && r.sequence <= current_sequence)
470 .cloned()
471 .collect::<Vec<_>>()
472 };
473
474 let response = SyncResponse {
475 records,
476 has_more: false,
477 current_sequence,
478 };
479
480 send_message_generic(stream, &ReplicationMessage::SyncResponse(response)).await
481}
482
483#[allow(dead_code)]
485async fn send_message(
486 stream: &mut TcpStream,
487 msg: &ReplicationMessage,
488) -> Result<(), ReplicationError> {
489 send_message_generic(stream, msg).await
490}
491
492async fn send_message_generic<S>(
494 stream: &mut S,
495 msg: &ReplicationMessage,
496) -> Result<(), ReplicationError>
497where
498 S: AsyncWrite + Unpin,
499{
500 let bytes =
501 serde_json::to_vec(msg).map_err(|e| ReplicationError::Serialization(e.to_string()))?;
502
503 let len = bytes.len() as u32;
504 stream.write_all(&len.to_be_bytes()).await?;
505 stream.write_all(&bytes).await?;
506 stream.flush().await?;
507
508 Ok(())
509}
510
511#[allow(dead_code)]
513async fn read_message(
514 stream: &mut TcpStream,
515) -> Result<Option<ReplicationMessage>, ReplicationError> {
516 read_message_generic(stream).await
517}
518
519async fn read_message_generic<S>(
521 stream: &mut S,
522) -> Result<Option<ReplicationMessage>, ReplicationError>
523where
524 S: AsyncRead + Unpin,
525{
526 let mut len_buf = [0u8; 4];
527 match stream.read_exact(&mut len_buf).await {
528 Ok(_) => {}
529 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
530 return Ok(None);
531 }
532 Err(e) => return Err(e.into()),
533 }
534
535 let len = u32::from_be_bytes(len_buf) as usize;
536 if len > super::protocol::MAX_MESSAGE_SIZE {
537 return Err(ReplicationError::Protocol(format!(
538 "Message too large: {} bytes",
539 len
540 )));
541 }
542
543 let mut buf = vec![0u8; len];
544 stream.read_exact(&mut buf).await?;
545
546 let msg = serde_json::from_slice(&buf)
547 .map_err(|e| ReplicationError::Serialization(e.to_string()))?;
548
549 Ok(Some(msg))
550}
551
552#[cfg(test)]
553mod tests {
554 use super::*;
555
556 #[test]
557 fn test_replicator_creation() {
558 let integrity = StateIntegrity::new();
559 let config = ReplicatorConfig::default();
560 let (replicator, _handle) = StateReplicator::new(config, integrity);
561
562 assert_eq!(replicator.fencing_token.load(Ordering::Acquire), 1);
563 }
564
565 #[test]
566 fn test_fencing_token_increment() {
567 let integrity = StateIntegrity::new();
568 let config = ReplicatorConfig::default();
569 let (replicator, _handle) = StateReplicator::new(config, integrity);
570
571 assert_eq!(replicator.next_fencing_token(), 2);
572 assert_eq!(replicator.next_fencing_token(), 3);
573 }
574}