Skip to main content

astra_core/
raft.rs

1use std::collections::{BTreeMap, HashMap, VecDeque};
2use std::fmt::Debug;
3use std::io::{Cursor, Read};
4use std::ops::{Bound, RangeBounds};
5use std::path::Path;
6use std::sync::{Arc, Mutex};
7use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
8
9use astra_proto::astraraftpb::internal_raft_client::InternalRaftClient;
10use astra_proto::astraraftpb::internal_raft_server::InternalRaft;
11use astra_proto::astraraftpb::RaftBytes;
12use openraft::entry::{Entry, EntryPayload};
13use openraft::error::{InstallSnapshotError, NetworkError, RPCError, RaftError};
14use openraft::network::{RPCOption, RaftNetwork, RaftNetworkFactory};
15use openraft::raft::{
16    AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest, InstallSnapshotResponse,
17    VoteRequest, VoteResponse,
18};
19use openraft::storage::{LogFlushed, RaftLogStorage, RaftSnapshotBuilder, RaftStateMachine};
20use openraft::{
21    BasicNode, LogId, LogState, Raft, RaftLogReader, Snapshot, SnapshotMeta, StorageError,
22    StoredMembership, Vote,
23};
24use serde::{Deserialize, Serialize};
25use tokio::sync::{oneshot, Mutex as AsyncMutex, Notify, RwLock};
26use tonic::transport::Channel;
27use tonic::{Code, Request, Response, Status};
28use tracing::{debug, info, warn};
29
30use crate::config::{AstraConfig, WalIoEngine};
31use crate::errors::StoreError;
32use crate::metrics;
33use crate::store::{
34    DeleteOutput, KvStore, LeaseGrantOutput, LeaseRevokeOutput, PutOutput, RangeOutput,
35    SnapshotState, ValueEntry,
36};
37
38openraft::declare_raft_types!(
39    pub AstraTypeConfig:
40        D = AstraWriteRequest,
41        R = AstraWriteResponse,
42        NodeId = u64,
43        Node = BasicNode,
44        Entry = Entry<AstraTypeConfig>,
45        SnapshotData = Cursor<Vec<u8>>,
46        AsyncRuntime = openraft::TokioRuntime,
47);
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct AstraBatchPutOp {
51    pub key: Vec<u8>,
52    pub value: Vec<u8>,
53    pub lease: i64,
54    pub ignore_value: bool,
55    pub ignore_lease: bool,
56    #[serde(default)]
57    pub prev_kv: bool,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct AstraBatchPutRefOp {
62    pub key: Vec<u8>,
63    pub value_idx: u32,
64    pub lease: i64,
65    pub ignore_value: bool,
66    pub ignore_lease: bool,
67    #[serde(default)]
68    pub prev_kv: bool,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct AstraTokenValue {
73    pub token_id: u32,
74    pub value: Vec<u8>,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct AstraBatchPutTokenOp {
79    pub key: Vec<u8>,
80    pub token_id: u32,
81    pub lease: i64,
82    pub ignore_value: bool,
83    pub ignore_lease: bool,
84    #[serde(default)]
85    pub prev_kv: bool,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct AstraBatchPutResult {
90    pub revision: i64,
91    pub prev: Option<ValueEntry>,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub enum AstraTxnCmpResult {
96    Equal,
97    Greater,
98    Less,
99    NotEqual,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub enum AstraTxnCmpTarget {
104    Version,
105    CreateRevision,
106    ModRevision,
107    Value,
108    Lease,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub enum AstraTxnCmpValue {
113    I64(i64),
114    Bytes(Vec<u8>),
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct AstraTxnCompare {
119    pub result: AstraTxnCmpResult,
120    pub target: AstraTxnCmpTarget,
121    pub key: Vec<u8>,
122    pub range_end: Vec<u8>,
123    pub target_value: AstraTxnCmpValue,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub enum AstraTxnOp {
128    Range {
129        key: Vec<u8>,
130        range_end: Vec<u8>,
131        limit: i64,
132        revision: i64,
133        keys_only: bool,
134        count_only: bool,
135    },
136    Put {
137        key: Vec<u8>,
138        value: Vec<u8>,
139        lease: i64,
140        ignore_value: bool,
141        ignore_lease: bool,
142        prev_kv: bool,
143    },
144    DeleteRange {
145        key: Vec<u8>,
146        range_end: Vec<u8>,
147        prev_kv: bool,
148    },
149    Txn {
150        compare: Vec<AstraTxnCompare>,
151        success: Vec<AstraTxnOp>,
152        failure: Vec<AstraTxnOp>,
153    },
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub enum AstraTxnOpResponse {
158    Range {
159        revision: i64,
160        count: i64,
161        more: bool,
162        kvs: Vec<(Vec<u8>, ValueEntry)>,
163    },
164    Put {
165        revision: i64,
166        prev: Option<ValueEntry>,
167    },
168    Delete {
169        revision: i64,
170        deleted: i64,
171        prev_kvs: Vec<(Vec<u8>, ValueEntry)>,
172    },
173    Txn {
174        revision: i64,
175        succeeded: bool,
176        responses: Vec<AstraTxnOpResponse>,
177    },
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
181pub enum AstraWriteRequest {
182    Put {
183        key: Vec<u8>,
184        value: Vec<u8>,
185        lease: i64,
186        ignore_value: bool,
187        ignore_lease: bool,
188        #[serde(default)]
189        prev_kv: bool,
190    },
191    PutBatch {
192        ops: Vec<AstraBatchPutOp>,
193    },
194    PutBatchRef {
195        #[serde(default)]
196        batch_id: u64,
197        #[serde(default)]
198        submit_ts_micros: u64,
199        values: Vec<Vec<u8>>,
200        ops: Vec<AstraBatchPutRefOp>,
201    },
202    PutBatchTokenized {
203        #[serde(default)]
204        batch_id: u64,
205        #[serde(default)]
206        submit_ts_micros: u64,
207        #[serde(default)]
208        dict_epoch: u64,
209        #[serde(default)]
210        dict_additions: Vec<AstraTokenValue>,
211        ops: Vec<AstraBatchPutTokenOp>,
212    },
213    DeleteRange {
214        key: Vec<u8>,
215        range_end: Vec<u8>,
216        prev_kv: bool,
217    },
218    Txn {
219        compare: Vec<AstraTxnCompare>,
220        success: Vec<AstraTxnOp>,
221        failure: Vec<AstraTxnOp>,
222    },
223    Compact {
224        revision: i64,
225    },
226    LeaseGrant {
227        id: i64,
228        ttl: i64,
229    },
230    LeaseRevoke {
231        id: i64,
232    },
233    LeaseKeepAlive {
234        id: i64,
235    },
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize)]
239pub enum AstraWriteResponse {
240    Put {
241        revision: i64,
242        prev: Option<ValueEntry>,
243    },
244    PutBatch {
245        results: Vec<AstraBatchPutResult>,
246    },
247    Delete {
248        revision: i64,
249        deleted: i64,
250        prev_kvs: Vec<(Vec<u8>, ValueEntry)>,
251    },
252    Txn {
253        revision: i64,
254        succeeded: bool,
255        responses: Vec<AstraTxnOpResponse>,
256    },
257    Compact {
258        revision: i64,
259        compact_revision: i64,
260    },
261    LeaseGrant {
262        revision: i64,
263        id: i64,
264        ttl: i64,
265    },
266    LeaseRevoke {
267        revision: i64,
268        deleted: i64,
269    },
270    LeaseKeepAlive {
271        revision: i64,
272        id: i64,
273        ttl: i64,
274    },
275    LeaseTtl {
276        revision: i64,
277        id: i64,
278        ttl: i64,
279        granted_ttl: i64,
280        keys: Vec<Vec<u8>>,
281    },
282    LeaseLeases {
283        revision: i64,
284        leases: Vec<i64>,
285    },
286    Empty,
287}
288
289#[derive(Debug, Clone)]
290pub struct WalBatchConfig {
291    pub max_batch_requests: usize,
292    pub max_batch_bytes: usize,
293    pub max_linger: Duration,
294    pub low_concurrency_threshold: usize,
295    pub low_linger: Duration,
296    pub channel_capacity: usize,
297    pub pending_limit: usize,
298    pub segment_bytes: u64,
299    pub io_engine: WalIoEngine,
300}
301
302impl Default for WalBatchConfig {
303    fn default() -> Self {
304        Self {
305            max_batch_requests: 1_000,
306            max_batch_bytes: 8 * 1024 * 1024,
307            max_linger: Duration::from_millis(2),
308            low_concurrency_threshold: 5,
309            low_linger: Duration::ZERO,
310            channel_capacity: 8_192,
311            pending_limit: 2_000,
312            segment_bytes: 64 * 1024 * 1024,
313            io_engine: WalIoEngine::Auto,
314        }
315    }
316}
317
318#[derive(Debug, Clone)]
319pub struct RaftBootstrap {
320    pub config: Arc<openraft::Config>,
321}
322
323impl RaftBootstrap {
324    pub fn new(cfg: &AstraConfig) -> anyhow::Result<Self> {
325        let cfg = openraft::Config {
326            cluster_name: "astra".to_string(),
327            election_timeout_min: cfg.raft_election_timeout_min_ms,
328            election_timeout_max: cfg.raft_election_timeout_max_ms,
329            heartbeat_interval: cfg.raft_heartbeat_interval_ms,
330            snapshot_max_chunk_size: 2 * 1024 * 1024,
331            max_payload_entries: cfg.raft_max_payload_entries,
332            snapshot_policy: openraft::SnapshotPolicy::LogsSinceLast(512),
333            replication_lag_threshold: cfg.raft_replication_lag_threshold,
334            max_in_snapshot_log_to_keep: 64,
335            purge_batch_size: 256,
336            ..Default::default()
337        };
338
339        let cfg = cfg.validate()?;
340        Ok(Self {
341            config: Arc::new(cfg),
342        })
343    }
344}
345
346#[derive(Debug, Clone, Serialize, Deserialize)]
347#[serde(tag = "kind", rename_all = "snake_case")]
348enum DurableRecord {
349    Append {
350        entries: Vec<Entry<AstraTypeConfig>>,
351    },
352    Truncate {
353        since: LogId<u64>,
354    },
355    Purge {
356        upto: LogId<u64>,
357    },
358    Vote {
359        vote: Vote<u64>,
360    },
361    Committed {
362        committed: Option<LogId<u64>>,
363    },
364}
365
366#[derive(Debug)]
367struct LogStoreState {
368    logs: BTreeMap<u64, Entry<AstraTypeConfig>>,
369    last_purged_log_id: Option<LogId<u64>>,
370    vote: Option<Vote<u64>>,
371    committed: Option<LogId<u64>>,
372}
373
374impl Default for LogStoreState {
375    fn default() -> Self {
376        Self {
377            logs: BTreeMap::new(),
378            last_purged_log_id: None,
379            vote: None,
380            committed: None,
381        }
382    }
383}
384
385fn align_up(value: usize, align: usize) -> usize {
386    if value % align == 0 {
387        value
388    } else {
389        value + (align - (value % align))
390    }
391}
392
393fn now_micros() -> u64 {
394    SystemTime::now()
395        .duration_since(UNIX_EPOCH)
396        .map(|d| d.as_micros() as u64)
397        .unwrap_or_default()
398}
399
400fn encode_wal_frame(payload: &[u8], block_size: usize) -> Vec<u8> {
401    let mut header = Vec::with_capacity(8);
402    header.extend_from_slice(b"ASTR");
403    header.extend_from_slice(&(payload.len() as u32).to_le_bytes());
404
405    let total = header.len() + payload.len();
406    let padded = align_up(total, block_size);
407    let mut out = Vec::with_capacity(padded);
408    out.extend_from_slice(&header);
409    out.extend_from_slice(payload);
410    out.resize(padded, 0_u8);
411    out
412}
413
414#[derive(Debug)]
415struct PosixBlockDevice {
416    file: std::fs::File,
417    offset: u64,
418    allocated: u64,
419    segment_bytes: u64,
420    block_size: usize,
421}
422
423impl PosixBlockDevice {
424    fn open(path: &Path, offset: u64, segment_bytes: u64) -> std::io::Result<Self> {
425        std::fs::create_dir_all(path.parent().unwrap_or_else(|| Path::new(".")))?;
426
427        let file = std::fs::OpenOptions::new()
428            .create(true)
429            .read(true)
430            .append(false)
431            .write(true)
432            .open(path)?;
433
434        let mut this = Self {
435            allocated: file.metadata()?.len(),
436            file,
437            offset,
438            segment_bytes,
439            block_size: 4096,
440        };
441
442        if this.allocated == 0 {
443            this.preallocate(this.segment_bytes)?;
444        }
445        this.ensure_capacity(this.offset)?;
446        Ok(this)
447    }
448
449    fn preallocate(&mut self, bytes: u64) -> std::io::Result<()> {
450        if bytes == 0 {
451            return Ok(());
452        }
453        #[cfg(target_os = "linux")]
454        {
455            use std::os::unix::io::AsRawFd;
456            let fd = self.file.as_raw_fd();
457            let rc = unsafe {
458                libc::fallocate(fd, 0, self.allocated as libc::off_t, bytes as libc::off_t)
459            };
460            if rc != 0 {
461                return Err(std::io::Error::last_os_error());
462            }
463            self.allocated = self.allocated.saturating_add(bytes);
464            info!(
465                allocated_bytes = self.allocated,
466                segment_bytes = self.segment_bytes,
467                "wal fallocate preallocation complete"
468            );
469            return Ok(());
470        }
471
472        #[cfg(not(target_os = "linux"))]
473        {
474            let next = self.allocated.saturating_add(bytes);
475            self.file.set_len(next)?;
476            self.allocated = next;
477            return Ok(());
478        }
479    }
480
481    fn ensure_capacity(&mut self, required_end: u64) -> std::io::Result<()> {
482        while required_end > self.allocated {
483            self.preallocate(self.segment_bytes)?;
484        }
485        Ok(())
486    }
487
488    fn append_payload(&mut self, payload: &[u8]) -> std::io::Result<()> {
489        use std::os::unix::io::AsRawFd;
490
491        let frame = encode_wal_frame(payload, self.block_size);
492        let required_end = self.offset.saturating_add(frame.len() as u64);
493        self.ensure_capacity(required_end)?;
494
495        let fd = self.file.as_raw_fd();
496        let wrote = unsafe {
497            libc::pwrite(
498                fd,
499                frame.as_ptr() as *const libc::c_void,
500                frame.len(),
501                self.offset as libc::off_t,
502            )
503        };
504        if wrote < 0 || wrote as usize != frame.len() {
505            return Err(std::io::Error::last_os_error());
506        }
507        self.offset = required_end;
508        Ok(())
509    }
510
511    fn sync_data(&mut self) -> std::io::Result<()> {
512        #[cfg(target_os = "linux")]
513        {
514            use std::os::unix::io::AsRawFd;
515            let fd = self.file.as_raw_fd();
516            let rc = unsafe { libc::fdatasync(fd) };
517            if rc != 0 {
518                return Err(std::io::Error::last_os_error());
519            }
520            return Ok(());
521        }
522
523        #[cfg(not(target_os = "linux"))]
524        {
525            self.file.sync_data()
526        }
527    }
528}
529
530#[derive(Debug)]
531enum WalDevice {
532    Posix(PosixBlockDevice),
533}
534
535impl WalDevice {
536    fn open(
537        path: &Path,
538        offset: u64,
539        cfg: &WalBatchConfig,
540    ) -> std::io::Result<(Self, &'static str)> {
541        if cfg.io_engine == WalIoEngine::IoUring {
542            warn!(
543                "io_uring requested but direct append path runs on tokio runtime; using posix wal engine"
544            );
545        }
546
547        let d = PosixBlockDevice::open(path, offset, cfg.segment_bytes)?;
548        info!("wal io engine selected: posix");
549        Ok((WalDevice::Posix(d), "posix"))
550    }
551
552    fn append_payload(&mut self, payload: &[u8]) -> std::io::Result<()> {
553        match self {
554            WalDevice::Posix(d) => d.append_payload(payload),
555        }
556    }
557
558    fn sync_data(&mut self) -> std::io::Result<()> {
559        match self {
560            WalDevice::Posix(d) => d.sync_data(),
561        }
562    }
563}
564
565enum WalWriteCompletion {
566    Append(LogFlushed<AstraTypeConfig>),
567    Waiter(oneshot::Sender<std::io::Result<()>>),
568}
569
570impl WalWriteCompletion {
571    fn complete_ok(self) {
572        match self {
573            WalWriteCompletion::Append(callback) => callback.log_io_completed(Ok(())),
574            WalWriteCompletion::Waiter(tx) => {
575                let _ = tx.send(Ok(()));
576            }
577        }
578    }
579
580    fn complete_err(self, err_msg: &str) {
581        match self {
582            WalWriteCompletion::Append(callback) => {
583                callback.log_io_completed(Err(std::io::Error::other(err_msg.to_string())))
584            }
585            WalWriteCompletion::Waiter(tx) => {
586                let _ = tx.send(Err(std::io::Error::other(err_msg.to_string())));
587            }
588        }
589    }
590}
591
592struct WalQueueItem {
593    payload: Vec<u8>,
594    append_entries: usize,
595    timeline_batch_id: Option<u64>,
596    timeline_submit_ts_micros: Option<u64>,
597    completion: WalWriteCompletion,
598}
599
600#[derive(Debug, Clone, Copy)]
601struct TimelineMarker {
602    batch_id: u64,
603    submit_ts_micros: u64,
604}
605
606#[derive(Default)]
607struct WalQueueState {
608    queue: VecDeque<WalQueueItem>,
609    queued_bytes: usize,
610    flushing: bool,
611}
612
613pub struct AstraLogStore {
614    inner: Arc<RwLock<LogStoreState>>,
615    wal_device: Arc<Mutex<WalDevice>>,
616    wal_cfg: WalBatchConfig,
617    wal_queue: Arc<AsyncMutex<WalQueueState>>,
618    wal_queue_space: Arc<Notify>,
619    last_flushed_committed: Option<LogId<u64>>,
620    last_committed_flush_at: Instant,
621    timeline_by_log_index: BTreeMap<u64, TimelineMarker>,
622}
623
624#[derive(Debug, Clone)]
625pub struct AstraLogReader {
626    inner: Arc<RwLock<LogStoreState>>,
627}
628
629impl AstraLogStore {
630    const COMMITTED_FLUSH_INTERVAL: Duration = Duration::from_millis(25);
631    const COMMITTED_FLUSH_STEP: u64 = 64;
632
633    pub async fn open(data_dir: &Path, mut cfg: WalBatchConfig) -> Result<Self, StorageError<u64>> {
634        cfg.max_batch_requests = cfg.max_batch_requests.max(1);
635        cfg.max_batch_bytes = cfg.max_batch_bytes.max(4 * 1024);
636        cfg.pending_limit = cfg.pending_limit.max(1);
637        cfg.low_concurrency_threshold = cfg.low_concurrency_threshold.max(1);
638
639        std::fs::create_dir_all(data_dir).map_err(|e| {
640            StorageError::from_io_error(
641                openraft::ErrorSubject::Store,
642                openraft::ErrorVerb::Write,
643                e,
644            )
645        })?;
646
647        let wal_path = data_dir.join("unified-raft.wal");
648        let mut state = LogStoreState::default();
649        let wal_offset = replay_wal(&wal_path, &mut state)?;
650        let (device, engine) = WalDevice::open(&wal_path, wal_offset, &cfg).map_err(|e| {
651            StorageError::from_io_error(
652                openraft::ErrorSubject::Store,
653                openraft::ErrorVerb::Write,
654                e,
655            )
656        })?;
657
658        let last_flushed_committed = state.committed.clone();
659
660        let inner = Arc::new(RwLock::new(state));
661
662        info!(
663            wal_path = %data_dir.join("unified-raft.wal").display(),
664            max_batch_requests = cfg.max_batch_requests,
665            max_batch_bytes = cfg.max_batch_bytes,
666            max_linger_us = cfg.max_linger.as_micros(),
667            low_concurrency_threshold = cfg.low_concurrency_threshold,
668            low_linger_us = cfg.low_linger.as_micros(),
669            pending_limit = cfg.pending_limit,
670            segment_bytes = cfg.segment_bytes,
671            io_engine = %cfg.io_engine.as_str(),
672            selected_engine = engine,
673            "wal direct append path initialized"
674        );
675
676        let mut timeline_by_log_index = BTreeMap::new();
677        {
678            let guard = inner.read().await;
679            for (log_index, entry) in &guard.logs {
680                if let EntryPayload::Normal(req) = &entry.payload {
681                    match req {
682                        AstraWriteRequest::PutBatchRef {
683                            batch_id,
684                            submit_ts_micros,
685                            ..
686                        }
687                        | AstraWriteRequest::PutBatchTokenized {
688                            batch_id,
689                            submit_ts_micros,
690                            ..
691                        } if *batch_id > 0 && *submit_ts_micros > 0 => {
692                            timeline_by_log_index.insert(
693                                *log_index,
694                                TimelineMarker {
695                                    batch_id: *batch_id,
696                                    submit_ts_micros: *submit_ts_micros,
697                                },
698                            );
699                        }
700                        _ => {}
701                    }
702                }
703            }
704        }
705
706        let store = Self {
707            inner,
708            wal_device: Arc::new(Mutex::new(device)),
709            wal_cfg: cfg,
710            wal_queue: Arc::new(AsyncMutex::new(WalQueueState::default())),
711            wal_queue_space: Arc::new(Notify::new()),
712            last_flushed_committed,
713            last_committed_flush_at: Instant::now(),
714            timeline_by_log_index,
715        };
716        metrics::set_wal_queue_depth(0);
717        metrics::set_wal_queue_bytes(0);
718        Ok(store)
719    }
720
721    pub async fn wal_queue_snapshot(&self) -> (usize, usize) {
722        let guard = self.wal_queue.lock().await;
723        (guard.queue.len(), guard.queued_bytes)
724    }
725
726    fn encode_records_payload(records: Vec<DurableRecord>) -> Result<Vec<u8>, StorageError<u64>> {
727        encode_durable_records(records).map_err(|e| {
728            StorageError::from_io_error(openraft::ErrorSubject::Logs, openraft::ErrorVerb::Write, e)
729        })
730    }
731
732    fn logical_write_ops(entry: &Entry<AstraTypeConfig>) -> usize {
733        match &entry.payload {
734            EntryPayload::Normal(AstraWriteRequest::PutBatch { ops }) => ops.len().max(1),
735            EntryPayload::Normal(AstraWriteRequest::PutBatchRef { ops, .. }) => ops.len().max(1),
736            EntryPayload::Normal(AstraWriteRequest::PutBatchTokenized { ops, .. }) => {
737                ops.len().max(1)
738            }
739            EntryPayload::Normal(_) => 1,
740            EntryPayload::Blank | EntryPayload::Membership(_) => 0,
741        }
742    }
743
744    fn extract_timeline_meta(entries: &[Entry<AstraTypeConfig>]) -> (Option<u64>, Option<u64>) {
745        for entry in entries {
746            if let EntryPayload::Normal(req) = &entry.payload {
747                match req {
748                    AstraWriteRequest::PutBatchRef {
749                        batch_id,
750                        submit_ts_micros,
751                        ..
752                    }
753                    | AstraWriteRequest::PutBatchTokenized {
754                        batch_id,
755                        submit_ts_micros,
756                        ..
757                    } => {
758                        return (
759                            if *batch_id > 0 { Some(*batch_id) } else { None },
760                            if *submit_ts_micros > 0 {
761                                Some(*submit_ts_micros)
762                            } else {
763                                None
764                            },
765                        );
766                    }
767                    _ => {}
768                }
769            }
770        }
771        (None, None)
772    }
773
774    fn extract_timeline_markers(entries: &[Entry<AstraTypeConfig>]) -> Vec<(u64, TimelineMarker)> {
775        let mut out = Vec::new();
776        for entry in entries {
777            if let EntryPayload::Normal(req) = &entry.payload {
778                match req {
779                    AstraWriteRequest::PutBatchRef {
780                        batch_id,
781                        submit_ts_micros,
782                        ..
783                    }
784                    | AstraWriteRequest::PutBatchTokenized {
785                        batch_id,
786                        submit_ts_micros,
787                        ..
788                    } if *batch_id > 0 && *submit_ts_micros > 0 => {
789                        out.push((
790                            entry.log_id.index,
791                            TimelineMarker {
792                                batch_id: *batch_id,
793                                submit_ts_micros: *submit_ts_micros,
794                            },
795                        ));
796                    }
797                    _ => {}
798                }
799            }
800        }
801        out
802    }
803
804    async fn await_wal_flush(
805        rx: oneshot::Receiver<std::io::Result<()>>,
806    ) -> Result<(), StorageError<u64>> {
807        match rx.await {
808            Ok(result) => result.map_err(|e| {
809                StorageError::from_io_error(
810                    openraft::ErrorSubject::Logs,
811                    openraft::ErrorVerb::Write,
812                    e,
813                )
814            }),
815            Err(e) => Err(StorageError::from_io_error(
816                openraft::ErrorSubject::Logs,
817                openraft::ErrorVerb::Write,
818                std::io::Error::other(format!("wal flush callback dropped: {e}")),
819            )),
820        }
821    }
822
823    fn spawn_wal_flush_worker(&self) {
824        let wal_device = self.wal_device.clone();
825        let wal_cfg = self.wal_cfg.clone();
826        let wal_queue = self.wal_queue.clone();
827        let wal_queue_space = self.wal_queue_space.clone();
828
829        tokio::spawn(async move {
830            loop {
831                let put_inflight = metrics::current_put_inflight_requests();
832                let mut effective_linger = wal_cfg.max_linger;
833                if put_inflight < wal_cfg.low_concurrency_threshold as u64 {
834                    effective_linger = wal_cfg.low_linger;
835                }
836                metrics::set_wal_effective_linger_us(effective_linger.as_micros() as u64);
837
838                if effective_linger > Duration::ZERO {
839                    let should_linger = {
840                        let guard = wal_queue.lock().await;
841                        !guard.queue.is_empty()
842                            && guard.queue.len() < wal_cfg.max_batch_requests
843                            && guard.queued_bytes < wal_cfg.max_batch_bytes
844                    };
845                    if should_linger {
846                        tokio::time::sleep(effective_linger).await;
847                    }
848                }
849
850                let batch = {
851                    let mut guard = wal_queue.lock().await;
852                    if guard.queue.is_empty() {
853                        guard.flushing = false;
854                        metrics::set_wal_queue_depth(0);
855                        metrics::set_wal_queue_bytes(0);
856                        wal_queue_space.notify_waiters();
857                        break;
858                    }
859
860                    let mut drained = Vec::new();
861                    let mut batch_bytes = 0_usize;
862                    while let Some(front) = guard.queue.front() {
863                        let front_bytes = front.payload.len();
864                        let reached_request_limit = drained.len() >= wal_cfg.max_batch_requests;
865                        let reached_byte_limit = !drained.is_empty()
866                            && batch_bytes.saturating_add(front_bytes) > wal_cfg.max_batch_bytes;
867                        if reached_request_limit || reached_byte_limit {
868                            break;
869                        }
870
871                        let item = guard.queue.pop_front().expect("queue front exists");
872                        guard.queued_bytes = guard.queued_bytes.saturating_sub(item.payload.len());
873                        batch_bytes = batch_bytes.saturating_add(item.payload.len());
874                        drained.push(item);
875                    }
876
877                    if drained.is_empty() {
878                        if let Some(item) = guard.queue.pop_front() {
879                            guard.queued_bytes =
880                                guard.queued_bytes.saturating_sub(item.payload.len());
881                            drained.push(item);
882                        }
883                    }
884                    metrics::set_wal_queue_depth(guard.queue.len());
885                    metrics::set_wal_queue_bytes(guard.queued_bytes);
886                    wal_queue_space.notify_waiters();
887                    drained
888                };
889
890                if batch.is_empty() {
891                    continue;
892                }
893
894                let requests = batch.len();
895                let append_entries = batch.iter().map(|item| item.append_entries).sum::<usize>();
896                let payload_bytes = batch.iter().map(|item| item.payload.len()).sum::<usize>();
897                let timeline_batch_id = batch.iter().find_map(|item| item.timeline_batch_id);
898                let timeline_submit_ts_micros =
899                    batch.iter().find_map(|item| item.timeline_submit_ts_micros);
900                let mut payload = Vec::with_capacity(payload_bytes);
901                for item in &batch {
902                    payload.extend_from_slice(&item.payload);
903                }
904
905                let started = Instant::now();
906                let io_result = tokio::task::spawn_blocking({
907                    let wal_device = wal_device.clone();
908                    move || -> std::io::Result<()> {
909                        let mut wal_device = wal_device
910                            .lock()
911                            .map_err(|_| std::io::Error::other("wal device lock poisoned"))?;
912                        wal_device.append_payload(&payload)?;
913                        wal_device.sync_data()?;
914                        Ok(())
915                    }
916                })
917                .await;
918
919                let wal_sync_duration_ms = started.elapsed().as_millis() as u64;
920
921                let flush_outcome = match io_result {
922                    Ok(result) => result,
923                    Err(e) => Err(std::io::Error::other(format!(
924                        "wal sync worker join error: {e}"
925                    ))),
926                };
927
928                match flush_outcome {
929                    Ok(()) => {
930                        let since_submit_ms = timeline_submit_ts_micros
931                            .and_then(|submit| now_micros().checked_sub(submit))
932                            .map(|delta| delta / 1_000);
933                        if append_entries >= 50 {
934                            info!(
935                                requests,
936                                append_entries,
937                                payload_bytes,
938                                wal_sync_duration_ms,
939                                batch_id = timeline_batch_id,
940                                since_submit_ms,
941                                fdatasync_calls = 1,
942                                "wal vectorized append flush complete"
943                            );
944                        } else {
945                            debug!(
946                                requests,
947                                append_entries,
948                                payload_bytes,
949                                wal_sync_duration_ms,
950                                batch_id = timeline_batch_id,
951                                since_submit_ms,
952                                fdatasync_calls = 1,
953                                "wal direct flush complete"
954                            );
955                        }
956                        for item in batch {
957                            item.completion.complete_ok();
958                        }
959                    }
960                    Err(err) => {
961                        let err_msg = err.to_string();
962                        warn!(
963                            requests,
964                            append_entries,
965                            payload_bytes,
966                            wal_sync_duration_ms,
967                            error = %err_msg,
968                            "wal batch flush failed"
969                        );
970                        for item in batch {
971                            item.completion.complete_err(&err_msg);
972                        }
973                    }
974                }
975            }
976        });
977    }
978
979    async fn enqueue_payload(
980        &self,
981        payload: Vec<u8>,
982        append_entries: usize,
983        timeline_batch_id: Option<u64>,
984        timeline_submit_ts_micros: Option<u64>,
985        completion: WalWriteCompletion,
986    ) -> Result<(), StorageError<u64>> {
987        if payload.is_empty() {
988            completion.complete_ok();
989            return Ok(());
990        }
991
992        let payload_len = payload.len();
993        let pending_limit = self.wal_cfg.pending_limit.max(1);
994        let mut maybe_item = Some(WalQueueItem {
995            payload,
996            append_entries,
997            timeline_batch_id,
998            timeline_submit_ts_micros,
999            completion,
1000        });
1001
1002        loop {
1003            let mut start_worker = false;
1004            {
1005                let mut guard = self.wal_queue.lock().await;
1006                if guard.queue.len() < pending_limit {
1007                    let item = maybe_item.take().expect("wal queue item available");
1008                    guard.queued_bytes = guard.queued_bytes.saturating_add(payload_len);
1009                    guard.queue.push_back(item);
1010                    metrics::set_wal_queue_depth(guard.queue.len());
1011                    metrics::set_wal_queue_bytes(guard.queued_bytes);
1012
1013                    if !guard.flushing {
1014                        guard.flushing = true;
1015                        start_worker = true;
1016                    }
1017                }
1018            }
1019
1020            if maybe_item.is_none() {
1021                if start_worker {
1022                    self.spawn_wal_flush_worker();
1023                }
1024                return Ok(());
1025            }
1026
1027            self.wal_queue_space.notified().await;
1028        }
1029    }
1030
1031    async fn append_records(
1032        &self,
1033        records: Vec<DurableRecord>,
1034        append_entries: usize,
1035    ) -> Result<(), StorageError<u64>> {
1036        let payload = Self::encode_records_payload(records)?;
1037        let (tx, rx) = oneshot::channel::<std::io::Result<()>>();
1038        self.enqueue_payload(
1039            payload,
1040            append_entries,
1041            None,
1042            None,
1043            WalWriteCompletion::Waiter(tx),
1044        )
1045        .await?;
1046        Self::await_wal_flush(rx).await
1047    }
1048
1049    async fn append_records_with_callback(
1050        &self,
1051        records: Vec<DurableRecord>,
1052        append_entries: usize,
1053        timeline_batch_id: Option<u64>,
1054        timeline_submit_ts_micros: Option<u64>,
1055        callback: LogFlushed<AstraTypeConfig>,
1056    ) -> Result<(), StorageError<u64>> {
1057        let payload = match Self::encode_records_payload(records) {
1058            Ok(payload) => payload,
1059            Err(err) => {
1060                callback.log_io_completed(Err(std::io::Error::other(err.to_string())));
1061                return Err(err);
1062            }
1063        };
1064
1065        if let Err(err) = self
1066            .enqueue_payload(
1067                payload,
1068                append_entries,
1069                timeline_batch_id,
1070                timeline_submit_ts_micros,
1071                WalWriteCompletion::Append(callback),
1072            )
1073            .await
1074        {
1075            // Callback is delivered on flush completion; if enqueue itself fails, report here.
1076            // This path should be rare and indicates shutdown or internal queue failure.
1077            warn!(error = %err, "wal enqueue failed for append callback path");
1078            return Err(err);
1079        }
1080        Ok(())
1081    }
1082}
1083
1084fn replay_wal(path: &Path, state: &mut LogStoreState) -> Result<u64, StorageError<u64>> {
1085    if !path.exists() {
1086        return Ok(0);
1087    }
1088
1089    let bytes = std::fs::read(path).map_err(|e| {
1090        StorageError::from_io_error(openraft::ErrorSubject::Store, openraft::ErrorVerb::Read, e)
1091    })?;
1092
1093    let mut offset = 0_usize;
1094    let mut durable_end = 0_usize;
1095    while offset + 8 <= bytes.len() {
1096        let header = &bytes[offset..offset + 8];
1097        if header[..4] == [0, 0, 0, 0] {
1098            offset = align_up(offset + 1, 4096);
1099            continue;
1100        }
1101
1102        if &header[..4] != b"ASTR" {
1103            break;
1104        }
1105
1106        let payload_len = u32::from_le_bytes(header[4..8].try_into().expect("slice len")) as usize;
1107        let start = offset + 8;
1108        let end = start + payload_len;
1109        if end > bytes.len() {
1110            break;
1111        }
1112
1113        let payload = &bytes[start..end];
1114        let mut cursor = Cursor::new(payload);
1115        while (cursor.position() as usize) < payload.len() {
1116            let mut len_buf = [0_u8; 4];
1117            if cursor.read_exact(&mut len_buf).is_err() {
1118                break;
1119            }
1120            let rec_len = u32::from_le_bytes(len_buf) as usize;
1121            if rec_len == 0 {
1122                break;
1123            }
1124
1125            let mut rec_buf = vec![0_u8; rec_len];
1126            if cursor.read_exact(&mut rec_buf).is_err() {
1127                break;
1128            }
1129
1130            let rec: DurableRecord = match bincode::deserialize(&rec_buf) {
1131                Ok(rec) => rec,
1132                Err(_) => {
1133                    // Treat undecodable trailing data as EOF to keep startup resilient
1134                    // across WAL record-format changes.
1135                    break;
1136                }
1137            };
1138
1139            apply_durable_record(state, rec);
1140        }
1141
1142        offset = align_up(end, 4096);
1143        durable_end = offset;
1144    }
1145
1146    Ok(durable_end as u64)
1147}
1148
1149fn apply_durable_record(state: &mut LogStoreState, rec: DurableRecord) {
1150    match rec {
1151        DurableRecord::Append { entries } => {
1152            for entry in entries {
1153                state.logs.insert(entry.log_id.index, entry);
1154            }
1155        }
1156        DurableRecord::Truncate { since } => {
1157            let keys = state
1158                .logs
1159                .range(since.index..)
1160                .map(|(k, _)| *k)
1161                .collect::<Vec<_>>();
1162            for k in keys {
1163                state.logs.remove(&k);
1164            }
1165        }
1166        DurableRecord::Purge { upto } => {
1167            let keys = state
1168                .logs
1169                .range(..=upto.index)
1170                .map(|(k, _)| *k)
1171                .collect::<Vec<_>>();
1172            for k in keys {
1173                state.logs.remove(&k);
1174            }
1175            state.last_purged_log_id = Some(upto);
1176        }
1177        DurableRecord::Vote { vote } => {
1178            state.vote = Some(vote);
1179        }
1180        DurableRecord::Committed { committed } => {
1181            state.committed = committed;
1182        }
1183    }
1184}
1185
1186fn encode_durable_records(records: Vec<DurableRecord>) -> std::io::Result<Vec<u8>> {
1187    let mut encoded = Vec::new();
1188    for rec in records {
1189        let rec_bytes = bincode::serialize(&rec)
1190            .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err.to_string()))?;
1191        encoded.extend_from_slice(&(rec_bytes.len() as u32).to_le_bytes());
1192        encoded.extend_from_slice(&rec_bytes);
1193    }
1194    Ok(encoded)
1195}
1196
1197impl RaftLogReader<AstraTypeConfig> for AstraLogReader {
1198    async fn try_get_log_entries<RB: RangeBounds<u64> + Clone + Debug + openraft::OptionalSend>(
1199        &mut self,
1200        range: RB,
1201    ) -> Result<Vec<Entry<AstraTypeConfig>>, StorageError<u64>> {
1202        let (start, end) = range_to_bounds(range);
1203        let guard = self.inner.read().await;
1204        let out = guard
1205            .logs
1206            .range(start..end)
1207            .map(|(_, v)| v.clone())
1208            .collect::<Vec<_>>();
1209        Ok(out)
1210    }
1211}
1212
1213impl RaftLogReader<AstraTypeConfig> for AstraLogStore {
1214    async fn try_get_log_entries<RB: RangeBounds<u64> + Clone + Debug + openraft::OptionalSend>(
1215        &mut self,
1216        range: RB,
1217    ) -> Result<Vec<Entry<AstraTypeConfig>>, StorageError<u64>> {
1218        let mut reader = AstraLogReader {
1219            inner: self.inner.clone(),
1220        };
1221        reader.try_get_log_entries(range).await
1222    }
1223}
1224
1225fn range_to_bounds<RB: RangeBounds<u64>>(range: RB) -> (u64, u64) {
1226    let start = match range.start_bound() {
1227        Bound::Included(v) => *v,
1228        Bound::Excluded(v) => v.saturating_add(1),
1229        Bound::Unbounded => 0,
1230    };
1231    let end = match range.end_bound() {
1232        Bound::Included(v) => v.saturating_add(1),
1233        Bound::Excluded(v) => *v,
1234        Bound::Unbounded => u64::MAX,
1235    };
1236    (start, end)
1237}
1238
1239impl RaftLogStorage<AstraTypeConfig> for AstraLogStore {
1240    type LogReader = AstraLogReader;
1241
1242    async fn get_log_state(&mut self) -> Result<LogState<AstraTypeConfig>, StorageError<u64>> {
1243        let guard = self.inner.read().await;
1244        let last_log_id = guard
1245            .logs
1246            .values()
1247            .last()
1248            .map(|e| e.log_id.clone())
1249            .or_else(|| guard.last_purged_log_id.clone());
1250
1251        Ok(LogState {
1252            last_purged_log_id: guard.last_purged_log_id.clone(),
1253            last_log_id,
1254        })
1255    }
1256
1257    async fn get_log_reader(&mut self) -> Self::LogReader {
1258        AstraLogReader {
1259            inner: self.inner.clone(),
1260        }
1261    }
1262
1263    async fn save_vote(&mut self, vote: &Vote<u64>) -> Result<(), StorageError<u64>> {
1264        {
1265            let mut guard = self.inner.write().await;
1266            guard.vote = Some(vote.clone());
1267        }
1268        self.append_records(vec![DurableRecord::Vote { vote: vote.clone() }], 0)
1269            .await
1270    }
1271
1272    async fn read_vote(&mut self) -> Result<Option<Vote<u64>>, StorageError<u64>> {
1273        let guard = self.inner.read().await;
1274        Ok(guard.vote.clone())
1275    }
1276
1277    async fn save_committed(
1278        &mut self,
1279        committed: Option<LogId<u64>>,
1280    ) -> Result<(), StorageError<u64>> {
1281        {
1282            let mut guard = self.inner.write().await;
1283            guard.committed = committed.clone();
1284        }
1285
1286        if let Some(committed_log_id) = committed.as_ref() {
1287            if let Some((tracked_log_index, marker)) = self
1288                .timeline_by_log_index
1289                .range(..=committed_log_id.index)
1290                .next_back()
1291                .map(|(idx, marker)| (*idx, *marker))
1292            {
1293                let since_submit_ms = now_micros()
1294                    .checked_sub(marker.submit_ts_micros)
1295                    .map(|delta| delta / 1_000)
1296                    .unwrap_or(0);
1297                info!(
1298                    stage = "quorum_ack_commit_advanced",
1299                    batch_id = marker.batch_id,
1300                    committed_index = committed_log_id.index,
1301                    tracked_log_index,
1302                    since_submit_ms,
1303                    "raft timeline"
1304                );
1305                metrics::observe_put_quorum_ack_ms(since_submit_ms);
1306            }
1307
1308            let keep_from = committed_log_id.index.saturating_add(1);
1309            self.timeline_by_log_index = self.timeline_by_log_index.split_off(&keep_from);
1310        }
1311
1312        let should_flush = match (&self.last_flushed_committed, &committed) {
1313            (None, None) => false,
1314            (None, Some(_)) => true,
1315            (Some(_), None) => true,
1316            (Some(prev), Some(next)) => {
1317                let step_reached =
1318                    next.index >= prev.index.saturating_add(Self::COMMITTED_FLUSH_STEP);
1319                step_reached
1320                    || self.last_committed_flush_at.elapsed() >= Self::COMMITTED_FLUSH_INTERVAL
1321            }
1322        };
1323
1324        if should_flush {
1325            self.append_records(
1326                vec![DurableRecord::Committed {
1327                    committed: committed.clone(),
1328                }],
1329                0,
1330            )
1331            .await?;
1332            self.last_flushed_committed = committed;
1333            self.last_committed_flush_at = Instant::now();
1334        }
1335
1336        Ok(())
1337    }
1338
1339    async fn read_committed(&mut self) -> Result<Option<LogId<u64>>, StorageError<u64>> {
1340        let guard = self.inner.read().await;
1341        Ok(guard.committed.clone())
1342    }
1343
1344    async fn append<I>(
1345        &mut self,
1346        entries: I,
1347        callback: LogFlushed<AstraTypeConfig>,
1348    ) -> Result<(), StorageError<u64>>
1349    where
1350        I: IntoIterator<Item = Entry<AstraTypeConfig>> + openraft::OptionalSend,
1351        I::IntoIter: openraft::OptionalSend,
1352    {
1353        let entries_vec = entries.into_iter().collect::<Vec<_>>();
1354        if entries_vec.is_empty() {
1355            callback.log_io_completed(Ok(()));
1356            return Ok(());
1357        }
1358
1359        let timeline_markers = Self::extract_timeline_markers(&entries_vec);
1360
1361        {
1362            let mut guard = self.inner.write().await;
1363            for ent in &entries_vec {
1364                guard.logs.insert(ent.log_id.index, ent.clone());
1365            }
1366        }
1367        for (log_index, marker) in timeline_markers {
1368            self.timeline_by_log_index.insert(log_index, marker);
1369        }
1370
1371        let append_entries = entries_vec
1372            .iter()
1373            .map(Self::logical_write_ops)
1374            .sum::<usize>()
1375            .max(1);
1376        let (timeline_batch_id, timeline_submit_ts_micros) =
1377            Self::extract_timeline_meta(&entries_vec);
1378        if let Some(batch_id) = timeline_batch_id {
1379            let since_submit_ms = timeline_submit_ts_micros
1380                .and_then(|submit| now_micros().checked_sub(submit))
1381                .map(|delta| delta / 1_000);
1382            debug!(
1383                stage = "raft_log_append_in_memory",
1384                batch_id,
1385                append_entries,
1386                log_entries = entries_vec.len(),
1387                since_submit_ms,
1388                "raft timeline"
1389            );
1390        }
1391        self.append_records_with_callback(
1392            vec![DurableRecord::Append {
1393                entries: entries_vec,
1394            }],
1395            append_entries,
1396            timeline_batch_id,
1397            timeline_submit_ts_micros,
1398            callback,
1399        )
1400        .await
1401    }
1402
1403    async fn truncate(&mut self, log_id: LogId<u64>) -> Result<(), StorageError<u64>> {
1404        {
1405            let mut guard = self.inner.write().await;
1406            let keys = guard
1407                .logs
1408                .range(log_id.index..)
1409                .map(|(k, _)| *k)
1410                .collect::<Vec<_>>();
1411            for k in keys {
1412                guard.logs.remove(&k);
1413            }
1414        }
1415        self.timeline_by_log_index = self.timeline_by_log_index.split_off(&log_id.index);
1416
1417        self.append_records(vec![DurableRecord::Truncate { since: log_id }], 0)
1418            .await
1419    }
1420
1421    async fn purge(&mut self, log_id: LogId<u64>) -> Result<(), StorageError<u64>> {
1422        {
1423            let mut guard = self.inner.write().await;
1424            let keys = guard
1425                .logs
1426                .range(..=log_id.index)
1427                .map(|(k, _)| *k)
1428                .collect::<Vec<_>>();
1429            for k in keys {
1430                guard.logs.remove(&k);
1431            }
1432            guard.last_purged_log_id = Some(log_id.clone());
1433        }
1434        let keep_from = log_id.index.saturating_add(1);
1435        self.timeline_by_log_index = self.timeline_by_log_index.split_off(&keep_from);
1436
1437        self.append_records(vec![DurableRecord::Purge { upto: log_id }], 0)
1438            .await
1439    }
1440}
1441
1442#[derive(Debug, Clone)]
1443struct SnapshotBlob {
1444    meta: SnapshotMeta<u64, BasicNode>,
1445    bytes: Vec<u8>,
1446}
1447
1448#[derive(Debug, Clone, Serialize, Deserialize)]
1449struct SnapshotPayloadV2 {
1450    store: SnapshotState,
1451    token_dict: Vec<(u32, Vec<u8>)>,
1452    token_dict_epoch: u64,
1453}
1454
1455#[derive(Debug)]
1456struct AstraStateMachineShared {
1457    store: Arc<KvStore>,
1458    last_applied_log: Option<LogId<u64>>,
1459    last_membership: StoredMembership<u64, BasicNode>,
1460    current_snapshot: Option<SnapshotBlob>,
1461    token_dict: HashMap<u32, Vec<u8>>,
1462    token_dict_epoch: u64,
1463}
1464
1465#[derive(Debug, Clone)]
1466pub struct AstraStateMachine {
1467    shared: Arc<RwLock<AstraStateMachineShared>>,
1468}
1469
1470impl AstraStateMachine {
1471    pub fn new(store: Arc<KvStore>) -> Self {
1472        Self {
1473            shared: Arc::new(RwLock::new(AstraStateMachineShared {
1474                store,
1475                last_applied_log: None,
1476                last_membership: StoredMembership::default(),
1477                current_snapshot: None,
1478                token_dict: HashMap::new(),
1479                token_dict_epoch: 0,
1480            })),
1481        }
1482    }
1483}
1484
1485fn txn_compare_i64(lhs: i64, rhs: i64, result: &AstraTxnCmpResult) -> bool {
1486    match result {
1487        AstraTxnCmpResult::Equal => lhs == rhs,
1488        AstraTxnCmpResult::Greater => lhs > rhs,
1489        AstraTxnCmpResult::Less => lhs < rhs,
1490        AstraTxnCmpResult::NotEqual => lhs != rhs,
1491    }
1492}
1493
1494fn txn_compare_bytes(lhs: &[u8], rhs: &[u8], result: &AstraTxnCmpResult) -> bool {
1495    match result {
1496        AstraTxnCmpResult::Equal => lhs == rhs,
1497        AstraTxnCmpResult::Greater => lhs > rhs,
1498        AstraTxnCmpResult::Less => lhs < rhs,
1499        AstraTxnCmpResult::NotEqual => lhs != rhs,
1500    }
1501}
1502
1503fn txn_ops_have_write(ops: &[AstraTxnOp]) -> bool {
1504    for op in ops {
1505        match op {
1506            AstraTxnOp::Put { .. } | AstraTxnOp::DeleteRange { .. } => return true,
1507            AstraTxnOp::Txn {
1508                success, failure, ..
1509            } => {
1510                if txn_ops_have_write(success) || txn_ops_have_write(failure) {
1511                    return true;
1512                }
1513            }
1514            AstraTxnOp::Range { .. } => {}
1515        }
1516    }
1517    false
1518}
1519
1520fn eval_txn_compare(store: &KvStore, cmp: &AstraTxnCompare) -> Result<bool, StoreError> {
1521    let out = store.range(&cmp.key, &cmp.range_end, 0, 0, false, false)?;
1522
1523    if out.kvs.is_empty() {
1524        return match (&cmp.target, &cmp.target_value) {
1525            (AstraTxnCmpTarget::Value, AstraTxnCmpValue::Bytes(rhs)) => {
1526                Ok(txn_compare_bytes(&[], rhs, &cmp.result))
1527            }
1528            (
1529                AstraTxnCmpTarget::Version
1530                | AstraTxnCmpTarget::CreateRevision
1531                | AstraTxnCmpTarget::ModRevision
1532                | AstraTxnCmpTarget::Lease,
1533                AstraTxnCmpValue::I64(rhs),
1534            ) => Ok(txn_compare_i64(0, *rhs, &cmp.result)),
1535            _ => Err(StoreError::InvalidArgument(
1536                "txn compare target and value type mismatch".to_string(),
1537            )),
1538        };
1539    }
1540
1541    for (_, kv) in out.kvs {
1542        let ok = match (&cmp.target, &cmp.target_value) {
1543            (AstraTxnCmpTarget::Version, AstraTxnCmpValue::I64(rhs)) => {
1544                txn_compare_i64(kv.version, *rhs, &cmp.result)
1545            }
1546            (AstraTxnCmpTarget::CreateRevision, AstraTxnCmpValue::I64(rhs)) => {
1547                txn_compare_i64(kv.create_revision, *rhs, &cmp.result)
1548            }
1549            (AstraTxnCmpTarget::ModRevision, AstraTxnCmpValue::I64(rhs)) => {
1550                txn_compare_i64(kv.mod_revision, *rhs, &cmp.result)
1551            }
1552            (AstraTxnCmpTarget::Lease, AstraTxnCmpValue::I64(rhs)) => {
1553                txn_compare_i64(kv.lease, *rhs, &cmp.result)
1554            }
1555            (AstraTxnCmpTarget::Value, AstraTxnCmpValue::Bytes(rhs)) => {
1556                txn_compare_bytes(&kv.value, rhs, &cmp.result)
1557            }
1558            _ => {
1559                return Err(StoreError::InvalidArgument(
1560                    "txn compare target and value type mismatch".to_string(),
1561                ));
1562            }
1563        };
1564        if !ok {
1565            return Ok(false);
1566        }
1567    }
1568
1569    Ok(true)
1570}
1571
1572fn eval_txn_compares(store: &KvStore, compares: &[AstraTxnCompare]) -> Result<bool, StoreError> {
1573    for cmp in compares {
1574        if !eval_txn_compare(store, cmp)? {
1575            return Ok(false);
1576        }
1577    }
1578    Ok(true)
1579}
1580
1581fn apply_txn_ops(
1582    store: &KvStore,
1583    ops: &[AstraTxnOp],
1584    forced_revision: Option<i64>,
1585) -> Result<Vec<AstraTxnOpResponse>, StoreError> {
1586    let mut out = Vec::with_capacity(ops.len());
1587
1588    for op in ops {
1589        match op {
1590            AstraTxnOp::Range {
1591                key,
1592                range_end,
1593                limit,
1594                revision,
1595                keys_only,
1596                count_only,
1597            } => {
1598                let RangeOutput {
1599                    revision,
1600                    count,
1601                    more,
1602                    kvs,
1603                } = store.range(key, range_end, *limit, *revision, *keys_only, *count_only)?;
1604                out.push(AstraTxnOpResponse::Range {
1605                    revision,
1606                    count,
1607                    more,
1608                    kvs,
1609                });
1610            }
1611            AstraTxnOp::Put {
1612                key,
1613                value,
1614                lease,
1615                ignore_value,
1616                ignore_lease,
1617                prev_kv,
1618            } => {
1619                let PutOutput {
1620                    revision,
1621                    prev,
1622                    current: _,
1623                } = if let Some(rev) = forced_revision {
1624                    store.apply_put_at_revision(
1625                        key.clone(),
1626                        value.clone(),
1627                        *lease,
1628                        *ignore_value,
1629                        *ignore_lease,
1630                        rev,
1631                    )?
1632                } else {
1633                    store.put(
1634                        key.clone(),
1635                        value.clone(),
1636                        *lease,
1637                        *ignore_value,
1638                        *ignore_lease,
1639                    )?
1640                };
1641                out.push(AstraTxnOpResponse::Put {
1642                    revision,
1643                    prev: if *prev_kv { prev } else { None },
1644                });
1645            }
1646            AstraTxnOp::DeleteRange {
1647                key,
1648                range_end,
1649                prev_kv,
1650            } => {
1651                let DeleteOutput {
1652                    revision,
1653                    deleted,
1654                    prev_kvs,
1655                } = if let Some(rev) = forced_revision {
1656                    store.apply_delete_at_revision(key, range_end, *prev_kv, rev)?
1657                } else {
1658                    store.delete_range(key, range_end, *prev_kv)?
1659                };
1660                out.push(AstraTxnOpResponse::Delete {
1661                    revision,
1662                    deleted,
1663                    prev_kvs,
1664                });
1665            }
1666            AstraTxnOp::Txn {
1667                compare,
1668                success,
1669                failure,
1670            } => {
1671                let succeeded = eval_txn_compares(store, compare)?;
1672                let branch = if succeeded { success } else { failure };
1673                let nested_forced = if txn_ops_have_write(branch) {
1674                    forced_revision
1675                } else {
1676                    None
1677                };
1678                let responses = apply_txn_ops(store, branch, nested_forced)?;
1679                out.push(AstraTxnOpResponse::Txn {
1680                    revision: forced_revision.unwrap_or_else(|| store.current_revision()),
1681                    succeeded,
1682                    responses,
1683                });
1684            }
1685        }
1686    }
1687
1688    Ok(out)
1689}
1690
1691#[derive(Debug, Clone)]
1692pub struct AstraSnapshotBuilder {
1693    shared: Arc<RwLock<AstraStateMachineShared>>,
1694}
1695
1696impl RaftSnapshotBuilder<AstraTypeConfig> for AstraSnapshotBuilder {
1697    async fn build_snapshot(&mut self) -> Result<Snapshot<AstraTypeConfig>, StorageError<u64>> {
1698        let (snapshot_state, token_dict, token_dict_epoch, last_log_id, last_membership) = {
1699            let guard = self.shared.read().await;
1700            (
1701                guard.store.snapshot_state(),
1702                guard.token_dict.clone(),
1703                guard.token_dict_epoch,
1704                guard.last_applied_log.clone(),
1705                guard.last_membership.clone(),
1706            )
1707        };
1708
1709        let payload = SnapshotPayloadV2 {
1710            store: snapshot_state,
1711            token_dict: token_dict.into_iter().collect::<Vec<_>>(),
1712            token_dict_epoch,
1713        };
1714
1715        let bytes = bincode::serialize(&payload).map_err(|e| {
1716            StorageError::from_io_error(
1717                openraft::ErrorSubject::Snapshot(None),
1718                openraft::ErrorVerb::Write,
1719                std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()),
1720            )
1721        })?;
1722
1723        let snapshot_id = format!(
1724            "snapshot-{}-{}",
1725            last_log_id.as_ref().map(|v| v.index).unwrap_or_default(),
1726            SystemTime::now()
1727                .duration_since(UNIX_EPOCH)
1728                .map(|d| d.as_millis())
1729                .unwrap_or_default()
1730        );
1731
1732        let meta = SnapshotMeta {
1733            last_log_id: last_log_id.clone(),
1734            last_membership: last_membership.clone(),
1735            snapshot_id,
1736        };
1737
1738        {
1739            let mut guard = self.shared.write().await;
1740            guard.current_snapshot = Some(SnapshotBlob {
1741                meta: meta.clone(),
1742                bytes: bytes.clone(),
1743            });
1744        }
1745
1746        Ok(Snapshot {
1747            meta,
1748            snapshot: Box::new(Cursor::new(bytes)),
1749        })
1750    }
1751}
1752
1753impl RaftStateMachine<AstraTypeConfig> for AstraStateMachine {
1754    type SnapshotBuilder = AstraSnapshotBuilder;
1755
1756    async fn applied_state(
1757        &mut self,
1758    ) -> Result<(Option<LogId<u64>>, StoredMembership<u64, BasicNode>), StorageError<u64>> {
1759        let guard = self.shared.read().await;
1760        Ok((
1761            guard.last_applied_log.clone(),
1762            guard.last_membership.clone(),
1763        ))
1764    }
1765
1766    async fn apply<I>(&mut self, entries: I) -> Result<Vec<AstraWriteResponse>, StorageError<u64>>
1767    where
1768        I: IntoIterator<Item = Entry<AstraTypeConfig>> + openraft::OptionalSend,
1769        I::IntoIter: openraft::OptionalSend,
1770    {
1771        let entries = entries.into_iter().collect::<Vec<_>>();
1772
1773        let (store, mut token_dict, mut token_dict_epoch) = {
1774            let guard = self.shared.read().await;
1775            (
1776                guard.store.clone(),
1777                guard.token_dict.clone(),
1778                guard.token_dict_epoch,
1779            )
1780        };
1781
1782        let mut responses = Vec::with_capacity(entries.len());
1783
1784        let mut last_applied = None;
1785        let mut last_membership = None;
1786
1787        for ent in entries {
1788            let log_id = ent.log_id.clone();
1789
1790            match ent.payload {
1791                EntryPayload::Blank => {
1792                    responses.push(AstraWriteResponse::Empty);
1793                }
1794                EntryPayload::Membership(membership) => {
1795                    last_membership = Some(StoredMembership::new(Some(log_id.clone()), membership));
1796                    responses.push(AstraWriteResponse::Empty);
1797                }
1798                EntryPayload::Normal(req) => match req {
1799                    AstraWriteRequest::Put {
1800                        key,
1801                        value,
1802                        lease,
1803                        ignore_value,
1804                        ignore_lease,
1805                        prev_kv,
1806                    } => {
1807                        let PutOutput {
1808                            revision,
1809                            prev,
1810                            current: _,
1811                        } = store
1812                            .put(key, value, lease, ignore_value, ignore_lease)
1813                            .map_err(|e| {
1814                                StorageError::from_io_error(
1815                                    openraft::ErrorSubject::StateMachine,
1816                                    openraft::ErrorVerb::Write,
1817                                    std::io::Error::other(e.to_string()),
1818                                )
1819                            })?;
1820
1821                        responses.push(AstraWriteResponse::Put {
1822                            revision,
1823                            prev: if prev_kv { prev } else { None },
1824                        });
1825                    }
1826                    AstraWriteRequest::PutBatch { ops } => {
1827                        let mut results = Vec::with_capacity(ops.len());
1828                        for op in ops {
1829                            let PutOutput {
1830                                revision,
1831                                prev,
1832                                current: _,
1833                            } = store
1834                                .put(op.key, op.value, op.lease, op.ignore_value, op.ignore_lease)
1835                                .map_err(|e| {
1836                                    StorageError::from_io_error(
1837                                        openraft::ErrorSubject::StateMachine,
1838                                        openraft::ErrorVerb::Write,
1839                                        std::io::Error::other(e.to_string()),
1840                                    )
1841                                })?;
1842
1843                            results.push(AstraBatchPutResult {
1844                                revision,
1845                                prev: if op.prev_kv { prev } else { None },
1846                            });
1847                        }
1848                        responses.push(AstraWriteResponse::PutBatch { results });
1849                    }
1850                    AstraWriteRequest::PutBatchRef {
1851                        batch_id,
1852                        submit_ts_micros,
1853                        values,
1854                        ops,
1855                    } => {
1856                        let op_count = ops.len();
1857                        let unique_values = values.len();
1858                        let mut results = Vec::with_capacity(ops.len());
1859                        for op in ops {
1860                            let value =
1861                                values.get(op.value_idx as usize).cloned().ok_or_else(|| {
1862                                    StorageError::from_io_error(
1863                                        openraft::ErrorSubject::StateMachine,
1864                                        openraft::ErrorVerb::Write,
1865                                        std::io::Error::new(
1866                                            std::io::ErrorKind::InvalidData,
1867                                            format!(
1868                                                "put batch reference index out of bounds: {}",
1869                                                op.value_idx
1870                                            ),
1871                                        ),
1872                                    )
1873                                })?;
1874
1875                            let PutOutput {
1876                                revision,
1877                                prev,
1878                                current: _,
1879                            } = store
1880                                .put(op.key, value, op.lease, op.ignore_value, op.ignore_lease)
1881                                .map_err(|e| {
1882                                    StorageError::from_io_error(
1883                                        openraft::ErrorSubject::StateMachine,
1884                                        openraft::ErrorVerb::Write,
1885                                        std::io::Error::other(e.to_string()),
1886                                    )
1887                                })?;
1888
1889                            results.push(AstraBatchPutResult {
1890                                revision,
1891                                prev: if op.prev_kv { prev } else { None },
1892                            });
1893                        }
1894                        if batch_id > 0 {
1895                            let since_submit_ms = now_micros()
1896                                .checked_sub(submit_ts_micros)
1897                                .map(|delta| delta / 1_000);
1898                            info!(
1899                                stage = "sm_apply_done",
1900                                batch_id,
1901                                op_count,
1902                                unique_values,
1903                                tokenized = false,
1904                                since_submit_ms,
1905                                "raft timeline"
1906                            );
1907                        }
1908                        responses.push(AstraWriteResponse::PutBatch { results });
1909                    }
1910                    AstraWriteRequest::PutBatchTokenized {
1911                        batch_id,
1912                        submit_ts_micros,
1913                        dict_epoch,
1914                        dict_additions,
1915                        ops,
1916                    } => {
1917                        if dict_epoch > token_dict_epoch {
1918                            token_dict_epoch = dict_epoch;
1919                        }
1920                        for add in &dict_additions {
1921                            token_dict.insert(add.token_id, add.value.clone());
1922                        }
1923
1924                        let op_count = ops.len();
1925                        let dict_add_count = dict_additions.len();
1926                        let mut results = Vec::with_capacity(ops.len());
1927                        for op in ops {
1928                            let value = token_dict.get(&op.token_id).cloned().ok_or_else(|| {
1929                                StorageError::from_io_error(
1930                                    openraft::ErrorSubject::StateMachine,
1931                                    openraft::ErrorVerb::Write,
1932                                    std::io::Error::new(
1933                                        std::io::ErrorKind::InvalidData,
1934                                        format!(
1935                                            "missing token in tokenized batch: {}",
1936                                            op.token_id
1937                                        ),
1938                                    ),
1939                                )
1940                            })?;
1941
1942                            let PutOutput {
1943                                revision,
1944                                prev,
1945                                current: _,
1946                            } = store
1947                                .put(op.key, value, op.lease, op.ignore_value, op.ignore_lease)
1948                                .map_err(|e| {
1949                                    StorageError::from_io_error(
1950                                        openraft::ErrorSubject::StateMachine,
1951                                        openraft::ErrorVerb::Write,
1952                                        std::io::Error::other(e.to_string()),
1953                                    )
1954                                })?;
1955
1956                            results.push(AstraBatchPutResult {
1957                                revision,
1958                                prev: if op.prev_kv { prev } else { None },
1959                            });
1960                        }
1961
1962                        if batch_id > 0 {
1963                            let since_submit_ms = now_micros()
1964                                .checked_sub(submit_ts_micros)
1965                                .map(|delta| delta / 1_000);
1966                            info!(
1967                                stage = "sm_apply_done",
1968                                batch_id,
1969                                op_count,
1970                                dict_add_count,
1971                                dict_size = token_dict.len(),
1972                                tokenized = true,
1973                                since_submit_ms,
1974                                "raft timeline"
1975                            );
1976                        }
1977                        responses.push(AstraWriteResponse::PutBatch { results });
1978                    }
1979                    AstraWriteRequest::DeleteRange {
1980                        key,
1981                        range_end,
1982                        prev_kv,
1983                    } => {
1984                        let DeleteOutput {
1985                            revision,
1986                            deleted,
1987                            prev_kvs,
1988                        } = store.delete_range(&key, &range_end, prev_kv).map_err(|e| {
1989                            StorageError::from_io_error(
1990                                openraft::ErrorSubject::StateMachine,
1991                                openraft::ErrorVerb::Write,
1992                                std::io::Error::other(e.to_string()),
1993                            )
1994                        })?;
1995
1996                        responses.push(AstraWriteResponse::Delete {
1997                            revision,
1998                            deleted,
1999                            prev_kvs,
2000                        });
2001                    }
2002                    AstraWriteRequest::Txn {
2003                        compare,
2004                        success,
2005                        failure,
2006                    } => {
2007                        let succeeded = eval_txn_compares(&store, &compare).map_err(|e| {
2008                            StorageError::from_io_error(
2009                                openraft::ErrorSubject::StateMachine,
2010                                openraft::ErrorVerb::Write,
2011                                std::io::Error::other(e.to_string()),
2012                            )
2013                        })?;
2014                        let branch = if succeeded { &success } else { &failure };
2015                        let forced_revision = if txn_ops_have_write(branch) {
2016                            Some(store.reserve_revision())
2017                        } else {
2018                            None
2019                        };
2020                        let op_responses =
2021                            apply_txn_ops(&store, branch, forced_revision).map_err(|e| {
2022                                StorageError::from_io_error(
2023                                    openraft::ErrorSubject::StateMachine,
2024                                    openraft::ErrorVerb::Write,
2025                                    std::io::Error::other(e.to_string()),
2026                                )
2027                            })?;
2028                        responses.push(AstraWriteResponse::Txn {
2029                            revision: forced_revision.unwrap_or_else(|| store.current_revision()),
2030                            succeeded,
2031                            responses: op_responses,
2032                        });
2033                    }
2034                    AstraWriteRequest::Compact { revision } => {
2035                        let compact_revision = store.compact_to(revision).map_err(|e| {
2036                            StorageError::from_io_error(
2037                                openraft::ErrorSubject::StateMachine,
2038                                openraft::ErrorVerb::Write,
2039                                std::io::Error::other(e.to_string()),
2040                            )
2041                        })?;
2042                        responses.push(AstraWriteResponse::Compact {
2043                            revision: store.current_revision(),
2044                            compact_revision,
2045                        });
2046                    }
2047                    AstraWriteRequest::LeaseGrant { id, ttl } => {
2048                        let LeaseGrantOutput { revision, id, ttl } =
2049                            store.lease_grant(id, ttl).map_err(|e| {
2050                                StorageError::from_io_error(
2051                                    openraft::ErrorSubject::StateMachine,
2052                                    openraft::ErrorVerb::Write,
2053                                    std::io::Error::other(e.to_string()),
2054                                )
2055                            })?;
2056                        responses.push(AstraWriteResponse::LeaseGrant { revision, id, ttl });
2057                    }
2058                    AstraWriteRequest::LeaseRevoke { id } => {
2059                        let LeaseRevokeOutput { revision, deleted } =
2060                            store.lease_revoke(id).map_err(|e| {
2061                                StorageError::from_io_error(
2062                                    openraft::ErrorSubject::StateMachine,
2063                                    openraft::ErrorVerb::Write,
2064                                    std::io::Error::other(e.to_string()),
2065                                )
2066                            })?;
2067                        responses.push(AstraWriteResponse::LeaseRevoke { revision, deleted });
2068                    }
2069                    AstraWriteRequest::LeaseKeepAlive { id } => {
2070                        let LeaseGrantOutput { revision, id, ttl } =
2071                            store.lease_keep_alive(id).map_err(|e| {
2072                                StorageError::from_io_error(
2073                                    openraft::ErrorSubject::StateMachine,
2074                                    openraft::ErrorVerb::Write,
2075                                    std::io::Error::other(e.to_string()),
2076                                )
2077                            })?;
2078                        responses.push(AstraWriteResponse::LeaseKeepAlive { revision, id, ttl });
2079                    }
2080                },
2081            }
2082
2083            last_applied = Some(log_id);
2084        }
2085
2086        let mut guard = self.shared.write().await;
2087        if let Some(last) = last_applied {
2088            guard.last_applied_log = Some(last);
2089        }
2090        if let Some(membership) = last_membership {
2091            guard.last_membership = membership;
2092        }
2093        guard.token_dict = token_dict;
2094        guard.token_dict_epoch = token_dict_epoch;
2095
2096        Ok(responses)
2097    }
2098
2099    async fn get_snapshot_builder(&mut self) -> Self::SnapshotBuilder {
2100        AstraSnapshotBuilder {
2101            shared: self.shared.clone(),
2102        }
2103    }
2104
2105    async fn begin_receiving_snapshot(
2106        &mut self,
2107    ) -> Result<Box<Cursor<Vec<u8>>>, StorageError<u64>> {
2108        Ok(Box::new(Cursor::new(Vec::new())))
2109    }
2110
2111    async fn install_snapshot(
2112        &mut self,
2113        meta: &SnapshotMeta<u64, BasicNode>,
2114        snapshot: Box<Cursor<Vec<u8>>>,
2115    ) -> Result<(), StorageError<u64>> {
2116        let cursor = *snapshot;
2117        let bytes = cursor.into_inner();
2118
2119        let (snapshot_state, token_dict, token_dict_epoch) =
2120            match bincode::deserialize::<SnapshotPayloadV2>(&bytes) {
2121                Ok(v2) => (
2122                    v2.store,
2123                    v2.token_dict.into_iter().collect::<HashMap<_, _>>(),
2124                    v2.token_dict_epoch,
2125                ),
2126                Err(_) => {
2127                    // Backward compatibility for v1 snapshots that only stored KvStore state.
2128                    let snapshot_state: SnapshotState =
2129                        bincode::deserialize(&bytes).map_err(|e| {
2130                            StorageError::from_io_error(
2131                                openraft::ErrorSubject::Snapshot(Some(meta.signature())),
2132                                openraft::ErrorVerb::Read,
2133                                std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()),
2134                            )
2135                        })?;
2136                    (snapshot_state, HashMap::new(), 0)
2137                }
2138            };
2139
2140        let store = {
2141            let guard = self.shared.read().await;
2142            guard.store.clone()
2143        };
2144
2145        store.load_snapshot_state(snapshot_state).map_err(|e| {
2146            StorageError::from_io_error(
2147                openraft::ErrorSubject::StateMachine,
2148                openraft::ErrorVerb::Write,
2149                std::io::Error::other(e.to_string()),
2150            )
2151        })?;
2152
2153        let mut guard = self.shared.write().await;
2154        guard.last_applied_log = meta.last_log_id.clone();
2155        guard.last_membership = meta.last_membership.clone();
2156        guard.token_dict = token_dict;
2157        guard.token_dict_epoch = token_dict_epoch;
2158        guard.current_snapshot = Some(SnapshotBlob {
2159            meta: meta.clone(),
2160            bytes,
2161        });
2162
2163        Ok(())
2164    }
2165
2166    async fn get_current_snapshot(
2167        &mut self,
2168    ) -> Result<Option<Snapshot<AstraTypeConfig>>, StorageError<u64>> {
2169        let guard = self.shared.read().await;
2170        Ok(guard.current_snapshot.as_ref().map(|blob| Snapshot {
2171            meta: blob.meta.clone(),
2172            snapshot: Box::new(Cursor::new(blob.bytes.clone())),
2173        }))
2174    }
2175}
2176
2177#[derive(Debug, Clone)]
2178pub struct AstraNetworkFactory;
2179
2180#[derive(Debug, Clone)]
2181pub struct AstraNetwork {
2182    addr: String,
2183    client: Option<InternalRaftClient<Channel>>,
2184}
2185
2186impl RaftNetworkFactory<AstraTypeConfig> for AstraNetworkFactory {
2187    type Network = AstraNetwork;
2188
2189    async fn new_client(&mut self, _target: u64, node: &BasicNode) -> Self::Network {
2190        Self::Network {
2191            addr: node.addr.clone(),
2192            client: None,
2193        }
2194    }
2195}
2196
2197impl AstraNetwork {
2198    async fn ensure_client(
2199        &mut self,
2200    ) -> Result<&mut InternalRaftClient<Channel>, tonic::transport::Error> {
2201        let addr = if self.addr.starts_with("http://") || self.addr.starts_with("https://") {
2202            self.addr.clone()
2203        } else {
2204            format!("http://{}", self.addr)
2205        };
2206        if self.client.is_none() {
2207            self.client = Some(InternalRaftClient::connect(addr).await?);
2208        }
2209        Ok(self.client.as_mut().expect("client initialized"))
2210    }
2211
2212    fn invalidate_client(&mut self) {
2213        self.client = None;
2214    }
2215
2216    fn should_retry_status(status: &Status) -> bool {
2217        matches!(
2218            status.code(),
2219            Code::Unavailable | Code::Cancelled | Code::DeadlineExceeded | Code::Unknown
2220        )
2221    }
2222}
2223
2224impl RaftNetwork<AstraTypeConfig> for AstraNetwork {
2225    async fn append_entries(
2226        &mut self,
2227        rpc: AppendEntriesRequest<AstraTypeConfig>,
2228        option: RPCOption,
2229    ) -> Result<AppendEntriesResponse<u64>, RPCError<u64, BasicNode, RaftError<u64>>> {
2230        let entry_count = rpc.entries.len();
2231        let first_log_index = rpc.entries.first().map(|e| e.log_id.index);
2232        let last_log_index = rpc.entries.last().map(|e| e.log_id.index);
2233        let payload = serde_json::to_vec(&rpc).map_err(|e| {
2234            RPCError::Network(NetworkError::new(&std::io::Error::other(e.to_string())))
2235        })?;
2236        let payload_bytes = payload.len();
2237
2238        let timeout = option.hard_ttl();
2239        for attempt in 0..2 {
2240            let started = Instant::now();
2241            let mut req = Request::new(RaftBytes {
2242                payload: payload.clone(),
2243            });
2244            req.set_timeout(timeout);
2245            let resp = self
2246                .ensure_client()
2247                .await
2248                .map_err(|e| RPCError::Network(NetworkError::new(&e)))?
2249                .append_entries(req)
2250                .await;
2251
2252            match resp {
2253                Ok(resp) => {
2254                    let elapsed_ms = started.elapsed().as_millis() as u64;
2255                    debug!(
2256                        stage = "append_entries_rpc_send_done",
2257                        entry_count,
2258                        first_log_index,
2259                        last_log_index,
2260                        payload_bytes,
2261                        elapsed_ms,
2262                        attempt,
2263                        "raft timeline"
2264                    );
2265                    return serde_json::from_slice::<AppendEntriesResponse<u64>>(
2266                        &resp.into_inner().payload,
2267                    )
2268                    .map_err(|e| {
2269                        RPCError::Network(NetworkError::new(&std::io::Error::other(e.to_string())))
2270                    });
2271                }
2272                Err(status) => {
2273                    let elapsed_ms = started.elapsed().as_millis() as u64;
2274                    debug!(
2275                        stage = "append_entries_rpc_send_done",
2276                        entry_count,
2277                        first_log_index,
2278                        last_log_index,
2279                        payload_bytes,
2280                        elapsed_ms,
2281                        attempt,
2282                        code = %status.code(),
2283                        success = false,
2284                        "raft timeline"
2285                    );
2286                    self.invalidate_client();
2287                    if attempt == 0 && Self::should_retry_status(&status) {
2288                        continue;
2289                    }
2290                    return Err(RPCError::Network(NetworkError::new(&status)));
2291                }
2292            }
2293        }
2294
2295        Err(RPCError::Network(NetworkError::new(
2296            &std::io::Error::other("append_entries exhausted retries"),
2297        )))
2298    }
2299
2300    async fn install_snapshot(
2301        &mut self,
2302        rpc: InstallSnapshotRequest<AstraTypeConfig>,
2303        option: RPCOption,
2304    ) -> Result<
2305        InstallSnapshotResponse<u64>,
2306        RPCError<u64, BasicNode, RaftError<u64, InstallSnapshotError>>,
2307    > {
2308        let payload = serde_json::to_vec(&rpc).map_err(|e| {
2309            RPCError::Network(NetworkError::new(&std::io::Error::other(e.to_string())))
2310        })?;
2311
2312        let timeout = option.hard_ttl();
2313        for attempt in 0..2 {
2314            let mut req = Request::new(RaftBytes {
2315                payload: payload.clone(),
2316            });
2317            req.set_timeout(timeout);
2318            let resp = self
2319                .ensure_client()
2320                .await
2321                .map_err(|e| RPCError::Network(NetworkError::new(&e)))?
2322                .install_snapshot(req)
2323                .await;
2324
2325            match resp {
2326                Ok(resp) => {
2327                    return serde_json::from_slice::<InstallSnapshotResponse<u64>>(
2328                        &resp.into_inner().payload,
2329                    )
2330                    .map_err(|e| {
2331                        RPCError::Network(NetworkError::new(&std::io::Error::other(e.to_string())))
2332                    });
2333                }
2334                Err(status) => {
2335                    self.invalidate_client();
2336                    if attempt == 0 && Self::should_retry_status(&status) {
2337                        continue;
2338                    }
2339                    return Err(RPCError::Network(NetworkError::new(&status)));
2340                }
2341            }
2342        }
2343
2344        Err(RPCError::Network(NetworkError::new(
2345            &std::io::Error::other("install_snapshot exhausted retries"),
2346        )))
2347    }
2348
2349    async fn vote(
2350        &mut self,
2351        rpc: VoteRequest<u64>,
2352        option: RPCOption,
2353    ) -> Result<VoteResponse<u64>, RPCError<u64, BasicNode, RaftError<u64>>> {
2354        let payload = serde_json::to_vec(&rpc).map_err(|e| {
2355            RPCError::Network(NetworkError::new(&std::io::Error::other(e.to_string())))
2356        })?;
2357
2358        let timeout = option.hard_ttl();
2359        for attempt in 0..2 {
2360            let mut req = Request::new(RaftBytes {
2361                payload: payload.clone(),
2362            });
2363            req.set_timeout(timeout);
2364            let resp = self
2365                .ensure_client()
2366                .await
2367                .map_err(|e| RPCError::Network(NetworkError::new(&e)))?
2368                .vote(req)
2369                .await;
2370
2371            match resp {
2372                Ok(resp) => {
2373                    return serde_json::from_slice::<VoteResponse<u64>>(&resp.into_inner().payload)
2374                        .map_err(|e| {
2375                            RPCError::Network(NetworkError::new(&std::io::Error::other(
2376                                e.to_string(),
2377                            )))
2378                        });
2379                }
2380                Err(status) => {
2381                    self.invalidate_client();
2382                    if attempt == 0 && Self::should_retry_status(&status) {
2383                        continue;
2384                    }
2385                    return Err(RPCError::Network(NetworkError::new(&status)));
2386                }
2387            }
2388        }
2389
2390        Err(RPCError::Network(NetworkError::new(
2391            &std::io::Error::other("vote exhausted retries"),
2392        )))
2393    }
2394}
2395
2396#[derive(Clone)]
2397pub struct AstraRaftService {
2398    pub raft: Raft<AstraTypeConfig>,
2399    pub node_id: u64,
2400    pub chaos_append_ack_delay_enabled: bool,
2401    pub chaos_append_ack_delay_min: Duration,
2402    pub chaos_append_ack_delay_max: Duration,
2403    pub chaos_append_ack_delay_node_id: u64,
2404}
2405
2406impl AstraRaftService {
2407    fn append_ack_chaos_delay(&self, entry_count: usize) -> Option<Duration> {
2408        if !self.chaos_append_ack_delay_enabled || entry_count == 0 {
2409            return None;
2410        }
2411        if self.chaos_append_ack_delay_node_id != 0
2412            && self.chaos_append_ack_delay_node_id != self.node_id
2413        {
2414            return None;
2415        }
2416
2417        let min_ms = self.chaos_append_ack_delay_min.as_millis() as u64;
2418        let max_ms = self
2419            .chaos_append_ack_delay_max
2420            .as_millis()
2421            .max(min_ms as u128) as u64;
2422        let span = max_ms.saturating_sub(min_ms);
2423        let jitter = if span == 0 {
2424            0
2425        } else {
2426            let now_ns = SystemTime::now()
2427                .duration_since(UNIX_EPOCH)
2428                .map(|d| d.as_nanos() as u64)
2429                .unwrap_or_default();
2430            now_ns % (span + 1)
2431        };
2432        Some(Duration::from_millis(min_ms.saturating_add(jitter)))
2433    }
2434}
2435
2436#[tonic::async_trait]
2437impl InternalRaft for AstraRaftService {
2438    async fn append_entries(
2439        &self,
2440        request: Request<RaftBytes>,
2441    ) -> Result<Response<RaftBytes>, Status> {
2442        let started = Instant::now();
2443        let rpc: AppendEntriesRequest<AstraTypeConfig> =
2444            serde_json::from_slice(&request.into_inner().payload)
2445                .map_err(|e| Status::invalid_argument(e.to_string()))?;
2446        let entry_count = rpc.entries.len();
2447        let first_log_index = rpc.entries.first().map(|e| e.log_id.index);
2448        let last_log_index = rpc.entries.last().map(|e| e.log_id.index);
2449
2450        if let Some(delay) = self.append_ack_chaos_delay(entry_count) {
2451            tokio::time::sleep(delay).await;
2452            debug!(
2453                stage = "append_entries_chaos_delay",
2454                node_id = self.node_id,
2455                delay_ms = delay.as_millis() as u64,
2456                entry_count,
2457                first_log_index,
2458                last_log_index,
2459                "raft timeline"
2460            );
2461        }
2462
2463        let resp = self
2464            .raft
2465            .append_entries(rpc)
2466            .await
2467            .map_err(|e| Status::internal(e.to_string()))?;
2468        let elapsed_ms = started.elapsed().as_millis() as u64;
2469        debug!(
2470            stage = "append_entries_rpc_recv_done",
2471            entry_count, first_log_index, last_log_index, elapsed_ms, "raft timeline"
2472        );
2473
2474        let payload = serde_json::to_vec(&resp).map_err(|e| Status::internal(e.to_string()))?;
2475        Ok(Response::new(RaftBytes { payload }))
2476    }
2477
2478    async fn vote(&self, request: Request<RaftBytes>) -> Result<Response<RaftBytes>, Status> {
2479        let rpc: VoteRequest<u64> = serde_json::from_slice(&request.into_inner().payload)
2480            .map_err(|e| Status::invalid_argument(e.to_string()))?;
2481
2482        let resp = self
2483            .raft
2484            .vote(rpc)
2485            .await
2486            .map_err(|e| Status::internal(e.to_string()))?;
2487
2488        let payload = serde_json::to_vec(&resp).map_err(|e| Status::internal(e.to_string()))?;
2489        Ok(Response::new(RaftBytes { payload }))
2490    }
2491
2492    async fn install_snapshot(
2493        &self,
2494        request: Request<RaftBytes>,
2495    ) -> Result<Response<RaftBytes>, Status> {
2496        let rpc: InstallSnapshotRequest<AstraTypeConfig> =
2497            serde_json::from_slice(&request.into_inner().payload)
2498                .map_err(|e| Status::invalid_argument(e.to_string()))?;
2499
2500        let resp = self
2501            .raft
2502            .install_snapshot(rpc)
2503            .await
2504            .map_err(|e| Status::internal(e.to_string()))?;
2505
2506        let payload = serde_json::to_vec(&resp).map_err(|e| Status::internal(e.to_string()))?;
2507        Ok(Response::new(RaftBytes { payload }))
2508    }
2509}
2510
2511pub fn parse_raft_nodes(peers: &[String]) -> HashMap<u64, BasicNode> {
2512    peers
2513        .iter()
2514        .enumerate()
2515        .map(|(idx, p)| {
2516            let addr = p
2517                .strip_prefix("http://")
2518                .or_else(|| p.strip_prefix("https://"))
2519                .unwrap_or(p)
2520                .to_string();
2521            ((idx as u64) + 1, BasicNode { addr })
2522        })
2523        .collect::<HashMap<_, _>>()
2524}
2525
2526pub async fn maybe_initialize(
2527    raft: &Raft<AstraTypeConfig>,
2528    nodes: HashMap<u64, BasicNode>,
2529) -> Result<(), RaftError<u64, openraft::error::InitializeError<u64, BasicNode>>> {
2530    if raft.is_initialized().await.unwrap_or(false) {
2531        return Ok(());
2532    }
2533
2534    let members = nodes.into_iter().collect::<BTreeMap<_, _>>();
2535    raft.initialize(members).await
2536}