use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::sync::{mpsc, oneshot};
use tokio::task::JoinHandle;
use tracing::{debug, error, warn};
use yggr_core::{
Action, ConfigChange, Engine, EngineConfig, Env, Event, LogEntry, LogIndex, LogPayload, NodeId,
RandomizedEnv, RoleState, Term,
};
use crate::state_machine::StateMachine;
use crate::storage::{RecoveredState, Storage, StoredHardState, StoredSnapshot};
use crate::transport::Transport;
pub struct Node<S: StateMachine> {
inputs: mpsc::Sender<DriverInput<S>>,
background: Arc<Mutex<Option<BackgroundTasks>>>,
}
impl<S: StateMachine> std::fmt::Debug for Node<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Node").finish_non_exhaustive()
}
}
impl<S: StateMachine> Clone for Node<S> {
fn clone(&self) -> Self {
Self {
inputs: self.inputs.clone(),
background: Arc::clone(&self.background),
}
}
}
struct BackgroundTasks {
ticker: JoinHandle<()>,
driver: JoinHandle<()>,
apply: JoinHandle<()>,
}
#[derive(Debug, Clone)]
pub enum Bootstrap {
NewCluster {
members: Vec<NodeId>,
},
Join,
Recover,
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct Config {
pub node_id: NodeId,
pub peers: Vec<NodeId>,
pub election_timeout_min_ticks: u64,
pub election_timeout_max_ticks: u64,
pub heartbeat_interval_ticks: u64,
pub tick_interval: Duration,
pub bootstrap: Bootstrap,
pub max_pending_proposals: usize,
pub max_pending_applies: usize,
pub max_batch_delay_ticks: u64,
pub max_batch_entries: usize,
pub snapshot_hint_threshold_entries: u64,
pub max_log_entries: u64,
pub snapshot_chunk_size_bytes: usize,
pub pre_vote: bool,
pub lease_duration_ticks: u64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum ConfigError {
HeartbeatNotLessThanElectionMin {
heartbeat_interval_ticks: u64,
election_timeout_min_ticks: u64,
},
InvalidElectionTimeoutRange {
election_timeout_min_ticks: u64,
election_timeout_max_ticks: u64,
},
PeersContainSelf { node_id: NodeId },
InvalidSnapshotChunkSize { snapshot_chunk_size_bytes: usize },
LeaseDurationTooLarge {
lease_duration_ticks: u64,
election_timeout_min_ticks: u64,
heartbeat_interval_ticks: u64,
},
}
impl std::fmt::Display for ConfigError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::HeartbeatNotLessThanElectionMin {
heartbeat_interval_ticks,
election_timeout_min_ticks,
} => write!(
f,
"heartbeat interval {heartbeat_interval_ticks} must be less than election timeout min {election_timeout_min_ticks}"
),
Self::InvalidElectionTimeoutRange {
election_timeout_min_ticks,
election_timeout_max_ticks,
} => write!(
f,
"election timeout range [{election_timeout_min_ticks}, {election_timeout_max_ticks}) is empty"
),
Self::PeersContainSelf { node_id } => {
write!(f, "peer set must not contain self ({node_id})")
}
Self::InvalidSnapshotChunkSize {
snapshot_chunk_size_bytes,
} => write!(
f,
"snapshot chunk size must be greater than zero (got {snapshot_chunk_size_bytes})"
),
Self::LeaseDurationTooLarge {
lease_duration_ticks,
election_timeout_min_ticks,
heartbeat_interval_ticks,
} => write!(
f,
"lease_duration_ticks {lease_duration_ticks} must be strictly less than \
election_timeout_min_ticks {election_timeout_min_ticks} - \
heartbeat_interval_ticks {heartbeat_interval_ticks}"
),
}
}
}
impl std::error::Error for ConfigError {}
impl Config {
pub fn new(node_id: NodeId, peers: impl IntoIterator<Item = NodeId>) -> Self {
let peers: Vec<NodeId> = peers.into_iter().collect();
let bootstrap = if peers.is_empty() {
Bootstrap::Recover
} else {
let mut members = peers.clone();
members.push(node_id);
Bootstrap::NewCluster { members }
};
Self {
node_id,
peers,
election_timeout_min_ticks: 10,
election_timeout_max_ticks: 20,
heartbeat_interval_ticks: 3,
tick_interval: Duration::from_millis(50),
bootstrap,
max_pending_proposals: 1024,
max_pending_applies: 4096,
max_batch_delay_ticks: 0,
max_batch_entries: 64,
snapshot_hint_threshold_entries: 1024,
max_log_entries: 0,
snapshot_chunk_size_bytes: 64 * 1024,
pre_vote: true,
lease_duration_ticks: 0,
}
}
pub fn validate(&self) -> Result<(), ConfigError> {
if self.heartbeat_interval_ticks >= self.election_timeout_min_ticks {
return Err(ConfigError::HeartbeatNotLessThanElectionMin {
heartbeat_interval_ticks: self.heartbeat_interval_ticks,
election_timeout_min_ticks: self.election_timeout_min_ticks,
});
}
if self.election_timeout_min_ticks >= self.election_timeout_max_ticks {
return Err(ConfigError::InvalidElectionTimeoutRange {
election_timeout_min_ticks: self.election_timeout_min_ticks,
election_timeout_max_ticks: self.election_timeout_max_ticks,
});
}
if self.peers.contains(&self.node_id) {
return Err(ConfigError::PeersContainSelf {
node_id: self.node_id,
});
}
if self.snapshot_chunk_size_bytes == 0 {
return Err(ConfigError::InvalidSnapshotChunkSize {
snapshot_chunk_size_bytes: self.snapshot_chunk_size_bytes,
});
}
if self.lease_duration_ticks > 0 {
let safe_max = self
.election_timeout_min_ticks
.saturating_sub(self.heartbeat_interval_ticks);
if self.lease_duration_ticks >= safe_max {
return Err(ConfigError::LeaseDurationTooLarge {
lease_duration_ticks: self.lease_duration_ticks,
election_timeout_min_ticks: self.election_timeout_min_ticks,
heartbeat_interval_ticks: self.heartbeat_interval_ticks,
});
}
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Role {
Follower,
PreCandidate,
Candidate,
Leader,
}
impl std::fmt::Display for Role {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Follower => f.write_str("follower"),
Self::PreCandidate => f.write_str("pre-candidate"),
Self::Candidate => f.write_str("candidate"),
Self::Leader => f.write_str("leader"),
}
}
}
#[derive(Debug, Clone)]
pub struct NodeStatus {
pub node_id: NodeId,
pub role: Role,
pub current_term: Term,
pub commit_index: LogIndex,
pub last_applied: LogIndex,
pub leader_hint: Option<NodeId>,
pub peers: Vec<NodeId>,
}
#[non_exhaustive]
#[derive(Debug)]
pub enum ProposeError {
NotLeader { leader_hint: NodeId },
NoLeader,
Shutdown,
DriverDead,
Busy,
Fatal { reason: &'static str },
}
impl std::fmt::Display for ProposeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NotLeader { leader_hint } => write!(f, "not leader; try {leader_hint}"),
Self::NoLeader => write!(f, "no leader known yet"),
Self::Shutdown => write!(f, "node is shutting down"),
Self::DriverDead => write!(f, "node driver task died"),
Self::Busy => write!(f, "too many in-flight proposals; retry later"),
Self::Fatal { reason } => write!(f, "fatal runtime error: {reason}"),
}
}
}
impl std::error::Error for ProposeError {}
#[non_exhaustive]
#[derive(Debug)]
pub enum TransferLeadershipError {
NotLeader { leader_hint: NodeId },
NoLeader,
InvalidTarget { target: NodeId },
Shutdown,
DriverDead,
Fatal { reason: &'static str },
}
impl std::fmt::Display for TransferLeadershipError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NotLeader { leader_hint } => write!(f, "not leader; try {leader_hint}"),
Self::NoLeader => write!(f, "no leader known yet"),
Self::InvalidTarget { target } => write!(f, "invalid transfer target: {target}"),
Self::Shutdown => write!(f, "node is shutting down"),
Self::DriverDead => write!(f, "node driver task died"),
Self::Fatal { reason } => write!(f, "fatal runtime error: {reason}"),
}
}
}
impl std::error::Error for TransferLeadershipError {}
#[non_exhaustive]
#[derive(Debug)]
pub enum ReadError {
NotLeader { leader_hint: NodeId },
NotReady,
SteppedDown,
Shutdown,
DriverDead,
Fatal { reason: &'static str },
}
impl std::fmt::Display for ReadError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NotLeader { leader_hint } => write!(f, "not leader; try {leader_hint}"),
Self::NotReady => write!(f, "leader not ready to serve linearizable reads"),
Self::SteppedDown => write!(f, "leader stepped down before read completed"),
Self::Shutdown => write!(f, "node is shutting down"),
Self::DriverDead => write!(f, "node driver task died"),
Self::Fatal { reason } => write!(f, "fatal runtime error: {reason}"),
}
}
}
impl std::error::Error for ReadError {}
#[non_exhaustive]
#[derive(Debug)]
pub enum NodeStartError<E> {
Config(ConfigError),
Storage(E),
}
impl<E: std::fmt::Display> std::fmt::Display for NodeStartError<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Config(e) => write!(f, "invalid config: {e}"),
Self::Storage(e) => write!(f, "storage recover failed: {e}"),
}
}
}
impl<E: std::error::Error + 'static> std::error::Error for NodeStartError<E> {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Config(e) => Some(e),
Self::Storage(e) => Some(e),
}
}
}
impl<S: StateMachine> Node<S> {
pub async fn start<St, Tr>(
config: Config,
state_machine: S,
mut storage: St,
transport: Tr,
) -> Result<Self, NodeStartError<St::Error>>
where
St: Storage<Vec<u8>>,
Tr: Transport<Vec<u8>>,
{
config.validate().map_err(NodeStartError::Config)?;
let recovered = storage.recover().await.map_err(NodeStartError::Storage)?;
let initial_peers: Vec<NodeId> = match &config.bootstrap {
Bootstrap::Join => Vec::new(),
Bootstrap::NewCluster { .. } | Bootstrap::Recover => config.peers.clone(),
};
let seed_nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_or(0, |d| u64::try_from(d.as_nanos()).unwrap_or(0));
let seed = seed_nanos.wrapping_mul(0x9E37_79B9_7F4A_7C15) ^ config.node_id.get();
let env: Box<dyn Env> = Box::new(RandomizedEnv::new(
seed,
config.election_timeout_min_ticks,
config.election_timeout_max_ticks,
));
let engine_cfg = EngineConfig::new(
config.snapshot_chunk_size_bytes,
config.snapshot_hint_threshold_entries,
)
.with_pre_vote(config.pre_vote)
.with_max_log_entries(config.max_log_entries)
.with_lease_duration_ticks(config.lease_duration_ticks);
let engine: Engine<Vec<u8>> = Engine::with_config(
config.node_id,
initial_peers.iter().copied(),
env,
config.heartbeat_interval_ticks,
engine_cfg,
);
let max_pending_proposals = config.max_pending_proposals;
let (apply_tx, apply_rx) = mpsc::channel::<ApplyRequest<S>>(config.max_pending_applies);
let apply: JoinHandle<()> = tokio::spawn(apply_loop(state_machine, apply_rx));
let (inputs_tx, inputs_rx) = mpsc::channel::<DriverInput<S>>(1024);
let tick_inputs = inputs_tx.clone();
let tick_interval = config.tick_interval;
let ticker: JoinHandle<()> = tokio::spawn(async move {
let mut interval = tokio::time::interval(tick_interval);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
interval.tick().await;
if tick_inputs.send(DriverInput::Tick).await.is_err() {
return;
}
}
});
let driver_state = Driver {
node_id: config.node_id,
engine,
apply_tx,
storage,
transport,
inputs: inputs_rx,
inputs_tx: inputs_tx.clone(),
pending_proposals: HashMap::new(),
pending_config_changes: HashMap::new(),
max_pending_proposals,
pending_reads: HashMap::new(),
next_read_id: 0,
batch_buffer: Vec::new(),
batch_delay_remaining: 0,
max_batch_delay_ticks: config.max_batch_delay_ticks,
max_batch_entries: config.max_batch_entries,
last_applied: LogIndex::ZERO,
snapshot_in_flight: false,
};
let driver: JoinHandle<()> = tokio::spawn(driver_loop(driver_state, recovered));
Ok(Self {
inputs: inputs_tx,
background: Arc::new(Mutex::new(Some(BackgroundTasks {
ticker,
driver,
apply,
}))),
})
}
pub async fn propose(&self, command: S::Command) -> Result<S::Response, ProposeError> {
let (tx, rx) = oneshot::channel();
if self
.inputs
.send(DriverInput::Propose { command, reply: tx })
.await
.is_err()
{
return Err(ProposeError::Shutdown);
}
match rx.await {
Ok(r) => r,
Err(_) => Err(ProposeError::DriverDead),
}
}
pub async fn add_peer(&self, peer: NodeId) -> Result<(), ProposeError> {
self.config_change(ConfigChange::AddPeer(peer)).await
}
pub async fn remove_peer(&self, peer: NodeId) -> Result<(), ProposeError> {
self.config_change(ConfigChange::RemovePeer(peer)).await
}
pub async fn read_linearizable<R, F>(&self, reader: F) -> Result<R, ReadError>
where
R: Send + 'static,
F: FnOnce(&S) -> R + Send + 'static,
{
let (ok_tx, ok_rx) = oneshot::channel::<R>();
let (err_tx, err_rx) = oneshot::channel::<ReadError>();
let boxed_reader: Box<dyn FnOnce(&S) + Send> = Box::new(move |sm: &S| {
let value = reader(sm);
let _ = ok_tx.send(value);
});
let boxed_on_failure: Box<dyn FnOnce(ReadError) + Send> = Box::new(move |e: ReadError| {
let _ = err_tx.send(e);
});
if self
.inputs
.send(DriverInput::Read {
reader: boxed_reader,
on_failure: boxed_on_failure,
})
.await
.is_err()
{
return Err(ReadError::Shutdown);
}
tokio::select! {
biased;
Ok(err) = err_rx => Err(err),
Ok(val) = ok_rx => Ok(val),
else => Err(ReadError::DriverDead),
}
}
pub async fn transfer_leadership_to(
&self,
peer: NodeId,
) -> Result<(), TransferLeadershipError> {
let (tx, rx) = oneshot::channel();
if self
.inputs
.send(DriverInput::TransferLeadership {
target: peer,
reply: tx,
})
.await
.is_err()
{
return Err(TransferLeadershipError::Shutdown);
}
match rx.await {
Ok(r) => r,
Err(_) => Err(TransferLeadershipError::DriverDead),
}
}
async fn config_change(&self, change: ConfigChange) -> Result<(), ProposeError> {
let (tx, rx) = oneshot::channel();
if self
.inputs
.send(DriverInput::ConfigChange { change, reply: tx })
.await
.is_err()
{
return Err(ProposeError::Shutdown);
}
match rx.await {
Ok(r) => r,
Err(_) => Err(ProposeError::DriverDead),
}
}
pub async fn status(&self) -> Result<NodeStatus, ProposeError> {
let (tx, rx) = oneshot::channel();
if self
.inputs
.send(DriverInput::Status { reply: tx })
.await
.is_err()
{
return Err(ProposeError::Shutdown);
}
match rx.await {
Ok(s) => Ok(s),
Err(_) => Err(ProposeError::DriverDead),
}
}
pub async fn metrics(&self) -> Result<yggr_core::engine::metrics::EngineMetrics, ProposeError> {
let (tx, rx) = oneshot::channel();
if self
.inputs
.send(DriverInput::Metrics { reply: tx })
.await
.is_err()
{
return Err(ProposeError::Shutdown);
}
match rx.await {
Ok(m) => Ok(m),
Err(_) => Err(ProposeError::DriverDead),
}
}
pub async fn shutdown(self) -> Result<(), ProposeError> {
let (tx, rx) = oneshot::channel();
let shutdown_requested = self
.inputs
.send(DriverInput::Shutdown { reply: tx })
.await
.is_ok();
let background = self
.background
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.take();
if shutdown_requested {
match rx.await {
Ok(()) => {}
Err(_) => return Err(ProposeError::DriverDead),
}
}
let Some(background) = background else {
return Ok(());
};
let BackgroundTasks {
ticker,
driver,
apply,
} = background;
ticker.abort();
match ticker.await {
Ok(()) => {}
Err(e) if e.is_cancelled() => {}
Err(_) => return Err(ProposeError::DriverDead),
}
match driver.await {
Ok(()) => {}
Err(_) => return Err(ProposeError::DriverDead),
}
match apply.await {
Ok(()) => Ok(()),
Err(_) => Err(ProposeError::DriverDead),
}
}
}
struct Driver<S: StateMachine, St, Tr> {
node_id: NodeId,
engine: Engine<Vec<u8>>,
apply_tx: mpsc::Sender<ApplyRequest<S>>,
storage: St,
transport: Tr,
inputs: mpsc::Receiver<DriverInput<S>>,
inputs_tx: mpsc::Sender<DriverInput<S>>,
pending_proposals: HashMap<LogIndex, oneshot::Sender<Result<S::Response, ProposeError>>>,
pending_config_changes: HashMap<LogIndex, oneshot::Sender<Result<(), ProposeError>>>,
max_pending_proposals: usize,
pending_reads: HashMap<u64, PendingRead<S>>,
next_read_id: u64,
batch_buffer: Vec<BatchEntry<S>>,
batch_delay_remaining: u64,
max_batch_delay_ticks: u64,
max_batch_entries: usize,
last_applied: LogIndex,
snapshot_in_flight: bool,
}
struct PendingRead<S: StateMachine> {
reader: Box<dyn FnOnce(&S) + Send>,
on_failure: Box<dyn FnOnce(ReadError) + Send>,
}
type BatchEntry<S> = (
Vec<u8>,
oneshot::Sender<Result<<S as StateMachine>::Response, ProposeError>>,
);
enum ApplyRequest<S: StateMachine> {
Command {
command: S::Command,
reply: Option<oneshot::Sender<Result<S::Response, ProposeError>>>,
},
Read { reader: Box<dyn FnOnce(&S) + Send> },
Restore { bytes: Vec<u8> },
TakeSnapshot {
reply: oneshot::Sender<Result<Vec<u8>, crate::state_machine::SnapshotError>>,
},
}
async fn apply_loop<S: StateMachine>(
mut state_machine: S,
mut rx: mpsc::Receiver<ApplyRequest<S>>,
) {
while let Some(req) = rx.recv().await {
match req {
ApplyRequest::Command { command, reply } => {
let response = state_machine.apply(command);
if let Some(reply) = reply {
let _ = reply.send(Ok(response));
}
}
ApplyRequest::Read { reader } => {
reader(&state_machine);
}
ApplyRequest::Restore { bytes } => {
state_machine.restore(bytes);
}
ApplyRequest::TakeSnapshot { reply } => {
let result = state_machine.snapshot();
let _ = reply.send(result);
}
}
}
}
enum DriverInput<S: StateMachine> {
Tick,
Propose {
command: S::Command,
reply: oneshot::Sender<Result<S::Response, ProposeError>>,
},
ConfigChange {
change: ConfigChange,
reply: oneshot::Sender<Result<(), ProposeError>>,
},
TransferLeadership {
target: NodeId,
reply: oneshot::Sender<Result<(), TransferLeadershipError>>,
},
Read {
reader: Box<dyn FnOnce(&S) + Send>,
on_failure: Box<dyn FnOnce(ReadError) + Send>,
},
Status {
reply: oneshot::Sender<NodeStatus>,
},
Metrics {
reply: oneshot::Sender<yggr_core::engine::metrics::EngineMetrics>,
},
Shutdown {
reply: oneshot::Sender<()>,
},
SnapshotReady {
last_included_index: LogIndex,
bytes: Option<Vec<u8>>,
},
}
async fn driver_loop<S, St, Tr>(mut d: Driver<S, St, Tr>, recovered: RecoveredState<Vec<u8>>)
where
S: StateMachine,
St: Storage<Vec<u8>>,
Tr: Transport<Vec<u8>>,
{
hydrate_engine(&mut d, recovered).await;
loop {
let result: Result<(), Fatal> = tokio::select! {
input = d.inputs.recv() => {
let Some(input) = input else { return };
match input {
DriverInput::Tick => handle_tick(&mut d).await,
DriverInput::Propose { command, reply } => {
handle_propose(&mut d, command, reply).await
}
DriverInput::ConfigChange { change, reply } => {
handle_config_change(&mut d, change, reply).await
}
DriverInput::Read { reader, on_failure } => {
handle_read(&mut d, reader, on_failure).await
}
DriverInput::TransferLeadership { target, reply } => {
handle_transfer_leadership(&mut d, target, reply).await
}
DriverInput::Status { reply } => {
let _ = reply.send(build_status(&d));
Ok(())
}
DriverInput::Metrics { reply } => {
let _ = reply.send(d.engine.metrics());
Ok(())
}
DriverInput::SnapshotReady { last_included_index, bytes } => {
handle_snapshot_ready(&mut d, last_included_index, bytes).await
}
DriverInput::Shutdown { reply } => {
debug!(target = "yggr::node", "shutdown requested");
let _ = flush_batch(&mut d).await;
d.transport.shutdown().await;
let _ = reply.send(());
return;
}
}
}
incoming = d.transport.recv() => {
let Some(incoming) = incoming else {
warn!(target = "yggr::node", "transport recv returned None; shutting down");
return;
};
step_and_dispatch(&mut d, Event::Incoming(incoming)).await
}
};
if let Err(Fatal { reason }) = result {
fail_all_pending(&mut d, reason);
return;
}
}
}
async fn hydrate_engine<S, St, Tr>(d: &mut Driver<S, St, Tr>, recovered: RecoveredState<Vec<u8>>)
where
S: StateMachine,
St: Storage<Vec<u8>>,
Tr: Transport<Vec<u8>>,
{
use yggr_core::{RecoveredHardState, RecoveredSnapshot};
let RecoveredState {
hard_state,
snapshot,
log,
} = recovered;
if hard_state.is_none() && snapshot.is_none() && log.is_empty() {
return;
}
let snapshot_bytes_for_sm = snapshot.as_ref().map(|s| s.bytes.clone());
let (current_term, voted_for) =
hard_state.map_or((Term::ZERO, None), |hs| (hs.current_term, hs.voted_for));
let recovered = RecoveredHardState {
current_term,
voted_for,
snapshot: snapshot.map(|s| RecoveredSnapshot {
last_included_index: s.last_included_index,
last_included_term: s.last_included_term,
peers: s.peers,
bytes: s.bytes,
}),
post_snapshot_log: log,
};
d.engine.recover_from(recovered);
if let Some(bytes) = snapshot_bytes_for_sm {
let _ = d.apply_tx.send(ApplyRequest::Restore { bytes }).await;
}
d.last_applied = d.engine.commit_index();
}
async fn handle_propose<S, St, Tr>(
d: &mut Driver<S, St, Tr>,
command: S::Command,
reply: oneshot::Sender<Result<S::Response, ProposeError>>,
) -> Result<(), Fatal>
where
S: StateMachine,
St: Storage<Vec<u8>>,
Tr: Transport<Vec<u8>>,
{
if d.pending_proposals.len() + d.pending_config_changes.len() >= d.max_pending_proposals {
let _ = reply.send(Err(ProposeError::Busy));
return Ok(());
}
let bytes = S::encode_command(&command);
if d.max_batch_delay_ticks > 0 && matches!(d.engine.role(), RoleState::Leader(_)) {
d.batch_buffer.push((bytes, reply));
if d.batch_delay_remaining == 0 {
d.batch_delay_remaining = d.max_batch_delay_ticks;
}
if d.batch_buffer.len() >= d.max_batch_entries {
return flush_batch(d).await;
}
return Ok(());
}
let last_before = d
.engine
.log()
.last_log_id()
.map_or(LogIndex::ZERO, |l| l.index);
let actions = d.engine.step(Event::ClientProposal(bytes));
let last_after = d
.engine
.log()
.last_log_id()
.map_or(LogIndex::ZERO, |l| l.index);
if last_after > last_before {
d.pending_proposals.insert(last_after, reply);
} else {
let leader_hint = actions.iter().find_map(|a| match a {
Action::Redirect { leader_hint } => Some(*leader_hint),
_ => None,
});
let err = leader_hint.map_or(ProposeError::NoLeader, |h| ProposeError::NotLeader {
leader_hint: h,
});
let _ = reply.send(Err(err));
}
dispatch_actions(d, actions).await
}
async fn flush_batch<S, St, Tr>(d: &mut Driver<S, St, Tr>) -> Result<(), Fatal>
where
S: StateMachine,
St: Storage<Vec<u8>>,
Tr: Transport<Vec<u8>>,
{
if d.batch_buffer.is_empty() {
d.batch_delay_remaining = 0;
return Ok(());
}
let drained: Vec<BatchEntry<S>> = std::mem::take(&mut d.batch_buffer);
d.batch_delay_remaining = 0;
let (commands, replies): (Vec<Vec<u8>>, Vec<_>) = drained.into_iter().unzip();
let last_before = d
.engine
.log()
.last_log_id()
.map_or(LogIndex::ZERO, |l| l.index);
let actions = d.engine.step(Event::ClientProposalBatch(commands));
let last_after = d
.engine
.log()
.last_log_id()
.map_or(LogIndex::ZERO, |l| l.index);
let appended = last_after.get().saturating_sub(last_before.get());
if appended == u64::try_from(replies.len()).unwrap_or(0) && appended > 0 {
for (offset, reply) in replies.into_iter().enumerate() {
let idx = LogIndex::new(
last_before
.get()
.saturating_add(1)
.saturating_add(u64::try_from(offset).unwrap_or(0)),
);
d.pending_proposals.insert(idx, reply);
}
} else {
let leader_hint = actions.iter().find_map(|a| match a {
Action::Redirect { leader_hint } => Some(*leader_hint),
_ => None,
});
for reply in replies {
let err = leader_hint.map_or(ProposeError::NoLeader, |h| ProposeError::NotLeader {
leader_hint: h,
});
let _ = reply.send(Err(err));
}
}
dispatch_actions(d, actions).await
}
async fn handle_config_change<S, St, Tr>(
d: &mut Driver<S, St, Tr>,
change: ConfigChange,
reply: oneshot::Sender<Result<(), ProposeError>>,
) -> Result<(), Fatal>
where
S: StateMachine,
St: Storage<Vec<u8>>,
Tr: Transport<Vec<u8>>,
{
if d.pending_proposals.len() + d.pending_config_changes.len() >= d.max_pending_proposals {
let _ = reply.send(Err(ProposeError::Busy));
return Ok(());
}
let last_before = d
.engine
.log()
.last_log_id()
.map_or(LogIndex::ZERO, |l| l.index);
let actions = d.engine.step(Event::ProposeConfigChange(change));
let last_after = d
.engine
.log()
.last_log_id()
.map_or(LogIndex::ZERO, |l| l.index);
if last_after > last_before {
d.pending_config_changes.insert(last_after, reply);
} else {
let leader_hint = actions.iter().find_map(|a| match a {
Action::Redirect { leader_hint } => Some(*leader_hint),
_ => None,
});
let err = leader_hint.map_or(ProposeError::NoLeader, |h| ProposeError::NotLeader {
leader_hint: h,
});
let _ = reply.send(Err(err));
}
dispatch_actions(d, actions).await
}
async fn handle_transfer_leadership<S, St, Tr>(
d: &mut Driver<S, St, Tr>,
target: NodeId,
reply: oneshot::Sender<Result<(), TransferLeadershipError>>,
) -> Result<(), Fatal>
where
S: StateMachine,
St: Storage<Vec<u8>>,
Tr: Transport<Vec<u8>>,
{
match d.engine.role() {
RoleState::Leader(_) => {}
RoleState::Follower(f) => {
let err = f
.leader_id()
.map_or(TransferLeadershipError::NoLeader, |leader_hint| {
TransferLeadershipError::NotLeader { leader_hint }
});
let _ = reply.send(Err(err));
return Ok(());
}
RoleState::PreCandidate(_) | RoleState::Candidate(_) => {
let _ = reply.send(Err(TransferLeadershipError::NoLeader));
return Ok(());
}
}
if target == d.node_id || !d.engine.peers().contains(&target) {
let _ = reply.send(Err(TransferLeadershipError::InvalidTarget { target }));
return Ok(());
}
let actions = d.engine.step(Event::TransferLeadership { target });
match dispatch_actions(d, actions).await {
Ok(()) => {
let _ = reply.send(Ok(()));
Ok(())
}
Err(fatal) => {
let _ = reply.send(Err(TransferLeadershipError::Fatal {
reason: fatal.reason,
}));
Err(fatal)
}
}
}
async fn handle_snapshot_ready<S, St, Tr>(
d: &mut Driver<S, St, Tr>,
last_included_index: LogIndex,
bytes: Option<Vec<u8>>,
) -> Result<(), Fatal>
where
S: StateMachine,
St: Storage<Vec<u8>>,
Tr: Transport<Vec<u8>>,
{
d.snapshot_in_flight = false;
let Some(bytes) = bytes else {
warn!(
target = "yggr::node",
"snapshot forwarder saw apply task teardown"
);
return Ok(());
};
if bytes.is_empty() {
return Ok(());
}
let actions = d.engine.step(Event::SnapshotTaken {
last_included_index,
bytes,
});
dispatch_actions(d, actions).await
}
async fn handle_read<S, St, Tr>(
d: &mut Driver<S, St, Tr>,
reader: Box<dyn FnOnce(&S) + Send>,
on_failure: Box<dyn FnOnce(ReadError) + Send>,
) -> Result<(), Fatal>
where
S: StateMachine,
St: Storage<Vec<u8>>,
Tr: Transport<Vec<u8>>,
{
let id = d.next_read_id;
d.next_read_id = d.next_read_id.wrapping_add(1);
d.pending_reads
.insert(id, PendingRead { reader, on_failure });
let actions = d.engine.step(Event::ProposeRead { id });
match dispatch_actions(d, actions).await {
Ok(()) => Ok(()),
Err(fatal) => {
if let Some(pending) = d.pending_reads.remove(&id) {
(pending.on_failure)(ReadError::Fatal {
reason: fatal.reason,
});
}
Err(fatal)
}
}
}
#[derive(Debug, Clone, Copy)]
struct Fatal {
reason: &'static str,
}
async fn step_and_dispatch<S, St, Tr>(
d: &mut Driver<S, St, Tr>,
event: Event<Vec<u8>>,
) -> Result<(), Fatal>
where
S: StateMachine,
St: Storage<Vec<u8>>,
Tr: Transport<Vec<u8>>,
{
let actions = d.engine.step(event);
dispatch_actions(d, actions).await
}
async fn handle_tick<S, St, Tr>(d: &mut Driver<S, St, Tr>) -> Result<(), Fatal>
where
S: StateMachine,
St: Storage<Vec<u8>>,
Tr: Transport<Vec<u8>>,
{
if d.batch_delay_remaining > 0 {
d.batch_delay_remaining -= 1;
if d.batch_delay_remaining == 0 && !d.batch_buffer.is_empty() {
flush_batch(d).await?;
}
}
step_and_dispatch(d, Event::Tick).await
}
#[allow(clippy::too_many_lines)]
async fn dispatch_actions<S, St, Tr>(
d: &mut Driver<S, St, Tr>,
actions: Vec<Action<Vec<u8>>>,
) -> Result<(), Fatal>
where
S: StateMachine,
St: Storage<Vec<u8>>,
Tr: Transport<Vec<u8>>,
{
let mut pending: VecDeque<Action<Vec<u8>>> = actions.into();
while let Some(action) = pending.pop_front() {
match action {
Action::PersistHardState {
current_term,
voted_for,
} => {
if let Err(e) = d
.storage
.persist_hard_state(StoredHardState {
current_term,
voted_for,
})
.await
{
error!(target = "yggr::node", error = %e, "fatal: persist_hard_state failed");
return Err(Fatal {
reason: "persist_hard_state failed",
});
}
}
Action::PersistLogEntries(entries) => {
if let Err(e) = d.storage.append_log(entries).await {
error!(target = "yggr::node", error = %e, "fatal: append_log failed");
return Err(Fatal {
reason: "append_log failed",
});
}
}
Action::PersistSnapshot {
last_included_index,
last_included_term,
peers,
bytes,
} => {
if let Err(e) = d
.storage
.persist_snapshot(StoredSnapshot {
last_included_index,
last_included_term,
peers,
bytes,
})
.await
{
error!(target = "yggr::node", error = %e, "fatal: persist_snapshot failed");
return Err(Fatal {
reason: "persist_snapshot failed",
});
}
}
Action::Send { to, message } => {
if let Err(e) = d.transport.send(to, message).await {
debug!(target = "yggr::node", peer = %to, error = %e, "send failed");
}
}
Action::Apply(entries) => {
apply_entries(d, entries).await?;
}
Action::ApplySnapshot { bytes } => {
if d.apply_tx
.send(ApplyRequest::Restore { bytes })
.await
.is_err()
{
return Err(Fatal {
reason: "apply task died",
});
}
}
Action::SnapshotHint {
last_included_index,
} => {
if d.snapshot_in_flight {
continue;
}
let (reply_tx, reply_rx) = oneshot::channel();
if d.apply_tx
.send(ApplyRequest::TakeSnapshot { reply: reply_tx })
.await
.is_err()
{
return Err(Fatal {
reason: "apply task died",
});
}
d.snapshot_in_flight = true;
let inputs_tx = d.inputs_tx.clone();
tokio::spawn(async move {
let bytes = match reply_rx.await {
Ok(Ok(bytes)) => Some(bytes),
Ok(Err(err)) => {
tracing::warn!(
target = "yggr::node",
error = %err,
"state machine declined to produce a snapshot; \
dropping this hint, engine will re-hint later",
);
None
}
Err(_) => None,
};
let _ = inputs_tx
.send(DriverInput::SnapshotReady {
last_included_index,
bytes,
})
.await;
});
}
Action::Redirect { .. } => {
}
Action::ReadReady { id } => {
if let Some(pending) = d.pending_reads.remove(&id)
&& d.apply_tx
.send(ApplyRequest::Read {
reader: pending.reader,
})
.await
.is_err()
{
return Err(Fatal {
reason: "apply task died",
});
}
}
Action::ReadFailed { id, reason } => {
if let Some(pending) = d.pending_reads.remove(&id) {
let err = match reason {
yggr_core::ReadFailure::NotLeader { leader_hint } => {
ReadError::NotLeader { leader_hint }
}
yggr_core::ReadFailure::SteppedDown => ReadError::SteppedDown,
_ => ReadError::NotReady,
};
(pending.on_failure)(err);
}
}
}
}
if d.engine.commit_index() > d.last_applied {
d.last_applied = d.engine.commit_index();
}
Ok(())
}
fn build_status<S, St, Tr>(d: &Driver<S, St, Tr>) -> NodeStatus
where
S: StateMachine,
St: Storage<Vec<u8>>,
Tr: Transport<Vec<u8>>,
{
let (role, leader_hint) = match d.engine.role() {
RoleState::Follower(f) => (Role::Follower, f.leader_id()),
RoleState::PreCandidate(_) => (Role::PreCandidate, None),
RoleState::Candidate(_) => (Role::Candidate, None),
RoleState::Leader(_) => (Role::Leader, None),
};
NodeStatus {
node_id: d.node_id,
role,
current_term: d.engine.current_term(),
commit_index: d.engine.commit_index(),
last_applied: d.last_applied,
leader_hint,
peers: d.engine.peers().iter().copied().collect(),
}
}
async fn apply_entries<S, St, Tr>(
d: &mut Driver<S, St, Tr>,
entries: Vec<LogEntry<Vec<u8>>>,
) -> Result<(), Fatal>
where
S: StateMachine,
St: Storage<Vec<u8>>,
Tr: Transport<Vec<u8>>,
{
for entry in entries {
if let LogPayload::Command(bytes) = entry.payload {
let cmd = match S::decode_command(&bytes) {
Ok(c) => c,
Err(e) => {
error!(target = "yggr::node", error = %e, index = ?entry.id.index,
"fatal: failed to decode committed command");
return Err(Fatal {
reason: "failed to decode committed command",
});
}
};
let reply = d.pending_proposals.remove(&entry.id.index);
if d.apply_tx
.send(ApplyRequest::Command {
command: cmd,
reply,
})
.await
.is_err()
{
return Err(Fatal {
reason: "apply task died",
});
}
}
}
let commit = d.engine.commit_index();
let to_fire: Vec<LogIndex> = d
.pending_config_changes
.keys()
.copied()
.filter(|&idx| idx <= commit)
.collect();
for idx in to_fire {
if let Some(reply) = d.pending_config_changes.remove(&idx) {
let _ = reply.send(Ok(()));
}
}
if !matches!(d.engine.role(), RoleState::Leader(_)) {
let leader_hint = match d.engine.role() {
RoleState::Follower(f) => f.leader_id(),
_ => None,
};
let stale: Vec<LogIndex> = d.pending_proposals.keys().copied().collect();
for idx in stale {
if let Some(reply) = d.pending_proposals.remove(&idx) {
let err = leader_hint.map_or(ProposeError::NoLeader, |h| ProposeError::NotLeader {
leader_hint: h,
});
let _ = reply.send(Err(err));
}
}
let stale_cc: Vec<LogIndex> = d.pending_config_changes.keys().copied().collect();
for idx in stale_cc {
if let Some(reply) = d.pending_config_changes.remove(&idx) {
let err = leader_hint.map_or(ProposeError::NoLeader, |h| ProposeError::NotLeader {
leader_hint: h,
});
let _ = reply.send(Err(err));
}
}
}
let _ = d.node_id;
Ok(())
}
fn fail_all_pending<S, St, Tr>(d: &mut Driver<S, St, Tr>, reason: &'static str)
where
S: StateMachine,
St: Storage<Vec<u8>>,
Tr: Transport<Vec<u8>>,
{
let indices: Vec<LogIndex> = d.pending_proposals.keys().copied().collect();
for idx in indices {
if let Some(reply) = d.pending_proposals.remove(&idx) {
let _ = reply.send(Err(ProposeError::Fatal { reason }));
}
}
let indices: Vec<LogIndex> = d.pending_config_changes.keys().copied().collect();
for idx in indices {
if let Some(reply) = d.pending_config_changes.remove(&idx) {
let _ = reply.send(Err(ProposeError::Fatal { reason }));
}
}
let read_ids: Vec<u64> = d.pending_reads.keys().copied().collect();
for id in read_ids {
if let Some(pending) = d.pending_reads.remove(&id) {
(pending.on_failure)(ReadError::Fatal { reason });
}
}
}