use crate::config::ReplicationConfig;
use crate::error::{ClusterError, Result};
use crate::node::NodeId;
use crate::partition::PartitionId;
use crate::protocol::Acks;
use dashmap::DashMap;
use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{oneshot, Mutex, RwLock};
use tracing::{debug, error, info, warn};
#[derive(Debug)]
pub struct PartitionReplication {
pub partition_id: PartitionId,
local_node: NodeId,
is_leader: bool,
leader_epoch: AtomicU64,
log_end_offset: AtomicU64,
high_watermark: AtomicU64,
replicas: RwLock<HashMap<NodeId, ReplicaProgress>>,
isr: RwLock<HashSet<NodeId>>,
pending_acks: DashMap<u64, PendingAck>,
config: ReplicationConfig,
}
#[derive(Debug, Clone)]
pub struct ReplicaProgress {
pub node_id: NodeId,
pub log_end_offset: u64,
pub last_fetch: Instant,
pub in_sync: bool,
pub lag: u64,
}
#[derive(Debug)]
#[allow(dead_code)]
struct PendingAck {
offset: u64,
acked_nodes: HashSet<NodeId>,
required_acks: usize,
completion: oneshot::Sender<Result<()>>,
created: Instant,
}
impl PartitionReplication {
pub fn new(
partition_id: PartitionId,
local_node: NodeId,
is_leader: bool,
config: ReplicationConfig,
) -> Self {
Self {
partition_id,
local_node,
is_leader,
leader_epoch: AtomicU64::new(0),
log_end_offset: AtomicU64::new(0),
high_watermark: AtomicU64::new(0),
replicas: RwLock::new(HashMap::new()),
isr: RwLock::new(HashSet::new()),
pending_acks: DashMap::new(),
config,
}
}
pub async fn become_leader(&mut self, epoch: u64, replicas: Vec<NodeId>) {
self.is_leader = true;
self.leader_epoch.store(epoch, Ordering::SeqCst);
let mut replica_map = self.replicas.write().await;
replica_map.clear();
for node_id in &replicas {
if node_id != &self.local_node {
replica_map.insert(
node_id.clone(),
ReplicaProgress {
node_id: node_id.clone(),
log_end_offset: 0,
last_fetch: Instant::now(),
in_sync: false,
lag: u64::MAX,
},
);
}
}
let mut isr = self.isr.write().await;
isr.clear();
isr.insert(self.local_node.clone());
info!(
partition = %self.partition_id,
epoch = epoch,
replicas = replicas.len(),
"Became partition leader"
);
}
pub fn become_follower(&mut self, epoch: u64) {
self.is_leader = false;
self.leader_epoch.store(epoch, Ordering::SeqCst);
info!(
partition = %self.partition_id,
epoch = epoch,
"Became partition follower"
);
}
pub async fn record_appended(&self, offset: u64) -> Result<()> {
self.log_end_offset.store(offset + 1, Ordering::SeqCst);
if self.is_leader {
self.maybe_advance_hwm().await;
}
Ok(())
}
pub async fn handle_replica_fetch(
&self,
replica_id: &NodeId,
fetch_offset: u64,
) -> Result<bool> {
if !self.is_leader {
return Err(ClusterError::NotLeader { leader: None });
}
let mut isr_changed = false;
let mut replicas = self.replicas.write().await;
if let Some(progress) = replicas.get_mut(replica_id) {
progress.last_fetch = Instant::now();
progress.log_end_offset = fetch_offset;
let leader_leo = self.log_end_offset.load(Ordering::SeqCst);
progress.lag = leader_leo.saturating_sub(fetch_offset);
let should_be_in_sync = progress.lag <= self.config.replica_lag_max_messages;
if should_be_in_sync != progress.in_sync {
progress.in_sync = should_be_in_sync;
isr_changed = true;
let mut isr = self.isr.write().await;
if should_be_in_sync {
isr.insert(replica_id.clone());
info!(
partition = %self.partition_id,
replica = %replica_id,
"Replica joined ISR"
);
} else {
isr.remove(replica_id);
warn!(
partition = %self.partition_id,
replica = %replica_id,
lag = progress.lag,
"Replica removed from ISR due to lag"
);
}
}
}
drop(replicas);
self.maybe_advance_hwm().await;
Ok(isr_changed)
}
pub async fn check_replica_health(&self) -> Vec<NodeId> {
if !self.is_leader {
return vec![];
}
let now = Instant::now();
let mut removed = vec![];
let mut replicas = self.replicas.write().await;
let mut isr = self.isr.write().await;
for (node_id, progress) in replicas.iter_mut() {
if progress.in_sync {
let since_fetch = now.duration_since(progress.last_fetch);
if since_fetch > self.config.replica_lag_max_time {
progress.in_sync = false;
isr.remove(node_id);
removed.push(node_id.clone());
warn!(
partition = %self.partition_id,
replica = %node_id,
lag_time = ?since_fetch,
"Replica removed from ISR due to time lag"
);
}
}
}
removed
}
async fn maybe_advance_hwm(&self) {
let replicas = self.replicas.read().await;
let isr = self.isr.read().await;
let mut min_leo = self.log_end_offset.load(Ordering::SeqCst);
for node_id in isr.iter() {
if node_id == &self.local_node {
continue;
}
if let Some(progress) = replicas.get(node_id) {
min_leo = min_leo.min(progress.log_end_offset);
}
}
drop(isr);
drop(replicas);
let current_hwm = self.high_watermark.load(Ordering::SeqCst);
if min_leo > current_hwm {
self.high_watermark.store(min_leo, Ordering::SeqCst);
self.complete_pending_acks(min_leo).await;
}
}
pub async fn wait_for_acks(&self, offset: u64, acks: Acks) -> Result<()> {
match acks {
Acks::None => Ok(()),
Acks::Leader => {
Ok(())
}
Acks::All => {
let isr = self.isr.read().await;
let required = isr.len();
if required <= 1 {
return Ok(());
}
let (tx, rx) = oneshot::channel();
let mut acked = HashSet::new();
acked.insert(self.local_node.clone());
self.pending_acks.insert(
offset,
PendingAck {
offset,
acked_nodes: acked,
required_acks: required,
completion: tx,
created: Instant::now(),
},
);
drop(isr);
match tokio::time::timeout(Duration::from_secs(30), rx).await {
Ok(Ok(result)) => result,
Ok(Err(_)) => Err(ClusterError::ChannelClosed),
Err(_) => {
self.pending_acks.remove(&offset);
Err(ClusterError::Timeout)
}
}
}
}
}
async fn complete_pending_acks(&self, up_to_offset: u64) {
let to_complete: Vec<_> = self
.pending_acks
.iter()
.filter(|e| e.key() <= &up_to_offset)
.map(|e| *e.key())
.collect();
for offset in to_complete {
if let Some((_, pending)) = self.pending_acks.remove(&offset) {
let _ = pending.completion.send(Ok(()));
}
}
}
pub fn high_watermark(&self) -> u64 {
self.high_watermark.load(Ordering::SeqCst)
}
pub fn log_end_offset(&self) -> u64 {
self.log_end_offset.load(Ordering::SeqCst)
}
pub fn leader_epoch(&self) -> u64 {
self.leader_epoch.load(Ordering::SeqCst)
}
pub async fn get_isr(&self) -> HashSet<NodeId> {
self.isr.read().await.clone()
}
pub async fn has_min_isr(&self) -> bool {
let isr = self.isr.read().await;
isr.len() >= self.config.min_isr as usize
}
pub fn cleanup_stale_pending_acks(&self, timeout: Duration) -> usize {
let now = Instant::now();
let mut cleaned = 0;
self.pending_acks.retain(|_, pending| {
let is_stale = now.duration_since(pending.created) >= timeout;
if is_stale {
cleaned += 1;
}
!is_stale
});
if cleaned > 0 {
debug!(
partition = %self.partition_id,
cleaned = cleaned,
"Cleaned up stale pending acks"
);
}
cleaned
}
}
pub struct ReplicationManager {
local_node: NodeId,
partitions: DashMap<PartitionId, Arc<PartitionReplication>>,
config: ReplicationConfig,
raft_node: Option<Arc<RwLock<crate::raft::RaftNode>>>,
}
impl ReplicationManager {
pub fn new(local_node: NodeId, config: ReplicationConfig) -> Self {
Self {
local_node,
partitions: DashMap::new(),
config,
raft_node: None,
}
}
pub fn set_raft_node(&mut self, raft_node: Arc<RwLock<crate::raft::RaftNode>>) {
self.raft_node = Some(raft_node);
}
pub fn get_or_create(
&self,
partition_id: PartitionId,
is_leader: bool,
) -> Arc<PartitionReplication> {
self.partitions
.entry(partition_id.clone())
.or_insert_with(|| {
Arc::new(PartitionReplication::new(
partition_id,
self.local_node.clone(),
is_leader,
self.config.clone(),
))
})
.clone()
}
pub fn get(&self, partition_id: &PartitionId) -> Option<Arc<PartitionReplication>> {
self.partitions.get(partition_id).map(|e| e.value().clone())
}
pub fn remove(&self, partition_id: &PartitionId) -> Option<Arc<PartitionReplication>> {
self.partitions.remove(partition_id).map(|(_, v)| v)
}
pub fn leading_partitions(&self) -> Vec<PartitionId> {
self.partitions
.iter()
.filter(|e| e.value().is_leader)
.map(|e| e.key().clone())
.collect()
}
pub async fn handle_replica_fetch(
&self,
partition_id: &PartitionId,
replica_id: &NodeId,
fetch_offset: u64,
) -> Result<()> {
let partition = self
.get(partition_id)
.ok_or_else(|| ClusterError::PartitionNotFound {
topic: partition_id.topic.clone(),
partition: partition_id.partition,
})?;
let isr_changed = partition
.handle_replica_fetch(replica_id, fetch_offset)
.await?;
if isr_changed {
if let Err(e) = self.propagate_isr_change(partition_id).await {
warn!(
partition = %partition_id,
error = %e,
"Failed to propagate ISR change (will retry on next health check)"
);
}
}
Ok(())
}
pub async fn run_health_checks(&self) {
for entry in self.partitions.iter() {
let partition = entry.value();
if partition.is_leader {
let removed = partition.check_replica_health().await;
if !removed.is_empty() {
warn!(
partition = %partition.partition_id,
removed = ?removed,
"Removed replicas from ISR - propagating via Raft"
);
if let Err(e) = self.propagate_isr_change(&partition.partition_id).await {
error!(
partition = %partition.partition_id,
error = %e,
"Failed to propagate ISR change via Raft"
);
}
}
}
}
}
async fn propagate_isr_change(&self, partition_id: &PartitionId) -> Result<()> {
let partition = match self.get(partition_id) {
Some(p) => p,
None => return Ok(()), };
let isr = partition.isr.read().await;
let isr_vec: Vec<NodeId> = isr.iter().cloned().collect();
drop(isr);
if let Some(raft_node) = &self.raft_node {
let node = raft_node.read().await;
let cmd = crate::metadata::MetadataCommand::UpdatePartitionIsr {
partition: partition_id.clone(),
isr: isr_vec.clone(),
};
match node.propose(cmd).await {
Ok(_response) => {
info!(
partition = %partition_id,
isr = ?isr_vec,
"ISR change propagated via Raft"
);
Ok(())
}
Err(e) => {
error!(
partition = %partition_id,
error = %e,
"Failed to propose ISR change to Raft"
);
Err(e)
}
}
} else {
debug!(
partition = %partition_id,
"No Raft node configured - ISR change local only"
);
Ok(())
}
}
}
pub struct FollowerFetcher {
local_node: NodeId,
partition_id: PartitionId,
leader_id: NodeId,
fetch_offset: u64,
high_watermark: u64,
transport: Arc<Mutex<crate::Transport>>,
config: ReplicationConfig,
shutdown: tokio::sync::broadcast::Receiver<()>,
last_fetch: std::time::Instant,
}
impl FollowerFetcher {
pub fn new(
local_node: NodeId,
partition_id: PartitionId,
leader_id: NodeId,
start_offset: u64,
transport: Arc<Mutex<crate::Transport>>,
config: ReplicationConfig,
shutdown: tokio::sync::broadcast::Receiver<()>,
) -> Self {
Self {
local_node,
partition_id,
leader_id,
fetch_offset: start_offset,
high_watermark: 0,
transport,
config,
shutdown,
last_fetch: std::time::Instant::now(),
}
}
pub async fn run(mut self) -> Result<()> {
let mut interval = tokio::time::interval(self.config.fetch_interval);
loop {
tokio::select! {
_ = interval.tick() => {
if let Err(e) = self.fetch_from_leader().await {
error!(
partition = %self.partition_id,
error = %e,
"Fetch from leader failed"
);
}
}
_ = self.shutdown.recv() => {
info!(partition = %self.partition_id, "Follower fetcher shutting down");
break;
}
}
}
Ok(())
}
async fn fetch_from_leader(&mut self) -> Result<()> {
use crate::protocol::{ClusterRequest, ClusterResponse, RequestHeader};
let header = RequestHeader::new(
rand::random(), self.local_node.clone(),
);
let request = ClusterRequest::Fetch {
header,
partition: self.partition_id.clone(),
offset: self.fetch_offset,
max_bytes: self.config.fetch_max_bytes,
};
let response = {
let transport = self.transport.lock().await;
transport.send(&self.leader_id, request).await?
};
match response {
ClusterResponse::Fetch {
header,
partition,
high_watermark,
log_start_offset: _,
records,
} => {
if !header.is_success() {
return Err(ClusterError::Network(format!(
"Fetch failed: {}",
header.error_message.unwrap_or_default()
)));
}
if partition != self.partition_id {
return Err(ClusterError::Network(format!(
"Partition mismatch: expected {}, got {}",
self.partition_id, partition
)));
}
if !records.is_empty() {
self.apply_records(&records).await?;
debug!(
partition = %self.partition_id,
records_bytes = records.len(),
new_offset = self.fetch_offset,
hwm = high_watermark,
"Applied records from leader"
);
}
self.high_watermark = high_watermark;
self.last_fetch = std::time::Instant::now();
self.report_replica_state().await?;
Ok(())
}
_ => Err(ClusterError::Network(format!(
"Unexpected response type: {:?}",
response
))),
}
}
async fn apply_records(&mut self, records: &[u8]) -> Result<()> {
let mut cursor = 0;
let mut last_offset: Option<u64> = None;
while cursor + 4 <= records.len() {
let len = u32::from_be_bytes([
records[cursor],
records[cursor + 1],
records[cursor + 2],
records[cursor + 3],
]) as usize;
cursor += 4;
if cursor + len > records.len() {
tracing::warn!(
"Truncated record at byte {} in apply_records (expected {} bytes, have {})",
cursor,
len,
records.len() - cursor
);
break;
}
match rivven_core::Message::from_bytes(&records[cursor..cursor + len]) {
Ok(msg) => {
last_offset = Some(msg.offset);
}
Err(e) => {
tracing::warn!("Failed to deserialize record at byte {}: {}", cursor, e);
break;
}
}
cursor += len;
}
if let Some(offset) = last_offset {
self.fetch_offset = offset + 1;
}
Ok(())
}
async fn report_replica_state(&mut self) -> Result<()> {
use crate::protocol::{ClusterRequest, RequestHeader};
let header = RequestHeader::new(rand::random(), self.local_node.clone());
let request = ClusterRequest::ReplicaState {
header,
partition: self.partition_id.clone(),
log_end_offset: self.fetch_offset,
high_watermark: self.high_watermark,
};
let transport = self.transport.lock().await;
let _ = transport.send(&self.leader_id, request).await;
Ok(())
}
pub fn lag(&self) -> u64 {
self.high_watermark.saturating_sub(self.fetch_offset)
}
pub fn fetch_age(&self) -> std::time::Duration {
self.last_fetch.elapsed()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_partition_replication_leader() {
let config = ReplicationConfig::default();
let partition_id = PartitionId::new("test", 0);
let mut replication =
PartitionReplication::new(partition_id.clone(), "node-1".to_string(), false, config);
replication
.become_leader(1, vec!["node-1".to_string(), "node-2".to_string()])
.await;
assert!(replication.is_leader);
assert_eq!(replication.leader_epoch(), 1);
}
#[tokio::test]
async fn test_hwm_advancement() {
let config = ReplicationConfig::default();
let partition_id = PartitionId::new("test", 0);
let mut replication = PartitionReplication::new(
partition_id.clone(),
"node-1".to_string(),
false,
config.clone(),
);
replication
.become_leader(1, vec!["node-1".to_string(), "node-2".to_string()])
.await;
replication.record_appended(0).await.unwrap();
replication.record_appended(1).await.unwrap();
replication.record_appended(2).await.unwrap();
assert_eq!(replication.high_watermark(), 3);
replication
.handle_replica_fetch(&"node-2".to_string(), 2)
.await
.unwrap();
let isr = replication.get_isr().await;
assert!(isr.contains("node-2"));
}
#[tokio::test]
async fn test_acks_none() {
let config = ReplicationConfig::default();
let partition_id = PartitionId::new("test", 0);
let replication =
PartitionReplication::new(partition_id, "node-1".to_string(), true, config);
let result = replication.wait_for_acks(100, Acks::None).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_follower_fetcher_lag_tracking() {
use crate::Transport;
let config = ReplicationConfig::default();
let partition_id = PartitionId::new("test", 0);
let (shutdown_tx, shutdown_rx) = tokio::sync::broadcast::channel(1);
let transport_config = crate::TransportConfig::default();
let transport = Transport::new(
"follower-1".into(),
"127.0.0.1:9093".parse().unwrap(),
transport_config,
);
let fetcher = FollowerFetcher::new(
"follower-1".to_string(),
partition_id,
"leader-1".to_string(),
0,
Arc::new(Mutex::new(transport)),
config,
shutdown_rx,
);
assert_eq!(fetcher.fetch_offset, 0);
assert_eq!(fetcher.high_watermark, 0);
assert_eq!(fetcher.lag(), 0);
drop(shutdown_tx);
}
#[tokio::test]
async fn test_replication_manager_partition_tracking() {
let config = ReplicationConfig::default();
let manager = ReplicationManager::new("node-1".to_string(), config);
let partition_id1 = PartitionId::new("topic-1", 0);
let partition_id2 = PartitionId::new("topic-1", 1);
manager.get_or_create(partition_id1.clone(), true);
manager.get_or_create(partition_id2.clone(), false);
assert_eq!(manager.leading_partitions().len(), 1);
assert!(manager.get(&partition_id1).is_some());
assert!(manager.get(&partition_id2).is_some());
manager.remove(&partition_id1);
assert!(manager.get(&partition_id1).is_none());
assert_eq!(manager.leading_partitions().len(), 0);
}
}