use nanoid::nanoid;
use std::sync::Arc;
use d_engine_proto::client::ClientResponse;
use d_engine_proto::common::EntryPayload;
use tokio::time::Instant;
use tonic::Status;
use crate::MaybeCloneOneshotSender;
use crate::RaftRequestWithSignal;
type ProposeSender = MaybeCloneOneshotSender<std::result::Result<ClientResponse, Status>>;
pub struct ProposeBatchBuffer {
payloads: Vec<EntryPayload>,
senders: Vec<ProposeSender>,
pub last_flush: Instant,
metrics_labels: Option<Arc<[(String, String)]>>,
metrics_enabled: bool,
}
impl ProposeBatchBuffer {
pub fn new(initial_capacity: usize) -> Self {
Self {
payloads: Vec::with_capacity(initial_capacity),
senders: Vec::with_capacity(initial_capacity),
last_flush: Instant::now(),
metrics_labels: None,
metrics_enabled: false,
}
}
pub fn with_length_gauge(
mut self,
node_id: u32,
buffer_name: &'static str,
enabled: bool,
) -> Self {
self.metrics_labels = Some(Arc::from(vec![
("node_id".to_string(), node_id.to_string()),
("buffer".to_string(), buffer_name.to_string()),
]));
self.metrics_enabled = enabled;
self
}
pub fn push(
&mut self,
payload: EntryPayload,
sender: ProposeSender,
) {
self.payloads.push(payload);
self.senders.push(sender);
if self.metrics_enabled {
if let Some(ref labels) = self.metrics_labels {
metrics::gauge!("batch.buffer_length", labels.as_ref())
.set(self.payloads.len() as f64);
}
}
}
pub fn flush(&mut self) -> Option<RaftRequestWithSignal> {
if self.payloads.is_empty() {
return None;
}
self.last_flush = Instant::now();
let n = self.payloads.len();
let mut payloads = Vec::with_capacity(n);
let mut senders = Vec::with_capacity(n);
std::mem::swap(&mut self.payloads, &mut payloads);
std::mem::swap(&mut self.senders, &mut senders);
if self.metrics_enabled {
if let Some(ref labels) = self.metrics_labels {
metrics::gauge!("batch.buffer_length", labels.as_ref()).set(0.0);
}
}
Some(RaftRequestWithSignal {
id: nanoid!(),
payloads,
senders,
wait_for_apply_event: true,
})
}
pub fn is_empty(&self) -> bool {
self.payloads.is_empty()
}
pub fn len(&self) -> usize {
self.payloads.len()
}
}