Skip to main content

nklave_core/replication/
primary.rs

1//! Primary node state replication
2//!
3//! The StateReplicator streams decision records to passive nodes
4//! and sends periodic heartbeats.
5
6use super::protocol::{
7    ErrorCode, ErrorMessage, Heartbeat, HelloMessage, ReplicationMessage, SyncResponse,
8    PROTOCOL_VERSION,
9};
10use super::tls::{ReplicationTlsConfig, ReplicationTlsStream};
11use crate::state::integrity::{DecisionRecord, StateIntegrity};
12use std::collections::VecDeque;
13use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
14use std::sync::{Arc, RwLock};
15use std::time::Duration;
16use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
17use tokio::net::{TcpListener, TcpStream};
18use tokio::sync::{broadcast, mpsc, watch};
19use tokio_rustls::TlsAcceptor;
20use tracing::{debug, error, info, warn};
21
22/// Configuration for the state replicator
23#[derive(Debug, Clone)]
24pub struct ReplicatorConfig {
25    /// Address to listen for passive node connections
26    pub listen_addr: String,
27
28    /// Heartbeat interval in milliseconds
29    pub heartbeat_interval_ms: u64,
30
31    /// Maximum number of records to buffer for slow passives
32    pub max_buffer_size: usize,
33
34    /// Node identifier
35    pub node_id: String,
36
37    /// Optional TLS configuration for mTLS
38    pub tls_config: Option<ReplicationTlsConfig>,
39}
40
41impl Default for ReplicatorConfig {
42    fn default() -> Self {
43        Self {
44            listen_addr: "127.0.0.1:26660".to_string(),
45            heartbeat_interval_ms: 1000,
46            max_buffer_size: 10000,
47            node_id: "primary".to_string(),
48            tls_config: None,
49        }
50    }
51}
52
53/// State replicator for the primary node
54pub struct StateReplicator {
55    config: ReplicatorConfig,
56
57    /// Current fencing token
58    fencing_token: AtomicU64,
59
60    /// Channel to receive new decision records
61    decision_rx: mpsc::Receiver<DecisionRecord>,
62
63    /// Broadcast channel for sending to passives
64    broadcast_tx: broadcast::Sender<ReplicationMessage>,
65
66    /// Record buffer for catch-up sync
67    record_buffer: Arc<RwLock<VecDeque<DecisionRecord>>>,
68
69    /// Current state integrity (for heartbeats)
70    integrity: Arc<RwLock<StateIntegrity>>,
71
72    /// Shutdown flag
73    _shutdown: AtomicBool,
74}
75
76/// Handle for controlling the replicator
77pub struct ReplicatorHandle {
78    /// Sender for new decision records
79    decision_tx: mpsc::Sender<DecisionRecord>,
80
81    /// Shutdown signal
82    shutdown_tx: watch::Sender<bool>,
83
84    /// Reference to update integrity
85    integrity: Arc<RwLock<StateIntegrity>>,
86}
87
88impl ReplicatorHandle {
89    /// Send a new decision record to be replicated
90    pub async fn replicate(&self, record: DecisionRecord) -> Result<(), ReplicationError> {
91        self.decision_tx
92            .send(record)
93            .await
94            .map_err(|_| ReplicationError::ChannelClosed)
95    }
96
97    /// Update the current state integrity (call after each signing)
98    pub fn update_integrity(&self, integrity: StateIntegrity) {
99        if let Ok(mut guard) = self.integrity.write() {
100            *guard = integrity;
101        }
102    }
103
104    /// Signal the replicator to shutdown
105    pub fn shutdown(self) {
106        let _ = self.shutdown_tx.send(true);
107    }
108}
109
110/// Errors from replication
111#[derive(Debug, thiserror::Error)]
112pub enum ReplicationError {
113    #[error("Channel closed")]
114    ChannelClosed,
115
116    #[error("IO error: {0}")]
117    Io(#[from] std::io::Error),
118
119    #[error("Serialization error: {0}")]
120    Serialization(String),
121
122    #[error("Protocol error: {0}")]
123    Protocol(String),
124
125    #[error("Fencing token rejected")]
126    FencingRejected,
127}
128
129impl StateReplicator {
130    /// Create a new state replicator
131    pub fn new(
132        config: ReplicatorConfig,
133        initial_integrity: StateIntegrity,
134    ) -> (Self, ReplicatorHandle) {
135        let (decision_tx, decision_rx) = mpsc::channel(1000);
136        let (broadcast_tx, _) = broadcast::channel(1000);
137        let (shutdown_tx, _shutdown_rx) = watch::channel(false);
138
139        let integrity = Arc::new(RwLock::new(initial_integrity));
140
141        let replicator = Self {
142            config: config.clone(),
143            fencing_token: AtomicU64::new(1),
144            decision_rx,
145            broadcast_tx,
146            record_buffer: Arc::new(RwLock::new(VecDeque::with_capacity(config.max_buffer_size))),
147            integrity: integrity.clone(),
148            _shutdown: AtomicBool::new(false),
149        };
150
151        let handle = ReplicatorHandle {
152            decision_tx,
153            shutdown_tx,
154            integrity,
155        };
156
157        (replicator, handle)
158    }
159
160    /// Start the replicator (runs until shutdown)
161    pub async fn run(mut self, mut shutdown_rx: watch::Receiver<bool>) -> Result<(), ReplicationError> {
162        let listener = TcpListener::bind(&self.config.listen_addr).await?;
163
164        // Build TLS acceptor if configured
165        let tls_acceptor: Option<TlsAcceptor> = if let Some(ref tls_config) = self.config.tls_config {
166            match tls_config.build_acceptor() {
167                Ok(acceptor) => {
168                    info!(addr = %self.config.listen_addr, "State replicator listening with mTLS");
169                    Some(acceptor)
170                }
171                Err(e) => {
172                    error!(error = %e, "Failed to build TLS acceptor");
173                    return Err(ReplicationError::Protocol(format!("TLS setup failed: {}", e)));
174                }
175            }
176        } else {
177            info!(addr = %self.config.listen_addr, "State replicator listening (plaintext - NOT RECOMMENDED FOR PRODUCTION)");
178            None
179        };
180
181        let heartbeat_interval = Duration::from_millis(self.config.heartbeat_interval_ms);
182        let mut heartbeat_timer = tokio::time::interval(heartbeat_interval);
183
184        loop {
185            tokio::select! {
186                // Accept new passive connections
187                accept_result = listener.accept() => {
188                    match accept_result {
189                        Ok((stream, addr)) => {
190                            info!(addr = %addr, "Passive node connected");
191                            let broadcast_rx = self.broadcast_tx.subscribe();
192                            let buffer = self.record_buffer.clone();
193                            let integrity = self.integrity.clone();
194                            let node_id = self.config.node_id.clone();
195                            let _fencing_token = self.fencing_token.load(Ordering::Acquire);
196                            let tls_acceptor_clone = tls_acceptor.clone();
197
198                            tokio::spawn(async move {
199                                // Wrap with TLS if configured
200                                let result = if let Some(acceptor) = tls_acceptor_clone {
201                                    match ReplicationTlsStream::accept(&acceptor, stream).await {
202                                        Ok(tls_stream) => {
203                                            info!(addr = %addr, "TLS handshake successful");
204                                            handle_passive_connection_generic(
205                                                tls_stream,
206                                                broadcast_rx,
207                                                buffer,
208                                                integrity,
209                                                node_id,
210                                                _fencing_token,
211                                            ).await
212                                        }
213                                        Err(e) => {
214                                            warn!(addr = %addr, error = %e, "TLS handshake failed");
215                                            return;
216                                        }
217                                    }
218                                } else {
219                                    handle_passive_connection_generic(
220                                        ReplicationTlsStream::plain(stream),
221                                        broadcast_rx,
222                                        buffer,
223                                        integrity,
224                                        node_id,
225                                        _fencing_token,
226                                    ).await
227                                };
228
229                                if let Err(e) = result {
230                                    warn!(error = %e, "Passive connection error");
231                                }
232                            });
233                        }
234                        Err(e) => {
235                            error!(error = %e, "Accept error");
236                        }
237                    }
238                }
239
240                // Receive new decision records to replicate
241                Some(record) = self.decision_rx.recv() => {
242                    self.buffer_record(record.clone());
243
244                    let msg = ReplicationMessage::Decision(record);
245                    let _ = self.broadcast_tx.send(msg);
246                }
247
248                // Send periodic heartbeats
249                _ = heartbeat_timer.tick() => {
250                    if let Ok(integrity) = self.integrity.read() {
251                        let heartbeat = Heartbeat::new(
252                            integrity.sequence_number,
253                            integrity.current_hash,
254                            self.fencing_token.load(Ordering::Acquire),
255                        );
256                        let msg = ReplicationMessage::Heartbeat(heartbeat);
257                        let _ = self.broadcast_tx.send(msg);
258                    }
259                }
260
261                // Check for shutdown
262                _ = shutdown_rx.changed() => {
263                    if *shutdown_rx.borrow() {
264                        info!("State replicator shutting down");
265                        break;
266                    }
267                }
268            }
269        }
270
271        Ok(())
272    }
273
274    /// Buffer a record for catch-up sync
275    fn buffer_record(&self, record: DecisionRecord) {
276        if let Ok(mut buffer) = self.record_buffer.write() {
277            if buffer.len() >= self.config.max_buffer_size {
278                buffer.pop_front();
279            }
280            buffer.push_back(record);
281        }
282    }
283
284    /// Increment and return the fencing token
285    pub fn next_fencing_token(&self) -> u64 {
286        self.fencing_token.fetch_add(1, Ordering::AcqRel) + 1
287    }
288}
289
290/// Handle a connected passive node (legacy function for TcpStream)
291#[allow(dead_code)]
292async fn handle_passive_connection(
293    stream: TcpStream,
294    broadcast_rx: broadcast::Receiver<ReplicationMessage>,
295    record_buffer: Arc<RwLock<VecDeque<DecisionRecord>>>,
296    integrity: Arc<RwLock<StateIntegrity>>,
297    node_id: String,
298    fencing_token: u64,
299) -> Result<(), ReplicationError> {
300    handle_passive_connection_generic(
301        ReplicationTlsStream::plain(stream),
302        broadcast_rx,
303        record_buffer,
304        integrity,
305        node_id,
306        fencing_token,
307    )
308    .await
309}
310
311/// Handle a connected passive node (generic version for any stream type)
312async fn handle_passive_connection_generic<S>(
313    mut stream: S,
314    mut broadcast_rx: broadcast::Receiver<ReplicationMessage>,
315    record_buffer: Arc<RwLock<VecDeque<DecisionRecord>>>,
316    integrity: Arc<RwLock<StateIntegrity>>,
317    node_id: String,
318    _fencing_token: u64,
319) -> Result<(), ReplicationError>
320where
321    S: AsyncRead + AsyncWrite + Unpin,
322{
323    // Send hello message
324    let (sequence, state_hash, genesis_root) = {
325        let guard = integrity
326            .read()
327            .map_err(|_| ReplicationError::Protocol("Lock poisoned".to_string()))?;
328        (
329            guard.sequence_number,
330            guard.current_hash,
331            guard.genesis_validators_root,
332        )
333    };
334
335    let hello = HelloMessage {
336        version: PROTOCOL_VERSION,
337        node_id: node_id.clone(),
338        role: "Primary".to_string(),
339        sequence,
340        state_hash,
341        genesis_root,
342    };
343
344    send_message_generic(&mut stream, &ReplicationMessage::Hello(hello)).await?;
345
346    // Read hello response from passive
347    let passive_hello = match read_message_generic(&mut stream).await? {
348        Some(ReplicationMessage::Hello(h)) => h,
349        Some(other) => {
350            return Err(ReplicationError::Protocol(format!(
351                "Expected Hello, got {:?}",
352                std::mem::discriminant(&other)
353            )));
354        }
355        None => {
356            return Err(ReplicationError::Protocol(
357                "Connection closed during handshake".to_string(),
358            ));
359        }
360    };
361
362    // Version check
363    if passive_hello.version != PROTOCOL_VERSION {
364        send_message_generic(
365            &mut stream,
366            &ReplicationMessage::Error(ErrorMessage {
367                code: ErrorCode::VersionMismatch,
368                description: format!(
369                    "Expected version {}, got {}",
370                    PROTOCOL_VERSION, passive_hello.version
371                ),
372                sequence: None,
373            }),
374        )
375        .await?;
376        return Err(ReplicationError::Protocol("Version mismatch".to_string()));
377    }
378
379    info!(
380        passive_id = %passive_hello.node_id,
381        passive_sequence = passive_hello.sequence,
382        "Passive node handshake complete"
383    );
384
385    // Handle sync if passive is behind
386    if passive_hello.sequence < sequence {
387        handle_sync_generic(&mut stream, &record_buffer, passive_hello.sequence, sequence).await?;
388    }
389
390    // Forward messages from broadcast channel
391    loop {
392        tokio::select! {
393            msg_result = broadcast_rx.recv() => {
394                match msg_result {
395                    Ok(msg) => {
396                        if let Err(e) = send_message_generic(&mut stream, &msg).await {
397                            warn!(error = %e, "Failed to send to passive");
398                            break;
399                        }
400                    }
401                    Err(broadcast::error::RecvError::Lagged(n)) => {
402                        warn!(lagged = n, "Passive fell behind, may need resync");
403                    }
404                    Err(broadcast::error::RecvError::Closed) => {
405                        break;
406                    }
407                }
408            }
409
410            // Read acknowledgments or sync requests from passive
411            read_result = read_message_generic(&mut stream) => {
412                match read_result {
413                    Ok(Some(ReplicationMessage::Ack(ack))) => {
414                        debug!(sequence = ack.sequence, "Received ACK from passive");
415                    }
416                    Ok(Some(ReplicationMessage::SyncRequest(req))) => {
417                        let current_seq = integrity.read()
418                            .map(|i| i.sequence_number)
419                            .unwrap_or(0);
420                        handle_sync_generic(&mut stream, &record_buffer, req.from_sequence, current_seq).await?;
421                    }
422                    Ok(Some(_)) => {
423                        // Ignore other messages
424                    }
425                    Ok(None) => {
426                        info!("Passive disconnected");
427                        break;
428                    }
429                    Err(e) => {
430                        warn!(error = %e, "Error reading from passive");
431                        break;
432                    }
433                }
434            }
435        }
436    }
437
438    Ok(())
439}
440
441/// Handle sync request from passive (legacy version for TcpStream)
442#[allow(dead_code)]
443async fn handle_sync(
444    stream: &mut TcpStream,
445    record_buffer: &Arc<RwLock<VecDeque<DecisionRecord>>>,
446    from_sequence: u64,
447    current_sequence: u64,
448) -> Result<(), ReplicationError> {
449    handle_sync_generic(stream, record_buffer, from_sequence, current_sequence).await
450}
451
452/// Handle sync request from passive (generic version)
453async fn handle_sync_generic<S>(
454    stream: &mut S,
455    record_buffer: &Arc<RwLock<VecDeque<DecisionRecord>>>,
456    from_sequence: u64,
457    current_sequence: u64,
458) -> Result<(), ReplicationError>
459where
460    S: AsyncRead + AsyncWrite + Unpin,
461{
462    let records = {
463        let buffer = record_buffer
464            .read()
465            .map_err(|_| ReplicationError::Protocol("Lock poisoned".to_string()))?;
466
467        buffer
468            .iter()
469            .filter(|r| r.sequence > from_sequence && r.sequence <= current_sequence)
470            .cloned()
471            .collect::<Vec<_>>()
472    };
473
474    let response = SyncResponse {
475        records,
476        has_more: false,
477        current_sequence,
478    };
479
480    send_message_generic(stream, &ReplicationMessage::SyncResponse(response)).await
481}
482
483/// Send a message with length prefix (legacy version for TcpStream)
484#[allow(dead_code)]
485async fn send_message(
486    stream: &mut TcpStream,
487    msg: &ReplicationMessage,
488) -> Result<(), ReplicationError> {
489    send_message_generic(stream, msg).await
490}
491
492/// Send a message with length prefix (generic version)
493async fn send_message_generic<S>(
494    stream: &mut S,
495    msg: &ReplicationMessage,
496) -> Result<(), ReplicationError>
497where
498    S: AsyncWrite + Unpin,
499{
500    let bytes =
501        serde_json::to_vec(msg).map_err(|e| ReplicationError::Serialization(e.to_string()))?;
502
503    let len = bytes.len() as u32;
504    stream.write_all(&len.to_be_bytes()).await?;
505    stream.write_all(&bytes).await?;
506    stream.flush().await?;
507
508    Ok(())
509}
510
511/// Read a message with length prefix (legacy version for TcpStream)
512#[allow(dead_code)]
513async fn read_message(
514    stream: &mut TcpStream,
515) -> Result<Option<ReplicationMessage>, ReplicationError> {
516    read_message_generic(stream).await
517}
518
519/// Read a message with length prefix (generic version)
520async fn read_message_generic<S>(
521    stream: &mut S,
522) -> Result<Option<ReplicationMessage>, ReplicationError>
523where
524    S: AsyncRead + Unpin,
525{
526    let mut len_buf = [0u8; 4];
527    match stream.read_exact(&mut len_buf).await {
528        Ok(_) => {}
529        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
530            return Ok(None);
531        }
532        Err(e) => return Err(e.into()),
533    }
534
535    let len = u32::from_be_bytes(len_buf) as usize;
536    if len > super::protocol::MAX_MESSAGE_SIZE {
537        return Err(ReplicationError::Protocol(format!(
538            "Message too large: {} bytes",
539            len
540        )));
541    }
542
543    let mut buf = vec![0u8; len];
544    stream.read_exact(&mut buf).await?;
545
546    let msg = serde_json::from_slice(&buf)
547        .map_err(|e| ReplicationError::Serialization(e.to_string()))?;
548
549    Ok(Some(msg))
550}
551
552#[cfg(test)]
553mod tests {
554    use super::*;
555
556    #[test]
557    fn test_replicator_creation() {
558        let integrity = StateIntegrity::new();
559        let config = ReplicatorConfig::default();
560        let (replicator, _handle) = StateReplicator::new(config, integrity);
561
562        assert_eq!(replicator.fencing_token.load(Ordering::Acquire), 1);
563    }
564
565    #[test]
566    fn test_fencing_token_increment() {
567        let integrity = StateIntegrity::new();
568        let config = ReplicatorConfig::default();
569        let (replicator, _handle) = StateReplicator::new(config, integrity);
570
571        assert_eq!(replicator.next_fencing_token(), 2);
572        assert_eq!(replicator.next_fencing_token(), 3);
573    }
574}