use std::io::SeekFrom;
use tokio::io::{AsyncSeekExt, AsyncWriteExt};
use crate::{AppData, AppDataResponse, RaftNetwork, RaftStorage};
use crate::core::{State, RaftCore, SnapshotState, UpdateCurrentLeader};
use crate::error::RaftResult;
use crate::raft::{InstallSnapshotRequest, InstallSnapshotResponse};
impl<D: AppData, R: AppDataResponse, N: RaftNetwork<D>, S: RaftStorage<D, R>> RaftCore<D, R, N, S> {
#[tracing::instrument(level="trace", skip(self, req))]
pub(super) async fn handle_install_snapshot_request(&mut self, req: InstallSnapshotRequest) -> RaftResult<InstallSnapshotResponse> {
if &req.term < &self.current_term {
return Ok(InstallSnapshotResponse{term: self.current_term});
}
self.update_next_election_timeout();
let mut report_metrics = false;
if &self.current_term != &req.term {
self.update_current_term(req.term, None);
self.save_hard_state().await?;
report_metrics = true;
}
if self.current_leader.as_ref() != Some(&req.leader_id) {
self.update_current_leader(UpdateCurrentLeader::OtherNode(req.leader_id));
report_metrics = true;
}
if !self.target_state.is_follower() && !self.target_state.is_non_voter() {
self.set_target_state(State::Follower); }
if report_metrics {
self.report_metrics();
}
match self.snapshot_state.take() {
None => Ok(self.begin_installing_snapshot(req).await?),
Some(SnapshotState::Snapshotting{handle, ..}) => {
handle.abort(); Ok(self.begin_installing_snapshot(req).await?)
}
Some(SnapshotState::Streaming{snapshot, id, offset}) => Ok(self.continue_installing_snapshot(req, offset, id, snapshot).await?),
}
}
#[tracing::instrument(level="trace", skip(self, req))]
async fn begin_installing_snapshot(&mut self, req: InstallSnapshotRequest) -> RaftResult<InstallSnapshotResponse> {
let (id, mut snapshot) = self.storage.create_snapshot().await
.map_err(|err| self.map_fatal_storage_error(err))?;
snapshot.as_mut().write_all(&req.data).await?;
if req.done {
self.finalize_snapshot_installation(req, id, snapshot).await?;
return Ok(InstallSnapshotResponse{term: self.current_term});
}
self.snapshot_state = Some(SnapshotState::Streaming{
offset: req.data.len() as u64,
id, snapshot,
});
return Ok(InstallSnapshotResponse{term: self.current_term});
}
#[tracing::instrument(level="trace", skip(self, req, offset, snapshot))]
async fn continue_installing_snapshot(
&mut self, req: InstallSnapshotRequest, mut offset: u64, id: String, mut snapshot: Box<S::Snapshot>,
) -> RaftResult<InstallSnapshotResponse> {
if &req.offset != &offset {
if let Err(err) = snapshot.as_mut().seek(SeekFrom::Start(req.offset)).await {
self.snapshot_state = Some(SnapshotState::Streaming{offset, id, snapshot});
return Err(err.into());
}
offset = req.offset;
}
if let Err(err) = snapshot.as_mut().write_all(&req.data).await {
self.snapshot_state = Some(SnapshotState::Streaming{offset, id, snapshot});
return Err(err.into());
}
offset += req.data.len() as u64;
if req.done {
self.finalize_snapshot_installation(req, id, snapshot).await?;
} else {
self.snapshot_state = Some(SnapshotState::Streaming{offset, id, snapshot});
}
return Ok(InstallSnapshotResponse{term: self.current_term});
}
#[tracing::instrument(level="trace", skip(self, req, snapshot))]
async fn finalize_snapshot_installation(&mut self, req: InstallSnapshotRequest, id: String, mut snapshot: Box<S::Snapshot>) -> RaftResult<()> {
snapshot.as_mut().shutdown().await.map_err(|err| self.map_fatal_storage_error(err.into()))?;
let delete_through = if &self.last_log_index > &req.last_included_index {
Some(req.last_included_index)
} else {
None
};
self.storage.finalize_snapshot_installation(req.last_included_index, req.last_included_term, delete_through, id, snapshot).await
.map_err(|err| self.map_fatal_storage_error(err))?;
let membership = self.storage.get_membership_config().await.map_err(|err| self.map_fatal_storage_error(err))?;
self.update_membership(membership)?;
self.last_log_index = req.last_included_index;
self.last_log_term = req.last_included_term;
self.last_applied = req.last_included_index;
self.snapshot_index = req.last_included_index;
Ok(())
}
}