d_engine_core/replication/
replication_handler.rs

1use std::cmp;
2use std::collections::HashMap;
3use std::fmt::Debug;
4use std::marker::PhantomData;
5use std::sync::Arc;
6
7use bytes::BytesMut;
8use d_engine_proto::client::WriteCommand;
9use d_engine_proto::common::Entry;
10use d_engine_proto::common::EntryPayload;
11use d_engine_proto::common::LogId;
12use d_engine_proto::common::NodeRole;
13use d_engine_proto::common::entry_payload::Payload;
14use d_engine_proto::server::replication::AppendEntriesRequest;
15use d_engine_proto::server::replication::AppendEntriesResponse;
16use d_engine_proto::server::replication::ConflictResult;
17use d_engine_proto::server::replication::SuccessResult;
18use d_engine_proto::server::replication::append_entries_response;
19use dashmap::DashMap;
20use prost::Message;
21use tonic::async_trait;
22use tracing::debug;
23use tracing::error;
24use tracing::info;
25use tracing::trace;
26use tracing::warn;
27
28use super::AppendResponseWithUpdates;
29use super::ReplicationCore;
30use crate::AppendResults;
31use crate::IdAllocationError;
32use crate::LeaderStateSnapshot;
33use crate::PeerUpdate;
34use crate::RaftContext;
35use crate::RaftLog;
36use crate::ReplicationError;
37use crate::Result;
38use crate::StateSnapshot;
39use crate::Transport;
40use crate::TypeConfig;
41use crate::alias::ROF;
42use crate::scoped_timer::ScopedTimer;
43use crate::utils::cluster::is_majority;
44
45#[derive(Clone)]
46pub struct ReplicationHandler<T>
47where
48    T: TypeConfig,
49{
50    pub my_id: u32,
51    _phantom: PhantomData<T>,
52}
53
54#[async_trait]
55impl<T> ReplicationCore<T> for ReplicationHandler<T>
56where
57    T: TypeConfig,
58{
59    async fn handle_raft_request_in_batch(
60        &self,
61        entry_payloads: Vec<EntryPayload>,
62        state_snapshot: StateSnapshot,
63        leader_state_snapshot: LeaderStateSnapshot,
64        cluster_metadata: &crate::raft_role::ClusterMetadata,
65        ctx: &RaftContext<T>,
66    ) -> Result<AppendResults> {
67        let _timer = ScopedTimer::new("handle_raft_request_in_batch");
68
69        debug!("-------- handle_raft_request_in_batch --------");
70
71        // ----------------------
72        // Phase 1: Pre-Checks and Cluster Topology Detection
73        // ----------------------
74        // Use cached replication targets from cluster metadata (zero-cost)
75        let replication_targets = &cluster_metadata.replication_targets;
76
77        // Separate Voters and Learners
78        // Use role (not status) to distinguish: Follower/Candidate are voters, Learner are learners
79        // This is more robust than using status, which can be temporarily non-Active
80        let (voters, learners): (Vec<_>, Vec<_>) = replication_targets
81            .iter()
82            .partition(|node| node.role != NodeRole::Learner as i32);
83
84        if !learners.is_empty() {
85            trace!(
86                "handle_raft_request_in_batch - voters: {:?}, learners: {:?}",
87                voters, learners
88            );
89        }
90
91        // ----------------------
92        // Phase 2: Process Client Commands
93        // ----------------------
94
95        // Record down the last index before new inserts, to avoid duplicated entries, bugfix#48
96        let raft_log = ctx.raft_log();
97        let leader_last_index_before = raft_log.last_entry_id();
98
99        let new_entries = self
100            .generate_new_entries(entry_payloads, state_snapshot.current_term, raft_log)
101            .await?;
102
103        // ----------------------
104        // Phase 3: Prepare Replication Data
105        // ----------------------
106        let replication_data = ReplicationData {
107            leader_last_index_before,
108            current_term: state_snapshot.current_term,
109            commit_index: state_snapshot.commit_index,
110            peer_next_indices: leader_state_snapshot.next_index,
111        };
112
113        let entries_per_peer = self.prepare_peer_entries(
114            &new_entries,
115            &replication_data,
116            ctx.node_config.raft.replication.append_entries_max_entries_per_replication,
117            raft_log,
118        );
119
120        // ----------------------
121        // Phase 4: Build Requests
122        // ----------------------
123        let requests = replication_targets
124            .iter()
125            .map(|m| {
126                self.build_append_request(raft_log, m.id, &entries_per_peer, &replication_data)
127            })
128            .collect();
129
130        // ----------------------
131        // Phase 5: Replication
132        // ----------------------
133
134        // No peers: logs already written in Phase 2, return immediately
135        // No replication needed, quorum is automatically achieved (standalone node)
136        if replication_targets.is_empty() {
137            debug!(
138                "Standalone node (leader={}): logs persisted, quorum automatically achieved",
139                self.my_id
140            );
141            return Ok(AppendResults {
142                commit_quorum_achieved: true,
143                peer_updates: HashMap::new(),
144                learner_progress: HashMap::new(),
145            });
146        }
147
148        // Multi-node cluster: perform replication to peers
149        let leader_current_term = state_snapshot.current_term;
150        let mut successes = 1; // Include leader itself
151        let mut peer_updates = HashMap::new();
152        let mut learner_progress = HashMap::new();
153
154        let membership = ctx.membership();
155        match ctx
156            .transport()
157            .send_append_requests(
158                requests,
159                &ctx.node_config.retry,
160                membership,
161                ctx.node_config.raft.rpc_compression.replication_response,
162            )
163            .await
164        {
165            Ok(append_result) => {
166                for response in append_result.responses {
167                    match response {
168                        Ok(append_response) => {
169                            // Skip responses from stale terms
170                            if append_response.term < leader_current_term {
171                                info!(%append_response.term, %leader_current_term, "append_response.term < leader_current_term");
172                                continue;
173                            }
174
175                            match append_response.result {
176                                Some(append_entries_response::Result::Success(success_result)) => {
177                                    // Only count successful responses from Voters
178                                    if voters.iter().any(|n| n.id == append_response.node_id) {
179                                        successes += 1;
180                                    }
181
182                                    let update = self.handle_success_response(
183                                        append_response.node_id,
184                                        append_response.term,
185                                        success_result,
186                                        leader_current_term,
187                                    )?;
188
189                                    // Record Learner progress
190                                    if learners.iter().any(|n| n.id == append_response.node_id) {
191                                        learner_progress
192                                            .insert(append_response.node_id, update.match_index);
193                                    }
194
195                                    peer_updates.insert(append_response.node_id, update);
196                                }
197
198                                Some(append_entries_response::Result::Conflict(
199                                    conflict_result,
200                                )) => {
201                                    let current_next_index = replication_data
202                                        .peer_next_indices
203                                        .get(&append_response.node_id)
204                                        .copied()
205                                        .unwrap_or(1);
206
207                                    let update = self.handle_conflict_response(
208                                        append_response.node_id,
209                                        conflict_result,
210                                        raft_log,
211                                        current_next_index,
212                                    )?;
213
214                                    // Record Learner progress
215                                    if learners.iter().any(|n| n.id == append_response.node_id) {
216                                        learner_progress
217                                            .insert(append_response.node_id, update.match_index);
218                                    }
219
220                                    peer_updates.insert(append_response.node_id, update);
221                                }
222
223                                Some(append_entries_response::Result::HigherTerm(higher_term)) => {
224                                    // Only handle higher term if it's greater than current term
225                                    if higher_term > leader_current_term {
226                                        return Err(
227                                            ReplicationError::HigherTerm(higher_term).into()
228                                        );
229                                    }
230                                }
231
232                                None => {
233                                    error!("TODO: need to figure out the reason of this cluase");
234                                    unreachable!();
235                                }
236                            }
237                        }
238                        Err(e) => {
239                            // Timeouts and network errors are logged but not added to peer_updates
240                            warn!("Peer request failed: {:?}", e);
241                        }
242                    }
243                }
244                let peer_ids = append_result.peer_ids;
245                debug!(
246                    "send_append_requests to: {:?} with succeed number = {}",
247                    &peer_ids, successes
248                );
249
250                let total_voters = voters.len() + 1; // Leader + voter peers
251                let commit_quorum_achieved = is_majority(successes, total_voters);
252                Ok(AppendResults {
253                    commit_quorum_achieved,
254                    peer_updates,
255                    learner_progress,
256                })
257            }
258            Err(e) => return Err(e),
259        }
260    }
261
262    fn handle_success_response(
263        &self,
264        peer_id: u32,
265        peer_term: u64,
266        success_result: SuccessResult,
267        leader_term: u64,
268    ) -> Result<PeerUpdate> {
269        let _timer = ScopedTimer::new("handle_success_response");
270
271        debug!(
272            ?success_result,
273            "Received success response from peer {}", peer_id
274        );
275
276        let match_log = success_result.last_match.unwrap_or(LogId { term: 0, index: 0 });
277
278        // Verify Term consistency
279        if peer_term > leader_term {
280            return Err(ReplicationError::HigherTerm(peer_term).into());
281        }
282
283        let peer_match_index = match_log.index;
284        let peer_next_index = peer_match_index + 1;
285
286        Ok(PeerUpdate {
287            match_index: Some(peer_match_index),
288            next_index: peer_next_index,
289            success: true,
290        })
291    }
292
293    fn handle_conflict_response(
294        &self,
295        peer_id: u32,
296        conflict_result: ConflictResult,
297        raft_log: &Arc<ROF<T>>,
298        current_next_index: u64,
299    ) -> Result<PeerUpdate> {
300        let _timer = ScopedTimer::new("handle_conflict_response");
301
302        debug!("Handling conflict from peer {}", peer_id);
303
304        // Calculate next_index based on conflict information
305        let next_index = match (
306            conflict_result.conflict_term,
307            conflict_result.conflict_index,
308        ) {
309            (Some(term), Some(index)) => {
310                if let Some(last_index_for_term) = raft_log.last_index_for_term(term) {
311                    last_index_for_term + 1
312                } else {
313                    // Term not found, fallback to conflict index
314                    index
315                }
316            }
317            (None, Some(index)) => index,
318            _ => current_next_index.saturating_sub(1), // Return to the initial position
319        };
320
321        // Make sure next_index is not less than 1
322        let next_index = next_index.max(1);
323        Ok(PeerUpdate {
324            match_index: None, // Unknown after conflict
325            next_index,
326            success: false,
327        })
328    }
329
330    fn retrieve_to_be_synced_logs_for_peers(
331        &self,
332        new_entries: Vec<Entry>,
333        leader_last_index_before_inserting_new_entries: u64,
334        max_legacy_entries_per_peer: u64, //Maximum number of entries
335        peer_next_indices: &HashMap<u32, u64>,
336        raft_log: &Arc<ROF<T>>,
337    ) -> DashMap<u32, Vec<Entry>> {
338        let _timer = ScopedTimer::new("retrieve_to_be_synced_logs_for_peers");
339
340        let peer_entries: DashMap<u32, Vec<Entry>> = DashMap::new();
341        trace!(
342            "retrieve_to_be_synced_logs_for_peers::leader_last_index: {}",
343            leader_last_index_before_inserting_new_entries
344        );
345        peer_next_indices.keys().for_each(|&id| {
346            if id == self.my_id {
347                return;
348            }
349            let peer_next_id = peer_next_indices.get(&id).copied().unwrap_or(1);
350
351            debug!("peer: {} next: {}", id, peer_next_id);
352            let mut entries = Vec::new();
353            if leader_last_index_before_inserting_new_entries >= peer_next_id {
354                let until_index = if (leader_last_index_before_inserting_new_entries - peer_next_id)
355                    >= max_legacy_entries_per_peer
356                {
357                    peer_next_id + max_legacy_entries_per_peer - 1
358                } else {
359                    leader_last_index_before_inserting_new_entries
360                };
361
362                let legacy_entries = match raft_log.get_entries_range(peer_next_id..=until_index) {
363                    Ok(entries) => entries,
364                    Err(e) => {
365                        error!("Failed to get legacy entries for peer {}: {:?}", id, e);
366                        Vec::new()
367                    }
368                };
369
370                if !legacy_entries.is_empty() {
371                    trace!("legacy_entries: {:?}", &legacy_entries);
372                    entries.extend(legacy_entries);
373                }
374            }
375
376            if !new_entries.is_empty() {
377                entries.extend(new_entries.clone()); // Add new entries
378            }
379            if !entries.is_empty() {
380                peer_entries.insert(id, entries);
381            }
382        });
383
384        peer_entries
385    }
386
387    /// As Follower only
388    async fn handle_append_entries(
389        &self,
390        request: AppendEntriesRequest,
391        state_snapshot: &StateSnapshot,
392        raft_log: &Arc<ROF<T>>,
393    ) -> Result<AppendResponseWithUpdates> {
394        let _timer = ScopedTimer::new("handle_append_entries");
395
396        debug!(
397            "[F-{:?}] >> receive leader append request {:?}",
398            self.my_id, request
399        );
400        let current_term = state_snapshot.current_term;
401        let mut last_log_id_option = raft_log.last_log_id();
402
403        //if there is no new entries need to insert, we just return the last local log index
404        let mut commit_index_update = None;
405
406        let response = self.check_append_entries_request_is_legal(current_term, &request, raft_log);
407
408        // Handle illegal requests (return conflict or higher Term)
409        if response.is_conflict() || response.is_higher_term() {
410            debug!("Rejecting AppendEntries: {:?}", &response);
411
412            return Ok(AppendResponseWithUpdates {
413                response,
414                commit_index_update,
415            });
416        }
417
418        //switch to follower listening state
419        debug!("switch to follower listening state");
420
421        let success = true;
422
423        if !request.entries.is_empty() {
424            last_log_id_option = raft_log
425                .filter_out_conflicts_and_append(
426                    request.prev_log_index,
427                    request.prev_log_term,
428                    request.entries.clone(),
429                )
430                .await?;
431        }
432
433        if let Some(new_commit_index) = Self::if_update_commit_index_as_follower(
434            state_snapshot.commit_index,
435            raft_log.last_entry_id(),
436            request.leader_commit_index,
437        ) {
438            debug!("new commit index received: {:?}", new_commit_index);
439            commit_index_update = Some(new_commit_index);
440        }
441
442        debug!(
443            "success: {:?}, current_term: {:?}, last_matched_id: {:?}",
444            success, current_term, last_log_id_option
445        );
446
447        Ok(AppendResponseWithUpdates {
448            response: AppendEntriesResponse::success(self.my_id, current_term, last_log_id_option),
449            commit_index_update,
450        })
451    }
452
453    ///If leaderCommit > commitIndex, set commitIndex = min(leaderCommit, index
454    /// of last new entry)
455    fn if_update_commit_index_as_follower(
456        my_commit_index: u64,
457        last_raft_log_id: u64,
458        leader_commit_index: u64,
459    ) -> Option<u64> {
460        debug!(
461            "Should I update my commit index? leader_commit_index:{:?} > state.commit_index:{:?} = {:?}",
462            leader_commit_index,
463            my_commit_index,
464            leader_commit_index > my_commit_index
465        );
466
467        if leader_commit_index > my_commit_index {
468            return Some(cmp::min(leader_commit_index, last_raft_log_id));
469        }
470        None
471    }
472
473    #[tracing::instrument(skip(self, raft_log))]
474    fn check_append_entries_request_is_legal(
475        &self,
476        my_term: u64,
477        request: &AppendEntriesRequest,
478        raft_log: &Arc<ROF<T>>,
479    ) -> AppendEntriesResponse {
480        let _timer = ScopedTimer::new("check_append_entries_request_is_legal");
481
482        // Rule 1: Term check
483        if my_term > request.term {
484            warn!(" my_term({}) >= req.term({}) ", my_term, request.term);
485            return AppendEntriesResponse::higher_term(self.my_id, my_term);
486        }
487
488        let last_log_id_option = raft_log.last_log_id();
489        let last_log_id = last_log_id_option.unwrap_or(LogId { term: 0, index: 0 }).index;
490
491        // Rule 2: Special handling for virtual log
492        if request.prev_log_index == 0 && request.prev_log_term == 0 {
493            // Accept virtual log request (regardless of whether the local log is empty)
494            return AppendEntriesResponse::success(self.my_id, my_term, last_log_id_option);
495        }
496
497        // Rule 3: General log matching check
498        match raft_log.entry_term(request.prev_log_index) {
499            Some(term) if term == request.prev_log_term => AppendEntriesResponse::success(
500                self.my_id,
501                my_term,
502                Some(LogId {
503                    term: request.prev_log_term,
504                    index: request.prev_log_index,
505                }),
506            ),
507            Some(conflict_term) => {
508                // Find first index of conflict term
509                // TODO:Upcoming feature #45 in v0.2.0
510                // let conflict_index = raft_log.first_index_for_term(conflict_term);
511                let conflict_index = if request.prev_log_index < last_log_id {
512                    request.prev_log_index.saturating_sub(1)
513                } else {
514                    last_log_id + 1
515                };
516                AppendEntriesResponse::conflict(
517                    self.my_id,
518                    my_term,
519                    Some(conflict_term),
520                    Some(conflict_index),
521                )
522            }
523            None => {
524                // prev_log_index not exist, return next expected index
525                let conflict_index = last_log_id + 1;
526                AppendEntriesResponse::conflict(self.my_id, my_term, None, Some(conflict_index))
527            }
528        }
529    }
530}
531
532#[derive(Debug)]
533pub struct ReplicationData {
534    pub leader_last_index_before: u64,
535    pub current_term: u64,
536    pub commit_index: u64,
537    pub peer_next_indices: HashMap<u32, u64>,
538}
539
540impl<T> ReplicationHandler<T>
541where
542    T: TypeConfig,
543{
544    pub fn new(my_id: u32) -> Self {
545        Self {
546            my_id,
547            _phantom: PhantomData,
548        }
549    }
550
551    /// Generate a new log entry
552    ///     including insert them into local raft log
553    pub async fn generate_new_entries(
554        &self,
555        entry_payloads: Vec<EntryPayload>,
556        current_term: u64,
557        raft_log: &Arc<ROF<T>>,
558    ) -> Result<Vec<Entry>> {
559        let _timer = ScopedTimer::new("generate_new_entries");
560
561        // Handle empty case early
562        if entry_payloads.is_empty() {
563            return Ok(Vec::new());
564        }
565
566        // Pre-allocate ID range in one atomic operation
567        let id_range = raft_log.pre_allocate_id_range(entry_payloads.len() as u64);
568        assert!(!id_range.is_empty());
569
570        let mut next_index = *id_range.start();
571
572        let mut entries = Vec::with_capacity(entry_payloads.len());
573
574        for payload in entry_payloads {
575            // Ensure we don't exceed allocated range
576            if next_index > *id_range.end() {
577                return Err(IdAllocationError::Overflow {
578                    start: next_index,
579                    end: *id_range.end(),
580                }
581                .into());
582            }
583
584            entries.push(Entry {
585                index: next_index,
586                term: current_term,
587                payload: Some(payload),
588            });
589
590            next_index += 1;
591        }
592
593        if !entries.is_empty() {
594            trace!(
595                "RaftLog insert_batch: {}..={}",
596                entries[0].index,
597                entries.last().unwrap().index
598            );
599            raft_log.insert_batch(entries.clone()).await?;
600        }
601
602        Ok(entries)
603    }
604
605    /// Prepare the items that need to be synchronized for each node
606    pub fn prepare_peer_entries(
607        &self,
608        new_entries: &[Entry],
609        data: &ReplicationData,
610        max_legacy_entries: u64,
611        raft_log: &Arc<ROF<T>>,
612    ) -> DashMap<u32, Vec<Entry>> {
613        self.retrieve_to_be_synced_logs_for_peers(
614            new_entries.to_vec(),
615            data.leader_last_index_before,
616            max_legacy_entries,
617            &data.peer_next_indices,
618            raft_log,
619        )
620    }
621
622    /// Build an append request for a single node
623    pub fn build_append_request(
624        &self,
625        raft_log: &Arc<ROF<T>>,
626        peer_id: u32,
627        entries_per_peer: &DashMap<u32, Vec<Entry>>,
628        data: &ReplicationData,
629    ) -> (u32, AppendEntriesRequest) {
630        let _timer = ScopedTimer::new("build_append_request");
631        // Calculate prev_log metadata
632        let (prev_log_index, prev_log_term) =
633            data.peer_next_indices.get(&peer_id).map_or((0, 0), |next_id| {
634                let prev_index = next_id.saturating_sub(1);
635                let term = raft_log.entry_term(prev_index).unwrap_or(0);
636                (prev_index, term)
637            });
638
639        // Get the items to be sent
640        let entries = entries_per_peer.get(&peer_id).map(|e| e.clone()).unwrap_or_default();
641
642        debug!(
643            "[Leader {} -> Follower {}] Replicating {} entries",
644            self.my_id,
645            peer_id,
646            entries.len()
647        );
648
649        (
650            peer_id,
651            AppendEntriesRequest {
652                term: data.current_term,
653                leader_id: self.my_id,
654                prev_log_index,
655                prev_log_term,
656                entries,
657                leader_commit_index: data.commit_index,
658            },
659        )
660    }
661}
662
663impl<T> Debug for ReplicationHandler<T>
664where
665    T: TypeConfig,
666{
667    fn fmt(
668        &self,
669        f: &mut std::fmt::Formatter<'_>,
670    ) -> std::fmt::Result {
671        f.debug_struct("ReplicationHandler").field("my_id", &self.my_id).finish()
672    }
673}
674
675/// Converts a vector of client WriteCommands into a vector of EntryPayloads.
676/// Each WriteCommand is serialized into bytes and wrapped in an EntryPayload::Command variant.
677///
678/// # Arguments
679/// * `commands` - A vector of WriteCommand to be converted
680///
681/// # Returns
682/// A vector of EntryPayload containing the serialized commands
683pub fn client_command_to_entry_payloads(commands: Vec<WriteCommand>) -> Vec<EntryPayload> {
684    commands
685        .into_iter()
686        .map(|cmd| {
687            let mut buf = BytesMut::with_capacity(cmd.encoded_len());
688            cmd.encode(&mut buf).unwrap();
689
690            EntryPayload {
691                payload: Some(Payload::Command(buf.freeze())),
692            }
693        })
694        .collect()
695}