use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_stream::StreamExt;
use tokio_stream::wrappers::ReceiverStream;
use tonic::transport::Channel;
use tracing::{debug, warn};
use crate::proto::ReplicationItem;
use crate::proto::replicator_data_client::ReplicatorDataClient;
use crate::replicator::quorum::QuorumTracker;
use crate::types::{Epoch, Lsn, ReplicaId};
struct SecondaryConnection {
item_tx: mpsc::UnboundedSender<ReplicationItem>,
}
pub struct PrimarySender {
connections: HashMap<ReplicaId, SecondaryConnection>,
#[allow(dead_code)]
primary_id: ReplicaId,
epoch: Epoch,
}
impl PrimarySender {
pub fn new(primary_id: ReplicaId, epoch: Epoch) -> Self {
Self {
connections: HashMap::new(),
primary_id,
epoch,
}
}
pub fn set_epoch(&mut self, epoch: Epoch) {
self.epoch = epoch;
}
pub async fn add_secondary(
&mut self,
replica_id: ReplicaId,
address: String,
quorum_tracker: Arc<tokio::sync::Mutex<QuorumTracker>>,
partition_state: Arc<crate::handles::PartitionState>,
) -> crate::Result<()> {
if self.connections.contains_key(&replica_id) {
return Ok(()); }
let channel = Channel::from_shared(address)
.map_err(|e| crate::KubericError::Internal(Box::new(e)))?
.connect()
.await
.map_err(|e| crate::KubericError::Internal(Box::new(e)))?;
let mut client = ReplicatorDataClient::new(channel);
let (grpc_tx, grpc_rx) = mpsc::channel::<ReplicationItem>(256);
let outbound = ReceiverStream::new(grpc_rx);
let response = client
.replication_stream(outbound)
.await
.map_err(|e| crate::KubericError::Internal(Box::new(e)))?;
let mut ack_stream = response.into_inner();
let rid = replica_id;
let ps = partition_state;
tokio::spawn(async move {
while let Some(result) = ack_stream.next().await {
match result {
Ok(ack) => {
debug!(replica_id = rid, lsn = ack.lsn, "received ACK");
let mut tracker = quorum_tracker.lock().await;
tracker.ack(ack.lsn, rid);
ps.set_committed_lsn(tracker.committed_lsn());
}
Err(e) => {
warn!(replica_id = rid, error = %e, "ACK stream error");
break;
}
}
}
});
let (unbounded_tx, mut unbounded_rx) = mpsc::unbounded_channel::<ReplicationItem>();
tokio::spawn(async move {
while let Some(item) = unbounded_rx.recv().await {
if grpc_tx.send(item).await.is_err() {
warn!(replica_id = rid, "gRPC stream closed, drain task exiting");
break;
}
}
});
self.connections.insert(
replica_id,
SecondaryConnection {
item_tx: unbounded_tx,
},
);
Ok(())
}
pub fn send_to_one(
&self,
replica_id: ReplicaId,
lsn: Lsn,
data: &bytes::Bytes,
committed_lsn: Lsn,
) {
let item = ReplicationItem {
epoch_data_loss: self.epoch.data_loss_number,
epoch_config: self.epoch.configuration_number,
lsn,
data: data.to_vec(),
committed_lsn,
};
if let Some(conn) = self.connections.get(&replica_id)
&& conn.item_tx.send(item).is_err()
{
warn!(replica_id, lsn, "send_to_one: channel closed");
}
}
pub fn remove_secondary(&mut self, replica_id: ReplicaId) {
self.connections.remove(&replica_id);
}
pub fn send_to_all(&mut self, lsn: Lsn, data: &bytes::Bytes, committed_lsn: Lsn) {
let item = ReplicationItem {
epoch_data_loss: self.epoch.data_loss_number,
epoch_config: self.epoch.configuration_number,
lsn,
data: data.to_vec(),
committed_lsn,
};
let mut dead = Vec::new();
for (&rid, conn) in &self.connections {
if conn.item_tx.send(item.clone()).is_err() {
warn!(
replica_id = rid,
lsn, "secondary channel closed — removing connection"
);
dead.push(rid);
}
}
for rid in dead {
self.connections.remove(&rid);
}
}
pub fn connection_count(&self) -> usize {
self.connections.len()
}
pub fn has_connection(&self, replica_id: &ReplicaId) -> bool {
self.connections.contains_key(replica_id)
}
pub fn connected_ids(&self) -> Vec<ReplicaId> {
self.connections.keys().cloned().collect()
}
pub fn close_all(&mut self) {
self.connections.clear();
}
}