use std::time::Duration;
use display_more::DisplayOptionExt;
use futures_util::FutureExt;
use crate::LogIdOptionExt;
use crate::RaftLogReader;
use crate::RaftTypeConfig;
use crate::StorageError;
use crate::async_runtime::MpscSender;
use crate::async_runtime::watch::WatchReceiver;
use crate::core::notification::Notification;
use crate::entry::RaftEntry;
use crate::entry::raft_entry_ext::RaftEntryExt;
use crate::errors::StorageIOResult;
use crate::log_id_range::LogIdRange;
use crate::progress::inflight_id::InflightId;
use crate::raft::AppendEntriesRequest;
use crate::raft_state::IOId;
use crate::replication::backoff_consumer::BackoffConsumer;
use crate::replication::event_watcher::EventWatcher;
use crate::replication::payload::Payload;
use crate::replication::replication_context::ReplicationContext;
use crate::storage::RaftLogStorage;
use crate::type_config::TypeConfigExt;
use crate::type_config::alias::EntryOf;
use crate::type_config::alias::LogIdOf;
use crate::vote::RaftVote;
pub(crate) struct StreamState<C, LS>
where
C: RaftTypeConfig,
LS: RaftLogStorage<C>,
{
pub(crate) replication_context: ReplicationContext<C>,
pub(crate) event_watcher: EventWatcher<C>,
pub(crate) log_reader: LS::LogReader,
pub(crate) payload: Option<Payload<C>>,
pub(crate) inflight_id: Option<InflightId>,
pub(crate) leader_committed: Option<LogIdOf<C>>,
pub(crate) backoff_consumer: BackoffConsumer,
}
impl<C, LS> StreamState<C, LS>
where
C: RaftTypeConfig,
LS: RaftLogStorage<C>,
{
pub(crate) async fn next_request(&mut self) -> Option<AppendEntriesRequest<C>> {
let log_id_range = self.get_log_id_range().await?;
tracing::debug!("{}: log_id_range: {}", func_name!(), log_id_range);
let res = self.read_log_entries(log_id_range).await;
let (entries, sending_range) = match res {
Ok(x) => x,
Err(sto_err) => {
tracing::error!("{} replication to target={}", sto_err, self.replication_context.target);
self.replication_context.tx_notify.send(Notification::StorageError { error: sto_err }).await.ok();
return None;
}
};
let belonging_leader = self.replication_context.leader_vote.leader_id().clone();
let accepted_io: IOId<C> = self.event_watcher.io_accepted_rx.borrow_watched().clone();
let current_leader = accepted_io.leader_id().clone();
if current_leader != belonging_leader {
tracing::info!(
"Leader changed from {} to {}, quit replication",
belonging_leader,
current_leader
);
return None;
}
self.update_log_id_range(sending_range.last);
let payload = AppendEntriesRequest {
vote: self.replication_context.leader_vote.clone().into_vote(),
prev_log_id: sending_range.prev.clone(),
leader_commit: self.event_watcher.committed_rx.borrow_watched().clone(),
entries,
};
let entry_count = payload.entries.len() as u64;
self.replication_context.replicate_batch.record(entry_count);
tracing::debug!("next_request: AppendEntries: {}", payload);
self.backoff_if_enabled().await;
Some(payload)
}
async fn get_log_id_range(&mut self) -> Option<LogIdRange<C>> {
let payload = self.payload.as_ref()?;
tracing::debug!("pipeline stream payload: {}", payload);
let prev = match payload {
Payload::LogIdRange { log_id_range } => return Some(log_id_range.clone()),
Payload::LogsSince { prev } => prev.clone(),
};
loop {
let current: IOId<C> = self.event_watcher.io_submitted_rx.borrow_watched().clone();
let last_log_id = current.last_log_id().cloned();
let committed: Option<LogIdOf<C>> = self.event_watcher.committed_rx.borrow_watched().clone();
tracing::debug!(
"building next entries range to replicate: current last_log_id: {}, current committed: {}",
last_log_id.display(),
committed.display()
);
if last_log_id > prev || committed > self.leader_committed {
self.leader_committed = committed;
return Some(LogIdRange::new(prev, last_log_id));
} else {
let data_change = self.event_watcher.replicate_rx.changed();
let io_change = self.event_watcher.io_submitted_rx.changed();
let committed_change = self.event_watcher.committed_rx.changed();
let cancel = self.replication_context.cancel_rx.changed();
futures_util::select! {
_data_changed = data_change.fuse() => {
let new_data = self.event_watcher.replicate_rx.borrow_watched().clone();
if Some(new_data.inflight_id) != self.inflight_id {
tracing::info!("current inflight_id: {} received payload with new inflight_id: {}, quit", self.inflight_id.display(), new_data.inflight_id);
return None;
}
}
_io_changed = io_change.fuse() => {
tracing::debug!("io_submitted_rx changed");
}
_committed_change = committed_change.fuse() => {
tracing::debug!("committed_rx changed");
self.leader_committed = self.event_watcher.committed_rx.borrow_watched().clone();
return Some(LogIdRange::new(prev, last_log_id));
}
cancel_res = cancel.fuse() => {
tracing::info!("Replication Stream is canceled, res: {:?}, when:(get_log_id_range:wait-for-changed)", cancel_res);
return None;
}
}
}
}
}
async fn backoff_if_enabled(&mut self) {
let Some(sleep_duration) = self.backoff_consumer.next_delay() else {
return;
};
let sleep = C::sleep(sleep_duration);
let cancel = self.replication_context.cancel_rx.changed();
tracing::debug!("backoff timeout: {:?}", sleep_duration);
futures_util::select! {
_ = sleep.fuse() => {
tracing::debug!("backoff timeout");
}
cancel_res = cancel.fuse() => {
tracing::info!("Replication Stream is canceled, res: {:?}, when:(backoff_if_enabled:wait-for-changed)", cancel_res);
}
}
}
fn update_log_id_range(&mut self, matching: Option<LogIdOf<C>>) {
let Some(payload) = self.payload.as_mut() else {
return;
};
payload.update_matching(matching);
if payload.len() == Some(0) {
self.payload = None;
}
}
async fn read_log_entries(
&mut self,
log_id_range: LogIdRange<C>,
) -> Result<(Vec<EntryOf<C>>, LogIdRange<C>), StorageError<C>> {
tracing::debug!("read_log_entries: log_id_range: {}", log_id_range);
let rng = &log_id_range;
let (start, end) = {
let start = rng.prev.next_index();
let end = rng.last.next_index();
(start, end)
};
if start == end {
let r = LogIdRange::new(rng.prev.clone(), rng.prev.clone());
Ok((vec![], r))
} else {
let max_entries = self.replication_context.config.max_payload_entries;
let end = std::cmp::min(end, start + max_entries);
let logs = self.log_reader.limited_get_log_entries(start, end).await.sto_read_logs()?;
if logs.is_empty() {
let sleep_duration = Duration::from_millis(10);
tracing::warn!(
"limited_get_log_entries({}, {}) returned empty; \
this violates the API contract but is handled gracefully as a heartbeat. \
Sleeping {:?} to avoid tight loop.",
start,
end,
sleep_duration
);
C::sleep(sleep_duration).await;
let r = LogIdRange::new(rng.prev.clone(), rng.prev.clone());
return Ok((vec![], r));
}
let first = logs.first().map(|ent| ent.ref_log_id()).unwrap();
let last = logs.last().map(|ent| ent.log_id()).unwrap();
debug_assert!(
logs.len() <= (end - start) as usize,
"expect logs ⊆ [{}..{}) but got {} entries, first: {}, last: {}",
start,
end,
logs.len(),
first,
last
);
let r = LogIdRange::new(rng.prev.clone(), Some(last));
Ok((logs, r))
}
}
}