use crate::RaftTypeConfig;
use crate::async_runtime::MpscReceiver;
use crate::async_runtime::TryRecvError;
use crate::core::raft_msg::RaftMsg;
use crate::errors::Fatal;
use crate::type_config::alias::MpscReceiverOf;
pub(crate) struct BatchRaftMsgReceiver<C>
where C: RaftTypeConfig
{
buffered: Option<RaftMsg<C>>,
inner: MpscReceiverOf<C, RaftMsg<C>>,
}
impl<C> BatchRaftMsgReceiver<C>
where C: RaftTypeConfig
{
pub(crate) fn new(receiver: MpscReceiverOf<C, RaftMsg<C>>) -> Self {
Self {
buffered: None,
inner: receiver,
}
}
pub(crate) async fn ensure_buffered(&mut self) -> Result<(), Fatal<C>> {
if self.buffered.is_some() {
return Ok(());
}
let msg = self.inner_recv().await?;
self.buffered = Some(msg);
Ok(())
}
pub(crate) fn try_recv(&mut self) -> Result<Option<RaftMsg<C>>, Fatal<C>> {
let msg = self.buffered_try_recv()?;
let Some(mut msg) = msg else {
return Ok(None);
};
self.merge_client_writes(&mut msg)?;
Ok(Some(msg))
}
fn buffered_try_recv(&mut self) -> Result<Option<RaftMsg<C>>, Fatal<C>> {
if let Some(msg) = self.buffered.take() {
return Ok(Some(msg));
}
self.inner_try_recv()
}
async fn inner_recv(&mut self) -> Result<RaftMsg<C>, Fatal<C>> {
let Some(msg) = self.inner.recv().await else {
tracing::info!("all rx_api senders are dropped");
return Err(Fatal::Stopped);
};
Ok(msg)
}
fn inner_try_recv(&mut self) -> Result<Option<RaftMsg<C>>, Fatal<C>> {
let res = self.inner.try_recv();
match res {
Ok(msg) => Ok(Some(msg)),
Err(e) => match e {
TryRecvError::Empty => {
tracing::debug!("all RaftMsg are processed, wait for more");
Ok(None)
}
TryRecvError::Disconnected => {
tracing::debug!("rx_api is disconnected, quit");
Err(Fatal::Stopped)
}
},
}
}
fn merge_client_writes(&mut self, msg: &mut RaftMsg<C>) -> Result<(), Fatal<C>> {
debug_assert!(self.buffered.is_none());
let (batch_payloads, batch_responders, batch_leader) = match msg {
RaftMsg::ClientWrite {
payloads,
responders,
expected_leader,
..
} => (payloads, responders, expected_leader),
_ => return Ok(()),
};
let max_batch_size = 4096;
for _i in 0..max_batch_size {
let next = self.inner_try_recv()?;
let Some(next) = next else {
break;
};
let mergeable = matches!(
&next,
RaftMsg::ClientWrite { expected_leader, .. } if expected_leader == batch_leader
);
if !mergeable {
self.buffered = Some(next);
break;
}
match next {
RaftMsg::ClientWrite {
payloads, responders, ..
} => {
batch_payloads.extend(payloads);
batch_responders.extend(responders);
}
_ => unreachable!(),
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::async_runtime::MpscSender;
use crate::batch::Batch;
use crate::engine::testing::UTConfig;
use crate::engine::testing::log_id;
use crate::entry::EntryPayload;
use crate::type_config::TypeConfigExt;
use crate::type_config::alias::BatchOf;
use crate::type_config::alias::CommittedLeaderIdOf;
use crate::type_config::alias::EntryPayloadOf;
type C = UTConfig<()>;
fn committed_leader_id(term: u64, node_id: u64) -> CommittedLeaderIdOf<C> {
*log_id(term, node_id, 0).committed_leader_id()
}
fn client_write(data: u64, leader: Option<CommittedLeaderIdOf<C>>) -> RaftMsg<C> {
RaftMsg::ClientWrite {
payloads: Batch::of([EntryPayload::Normal(data)]),
responders: Batch::of([None]),
expected_leader: leader,
#[cfg(feature = "runtime-stats")]
proposed_at: C::now(),
}
}
fn extract_payload_data(payloads: &BatchOf<C, EntryPayloadOf<C>>) -> Vec<u64> {
payloads
.as_ref()
.iter()
.map(|p| match p {
EntryPayload::Normal(d) => *d,
_ => panic!("expected Normal payload"),
})
.collect()
}
#[test]
fn test_merge_consecutive_client_writes_with_same_leader() {
C::run(async {
let (tx, rx) = C::mpsc(100);
let mut receiver: BatchRaftMsgReceiver<C> = BatchRaftMsgReceiver::new(rx);
let leader = Some(committed_leader_id(1, 1));
tx.send(client_write(1, leader)).await.unwrap();
tx.send(client_write(2, leader)).await.unwrap();
tx.send(client_write(3, leader)).await.unwrap();
receiver.ensure_buffered().await.unwrap();
let msg = receiver.try_recv().unwrap().unwrap();
let RaftMsg::ClientWrite {
payloads, responders, ..
} = msg
else {
panic!("expected ClientWrite");
};
assert_eq!(extract_payload_data(&payloads), vec![1, 2, 3]);
assert_eq!(responders.len(), payloads.len());
});
}
#[test]
fn test_no_merge_when_expected_leader_differs() {
C::run(async {
let (tx, rx) = C::mpsc(100);
let mut receiver: BatchRaftMsgReceiver<C> = BatchRaftMsgReceiver::new(rx);
let leader1 = Some(committed_leader_id(1, 1));
let leader2 = Some(committed_leader_id(2, 1));
tx.send(client_write(1, leader1)).await.unwrap();
tx.send(client_write(2, leader2)).await.unwrap();
receiver.ensure_buffered().await.unwrap();
let msg1 = receiver.try_recv().unwrap().unwrap();
let RaftMsg::ClientWrite {
payloads, responders, ..
} = msg1
else {
panic!("expected ClientWrite");
};
assert_eq!(extract_payload_data(&payloads), vec![1]);
assert_eq!(responders.len(), payloads.len());
let msg2 = receiver.try_recv().unwrap().unwrap();
let RaftMsg::ClientWrite {
payloads, responders, ..
} = msg2
else {
panic!("expected ClientWrite");
};
assert_eq!(extract_payload_data(&payloads), vec![2]);
assert_eq!(responders.len(), payloads.len());
});
}
#[test]
fn test_non_client_write_stops_merging() {
C::run(async {
let (tx, rx) = C::mpsc(100);
let mut receiver: BatchRaftMsgReceiver<C> = BatchRaftMsgReceiver::new(rx);
let leader = Some(committed_leader_id(1, 1));
tx.send(client_write(1, leader)).await.unwrap();
tx.send(RaftMsg::WithRaftState { req: Box::new(|_| {}) }).await.unwrap();
tx.send(client_write(2, leader)).await.unwrap();
receiver.ensure_buffered().await.unwrap();
let msg1 = receiver.try_recv().unwrap().unwrap();
let RaftMsg::ClientWrite {
payloads, responders, ..
} = msg1
else {
panic!("expected ClientWrite");
};
assert_eq!(extract_payload_data(&payloads), vec![1]);
assert_eq!(responders.len(), payloads.len());
let msg2 = receiver.try_recv().unwrap().unwrap();
assert!(matches!(msg2, RaftMsg::WithRaftState { .. }));
let msg3 = receiver.try_recv().unwrap().unwrap();
let RaftMsg::ClientWrite {
payloads, responders, ..
} = msg3
else {
panic!("expected ClientWrite");
};
assert_eq!(extract_payload_data(&payloads), vec![2]);
assert_eq!(responders.len(), payloads.len());
});
}
#[test]
fn test_try_recv_returns_none_when_empty() {
C::run(async {
let (_tx, rx) = C::mpsc::<RaftMsg<C>>(100);
let mut receiver: BatchRaftMsgReceiver<C> = BatchRaftMsgReceiver::new(rx);
let result = receiver.try_recv().unwrap();
assert!(result.is_none());
});
}
#[test]
fn test_ensure_buffered_waits_for_message() {
C::run(async {
let (tx, rx) = C::mpsc(100);
let mut receiver: BatchRaftMsgReceiver<C> = BatchRaftMsgReceiver::new(rx);
tx.send(client_write(42, None)).await.unwrap();
receiver.ensure_buffered().await.unwrap();
let msg = receiver.try_recv().unwrap().unwrap();
let RaftMsg::ClientWrite {
payloads, responders, ..
} = msg
else {
panic!("expected ClientWrite");
};
assert_eq!(extract_payload_data(&payloads), vec![42]);
assert_eq!(responders.len(), payloads.len());
});
}
#[test]
fn test_ensure_buffered_returns_immediately_if_already_buffered() {
C::run(async {
let (tx, rx) = C::mpsc(100);
let mut receiver: BatchRaftMsgReceiver<C> = BatchRaftMsgReceiver::new(rx);
let leader = Some(committed_leader_id(1, 1));
tx.send(client_write(1, None)).await.unwrap();
tx.send(client_write(2, leader)).await.unwrap();
receiver.ensure_buffered().await.unwrap();
let _msg1 = receiver.try_recv().unwrap().unwrap();
receiver.ensure_buffered().await.unwrap();
let msg2 = receiver.try_recv().unwrap().unwrap();
let RaftMsg::ClientWrite {
payloads, responders, ..
} = msg2
else {
panic!("expected ClientWrite");
};
assert_eq!(extract_payload_data(&payloads), vec![2]);
assert_eq!(responders.len(), payloads.len());
});
}
}