Skip to main content

nklave_core/replication/
passive.rs

1//! Passive node state receiver
2//!
3//! The PassiveReceiver connects to the primary, receives decision records,
4//! verifies the hash chain, and maintains shadow state ready for failover.
5
6use super::heartbeat::{HeartbeatConfig, HeartbeatMonitor};
7use super::protocol::{AckMessage, HelloMessage, ReplicationMessage, PROTOCOL_VERSION};
8use super::tls::{ReplicationTlsConfig, ReplicationTlsStream};
9use crate::state::integrity::{DecisionRecord, IntegrityError, StateIntegrity};
10use crate::state::validator::ValidatorState;
11use std::collections::HashMap;
12use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
13use std::sync::{Arc, RwLock};
14use std::time::Duration;
15use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
16use tokio::net::TcpStream;
17use tokio::sync::watch;
18use tokio_rustls::TlsConnector;
19use tracing::{debug, error, info, warn};
20
21/// Configuration for the passive receiver
22#[derive(Debug, Clone)]
23pub struct PassiveConfig {
24    /// Address of the primary node
25    pub primary_addr: String,
26
27    /// Node identifier
28    pub node_id: String,
29
30    /// Reconnect delay on connection failure (milliseconds)
31    pub reconnect_delay_ms: u64,
32
33    /// Heartbeat configuration
34    pub heartbeat_config: HeartbeatConfig,
35
36    /// Optional TLS configuration for mTLS
37    pub tls_config: Option<ReplicationTlsConfig>,
38
39    /// Server name for TLS verification (e.g., "primary.nklave.local")
40    pub tls_server_name: Option<String>,
41}
42
43impl Default for PassiveConfig {
44    fn default() -> Self {
45        Self {
46            primary_addr: "127.0.0.1:26660".to_string(),
47            node_id: "passive".to_string(),
48            reconnect_delay_ms: 5000,
49            heartbeat_config: HeartbeatConfig::default(),
50            tls_config: None,
51            tls_server_name: None,
52        }
53    }
54}
55
56/// Passive receiver that maintains shadow state
57pub struct PassiveReceiver {
58    config: PassiveConfig,
59
60    /// Shadow copy of state integrity
61    integrity: Arc<RwLock<StateIntegrity>>,
62
63    /// Shadow copy of validator states
64    _validator_states: Arc<RwLock<HashMap<[u8; 48], ValidatorState>>>,
65
66    /// Heartbeat monitor
67    heartbeat_monitor: Arc<HeartbeatMonitor>,
68
69    /// Current fencing token from primary
70    fencing_token: AtomicU64,
71
72    /// Replication lag (sequences behind primary)
73    replication_lag: AtomicU64,
74
75    /// Whether we're connected to primary
76    connected: AtomicBool,
77
78    /// Shutdown flag
79    _shutdown: AtomicBool,
80
81    /// Optional TLS connector for secure connections
82    tls_connector: Option<TlsConnector>,
83
84    /// Server name for TLS verification
85    tls_server_name: Option<String>,
86}
87
88/// Handle for controlling the passive receiver
89pub struct PassiveHandle {
90    /// Shutdown signal
91    shutdown_tx: watch::Sender<bool>,
92
93    /// Reference to integrity for promotion
94    integrity: Arc<RwLock<StateIntegrity>>,
95
96    /// Reference to validator states for promotion
97    validator_states: Arc<RwLock<HashMap<[u8; 48], ValidatorState>>>,
98
99    /// Heartbeat monitor for failover detection
100    heartbeat_monitor: Arc<HeartbeatMonitor>,
101}
102
103impl PassiveHandle {
104    /// Get the current state integrity (for promotion)
105    pub fn integrity(&self) -> StateIntegrity {
106        self.integrity.read().map(|i| i.clone()).unwrap_or_default()
107    }
108
109    /// Get the current validator states (for promotion)
110    pub fn validator_states(&self) -> HashMap<[u8; 48], ValidatorState> {
111        self.validator_states.read().map(|v| v.clone()).unwrap_or_default()
112    }
113
114    /// Subscribe to failover notifications
115    pub fn subscribe_failover(&self) -> watch::Receiver<bool> {
116        self.heartbeat_monitor.subscribe_failover()
117    }
118
119    /// Check if we should trigger failover
120    pub fn should_failover(&self) -> bool {
121        !self.heartbeat_monitor.is_primary_alive()
122    }
123
124    /// Get replication lag
125    pub fn replication_lag(&self) -> u64 {
126        // This would need to be tracked
127        0
128    }
129
130    /// Signal shutdown
131    pub fn shutdown(self) {
132        let _ = self.shutdown_tx.send(true);
133    }
134}
135
136/// Errors from passive receiver
137#[derive(Debug, thiserror::Error)]
138pub enum PassiveError {
139    #[error("Connection failed: {0}")]
140    ConnectionFailed(String),
141
142    #[error("IO error: {0}")]
143    Io(#[from] std::io::Error),
144
145    #[error("Serialization error: {0}")]
146    Serialization(String),
147
148    #[error("Protocol error: {0}")]
149    Protocol(String),
150
151    #[error("Hash chain verification failed: {0}")]
152    HashChainError(#[from] IntegrityError),
153
154    #[error("Genesis root mismatch")]
155    GenesisRootMismatch,
156}
157
158impl PassiveReceiver {
159    /// Create a new passive receiver
160    pub fn new(
161        config: PassiveConfig,
162        initial_integrity: StateIntegrity,
163        initial_validators: HashMap<[u8; 48], ValidatorState>,
164    ) -> (Self, PassiveHandle) {
165        let (shutdown_tx, _shutdown_rx) = watch::channel(false);
166
167        let integrity = Arc::new(RwLock::new(initial_integrity));
168        let validator_states = Arc::new(RwLock::new(initial_validators));
169        let heartbeat_monitor = Arc::new(HeartbeatMonitor::new(config.heartbeat_config.clone()));
170
171        // Build TLS connector if configured
172        let tls_connector = if let Some(ref tls_config) = config.tls_config {
173            match tls_config.build_connector() {
174                Ok(connector) => {
175                    info!("TLS connector configured for passive receiver");
176                    Some(connector)
177                }
178                Err(e) => {
179                    error!(error = %e, "Failed to build TLS connector, falling back to plaintext");
180                    None
181                }
182            }
183        } else {
184            None
185        };
186
187        let tls_server_name = config.tls_server_name.clone();
188
189        let receiver = Self {
190            config: config.clone(),
191            integrity: integrity.clone(),
192            _validator_states: validator_states.clone(),
193            heartbeat_monitor: heartbeat_monitor.clone(),
194            fencing_token: AtomicU64::new(0),
195            replication_lag: AtomicU64::new(0),
196            connected: AtomicBool::new(false),
197            _shutdown: AtomicBool::new(false),
198            tls_connector,
199            tls_server_name,
200        };
201
202        let handle = PassiveHandle {
203            shutdown_tx,
204            integrity,
205            validator_states,
206            heartbeat_monitor,
207        };
208
209        (receiver, handle)
210    }
211
212    /// Run the passive receiver (connects and stays connected to primary)
213    pub async fn run(self, mut shutdown_rx: watch::Receiver<bool>) -> Result<(), PassiveError> {
214        loop {
215            // Check for shutdown
216            if *shutdown_rx.borrow() {
217                info!("Passive receiver shutting down");
218                break;
219            }
220
221            // Try to connect
222            match self.connect_and_receive(&mut shutdown_rx).await {
223                Ok(()) => {
224                    info!("Connection to primary closed normally");
225                }
226                Err(e) => {
227                    warn!(error = %e, "Connection to primary failed");
228                    self.connected.store(false, Ordering::Release);
229                }
230            }
231
232            // Wait before reconnecting
233            tokio::select! {
234                _ = tokio::time::sleep(Duration::from_millis(self.config.reconnect_delay_ms)) => {}
235                _ = shutdown_rx.changed() => {
236                    if *shutdown_rx.borrow() {
237                        break;
238                    }
239                }
240            }
241        }
242
243        Ok(())
244    }
245
246    /// Connect to primary and receive messages
247    async fn connect_and_receive(
248        &self,
249        shutdown_rx: &mut watch::Receiver<bool>,
250    ) -> Result<(), PassiveError> {
251        let tcp_stream = TcpStream::connect(&self.config.primary_addr).await?;
252        info!(addr = %self.config.primary_addr, "TCP connected to primary");
253
254        // Wrap with TLS if configured
255        if let Some(ref connector) = self.tls_connector {
256            let server_name = self
257                .tls_server_name
258                .as_deref()
259                .unwrap_or("primary.nklave.local");
260
261            match ReplicationTlsStream::connect(connector, server_name, tcp_stream).await {
262                Ok(tls_stream) => {
263                    info!("TLS handshake successful with primary");
264                    self.connected.store(true, Ordering::Release);
265                    self.receive_loop_generic(tls_stream, shutdown_rx).await
266                }
267                Err(e) => {
268                    warn!(error = %e, "TLS handshake failed with primary");
269                    Err(PassiveError::ConnectionFailed(format!(
270                        "TLS handshake failed: {}",
271                        e
272                    )))
273                }
274            }
275        } else {
276            info!(addr = %self.config.primary_addr, "Connected to primary (plaintext - NOT RECOMMENDED FOR PRODUCTION)");
277            self.connected.store(true, Ordering::Release);
278            self.receive_loop_generic(ReplicationTlsStream::plain(tcp_stream), shutdown_rx)
279                .await
280        }
281    }
282
283    /// Main receive loop with generic stream type
284    async fn receive_loop_generic<S>(
285        &self,
286        mut stream: S,
287        shutdown_rx: &mut watch::Receiver<bool>,
288    ) -> Result<(), PassiveError>
289    where
290        S: AsyncRead + AsyncWrite + Unpin,
291    {
292        // Receive hello from primary
293        let primary_hello = match read_message_generic(&mut stream).await? {
294            Some(ReplicationMessage::Hello(h)) => h,
295            Some(other) => {
296                return Err(PassiveError::Protocol(format!(
297                    "Expected Hello, got {:?}",
298                    std::mem::discriminant(&other)
299                )));
300            }
301            None => {
302                return Err(PassiveError::Protocol(
303                    "Connection closed during handshake".to_string(),
304                ));
305            }
306        };
307
308        // Version check
309        if primary_hello.version != PROTOCOL_VERSION {
310            return Err(PassiveError::Protocol(format!(
311                "Version mismatch: expected {}, got {}",
312                PROTOCOL_VERSION, primary_hello.version
313            )));
314        }
315
316        // Send our hello
317        let (our_sequence, our_hash, our_genesis) = {
318            let guard = self
319                .integrity
320                .read()
321                .map_err(|_| PassiveError::Protocol("Lock poisoned".to_string()))?;
322            (
323                guard.sequence_number,
324                guard.current_hash,
325                guard.genesis_validators_root,
326            )
327        };
328
329        let our_hello = HelloMessage {
330            version: PROTOCOL_VERSION,
331            node_id: self.config.node_id.clone(),
332            role: "Passive".to_string(),
333            sequence: our_sequence,
334            state_hash: our_hash,
335            genesis_root: our_genesis,
336        };
337
338        send_message_generic(&mut stream, &ReplicationMessage::Hello(our_hello)).await?;
339
340        info!(
341            primary_id = %primary_hello.node_id,
342            primary_sequence = primary_hello.sequence,
343            our_sequence = our_sequence,
344            "Handshake complete with primary"
345        );
346
347        // Record initial heartbeat
348        self.heartbeat_monitor
349            .record_heartbeat(primary_hello.sequence, primary_hello.state_hash);
350
351        // Check genesis root compatibility
352        if let (Some(primary_genesis), Some(our_genesis)) =
353            (primary_hello.genesis_root, our_genesis)
354        {
355            if primary_genesis != our_genesis {
356                return Err(PassiveError::GenesisRootMismatch);
357            }
358        }
359
360        // Main receive loop
361        loop {
362            tokio::select! {
363                msg_result = read_message_generic(&mut stream) => {
364                    match msg_result {
365                        Ok(Some(msg)) => {
366                            self.handle_message_generic(msg, &mut stream).await?;
367                        }
368                        Ok(None) => {
369                            info!("Primary disconnected");
370                            break;
371                        }
372                        Err(e) => {
373                            return Err(e);
374                        }
375                    }
376                }
377
378                _ = shutdown_rx.changed() => {
379                    if *shutdown_rx.borrow() {
380                        break;
381                    }
382                }
383            }
384        }
385
386        Ok(())
387    }
388
389    /// Handle a received message (legacy version for TcpStream)
390    #[allow(dead_code)]
391    async fn handle_message(
392        &self,
393        msg: ReplicationMessage,
394        stream: &mut TcpStream,
395    ) -> Result<(), PassiveError> {
396        self.handle_message_generic(msg, stream).await
397    }
398
399    /// Handle a received message (generic version)
400    async fn handle_message_generic<S>(
401        &self,
402        msg: ReplicationMessage,
403        stream: &mut S,
404    ) -> Result<(), PassiveError>
405    where
406        S: AsyncRead + AsyncWrite + Unpin,
407    {
408        match msg {
409            ReplicationMessage::Heartbeat(hb) => {
410                self.heartbeat_monitor
411                    .record_heartbeat(hb.sequence, hb.state_hash);
412                self.fencing_token.store(hb.fencing_token, Ordering::Release);
413
414                // Calculate replication lag
415                let our_seq = self
416                    .integrity
417                    .read()
418                    .map(|i| i.sequence_number)
419                    .unwrap_or(0);
420                let lag = hb.sequence.saturating_sub(our_seq);
421                self.replication_lag.store(lag, Ordering::Release);
422
423                debug!(
424                    primary_sequence = hb.sequence,
425                    our_sequence = our_seq,
426                    lag = lag,
427                    "Heartbeat received"
428                );
429            }
430
431            ReplicationMessage::Decision(record) => {
432                self.apply_decision_record_generic(record, stream).await?;
433            }
434
435            ReplicationMessage::SyncResponse(response) => {
436                info!(
437                    record_count = response.records.len(),
438                    has_more = response.has_more,
439                    "Received sync response"
440                );
441
442                for record in response.records {
443                    self.apply_decision_record_generic(record, stream).await?;
444                }
445            }
446
447            ReplicationMessage::Error(err) => {
448                error!(
449                    code = ?err.code,
450                    description = %err.description,
451                    "Received error from primary"
452                );
453                return Err(PassiveError::Protocol(err.description));
454            }
455
456            _ => {
457                debug!("Ignoring unexpected message type");
458            }
459        }
460
461        Ok(())
462    }
463
464    /// Apply a decision record to our shadow state (legacy version for TcpStream)
465    #[allow(dead_code)]
466    async fn apply_decision_record(
467        &self,
468        record: DecisionRecord,
469        stream: &mut TcpStream,
470    ) -> Result<(), PassiveError> {
471        self.apply_decision_record_generic(record, stream).await
472    }
473
474    /// Apply a decision record to our shadow state (generic version)
475    async fn apply_decision_record_generic<S>(
476        &self,
477        record: DecisionRecord,
478        stream: &mut S,
479    ) -> Result<(), PassiveError>
480    where
481        S: AsyncWrite + Unpin,
482    {
483        let new_hash = {
484            let mut integrity = self
485                .integrity
486                .write()
487                .map_err(|_| PassiveError::Protocol("Lock poisoned".to_string()))?;
488
489            // Verify and apply the record
490            let new_hash = integrity.record_decision(&record)?;
491
492            debug!(
493                sequence = record.sequence,
494                new_hash = hex::encode(new_hash),
495                "Applied decision record"
496            );
497
498            new_hash
499        };
500
501        // Send acknowledgment
502        let ack = AckMessage {
503            sequence: record.sequence,
504            state_hash: new_hash,
505        };
506        send_message_generic(stream, &ReplicationMessage::Ack(ack)).await?;
507
508        // Update metrics
509        crate::metrics::set_state_sequence(record.sequence);
510
511        Ok(())
512    }
513
514    /// Get current replication lag
515    pub fn replication_lag(&self) -> u64 {
516        self.replication_lag.load(Ordering::Acquire)
517    }
518
519    /// Check if connected to primary
520    pub fn is_connected(&self) -> bool {
521        self.connected.load(Ordering::Acquire)
522    }
523}
524
525/// Send a message with length prefix (legacy version for TcpStream)
526#[allow(dead_code)]
527async fn send_message(
528    stream: &mut TcpStream,
529    msg: &ReplicationMessage,
530) -> Result<(), PassiveError> {
531    send_message_generic(stream, msg).await
532}
533
534/// Send a message with length prefix (generic version)
535async fn send_message_generic<S>(stream: &mut S, msg: &ReplicationMessage) -> Result<(), PassiveError>
536where
537    S: AsyncWrite + Unpin,
538{
539    let bytes =
540        serde_json::to_vec(msg).map_err(|e| PassiveError::Serialization(e.to_string()))?;
541
542    let len = bytes.len() as u32;
543    stream.write_all(&len.to_be_bytes()).await?;
544    stream.write_all(&bytes).await?;
545    stream.flush().await?;
546
547    Ok(())
548}
549
550/// Read a message with length prefix (legacy version for TcpStream)
551#[allow(dead_code)]
552async fn read_message(
553    stream: &mut TcpStream,
554) -> Result<Option<ReplicationMessage>, PassiveError> {
555    read_message_generic(stream).await
556}
557
558/// Read a message with length prefix (generic version)
559async fn read_message_generic<S>(stream: &mut S) -> Result<Option<ReplicationMessage>, PassiveError>
560where
561    S: AsyncRead + Unpin,
562{
563    let mut len_buf = [0u8; 4];
564    match stream.read_exact(&mut len_buf).await {
565        Ok(_) => {}
566        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
567            return Ok(None);
568        }
569        Err(e) => return Err(e.into()),
570    }
571
572    let len = u32::from_be_bytes(len_buf) as usize;
573    if len > super::protocol::MAX_MESSAGE_SIZE {
574        return Err(PassiveError::Protocol(format!(
575            "Message too large: {} bytes",
576            len
577        )));
578    }
579
580    let mut buf = vec![0u8; len];
581    stream.read_exact(&mut buf).await?;
582
583    let msg = serde_json::from_slice(&buf)
584        .map_err(|e| PassiveError::Serialization(e.to_string()))?;
585
586    Ok(Some(msg))
587}
588
589#[cfg(test)]
590mod tests {
591    use super::*;
592
593    #[test]
594    fn test_passive_receiver_creation() {
595        let config = PassiveConfig::default();
596        let integrity = StateIntegrity::new();
597        let validators = HashMap::new();
598
599        let (receiver, _handle) = PassiveReceiver::new(config, integrity, validators);
600
601        assert!(!receiver.is_connected());
602        assert_eq!(receiver.replication_lag(), 0);
603    }
604}