1use super::heartbeat::{HeartbeatConfig, HeartbeatMonitor};
7use super::protocol::{AckMessage, HelloMessage, ReplicationMessage, PROTOCOL_VERSION};
8use super::tls::{ReplicationTlsConfig, ReplicationTlsStream};
9use crate::state::integrity::{DecisionRecord, IntegrityError, StateIntegrity};
10use crate::state::validator::ValidatorState;
11use std::collections::HashMap;
12use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
13use std::sync::{Arc, RwLock};
14use std::time::Duration;
15use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
16use tokio::net::TcpStream;
17use tokio::sync::watch;
18use tokio_rustls::TlsConnector;
19use tracing::{debug, error, info, warn};
20
21#[derive(Debug, Clone)]
23pub struct PassiveConfig {
24 pub primary_addr: String,
26
27 pub node_id: String,
29
30 pub reconnect_delay_ms: u64,
32
33 pub heartbeat_config: HeartbeatConfig,
35
36 pub tls_config: Option<ReplicationTlsConfig>,
38
39 pub tls_server_name: Option<String>,
41}
42
43impl Default for PassiveConfig {
44 fn default() -> Self {
45 Self {
46 primary_addr: "127.0.0.1:26660".to_string(),
47 node_id: "passive".to_string(),
48 reconnect_delay_ms: 5000,
49 heartbeat_config: HeartbeatConfig::default(),
50 tls_config: None,
51 tls_server_name: None,
52 }
53 }
54}
55
56pub struct PassiveReceiver {
58 config: PassiveConfig,
59
60 integrity: Arc<RwLock<StateIntegrity>>,
62
63 _validator_states: Arc<RwLock<HashMap<[u8; 48], ValidatorState>>>,
65
66 heartbeat_monitor: Arc<HeartbeatMonitor>,
68
69 fencing_token: AtomicU64,
71
72 replication_lag: AtomicU64,
74
75 connected: AtomicBool,
77
78 _shutdown: AtomicBool,
80
81 tls_connector: Option<TlsConnector>,
83
84 tls_server_name: Option<String>,
86}
87
88pub struct PassiveHandle {
90 shutdown_tx: watch::Sender<bool>,
92
93 integrity: Arc<RwLock<StateIntegrity>>,
95
96 validator_states: Arc<RwLock<HashMap<[u8; 48], ValidatorState>>>,
98
99 heartbeat_monitor: Arc<HeartbeatMonitor>,
101}
102
103impl PassiveHandle {
104 pub fn integrity(&self) -> StateIntegrity {
106 self.integrity.read().map(|i| i.clone()).unwrap_or_default()
107 }
108
109 pub fn validator_states(&self) -> HashMap<[u8; 48], ValidatorState> {
111 self.validator_states.read().map(|v| v.clone()).unwrap_or_default()
112 }
113
114 pub fn subscribe_failover(&self) -> watch::Receiver<bool> {
116 self.heartbeat_monitor.subscribe_failover()
117 }
118
119 pub fn should_failover(&self) -> bool {
121 !self.heartbeat_monitor.is_primary_alive()
122 }
123
124 pub fn replication_lag(&self) -> u64 {
126 0
128 }
129
130 pub fn shutdown(self) {
132 let _ = self.shutdown_tx.send(true);
133 }
134}
135
136#[derive(Debug, thiserror::Error)]
138pub enum PassiveError {
139 #[error("Connection failed: {0}")]
140 ConnectionFailed(String),
141
142 #[error("IO error: {0}")]
143 Io(#[from] std::io::Error),
144
145 #[error("Serialization error: {0}")]
146 Serialization(String),
147
148 #[error("Protocol error: {0}")]
149 Protocol(String),
150
151 #[error("Hash chain verification failed: {0}")]
152 HashChainError(#[from] IntegrityError),
153
154 #[error("Genesis root mismatch")]
155 GenesisRootMismatch,
156}
157
158impl PassiveReceiver {
159 pub fn new(
161 config: PassiveConfig,
162 initial_integrity: StateIntegrity,
163 initial_validators: HashMap<[u8; 48], ValidatorState>,
164 ) -> (Self, PassiveHandle) {
165 let (shutdown_tx, _shutdown_rx) = watch::channel(false);
166
167 let integrity = Arc::new(RwLock::new(initial_integrity));
168 let validator_states = Arc::new(RwLock::new(initial_validators));
169 let heartbeat_monitor = Arc::new(HeartbeatMonitor::new(config.heartbeat_config.clone()));
170
171 let tls_connector = if let Some(ref tls_config) = config.tls_config {
173 match tls_config.build_connector() {
174 Ok(connector) => {
175 info!("TLS connector configured for passive receiver");
176 Some(connector)
177 }
178 Err(e) => {
179 error!(error = %e, "Failed to build TLS connector, falling back to plaintext");
180 None
181 }
182 }
183 } else {
184 None
185 };
186
187 let tls_server_name = config.tls_server_name.clone();
188
189 let receiver = Self {
190 config: config.clone(),
191 integrity: integrity.clone(),
192 _validator_states: validator_states.clone(),
193 heartbeat_monitor: heartbeat_monitor.clone(),
194 fencing_token: AtomicU64::new(0),
195 replication_lag: AtomicU64::new(0),
196 connected: AtomicBool::new(false),
197 _shutdown: AtomicBool::new(false),
198 tls_connector,
199 tls_server_name,
200 };
201
202 let handle = PassiveHandle {
203 shutdown_tx,
204 integrity,
205 validator_states,
206 heartbeat_monitor,
207 };
208
209 (receiver, handle)
210 }
211
212 pub async fn run(self, mut shutdown_rx: watch::Receiver<bool>) -> Result<(), PassiveError> {
214 loop {
215 if *shutdown_rx.borrow() {
217 info!("Passive receiver shutting down");
218 break;
219 }
220
221 match self.connect_and_receive(&mut shutdown_rx).await {
223 Ok(()) => {
224 info!("Connection to primary closed normally");
225 }
226 Err(e) => {
227 warn!(error = %e, "Connection to primary failed");
228 self.connected.store(false, Ordering::Release);
229 }
230 }
231
232 tokio::select! {
234 _ = tokio::time::sleep(Duration::from_millis(self.config.reconnect_delay_ms)) => {}
235 _ = shutdown_rx.changed() => {
236 if *shutdown_rx.borrow() {
237 break;
238 }
239 }
240 }
241 }
242
243 Ok(())
244 }
245
246 async fn connect_and_receive(
248 &self,
249 shutdown_rx: &mut watch::Receiver<bool>,
250 ) -> Result<(), PassiveError> {
251 let tcp_stream = TcpStream::connect(&self.config.primary_addr).await?;
252 info!(addr = %self.config.primary_addr, "TCP connected to primary");
253
254 if let Some(ref connector) = self.tls_connector {
256 let server_name = self
257 .tls_server_name
258 .as_deref()
259 .unwrap_or("primary.nklave.local");
260
261 match ReplicationTlsStream::connect(connector, server_name, tcp_stream).await {
262 Ok(tls_stream) => {
263 info!("TLS handshake successful with primary");
264 self.connected.store(true, Ordering::Release);
265 self.receive_loop_generic(tls_stream, shutdown_rx).await
266 }
267 Err(e) => {
268 warn!(error = %e, "TLS handshake failed with primary");
269 Err(PassiveError::ConnectionFailed(format!(
270 "TLS handshake failed: {}",
271 e
272 )))
273 }
274 }
275 } else {
276 info!(addr = %self.config.primary_addr, "Connected to primary (plaintext - NOT RECOMMENDED FOR PRODUCTION)");
277 self.connected.store(true, Ordering::Release);
278 self.receive_loop_generic(ReplicationTlsStream::plain(tcp_stream), shutdown_rx)
279 .await
280 }
281 }
282
283 async fn receive_loop_generic<S>(
285 &self,
286 mut stream: S,
287 shutdown_rx: &mut watch::Receiver<bool>,
288 ) -> Result<(), PassiveError>
289 where
290 S: AsyncRead + AsyncWrite + Unpin,
291 {
292 let primary_hello = match read_message_generic(&mut stream).await? {
294 Some(ReplicationMessage::Hello(h)) => h,
295 Some(other) => {
296 return Err(PassiveError::Protocol(format!(
297 "Expected Hello, got {:?}",
298 std::mem::discriminant(&other)
299 )));
300 }
301 None => {
302 return Err(PassiveError::Protocol(
303 "Connection closed during handshake".to_string(),
304 ));
305 }
306 };
307
308 if primary_hello.version != PROTOCOL_VERSION {
310 return Err(PassiveError::Protocol(format!(
311 "Version mismatch: expected {}, got {}",
312 PROTOCOL_VERSION, primary_hello.version
313 )));
314 }
315
316 let (our_sequence, our_hash, our_genesis) = {
318 let guard = self
319 .integrity
320 .read()
321 .map_err(|_| PassiveError::Protocol("Lock poisoned".to_string()))?;
322 (
323 guard.sequence_number,
324 guard.current_hash,
325 guard.genesis_validators_root,
326 )
327 };
328
329 let our_hello = HelloMessage {
330 version: PROTOCOL_VERSION,
331 node_id: self.config.node_id.clone(),
332 role: "Passive".to_string(),
333 sequence: our_sequence,
334 state_hash: our_hash,
335 genesis_root: our_genesis,
336 };
337
338 send_message_generic(&mut stream, &ReplicationMessage::Hello(our_hello)).await?;
339
340 info!(
341 primary_id = %primary_hello.node_id,
342 primary_sequence = primary_hello.sequence,
343 our_sequence = our_sequence,
344 "Handshake complete with primary"
345 );
346
347 self.heartbeat_monitor
349 .record_heartbeat(primary_hello.sequence, primary_hello.state_hash);
350
351 if let (Some(primary_genesis), Some(our_genesis)) =
353 (primary_hello.genesis_root, our_genesis)
354 {
355 if primary_genesis != our_genesis {
356 return Err(PassiveError::GenesisRootMismatch);
357 }
358 }
359
360 loop {
362 tokio::select! {
363 msg_result = read_message_generic(&mut stream) => {
364 match msg_result {
365 Ok(Some(msg)) => {
366 self.handle_message_generic(msg, &mut stream).await?;
367 }
368 Ok(None) => {
369 info!("Primary disconnected");
370 break;
371 }
372 Err(e) => {
373 return Err(e);
374 }
375 }
376 }
377
378 _ = shutdown_rx.changed() => {
379 if *shutdown_rx.borrow() {
380 break;
381 }
382 }
383 }
384 }
385
386 Ok(())
387 }
388
389 #[allow(dead_code)]
391 async fn handle_message(
392 &self,
393 msg: ReplicationMessage,
394 stream: &mut TcpStream,
395 ) -> Result<(), PassiveError> {
396 self.handle_message_generic(msg, stream).await
397 }
398
399 async fn handle_message_generic<S>(
401 &self,
402 msg: ReplicationMessage,
403 stream: &mut S,
404 ) -> Result<(), PassiveError>
405 where
406 S: AsyncRead + AsyncWrite + Unpin,
407 {
408 match msg {
409 ReplicationMessage::Heartbeat(hb) => {
410 self.heartbeat_monitor
411 .record_heartbeat(hb.sequence, hb.state_hash);
412 self.fencing_token.store(hb.fencing_token, Ordering::Release);
413
414 let our_seq = self
416 .integrity
417 .read()
418 .map(|i| i.sequence_number)
419 .unwrap_or(0);
420 let lag = hb.sequence.saturating_sub(our_seq);
421 self.replication_lag.store(lag, Ordering::Release);
422
423 debug!(
424 primary_sequence = hb.sequence,
425 our_sequence = our_seq,
426 lag = lag,
427 "Heartbeat received"
428 );
429 }
430
431 ReplicationMessage::Decision(record) => {
432 self.apply_decision_record_generic(record, stream).await?;
433 }
434
435 ReplicationMessage::SyncResponse(response) => {
436 info!(
437 record_count = response.records.len(),
438 has_more = response.has_more,
439 "Received sync response"
440 );
441
442 for record in response.records {
443 self.apply_decision_record_generic(record, stream).await?;
444 }
445 }
446
447 ReplicationMessage::Error(err) => {
448 error!(
449 code = ?err.code,
450 description = %err.description,
451 "Received error from primary"
452 );
453 return Err(PassiveError::Protocol(err.description));
454 }
455
456 _ => {
457 debug!("Ignoring unexpected message type");
458 }
459 }
460
461 Ok(())
462 }
463
464 #[allow(dead_code)]
466 async fn apply_decision_record(
467 &self,
468 record: DecisionRecord,
469 stream: &mut TcpStream,
470 ) -> Result<(), PassiveError> {
471 self.apply_decision_record_generic(record, stream).await
472 }
473
474 async fn apply_decision_record_generic<S>(
476 &self,
477 record: DecisionRecord,
478 stream: &mut S,
479 ) -> Result<(), PassiveError>
480 where
481 S: AsyncWrite + Unpin,
482 {
483 let new_hash = {
484 let mut integrity = self
485 .integrity
486 .write()
487 .map_err(|_| PassiveError::Protocol("Lock poisoned".to_string()))?;
488
489 let new_hash = integrity.record_decision(&record)?;
491
492 debug!(
493 sequence = record.sequence,
494 new_hash = hex::encode(new_hash),
495 "Applied decision record"
496 );
497
498 new_hash
499 };
500
501 let ack = AckMessage {
503 sequence: record.sequence,
504 state_hash: new_hash,
505 };
506 send_message_generic(stream, &ReplicationMessage::Ack(ack)).await?;
507
508 crate::metrics::set_state_sequence(record.sequence);
510
511 Ok(())
512 }
513
514 pub fn replication_lag(&self) -> u64 {
516 self.replication_lag.load(Ordering::Acquire)
517 }
518
519 pub fn is_connected(&self) -> bool {
521 self.connected.load(Ordering::Acquire)
522 }
523}
524
525#[allow(dead_code)]
527async fn send_message(
528 stream: &mut TcpStream,
529 msg: &ReplicationMessage,
530) -> Result<(), PassiveError> {
531 send_message_generic(stream, msg).await
532}
533
534async fn send_message_generic<S>(stream: &mut S, msg: &ReplicationMessage) -> Result<(), PassiveError>
536where
537 S: AsyncWrite + Unpin,
538{
539 let bytes =
540 serde_json::to_vec(msg).map_err(|e| PassiveError::Serialization(e.to_string()))?;
541
542 let len = bytes.len() as u32;
543 stream.write_all(&len.to_be_bytes()).await?;
544 stream.write_all(&bytes).await?;
545 stream.flush().await?;
546
547 Ok(())
548}
549
550#[allow(dead_code)]
552async fn read_message(
553 stream: &mut TcpStream,
554) -> Result<Option<ReplicationMessage>, PassiveError> {
555 read_message_generic(stream).await
556}
557
558async fn read_message_generic<S>(stream: &mut S) -> Result<Option<ReplicationMessage>, PassiveError>
560where
561 S: AsyncRead + Unpin,
562{
563 let mut len_buf = [0u8; 4];
564 match stream.read_exact(&mut len_buf).await {
565 Ok(_) => {}
566 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
567 return Ok(None);
568 }
569 Err(e) => return Err(e.into()),
570 }
571
572 let len = u32::from_be_bytes(len_buf) as usize;
573 if len > super::protocol::MAX_MESSAGE_SIZE {
574 return Err(PassiveError::Protocol(format!(
575 "Message too large: {} bytes",
576 len
577 )));
578 }
579
580 let mut buf = vec![0u8; len];
581 stream.read_exact(&mut buf).await?;
582
583 let msg = serde_json::from_slice(&buf)
584 .map_err(|e| PassiveError::Serialization(e.to_string()))?;
585
586 Ok(Some(msg))
587}
588
589#[cfg(test)]
590mod tests {
591 use super::*;
592
593 #[test]
594 fn test_passive_receiver_creation() {
595 let config = PassiveConfig::default();
596 let integrity = StateIntegrity::new();
597 let validators = HashMap::new();
598
599 let (receiver, _handle) = PassiveReceiver::new(config, integrity, validators);
600
601 assert!(!receiver.is_connected());
602 assert_eq!(receiver.replication_lag(), 0);
603 }
604}