mod replication_session_id;
mod response;
use std::fmt;
use std::io::SeekFrom;
use std::sync::Arc;
use anyerror::AnyError;
use futures::future::FutureExt;
pub(crate) use replication_session_id::ReplicationSessionId;
pub(crate) use response::Response;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncSeekExt;
use tokio::select;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use tokio::time::sleep;
use tokio::time::timeout;
use tokio::time::Duration;
use tokio::time::Instant;
use tracing_futures::Instrument;
use crate::config::Config;
use crate::core::notify::Notify;
use crate::display_ext::DisplayOption;
use crate::display_ext::DisplayOptionExt;
use crate::error::HigherVote;
use crate::error::RPCError;
use crate::error::ReplicationClosed;
use crate::error::ReplicationError;
use crate::error::Timeout;
use crate::log_id::LogIdOptionExt;
use crate::log_id_range::LogIdRange;
use crate::network::Backoff;
use crate::network::RPCOption;
use crate::network::RPCTypes;
use crate::network::RaftNetwork;
use crate::network::RaftNetworkFactory;
use crate::raft::AppendEntriesRequest;
use crate::raft::AppendEntriesResponse;
use crate::raft::InstallSnapshotRequest;
use crate::storage::RaftLogReader;
use crate::storage::RaftLogStorage;
use crate::storage::Snapshot;
use crate::utime::UTime;
use crate::ErrorSubject;
use crate::ErrorVerb;
use crate::LogId;
use crate::MessageSummary;
use crate::NodeId;
use crate::RaftTypeConfig;
use crate::StorageError;
use crate::StorageIOError;
use crate::ToStorageResult;
pub(crate) struct ReplicationHandle<C>
where C: RaftTypeConfig
{
    pub(crate) join_handle: JoinHandle<Result<(), ReplicationClosed>>,
    pub(crate) tx_repl: mpsc::UnboundedSender<Replicate<C>>,
}
pub(crate) struct ReplicationCore<C, N, LS>
where
    C: RaftTypeConfig,
    N: RaftNetworkFactory<C>,
    LS: RaftLogStorage<C>,
{
    target: C::NodeId,
    session_id: ReplicationSessionId<C::NodeId>,
    #[allow(clippy::type_complexity)]
    tx_raft_core: mpsc::UnboundedSender<Notify<C>>,
    rx_repl: mpsc::UnboundedReceiver<Replicate<C>>,
    network: N::Network,
    backoff: Option<Backoff>,
    log_reader: LS::LogReader,
    config: Arc<Config>,
    committed: Option<LogId<C::NodeId>>,
    matching: Option<LogId<C::NodeId>>,
    next_action: Option<Data<C>>,
}
impl<C, N, LS> ReplicationCore<C, N, LS>
where
    C: RaftTypeConfig,
    N: RaftNetworkFactory<C>,
    LS: RaftLogStorage<C>,
{
    #[tracing::instrument(level = "trace", skip_all,fields(target=display(target), session_id=display(session_id)))]
    #[allow(clippy::type_complexity)]
    #[allow(clippy::too_many_arguments)]
    pub(crate) fn spawn(
        target: C::NodeId,
        session_id: ReplicationSessionId<C::NodeId>,
        config: Arc<Config>,
        committed: Option<LogId<C::NodeId>>,
        matching: Option<LogId<C::NodeId>>,
        network: N::Network,
        log_reader: LS::LogReader,
        tx_raft_core: mpsc::UnboundedSender<Notify<C>>,
        span: tracing::Span,
    ) -> ReplicationHandle<C> {
        tracing::debug!(
            session_id = display(&session_id),
            target = display(&target),
            committed = display(committed.summary()),
            matching = debug(&matching),
            "spawn replication"
        );
        let (tx_repl, rx_repl) = mpsc::unbounded_channel();
        let this = Self {
            target,
            session_id,
            network,
            backoff: None,
            log_reader,
            config,
            committed,
            matching,
            tx_raft_core,
            rx_repl,
            next_action: None,
        };
        let join_handle = tokio::spawn(this.main().instrument(span));
        ReplicationHandle { join_handle, tx_repl }
    }
    #[tracing::instrument(level="debug", skip(self), fields(session=%self.session_id, target=display(self.target), cluster=%self.config.cluster_name))]
    async fn main(mut self) -> Result<(), ReplicationClosed> {
        loop {
            let action = self.next_action.take();
            let mut repl_id = None;
            let res = match action {
                None => Ok(None),
                Some(d) => {
                    tracing::debug!(replication_data = display(&d), "{} send replication RPC", func_name!());
                    repl_id = d.request_id;
                    match d.payload {
                        Payload::Logs(log_id_range) => self.send_log_entries(d.request_id, log_id_range).await,
                        Payload::Snapshot(snapshot_rx) => self.stream_snapshot(d.request_id, snapshot_rx).await,
                    }
                }
            };
            tracing::debug!(res = debug(&res), "replication action done");
            match res {
                Ok(next) => {
                    self.backoff = None;
                    if let Some(next) = next {
                        self.next_action = Some(next);
                    }
                }
                Err(err) => {
                    tracing::warn!(error=%err, "error replication to target={}", self.target);
                    match err {
                        ReplicationError::Closed(closed) => {
                            return Err(closed);
                        }
                        ReplicationError::HigherVote(h) => {
                            let _ = self.tx_raft_core.send(Notify::Network {
                                response: Response::HigherVote {
                                    target: self.target,
                                    higher: h.higher,
                                    vote: self.session_id.vote,
                                },
                            });
                            return Ok(());
                        }
                        ReplicationError::StorageError(error) => {
                            tracing::error!(error=%error, "error replication to target={}", self.target);
                            let _ = self.tx_raft_core.send(Notify::Network {
                                response: Response::StorageError { error },
                            });
                            return Ok(());
                        }
                        ReplicationError::RPCError(err) => {
                            tracing::error!(err = display(&err), "RPCError");
                            if let Some(request_id) = repl_id {
                                let _ = self.tx_raft_core.send(Notify::Network {
                                    response: Response::Progress {
                                        target: self.target,
                                        request_id,
                                        result: Err(err.to_string()),
                                        session_id: self.session_id,
                                    },
                                });
                            } else {
                                tracing::warn!(
                                    err = display(&err),
                                    "encountered RPCError but request_id is None, no response is sent"
                                );
                            }
                            if let RPCError::Unreachable(_unreachable) = err {
                                if self.backoff.is_none() {
                                    self.backoff = Some(self.network.backoff());
                                }
                            }
                        }
                    };
                }
            };
            if let Some(b) = &mut self.backoff {
                let duration = b.next().unwrap_or_else(|| {
                    tracing::warn!("backoff exhausted, using default");
                    Duration::from_millis(500)
                });
                self.backoff_drain_events(Instant::now() + duration).await?;
            }
            self.drain_events().await?;
        }
    }
    #[tracing::instrument(level = "debug", skip_all)]
    async fn send_log_entries(
        &mut self,
        request_id: Option<u64>,
        log_id_range: LogIdRange<C::NodeId>,
    ) -> Result<Option<Data<C>>, ReplicationError<C::NodeId, C::Node>> {
        tracing::debug!(
            request_id = display(request_id.display()),
            log_id_range = display(&log_id_range),
            "send_log_entries",
        );
        let start = log_id_range.prev_log_id.next_index();
        let end = log_id_range.last_log_id.next_index();
        let logs = if start == end {
            vec![]
        } else {
            let logs = self.log_reader.try_get_log_entries(start..end).await?;
            debug_assert_eq!(
                logs.len(),
                (end - start) as usize,
                "expect logs {}..{} but got only {} entries",
                start,
                end,
                logs.len()
            );
            logs
        };
        let leader_time = Instant::now();
        let payload = AppendEntriesRequest {
            vote: self.session_id.vote,
            prev_log_id: log_id_range.prev_log_id,
            leader_commit: self.committed,
            entries: logs,
        };
        tracing::debug!(
            payload=%payload.summary(),
            now = debug(leader_time),
            "start sending append_entries, timeout: {:?}",
            self.config.heartbeat_interval
        );
        let the_timeout = Duration::from_millis(self.config.heartbeat_interval);
        let option = RPCOption::new(the_timeout);
        let res = timeout(the_timeout, self.network.append_entries(payload, option)).await;
        tracing::debug!("append_entries res: {:?}", res);
        let append_res = res.map_err(|_e| {
            let to = Timeout {
                action: RPCTypes::AppendEntries,
                id: self.session_id.vote.leader_id().voted_for().unwrap(),
                target: self.target,
                timeout: the_timeout,
            };
            RPCError::Timeout(to)
        })?;
        let append_resp = append_res?;
        tracing::debug!(
            req = display(&log_id_range),
            resp = display(&append_resp),
            "append_entries resp"
        );
        match append_resp {
            AppendEntriesResponse::Success => {
                self.update_matching(request_id, leader_time, log_id_range.last_log_id);
                Ok(None)
            }
            AppendEntriesResponse::PartialSuccess(matching) => {
                debug_assert!(
                    matching <= log_id_range.last_log_id,
                    "matching ({}) should be <= last_log_id ({})",
                    matching.display(),
                    log_id_range.last_log_id.display()
                );
                debug_assert!(
                    matching.index() <= log_id_range.last_log_id.index(),
                    "matching.index ({}) should be <= last_log_id.index ({})",
                    matching.index().display(),
                    log_id_range.last_log_id.index().display()
                );
                debug_assert!(
                    matching >= log_id_range.prev_log_id,
                    "matching ({}) should be >= prev_log_id ({})",
                    matching.display(),
                    log_id_range.prev_log_id.display()
                );
                debug_assert!(
                    matching.index() >= log_id_range.prev_log_id.index(),
                    "matching.index ({}) should be >= prev_log_id.index ({})",
                    matching.index().display(),
                    log_id_range.prev_log_id.index().display()
                );
                self.update_matching(request_id, leader_time, matching);
                if matching < log_id_range.last_log_id {
                    Ok(Some(Data::new_logs(
                        request_id,
                        LogIdRange::new(matching, log_id_range.last_log_id),
                    )))
                } else {
                    Ok(None)
                }
            }
            AppendEntriesResponse::HigherVote(vote) => {
                debug_assert!(
                    vote > self.session_id.vote,
                    "higher vote({}) should be greater than leader's vote({})",
                    vote,
                    self.session_id.vote,
                );
                tracing::debug!(%vote, "append entries failed. converting to follower");
                Err(ReplicationError::HigherVote(HigherVote {
                    higher: vote,
                    mine: self.session_id.vote,
                }))
            }
            AppendEntriesResponse::Conflict => {
                let conflict = log_id_range.prev_log_id;
                debug_assert!(conflict.is_some(), "prev_log_id=None never conflict");
                let conflict = conflict.unwrap();
                self.update_conflicting(request_id, leader_time, conflict);
                Ok(None)
            }
        }
    }
    fn update_conflicting(&mut self, request_id: Option<u64>, leader_time: Instant, conflict: LogId<C::NodeId>) {
        tracing::debug!(
            target = display(self.target),
            request_id = display(request_id.display()),
            conflict = display(&conflict),
            "update_conflicting"
        );
        if let Some(request_id) = request_id {
            let _ = self.tx_raft_core.send({
                Notify::Network {
                    response: Response::Progress {
                        session_id: self.session_id,
                        request_id,
                        target: self.target,
                        result: Ok(UTime::new(leader_time, ReplicationResult::Conflict(conflict))),
                    },
                }
            });
        } else {
            tracing::info!(
                target = display(self.target),
                request_id = display(request_id.display()),
                conflict = display(&conflict),
                "replication conflict, but request_id is None, no response is sent to RaftCore"
            )
        }
    }
    #[tracing::instrument(level = "trace", skip(self))]
    fn update_matching(
        &mut self,
        request_id: Option<u64>,
        leader_time: Instant,
        new_matching: Option<LogId<C::NodeId>>,
    ) {
        tracing::debug!(
            request_id = display(request_id.display()),
            target = display(self.target),
            curr_matching = display(DisplayOption(&self.matching)),
            new_matching = display(DisplayOption(&new_matching)),
            "{}",
            func_name!()
        );
        debug_assert!(self.matching <= new_matching);
        self.matching = new_matching;
        if let Some(request_id) = request_id {
            let _ = self.tx_raft_core.send({
                Notify::Network {
                    response: Response::Progress {
                        session_id: self.session_id,
                        request_id,
                        target: self.target,
                        result: Ok(UTime::new(leader_time, ReplicationResult::Matching(new_matching))),
                    },
                }
            });
        }
    }
    #[tracing::instrument(level = "debug", skip(self))]
    pub async fn backoff_drain_events(&mut self, until: Instant) -> Result<(), ReplicationClosed> {
        let d = until - Instant::now();
        tracing::warn!(
            interval = debug(d),
            "{} backoff mode: drain events without processing them",
            func_name!()
        );
        loop {
            let sleep_duration = until - Instant::now();
            let sleep = sleep(sleep_duration);
            let recv = self.rx_repl.recv();
            tracing::debug!("backoff timeout: {:?}", sleep_duration);
            select! {
                _ = sleep => {
                    tracing::debug!("backoff timeout");
                    return Ok(());
                }
                recv_res = recv => {
                    let event = recv_res.ok_or(ReplicationClosed{})?;
                    self.process_event(event);
                }
            }
        }
    }
    #[tracing::instrument(level = "trace", skip_all)]
    pub async fn drain_events(&mut self) -> Result<(), ReplicationClosed> {
        tracing::debug!("drain_events");
        if self.next_action.is_none() {
            let event = self.rx_repl.recv().await.ok_or(ReplicationClosed {})?;
            self.process_event(event);
        }
        self.try_drain_events().await?;
        if self.next_action.is_none() {
            let m = &self.matching;
            self.next_action = Some(Data {
                request_id: None,
                payload: Payload::Logs(LogIdRange {
                    prev_log_id: *m,
                    last_log_id: *m,
                }),
            });
        }
        Ok(())
    }
    #[tracing::instrument(level = "trace", skip(self))]
    pub async fn try_drain_events(&mut self) -> Result<(), ReplicationClosed> {
        tracing::debug!("{}", func_name!());
        loop {
            let maybe_res = self.rx_repl.recv().now_or_never();
            let recv_res = match maybe_res {
                None => {
                    return Ok(());
                }
                Some(x) => x,
            };
            let event = recv_res.ok_or(ReplicationClosed {})?;
            self.process_event(event);
        }
    }
    #[tracing::instrument(level = "trace", skip_all)]
    pub fn process_event(&mut self, event: Replicate<C>) {
        tracing::debug!(event=%event.summary(), "process_event");
        match event {
            Replicate::Committed(c) => {
                debug_assert!(
                    c >= self.committed,
                    "expect new committed {} > self.committed {}",
                    c.summary(),
                    self.committed.summary()
                );
                self.committed = c;
            }
            Replicate::Heartbeat => {
                }
            Replicate::Data(d) => {
                debug_assert!(self.next_action.is_none(), "there can not be two data action in flight");
                self.next_action = Some(d);
            }
        }
    }
    #[tracing::instrument(level = "info", skip_all)]
    async fn stream_snapshot(
        &mut self,
        request_id: Option<u64>,
        rx: oneshot::Receiver<Option<Snapshot<C>>>,
    ) -> Result<Option<Data<C>>, ReplicationError<C::NodeId, C::Node>> {
        tracing::info!(request_id = display(request_id.display()), "{}", func_name!());
        let snapshot = rx.await.map_err(|e| {
            let io_err = StorageIOError::read_snapshot(None, AnyError::error(e));
            StorageError::IO { source: io_err }
        })?;
        tracing::info!(
            "received snapshot: request_id={}; meta:{}",
            request_id.display(),
            snapshot.as_ref().map(|x| &x.meta).summary()
        );
        let mut snapshot = match snapshot {
            None => {
                let io_err = StorageIOError::read_snapshot(None, AnyError::error("snapshot not found"));
                let sto_err = StorageError::IO { source: io_err };
                return Err(ReplicationError::StorageError(sto_err));
            }
            Some(x) => x,
        };
        let err_x = || (ErrorSubject::Snapshot(Some(snapshot.meta.signature())), ErrorVerb::Read);
        let mut offset = 0;
        let end = snapshot.snapshot.seek(SeekFrom::End(0)).await.sto_res(err_x)?;
        let mut buf = Vec::with_capacity(self.config.snapshot_max_chunk_size as usize);
        loop {
            snapshot.snapshot.seek(SeekFrom::Start(offset)).await.sto_res(err_x)?;
            let n_read = snapshot.snapshot.read_buf(&mut buf).await.sto_res(err_x)?;
            let leader_time = Instant::now();
            let done = (offset + n_read as u64) == end;
            let req = InstallSnapshotRequest {
                vote: self.session_id.vote,
                meta: snapshot.meta.clone(),
                offset,
                data: Vec::from(&buf[..n_read]),
                done,
            };
            buf.clear();
            tracing::debug!(
                snapshot_size = req.data.len(),
                req.offset,
                end,
                req.done,
                "sending snapshot chunk"
            );
            let snap_timeout = if done {
                self.config.install_snapshot_timeout()
            } else {
                self.config.send_snapshot_timeout()
            };
            let option = RPCOption::new(snap_timeout);
            let res = timeout(snap_timeout, self.network.install_snapshot(req, option)).await;
            let res = match res {
                Ok(outer_res) => match outer_res {
                    Ok(res) => res,
                    Err(err) => {
                        tracing::warn!(error=%err, "error sending InstallSnapshot RPC to target");
                        self.try_drain_events().await?;
                        sleep(Duration::from_millis(10)).await;
                        continue;
                    }
                },
                Err(err) => {
                    tracing::warn!(error=%err, "timeout while sending InstallSnapshot RPC to target");
                    self.try_drain_events().await?;
                    sleep(Duration::from_millis(10)).await;
                    continue;
                }
            };
            if res.vote > self.session_id.vote {
                return Err(ReplicationError::HigherVote(HigherVote {
                    higher: res.vote,
                    mine: self.session_id.vote,
                }));
            }
            if done {
                tracing::debug!(
                    "done install snapshot: snapshot last_log_id: {:?}, matching: {}",
                    snapshot.meta.last_log_id,
                    self.matching.summary(),
                );
                self.update_matching(request_id, leader_time, snapshot.meta.last_log_id);
                return Ok(None);
            }
            offset += n_read as u64;
            self.try_drain_events().await?;
        }
    }
}
#[derive(Debug)]
pub(crate) struct Data<C>
where C: RaftTypeConfig
{
    request_id: Option<u64>,
    payload: Payload<C>,
}
impl<C: RaftTypeConfig> fmt::Display for Data<C> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{{id: {}, payload: {}}}", self.request_id.display(), self.payload)
    }
}
impl<C> MessageSummary<Data<C>> for Data<C>
where C: RaftTypeConfig
{
    fn summary(&self) -> String {
        match &self.payload {
            Payload::Logs(log_id_range) => {
                format!("Logs{{request_id={}, {}}}", self.request_id.display(), log_id_range)
            }
            Payload::Snapshot(_) => {
                format!("Snapshot{{request_id={}}}", self.request_id.display())
            }
        }
    }
}
impl<C> Data<C>
where C: RaftTypeConfig
{
    fn new_logs(request_id: Option<u64>, log_id_range: LogIdRange<C::NodeId>) -> Self {
        Self {
            request_id,
            payload: Payload::Logs(log_id_range),
        }
    }
    fn new_snapshot(request_id: Option<u64>, snapshot_rx: oneshot::Receiver<Option<Snapshot<C>>>) -> Self {
        Self {
            request_id,
            payload: Payload::Snapshot(snapshot_rx),
        }
    }
}
pub(crate) enum Payload<C>
where C: RaftTypeConfig
{
    Logs(LogIdRange<C::NodeId>),
    Snapshot(oneshot::Receiver<Option<Snapshot<C>>>),
}
impl<C> fmt::Display for Payload<C>
where C: RaftTypeConfig
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::Logs(log_id_range) => {
                write!(f, "Logs({})", log_id_range)
            }
            Self::Snapshot(_) => {
                write!(f, "Snapshot()")
            }
        }
    }
}
impl<C> fmt::Debug for Payload<C>
where C: RaftTypeConfig
{
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::Logs(log_id_range) => {
                write!(f, "Logs({})", log_id_range)
            }
            Self::Snapshot(_) => {
                write!(f, "Snapshot()")
            }
        }
    }
}
#[derive(Clone, Debug)]
pub(crate) enum ReplicationResult<NID: NodeId> {
    Matching(Option<LogId<NID>>),
    Conflict(LogId<NID>),
}
pub(crate) enum Replicate<C>
where C: RaftTypeConfig
{
    Committed(Option<LogId<C::NodeId>>),
    Heartbeat,
    Data(Data<C>),
}
impl<C> Replicate<C>
where C: RaftTypeConfig
{
    pub(crate) fn logs(id: Option<u64>, log_id_range: LogIdRange<C::NodeId>) -> Self {
        Self::Data(Data::new_logs(id, log_id_range))
    }
    pub(crate) fn snapshot(id: Option<u64>, snapshot_rx: oneshot::Receiver<Option<Snapshot<C>>>) -> Self {
        Self::Data(Data::new_snapshot(id, snapshot_rx))
    }
}
impl<C> MessageSummary<Replicate<C>> for Replicate<C>
where C: RaftTypeConfig
{
    fn summary(&self) -> String {
        match self {
            Replicate::Committed(c) => {
                format!("Replicate::Committed: {:?}", c)
            }
            Replicate::Heartbeat => "Replicate::Heartbeat".to_string(),
            Replicate::Data(d) => {
                format!("Replicate::Data({})", d.summary())
            }
        }
    }
}