use crate::error::{Error, Result};
use crate::storage::Database;
use crate::types::{Entry, LogIndex};
use std::time::Duration;
use tokio::sync::{mpsc, oneshot};
use tokio::time::Instant;
#[derive(Debug, Clone)]
pub struct SequencerConfig {
pub batch_max_size: usize,
pub batch_max_age: Duration,
pub channel_size: usize,
}
impl Default for SequencerConfig {
fn default() -> Self {
Self {
batch_max_size: 256,
batch_max_age: Duration::from_secs(1),
channel_size: 1024,
}
}
}
struct SequenceRequest {
entry: Entry,
response: oneshot::Sender<Result<LogIndex>>,
}
#[derive(Clone)]
pub struct Sequencer {
sender: mpsc::Sender<SequenceRequest>,
}
impl Sequencer {
pub fn new(
db: Database,
config: SequencerConfig,
) -> (Self, impl std::future::Future<Output = ()>) {
let (sender, receiver) = mpsc::channel(config.channel_size);
let sequencer = Self { sender };
let worker = SequencerWorker {
db,
config,
receiver,
};
(sequencer, worker.run())
}
pub async fn add(&self, entry: Entry) -> Result<LogIndex> {
let (tx, rx) = oneshot::channel();
let request = SequenceRequest {
entry,
response: tx,
};
self.sender
.send(request)
.await
.map_err(|_| Error::Internal("sequencer channel closed".into()))?;
rx.await
.map_err(|_| Error::Internal("sequencer response dropped".into()))?
}
}
struct SequencerWorker {
db: Database,
config: SequencerConfig,
receiver: mpsc::Receiver<SequenceRequest>,
}
impl SequencerWorker {
async fn run(mut self) {
let mut batch: Vec<SequenceRequest> = Vec::with_capacity(self.config.batch_max_size);
let mut batch_deadline: Option<Instant> = None;
loop {
let timeout = batch_deadline
.map(|d| d.saturating_duration_since(Instant::now()))
.unwrap_or(Duration::MAX);
tokio::select! {
request = self.receiver.recv() => {
match request {
Some(req) => {
if batch.is_empty() {
batch_deadline = Some(Instant::now() + self.config.batch_max_age);
}
batch.push(req);
if batch.len() >= self.config.batch_max_size {
self.flush_batch(&mut batch).await;
batch_deadline = None;
}
}
None => {
if !batch.is_empty() {
self.flush_batch(&mut batch).await;
}
return;
}
}
}
_ = tokio::time::sleep(timeout), if !batch.is_empty() => {
self.flush_batch(&mut batch).await;
batch_deadline = None;
}
}
}
}
async fn flush_batch(&self, batch: &mut Vec<SequenceRequest>) {
if batch.is_empty() {
return;
}
tracing::debug!("Flushing batch of {} entries", batch.len());
let entries: Vec<Entry> = batch.iter().map(|r| r.entry.clone()).collect();
let result = self.db.sequence_entries(entries).await;
let requests: Vec<SequenceRequest> = std::mem::take(batch);
match result {
Ok(sequenced) => {
for (req, seq_entry) in requests.into_iter().zip(sequenced.into_iter()) {
let _ = req.response.send(Ok(seq_entry.index()));
}
}
Err(e) => {
let err_msg = e.to_string();
for req in requests {
let _ = req.response.send(Err(Error::Internal(err_msg.clone())));
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_defaults() {
let config = SequencerConfig::default();
assert_eq!(config.batch_max_size, 256);
assert_eq!(config.batch_max_age, Duration::from_secs(1));
}
}