1use std::collections::{BTreeMap, HashMap, VecDeque};
2use std::fmt::Debug;
3use std::io::{Cursor, Read};
4use std::ops::{Bound, RangeBounds};
5use std::path::Path;
6use std::sync::{Arc, Mutex};
7use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
8
9use astra_proto::astraraftpb::internal_raft_client::InternalRaftClient;
10use astra_proto::astraraftpb::internal_raft_server::InternalRaft;
11use astra_proto::astraraftpb::RaftBytes;
12use openraft::entry::{Entry, EntryPayload};
13use openraft::error::{InstallSnapshotError, NetworkError, RPCError, RaftError};
14use openraft::network::{RPCOption, RaftNetwork, RaftNetworkFactory};
15use openraft::raft::{
16 AppendEntriesRequest, AppendEntriesResponse, InstallSnapshotRequest, InstallSnapshotResponse,
17 VoteRequest, VoteResponse,
18};
19use openraft::storage::{LogFlushed, RaftLogStorage, RaftSnapshotBuilder, RaftStateMachine};
20use openraft::{
21 BasicNode, LogId, LogState, Raft, RaftLogReader, Snapshot, SnapshotMeta, StorageError,
22 StoredMembership, Vote,
23};
24use serde::{Deserialize, Serialize};
25use tokio::sync::{oneshot, Mutex as AsyncMutex, Notify, RwLock};
26use tonic::transport::Channel;
27use tonic::{Code, Request, Response, Status};
28use tracing::{debug, info, warn};
29
30use crate::config::{AstraConfig, WalIoEngine};
31use crate::errors::StoreError;
32use crate::metrics;
33use crate::store::{
34 DeleteOutput, KvStore, LeaseGrantOutput, LeaseRevokeOutput, PutOutput, RangeOutput,
35 SnapshotState, ValueEntry,
36};
37
38openraft::declare_raft_types!(
39 pub AstraTypeConfig:
40 D = AstraWriteRequest,
41 R = AstraWriteResponse,
42 NodeId = u64,
43 Node = BasicNode,
44 Entry = Entry<AstraTypeConfig>,
45 SnapshotData = Cursor<Vec<u8>>,
46 AsyncRuntime = openraft::TokioRuntime,
47);
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
50pub struct AstraBatchPutOp {
51 pub key: Vec<u8>,
52 pub value: Vec<u8>,
53 pub lease: i64,
54 pub ignore_value: bool,
55 pub ignore_lease: bool,
56 #[serde(default)]
57 pub prev_kv: bool,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct AstraBatchPutRefOp {
62 pub key: Vec<u8>,
63 pub value_idx: u32,
64 pub lease: i64,
65 pub ignore_value: bool,
66 pub ignore_lease: bool,
67 #[serde(default)]
68 pub prev_kv: bool,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct AstraTokenValue {
73 pub token_id: u32,
74 pub value: Vec<u8>,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct AstraBatchPutTokenOp {
79 pub key: Vec<u8>,
80 pub token_id: u32,
81 pub lease: i64,
82 pub ignore_value: bool,
83 pub ignore_lease: bool,
84 #[serde(default)]
85 pub prev_kv: bool,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct AstraBatchPutResult {
90 pub revision: i64,
91 pub prev: Option<ValueEntry>,
92}
93
94#[derive(Debug, Clone, Serialize, Deserialize)]
95pub enum AstraTxnCmpResult {
96 Equal,
97 Greater,
98 Less,
99 NotEqual,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub enum AstraTxnCmpTarget {
104 Version,
105 CreateRevision,
106 ModRevision,
107 Value,
108 Lease,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub enum AstraTxnCmpValue {
113 I64(i64),
114 Bytes(Vec<u8>),
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct AstraTxnCompare {
119 pub result: AstraTxnCmpResult,
120 pub target: AstraTxnCmpTarget,
121 pub key: Vec<u8>,
122 pub range_end: Vec<u8>,
123 pub target_value: AstraTxnCmpValue,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub enum AstraTxnOp {
128 Range {
129 key: Vec<u8>,
130 range_end: Vec<u8>,
131 limit: i64,
132 revision: i64,
133 keys_only: bool,
134 count_only: bool,
135 },
136 Put {
137 key: Vec<u8>,
138 value: Vec<u8>,
139 lease: i64,
140 ignore_value: bool,
141 ignore_lease: bool,
142 prev_kv: bool,
143 },
144 DeleteRange {
145 key: Vec<u8>,
146 range_end: Vec<u8>,
147 prev_kv: bool,
148 },
149 Txn {
150 compare: Vec<AstraTxnCompare>,
151 success: Vec<AstraTxnOp>,
152 failure: Vec<AstraTxnOp>,
153 },
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub enum AstraTxnOpResponse {
158 Range {
159 revision: i64,
160 count: i64,
161 more: bool,
162 kvs: Vec<(Vec<u8>, ValueEntry)>,
163 },
164 Put {
165 revision: i64,
166 prev: Option<ValueEntry>,
167 },
168 Delete {
169 revision: i64,
170 deleted: i64,
171 prev_kvs: Vec<(Vec<u8>, ValueEntry)>,
172 },
173 Txn {
174 revision: i64,
175 succeeded: bool,
176 responses: Vec<AstraTxnOpResponse>,
177 },
178}
179
180#[derive(Debug, Clone, Serialize, Deserialize)]
181pub enum AstraWriteRequest {
182 Put {
183 key: Vec<u8>,
184 value: Vec<u8>,
185 lease: i64,
186 ignore_value: bool,
187 ignore_lease: bool,
188 #[serde(default)]
189 prev_kv: bool,
190 },
191 PutBatch {
192 ops: Vec<AstraBatchPutOp>,
193 },
194 PutBatchRef {
195 #[serde(default)]
196 batch_id: u64,
197 #[serde(default)]
198 submit_ts_micros: u64,
199 values: Vec<Vec<u8>>,
200 ops: Vec<AstraBatchPutRefOp>,
201 },
202 PutBatchTokenized {
203 #[serde(default)]
204 batch_id: u64,
205 #[serde(default)]
206 submit_ts_micros: u64,
207 #[serde(default)]
208 dict_epoch: u64,
209 #[serde(default)]
210 dict_additions: Vec<AstraTokenValue>,
211 ops: Vec<AstraBatchPutTokenOp>,
212 },
213 DeleteRange {
214 key: Vec<u8>,
215 range_end: Vec<u8>,
216 prev_kv: bool,
217 },
218 Txn {
219 compare: Vec<AstraTxnCompare>,
220 success: Vec<AstraTxnOp>,
221 failure: Vec<AstraTxnOp>,
222 },
223 Compact {
224 revision: i64,
225 },
226 LeaseGrant {
227 id: i64,
228 ttl: i64,
229 },
230 LeaseRevoke {
231 id: i64,
232 },
233 LeaseKeepAlive {
234 id: i64,
235 },
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize)]
239pub enum AstraWriteResponse {
240 Put {
241 revision: i64,
242 prev: Option<ValueEntry>,
243 },
244 PutBatch {
245 results: Vec<AstraBatchPutResult>,
246 },
247 Delete {
248 revision: i64,
249 deleted: i64,
250 prev_kvs: Vec<(Vec<u8>, ValueEntry)>,
251 },
252 Txn {
253 revision: i64,
254 succeeded: bool,
255 responses: Vec<AstraTxnOpResponse>,
256 },
257 Compact {
258 revision: i64,
259 compact_revision: i64,
260 },
261 LeaseGrant {
262 revision: i64,
263 id: i64,
264 ttl: i64,
265 },
266 LeaseRevoke {
267 revision: i64,
268 deleted: i64,
269 },
270 LeaseKeepAlive {
271 revision: i64,
272 id: i64,
273 ttl: i64,
274 },
275 LeaseTtl {
276 revision: i64,
277 id: i64,
278 ttl: i64,
279 granted_ttl: i64,
280 keys: Vec<Vec<u8>>,
281 },
282 LeaseLeases {
283 revision: i64,
284 leases: Vec<i64>,
285 },
286 Empty,
287}
288
289#[derive(Debug, Clone)]
290pub struct WalBatchConfig {
291 pub max_batch_requests: usize,
292 pub max_batch_bytes: usize,
293 pub max_linger: Duration,
294 pub low_concurrency_threshold: usize,
295 pub low_linger: Duration,
296 pub channel_capacity: usize,
297 pub pending_limit: usize,
298 pub segment_bytes: u64,
299 pub io_engine: WalIoEngine,
300}
301
302impl Default for WalBatchConfig {
303 fn default() -> Self {
304 Self {
305 max_batch_requests: 1_000,
306 max_batch_bytes: 8 * 1024 * 1024,
307 max_linger: Duration::from_millis(2),
308 low_concurrency_threshold: 5,
309 low_linger: Duration::ZERO,
310 channel_capacity: 8_192,
311 pending_limit: 2_000,
312 segment_bytes: 64 * 1024 * 1024,
313 io_engine: WalIoEngine::Auto,
314 }
315 }
316}
317
318#[derive(Debug, Clone)]
319pub struct RaftBootstrap {
320 pub config: Arc<openraft::Config>,
321}
322
323impl RaftBootstrap {
324 pub fn new(cfg: &AstraConfig) -> anyhow::Result<Self> {
325 let cfg = openraft::Config {
326 cluster_name: "astra".to_string(),
327 election_timeout_min: cfg.raft_election_timeout_min_ms,
328 election_timeout_max: cfg.raft_election_timeout_max_ms,
329 heartbeat_interval: cfg.raft_heartbeat_interval_ms,
330 snapshot_max_chunk_size: 2 * 1024 * 1024,
331 max_payload_entries: cfg.raft_max_payload_entries,
332 snapshot_policy: openraft::SnapshotPolicy::LogsSinceLast(512),
333 replication_lag_threshold: cfg.raft_replication_lag_threshold,
334 max_in_snapshot_log_to_keep: 64,
335 purge_batch_size: 256,
336 ..Default::default()
337 };
338
339 let cfg = cfg.validate()?;
340 Ok(Self {
341 config: Arc::new(cfg),
342 })
343 }
344}
345
346#[derive(Debug, Clone, Serialize, Deserialize)]
347#[serde(tag = "kind", rename_all = "snake_case")]
348enum DurableRecord {
349 Append {
350 entries: Vec<Entry<AstraTypeConfig>>,
351 },
352 Truncate {
353 since: LogId<u64>,
354 },
355 Purge {
356 upto: LogId<u64>,
357 },
358 Vote {
359 vote: Vote<u64>,
360 },
361 Committed {
362 committed: Option<LogId<u64>>,
363 },
364}
365
366#[derive(Debug)]
367struct LogStoreState {
368 logs: BTreeMap<u64, Entry<AstraTypeConfig>>,
369 last_purged_log_id: Option<LogId<u64>>,
370 vote: Option<Vote<u64>>,
371 committed: Option<LogId<u64>>,
372}
373
374impl Default for LogStoreState {
375 fn default() -> Self {
376 Self {
377 logs: BTreeMap::new(),
378 last_purged_log_id: None,
379 vote: None,
380 committed: None,
381 }
382 }
383}
384
385fn align_up(value: usize, align: usize) -> usize {
386 if value % align == 0 {
387 value
388 } else {
389 value + (align - (value % align))
390 }
391}
392
393fn now_micros() -> u64 {
394 SystemTime::now()
395 .duration_since(UNIX_EPOCH)
396 .map(|d| d.as_micros() as u64)
397 .unwrap_or_default()
398}
399
400fn encode_wal_frame(payload: &[u8], block_size: usize) -> Vec<u8> {
401 let mut header = Vec::with_capacity(8);
402 header.extend_from_slice(b"ASTR");
403 header.extend_from_slice(&(payload.len() as u32).to_le_bytes());
404
405 let total = header.len() + payload.len();
406 let padded = align_up(total, block_size);
407 let mut out = Vec::with_capacity(padded);
408 out.extend_from_slice(&header);
409 out.extend_from_slice(payload);
410 out.resize(padded, 0_u8);
411 out
412}
413
414#[derive(Debug)]
415struct PosixBlockDevice {
416 file: std::fs::File,
417 offset: u64,
418 allocated: u64,
419 segment_bytes: u64,
420 block_size: usize,
421}
422
423impl PosixBlockDevice {
424 fn open(path: &Path, offset: u64, segment_bytes: u64) -> std::io::Result<Self> {
425 std::fs::create_dir_all(path.parent().unwrap_or_else(|| Path::new(".")))?;
426
427 let file = std::fs::OpenOptions::new()
428 .create(true)
429 .read(true)
430 .append(false)
431 .write(true)
432 .open(path)?;
433
434 let mut this = Self {
435 allocated: file.metadata()?.len(),
436 file,
437 offset,
438 segment_bytes,
439 block_size: 4096,
440 };
441
442 if this.allocated == 0 {
443 this.preallocate(this.segment_bytes)?;
444 }
445 this.ensure_capacity(this.offset)?;
446 Ok(this)
447 }
448
449 fn preallocate(&mut self, bytes: u64) -> std::io::Result<()> {
450 if bytes == 0 {
451 return Ok(());
452 }
453 #[cfg(target_os = "linux")]
454 {
455 use std::os::unix::io::AsRawFd;
456 let fd = self.file.as_raw_fd();
457 let rc = unsafe {
458 libc::fallocate(fd, 0, self.allocated as libc::off_t, bytes as libc::off_t)
459 };
460 if rc != 0 {
461 return Err(std::io::Error::last_os_error());
462 }
463 self.allocated = self.allocated.saturating_add(bytes);
464 info!(
465 allocated_bytes = self.allocated,
466 segment_bytes = self.segment_bytes,
467 "wal fallocate preallocation complete"
468 );
469 return Ok(());
470 }
471
472 #[cfg(not(target_os = "linux"))]
473 {
474 let next = self.allocated.saturating_add(bytes);
475 self.file.set_len(next)?;
476 self.allocated = next;
477 return Ok(());
478 }
479 }
480
481 fn ensure_capacity(&mut self, required_end: u64) -> std::io::Result<()> {
482 while required_end > self.allocated {
483 self.preallocate(self.segment_bytes)?;
484 }
485 Ok(())
486 }
487
488 fn append_payload(&mut self, payload: &[u8]) -> std::io::Result<()> {
489 use std::os::unix::io::AsRawFd;
490
491 let frame = encode_wal_frame(payload, self.block_size);
492 let required_end = self.offset.saturating_add(frame.len() as u64);
493 self.ensure_capacity(required_end)?;
494
495 let fd = self.file.as_raw_fd();
496 let wrote = unsafe {
497 libc::pwrite(
498 fd,
499 frame.as_ptr() as *const libc::c_void,
500 frame.len(),
501 self.offset as libc::off_t,
502 )
503 };
504 if wrote < 0 || wrote as usize != frame.len() {
505 return Err(std::io::Error::last_os_error());
506 }
507 self.offset = required_end;
508 Ok(())
509 }
510
511 fn sync_data(&mut self) -> std::io::Result<()> {
512 #[cfg(target_os = "linux")]
513 {
514 use std::os::unix::io::AsRawFd;
515 let fd = self.file.as_raw_fd();
516 let rc = unsafe { libc::fdatasync(fd) };
517 if rc != 0 {
518 return Err(std::io::Error::last_os_error());
519 }
520 return Ok(());
521 }
522
523 #[cfg(not(target_os = "linux"))]
524 {
525 self.file.sync_data()
526 }
527 }
528}
529
530#[derive(Debug)]
531enum WalDevice {
532 Posix(PosixBlockDevice),
533}
534
535impl WalDevice {
536 fn open(
537 path: &Path,
538 offset: u64,
539 cfg: &WalBatchConfig,
540 ) -> std::io::Result<(Self, &'static str)> {
541 if cfg.io_engine == WalIoEngine::IoUring {
542 warn!(
543 "io_uring requested but direct append path runs on tokio runtime; using posix wal engine"
544 );
545 }
546
547 let d = PosixBlockDevice::open(path, offset, cfg.segment_bytes)?;
548 info!("wal io engine selected: posix");
549 Ok((WalDevice::Posix(d), "posix"))
550 }
551
552 fn append_payload(&mut self, payload: &[u8]) -> std::io::Result<()> {
553 match self {
554 WalDevice::Posix(d) => d.append_payload(payload),
555 }
556 }
557
558 fn sync_data(&mut self) -> std::io::Result<()> {
559 match self {
560 WalDevice::Posix(d) => d.sync_data(),
561 }
562 }
563}
564
565enum WalWriteCompletion {
566 Append(LogFlushed<AstraTypeConfig>),
567 Waiter(oneshot::Sender<std::io::Result<()>>),
568}
569
570impl WalWriteCompletion {
571 fn complete_ok(self) {
572 match self {
573 WalWriteCompletion::Append(callback) => callback.log_io_completed(Ok(())),
574 WalWriteCompletion::Waiter(tx) => {
575 let _ = tx.send(Ok(()));
576 }
577 }
578 }
579
580 fn complete_err(self, err_msg: &str) {
581 match self {
582 WalWriteCompletion::Append(callback) => {
583 callback.log_io_completed(Err(std::io::Error::other(err_msg.to_string())))
584 }
585 WalWriteCompletion::Waiter(tx) => {
586 let _ = tx.send(Err(std::io::Error::other(err_msg.to_string())));
587 }
588 }
589 }
590}
591
592struct WalQueueItem {
593 payload: Vec<u8>,
594 append_entries: usize,
595 timeline_batch_id: Option<u64>,
596 timeline_submit_ts_micros: Option<u64>,
597 completion: WalWriteCompletion,
598}
599
600#[derive(Debug, Clone, Copy)]
601struct TimelineMarker {
602 batch_id: u64,
603 submit_ts_micros: u64,
604}
605
606#[derive(Default)]
607struct WalQueueState {
608 queue: VecDeque<WalQueueItem>,
609 queued_bytes: usize,
610 flushing: bool,
611}
612
613pub struct AstraLogStore {
614 inner: Arc<RwLock<LogStoreState>>,
615 wal_device: Arc<Mutex<WalDevice>>,
616 wal_cfg: WalBatchConfig,
617 wal_queue: Arc<AsyncMutex<WalQueueState>>,
618 wal_queue_space: Arc<Notify>,
619 last_flushed_committed: Option<LogId<u64>>,
620 last_committed_flush_at: Instant,
621 timeline_by_log_index: BTreeMap<u64, TimelineMarker>,
622}
623
624#[derive(Debug, Clone)]
625pub struct AstraLogReader {
626 inner: Arc<RwLock<LogStoreState>>,
627}
628
629impl AstraLogStore {
630 const COMMITTED_FLUSH_INTERVAL: Duration = Duration::from_millis(25);
631 const COMMITTED_FLUSH_STEP: u64 = 64;
632
633 pub async fn open(data_dir: &Path, mut cfg: WalBatchConfig) -> Result<Self, StorageError<u64>> {
634 cfg.max_batch_requests = cfg.max_batch_requests.max(1);
635 cfg.max_batch_bytes = cfg.max_batch_bytes.max(4 * 1024);
636 cfg.pending_limit = cfg.pending_limit.max(1);
637 cfg.low_concurrency_threshold = cfg.low_concurrency_threshold.max(1);
638
639 std::fs::create_dir_all(data_dir).map_err(|e| {
640 StorageError::from_io_error(
641 openraft::ErrorSubject::Store,
642 openraft::ErrorVerb::Write,
643 e,
644 )
645 })?;
646
647 let wal_path = data_dir.join("unified-raft.wal");
648 let mut state = LogStoreState::default();
649 let wal_offset = replay_wal(&wal_path, &mut state)?;
650 let (device, engine) = WalDevice::open(&wal_path, wal_offset, &cfg).map_err(|e| {
651 StorageError::from_io_error(
652 openraft::ErrorSubject::Store,
653 openraft::ErrorVerb::Write,
654 e,
655 )
656 })?;
657
658 let last_flushed_committed = state.committed.clone();
659
660 let inner = Arc::new(RwLock::new(state));
661
662 info!(
663 wal_path = %data_dir.join("unified-raft.wal").display(),
664 max_batch_requests = cfg.max_batch_requests,
665 max_batch_bytes = cfg.max_batch_bytes,
666 max_linger_us = cfg.max_linger.as_micros(),
667 low_concurrency_threshold = cfg.low_concurrency_threshold,
668 low_linger_us = cfg.low_linger.as_micros(),
669 pending_limit = cfg.pending_limit,
670 segment_bytes = cfg.segment_bytes,
671 io_engine = %cfg.io_engine.as_str(),
672 selected_engine = engine,
673 "wal direct append path initialized"
674 );
675
676 let mut timeline_by_log_index = BTreeMap::new();
677 {
678 let guard = inner.read().await;
679 for (log_index, entry) in &guard.logs {
680 if let EntryPayload::Normal(req) = &entry.payload {
681 match req {
682 AstraWriteRequest::PutBatchRef {
683 batch_id,
684 submit_ts_micros,
685 ..
686 }
687 | AstraWriteRequest::PutBatchTokenized {
688 batch_id,
689 submit_ts_micros,
690 ..
691 } if *batch_id > 0 && *submit_ts_micros > 0 => {
692 timeline_by_log_index.insert(
693 *log_index,
694 TimelineMarker {
695 batch_id: *batch_id,
696 submit_ts_micros: *submit_ts_micros,
697 },
698 );
699 }
700 _ => {}
701 }
702 }
703 }
704 }
705
706 let store = Self {
707 inner,
708 wal_device: Arc::new(Mutex::new(device)),
709 wal_cfg: cfg,
710 wal_queue: Arc::new(AsyncMutex::new(WalQueueState::default())),
711 wal_queue_space: Arc::new(Notify::new()),
712 last_flushed_committed,
713 last_committed_flush_at: Instant::now(),
714 timeline_by_log_index,
715 };
716 metrics::set_wal_queue_depth(0);
717 metrics::set_wal_queue_bytes(0);
718 Ok(store)
719 }
720
721 pub async fn wal_queue_snapshot(&self) -> (usize, usize) {
722 let guard = self.wal_queue.lock().await;
723 (guard.queue.len(), guard.queued_bytes)
724 }
725
726 fn encode_records_payload(records: Vec<DurableRecord>) -> Result<Vec<u8>, StorageError<u64>> {
727 encode_durable_records(records).map_err(|e| {
728 StorageError::from_io_error(openraft::ErrorSubject::Logs, openraft::ErrorVerb::Write, e)
729 })
730 }
731
732 fn logical_write_ops(entry: &Entry<AstraTypeConfig>) -> usize {
733 match &entry.payload {
734 EntryPayload::Normal(AstraWriteRequest::PutBatch { ops }) => ops.len().max(1),
735 EntryPayload::Normal(AstraWriteRequest::PutBatchRef { ops, .. }) => ops.len().max(1),
736 EntryPayload::Normal(AstraWriteRequest::PutBatchTokenized { ops, .. }) => {
737 ops.len().max(1)
738 }
739 EntryPayload::Normal(_) => 1,
740 EntryPayload::Blank | EntryPayload::Membership(_) => 0,
741 }
742 }
743
744 fn extract_timeline_meta(entries: &[Entry<AstraTypeConfig>]) -> (Option<u64>, Option<u64>) {
745 for entry in entries {
746 if let EntryPayload::Normal(req) = &entry.payload {
747 match req {
748 AstraWriteRequest::PutBatchRef {
749 batch_id,
750 submit_ts_micros,
751 ..
752 }
753 | AstraWriteRequest::PutBatchTokenized {
754 batch_id,
755 submit_ts_micros,
756 ..
757 } => {
758 return (
759 if *batch_id > 0 { Some(*batch_id) } else { None },
760 if *submit_ts_micros > 0 {
761 Some(*submit_ts_micros)
762 } else {
763 None
764 },
765 );
766 }
767 _ => {}
768 }
769 }
770 }
771 (None, None)
772 }
773
774 fn extract_timeline_markers(entries: &[Entry<AstraTypeConfig>]) -> Vec<(u64, TimelineMarker)> {
775 let mut out = Vec::new();
776 for entry in entries {
777 if let EntryPayload::Normal(req) = &entry.payload {
778 match req {
779 AstraWriteRequest::PutBatchRef {
780 batch_id,
781 submit_ts_micros,
782 ..
783 }
784 | AstraWriteRequest::PutBatchTokenized {
785 batch_id,
786 submit_ts_micros,
787 ..
788 } if *batch_id > 0 && *submit_ts_micros > 0 => {
789 out.push((
790 entry.log_id.index,
791 TimelineMarker {
792 batch_id: *batch_id,
793 submit_ts_micros: *submit_ts_micros,
794 },
795 ));
796 }
797 _ => {}
798 }
799 }
800 }
801 out
802 }
803
804 async fn await_wal_flush(
805 rx: oneshot::Receiver<std::io::Result<()>>,
806 ) -> Result<(), StorageError<u64>> {
807 match rx.await {
808 Ok(result) => result.map_err(|e| {
809 StorageError::from_io_error(
810 openraft::ErrorSubject::Logs,
811 openraft::ErrorVerb::Write,
812 e,
813 )
814 }),
815 Err(e) => Err(StorageError::from_io_error(
816 openraft::ErrorSubject::Logs,
817 openraft::ErrorVerb::Write,
818 std::io::Error::other(format!("wal flush callback dropped: {e}")),
819 )),
820 }
821 }
822
823 fn spawn_wal_flush_worker(&self) {
824 let wal_device = self.wal_device.clone();
825 let wal_cfg = self.wal_cfg.clone();
826 let wal_queue = self.wal_queue.clone();
827 let wal_queue_space = self.wal_queue_space.clone();
828
829 tokio::spawn(async move {
830 loop {
831 let put_inflight = metrics::current_put_inflight_requests();
832 let mut effective_linger = wal_cfg.max_linger;
833 if put_inflight < wal_cfg.low_concurrency_threshold as u64 {
834 effective_linger = wal_cfg.low_linger;
835 }
836 metrics::set_wal_effective_linger_us(effective_linger.as_micros() as u64);
837
838 if effective_linger > Duration::ZERO {
839 let should_linger = {
840 let guard = wal_queue.lock().await;
841 !guard.queue.is_empty()
842 && guard.queue.len() < wal_cfg.max_batch_requests
843 && guard.queued_bytes < wal_cfg.max_batch_bytes
844 };
845 if should_linger {
846 tokio::time::sleep(effective_linger).await;
847 }
848 }
849
850 let batch = {
851 let mut guard = wal_queue.lock().await;
852 if guard.queue.is_empty() {
853 guard.flushing = false;
854 metrics::set_wal_queue_depth(0);
855 metrics::set_wal_queue_bytes(0);
856 wal_queue_space.notify_waiters();
857 break;
858 }
859
860 let mut drained = Vec::new();
861 let mut batch_bytes = 0_usize;
862 while let Some(front) = guard.queue.front() {
863 let front_bytes = front.payload.len();
864 let reached_request_limit = drained.len() >= wal_cfg.max_batch_requests;
865 let reached_byte_limit = !drained.is_empty()
866 && batch_bytes.saturating_add(front_bytes) > wal_cfg.max_batch_bytes;
867 if reached_request_limit || reached_byte_limit {
868 break;
869 }
870
871 let item = guard.queue.pop_front().expect("queue front exists");
872 guard.queued_bytes = guard.queued_bytes.saturating_sub(item.payload.len());
873 batch_bytes = batch_bytes.saturating_add(item.payload.len());
874 drained.push(item);
875 }
876
877 if drained.is_empty() {
878 if let Some(item) = guard.queue.pop_front() {
879 guard.queued_bytes =
880 guard.queued_bytes.saturating_sub(item.payload.len());
881 drained.push(item);
882 }
883 }
884 metrics::set_wal_queue_depth(guard.queue.len());
885 metrics::set_wal_queue_bytes(guard.queued_bytes);
886 wal_queue_space.notify_waiters();
887 drained
888 };
889
890 if batch.is_empty() {
891 continue;
892 }
893
894 let requests = batch.len();
895 let append_entries = batch.iter().map(|item| item.append_entries).sum::<usize>();
896 let payload_bytes = batch.iter().map(|item| item.payload.len()).sum::<usize>();
897 let timeline_batch_id = batch.iter().find_map(|item| item.timeline_batch_id);
898 let timeline_submit_ts_micros =
899 batch.iter().find_map(|item| item.timeline_submit_ts_micros);
900 let mut payload = Vec::with_capacity(payload_bytes);
901 for item in &batch {
902 payload.extend_from_slice(&item.payload);
903 }
904
905 let started = Instant::now();
906 let io_result = tokio::task::spawn_blocking({
907 let wal_device = wal_device.clone();
908 move || -> std::io::Result<()> {
909 let mut wal_device = wal_device
910 .lock()
911 .map_err(|_| std::io::Error::other("wal device lock poisoned"))?;
912 wal_device.append_payload(&payload)?;
913 wal_device.sync_data()?;
914 Ok(())
915 }
916 })
917 .await;
918
919 let wal_sync_duration_ms = started.elapsed().as_millis() as u64;
920
921 let flush_outcome = match io_result {
922 Ok(result) => result,
923 Err(e) => Err(std::io::Error::other(format!(
924 "wal sync worker join error: {e}"
925 ))),
926 };
927
928 match flush_outcome {
929 Ok(()) => {
930 let since_submit_ms = timeline_submit_ts_micros
931 .and_then(|submit| now_micros().checked_sub(submit))
932 .map(|delta| delta / 1_000);
933 if append_entries >= 50 {
934 info!(
935 requests,
936 append_entries,
937 payload_bytes,
938 wal_sync_duration_ms,
939 batch_id = timeline_batch_id,
940 since_submit_ms,
941 fdatasync_calls = 1,
942 "wal vectorized append flush complete"
943 );
944 } else {
945 debug!(
946 requests,
947 append_entries,
948 payload_bytes,
949 wal_sync_duration_ms,
950 batch_id = timeline_batch_id,
951 since_submit_ms,
952 fdatasync_calls = 1,
953 "wal direct flush complete"
954 );
955 }
956 for item in batch {
957 item.completion.complete_ok();
958 }
959 }
960 Err(err) => {
961 let err_msg = err.to_string();
962 warn!(
963 requests,
964 append_entries,
965 payload_bytes,
966 wal_sync_duration_ms,
967 error = %err_msg,
968 "wal batch flush failed"
969 );
970 for item in batch {
971 item.completion.complete_err(&err_msg);
972 }
973 }
974 }
975 }
976 });
977 }
978
979 async fn enqueue_payload(
980 &self,
981 payload: Vec<u8>,
982 append_entries: usize,
983 timeline_batch_id: Option<u64>,
984 timeline_submit_ts_micros: Option<u64>,
985 completion: WalWriteCompletion,
986 ) -> Result<(), StorageError<u64>> {
987 if payload.is_empty() {
988 completion.complete_ok();
989 return Ok(());
990 }
991
992 let payload_len = payload.len();
993 let pending_limit = self.wal_cfg.pending_limit.max(1);
994 let mut maybe_item = Some(WalQueueItem {
995 payload,
996 append_entries,
997 timeline_batch_id,
998 timeline_submit_ts_micros,
999 completion,
1000 });
1001
1002 loop {
1003 let mut start_worker = false;
1004 {
1005 let mut guard = self.wal_queue.lock().await;
1006 if guard.queue.len() < pending_limit {
1007 let item = maybe_item.take().expect("wal queue item available");
1008 guard.queued_bytes = guard.queued_bytes.saturating_add(payload_len);
1009 guard.queue.push_back(item);
1010 metrics::set_wal_queue_depth(guard.queue.len());
1011 metrics::set_wal_queue_bytes(guard.queued_bytes);
1012
1013 if !guard.flushing {
1014 guard.flushing = true;
1015 start_worker = true;
1016 }
1017 }
1018 }
1019
1020 if maybe_item.is_none() {
1021 if start_worker {
1022 self.spawn_wal_flush_worker();
1023 }
1024 return Ok(());
1025 }
1026
1027 self.wal_queue_space.notified().await;
1028 }
1029 }
1030
1031 async fn append_records(
1032 &self,
1033 records: Vec<DurableRecord>,
1034 append_entries: usize,
1035 ) -> Result<(), StorageError<u64>> {
1036 let payload = Self::encode_records_payload(records)?;
1037 let (tx, rx) = oneshot::channel::<std::io::Result<()>>();
1038 self.enqueue_payload(
1039 payload,
1040 append_entries,
1041 None,
1042 None,
1043 WalWriteCompletion::Waiter(tx),
1044 )
1045 .await?;
1046 Self::await_wal_flush(rx).await
1047 }
1048
1049 async fn append_records_with_callback(
1050 &self,
1051 records: Vec<DurableRecord>,
1052 append_entries: usize,
1053 timeline_batch_id: Option<u64>,
1054 timeline_submit_ts_micros: Option<u64>,
1055 callback: LogFlushed<AstraTypeConfig>,
1056 ) -> Result<(), StorageError<u64>> {
1057 let payload = match Self::encode_records_payload(records) {
1058 Ok(payload) => payload,
1059 Err(err) => {
1060 callback.log_io_completed(Err(std::io::Error::other(err.to_string())));
1061 return Err(err);
1062 }
1063 };
1064
1065 if let Err(err) = self
1066 .enqueue_payload(
1067 payload,
1068 append_entries,
1069 timeline_batch_id,
1070 timeline_submit_ts_micros,
1071 WalWriteCompletion::Append(callback),
1072 )
1073 .await
1074 {
1075 warn!(error = %err, "wal enqueue failed for append callback path");
1078 return Err(err);
1079 }
1080 Ok(())
1081 }
1082}
1083
1084fn replay_wal(path: &Path, state: &mut LogStoreState) -> Result<u64, StorageError<u64>> {
1085 if !path.exists() {
1086 return Ok(0);
1087 }
1088
1089 let bytes = std::fs::read(path).map_err(|e| {
1090 StorageError::from_io_error(openraft::ErrorSubject::Store, openraft::ErrorVerb::Read, e)
1091 })?;
1092
1093 let mut offset = 0_usize;
1094 let mut durable_end = 0_usize;
1095 while offset + 8 <= bytes.len() {
1096 let header = &bytes[offset..offset + 8];
1097 if header[..4] == [0, 0, 0, 0] {
1098 offset = align_up(offset + 1, 4096);
1099 continue;
1100 }
1101
1102 if &header[..4] != b"ASTR" {
1103 break;
1104 }
1105
1106 let payload_len = u32::from_le_bytes(header[4..8].try_into().expect("slice len")) as usize;
1107 let start = offset + 8;
1108 let end = start + payload_len;
1109 if end > bytes.len() {
1110 break;
1111 }
1112
1113 let payload = &bytes[start..end];
1114 let mut cursor = Cursor::new(payload);
1115 while (cursor.position() as usize) < payload.len() {
1116 let mut len_buf = [0_u8; 4];
1117 if cursor.read_exact(&mut len_buf).is_err() {
1118 break;
1119 }
1120 let rec_len = u32::from_le_bytes(len_buf) as usize;
1121 if rec_len == 0 {
1122 break;
1123 }
1124
1125 let mut rec_buf = vec![0_u8; rec_len];
1126 if cursor.read_exact(&mut rec_buf).is_err() {
1127 break;
1128 }
1129
1130 let rec: DurableRecord = match bincode::deserialize(&rec_buf) {
1131 Ok(rec) => rec,
1132 Err(_) => {
1133 break;
1136 }
1137 };
1138
1139 apply_durable_record(state, rec);
1140 }
1141
1142 offset = align_up(end, 4096);
1143 durable_end = offset;
1144 }
1145
1146 Ok(durable_end as u64)
1147}
1148
1149fn apply_durable_record(state: &mut LogStoreState, rec: DurableRecord) {
1150 match rec {
1151 DurableRecord::Append { entries } => {
1152 for entry in entries {
1153 state.logs.insert(entry.log_id.index, entry);
1154 }
1155 }
1156 DurableRecord::Truncate { since } => {
1157 let keys = state
1158 .logs
1159 .range(since.index..)
1160 .map(|(k, _)| *k)
1161 .collect::<Vec<_>>();
1162 for k in keys {
1163 state.logs.remove(&k);
1164 }
1165 }
1166 DurableRecord::Purge { upto } => {
1167 let keys = state
1168 .logs
1169 .range(..=upto.index)
1170 .map(|(k, _)| *k)
1171 .collect::<Vec<_>>();
1172 for k in keys {
1173 state.logs.remove(&k);
1174 }
1175 state.last_purged_log_id = Some(upto);
1176 }
1177 DurableRecord::Vote { vote } => {
1178 state.vote = Some(vote);
1179 }
1180 DurableRecord::Committed { committed } => {
1181 state.committed = committed;
1182 }
1183 }
1184}
1185
1186fn encode_durable_records(records: Vec<DurableRecord>) -> std::io::Result<Vec<u8>> {
1187 let mut encoded = Vec::new();
1188 for rec in records {
1189 let rec_bytes = bincode::serialize(&rec)
1190 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err.to_string()))?;
1191 encoded.extend_from_slice(&(rec_bytes.len() as u32).to_le_bytes());
1192 encoded.extend_from_slice(&rec_bytes);
1193 }
1194 Ok(encoded)
1195}
1196
1197impl RaftLogReader<AstraTypeConfig> for AstraLogReader {
1198 async fn try_get_log_entries<RB: RangeBounds<u64> + Clone + Debug + openraft::OptionalSend>(
1199 &mut self,
1200 range: RB,
1201 ) -> Result<Vec<Entry<AstraTypeConfig>>, StorageError<u64>> {
1202 let (start, end) = range_to_bounds(range);
1203 let guard = self.inner.read().await;
1204 let out = guard
1205 .logs
1206 .range(start..end)
1207 .map(|(_, v)| v.clone())
1208 .collect::<Vec<_>>();
1209 Ok(out)
1210 }
1211}
1212
1213impl RaftLogReader<AstraTypeConfig> for AstraLogStore {
1214 async fn try_get_log_entries<RB: RangeBounds<u64> + Clone + Debug + openraft::OptionalSend>(
1215 &mut self,
1216 range: RB,
1217 ) -> Result<Vec<Entry<AstraTypeConfig>>, StorageError<u64>> {
1218 let mut reader = AstraLogReader {
1219 inner: self.inner.clone(),
1220 };
1221 reader.try_get_log_entries(range).await
1222 }
1223}
1224
1225fn range_to_bounds<RB: RangeBounds<u64>>(range: RB) -> (u64, u64) {
1226 let start = match range.start_bound() {
1227 Bound::Included(v) => *v,
1228 Bound::Excluded(v) => v.saturating_add(1),
1229 Bound::Unbounded => 0,
1230 };
1231 let end = match range.end_bound() {
1232 Bound::Included(v) => v.saturating_add(1),
1233 Bound::Excluded(v) => *v,
1234 Bound::Unbounded => u64::MAX,
1235 };
1236 (start, end)
1237}
1238
1239impl RaftLogStorage<AstraTypeConfig> for AstraLogStore {
1240 type LogReader = AstraLogReader;
1241
1242 async fn get_log_state(&mut self) -> Result<LogState<AstraTypeConfig>, StorageError<u64>> {
1243 let guard = self.inner.read().await;
1244 let last_log_id = guard
1245 .logs
1246 .values()
1247 .last()
1248 .map(|e| e.log_id.clone())
1249 .or_else(|| guard.last_purged_log_id.clone());
1250
1251 Ok(LogState {
1252 last_purged_log_id: guard.last_purged_log_id.clone(),
1253 last_log_id,
1254 })
1255 }
1256
1257 async fn get_log_reader(&mut self) -> Self::LogReader {
1258 AstraLogReader {
1259 inner: self.inner.clone(),
1260 }
1261 }
1262
1263 async fn save_vote(&mut self, vote: &Vote<u64>) -> Result<(), StorageError<u64>> {
1264 {
1265 let mut guard = self.inner.write().await;
1266 guard.vote = Some(vote.clone());
1267 }
1268 self.append_records(vec![DurableRecord::Vote { vote: vote.clone() }], 0)
1269 .await
1270 }
1271
1272 async fn read_vote(&mut self) -> Result<Option<Vote<u64>>, StorageError<u64>> {
1273 let guard = self.inner.read().await;
1274 Ok(guard.vote.clone())
1275 }
1276
1277 async fn save_committed(
1278 &mut self,
1279 committed: Option<LogId<u64>>,
1280 ) -> Result<(), StorageError<u64>> {
1281 {
1282 let mut guard = self.inner.write().await;
1283 guard.committed = committed.clone();
1284 }
1285
1286 if let Some(committed_log_id) = committed.as_ref() {
1287 if let Some((tracked_log_index, marker)) = self
1288 .timeline_by_log_index
1289 .range(..=committed_log_id.index)
1290 .next_back()
1291 .map(|(idx, marker)| (*idx, *marker))
1292 {
1293 let since_submit_ms = now_micros()
1294 .checked_sub(marker.submit_ts_micros)
1295 .map(|delta| delta / 1_000)
1296 .unwrap_or(0);
1297 info!(
1298 stage = "quorum_ack_commit_advanced",
1299 batch_id = marker.batch_id,
1300 committed_index = committed_log_id.index,
1301 tracked_log_index,
1302 since_submit_ms,
1303 "raft timeline"
1304 );
1305 metrics::observe_put_quorum_ack_ms(since_submit_ms);
1306 }
1307
1308 let keep_from = committed_log_id.index.saturating_add(1);
1309 self.timeline_by_log_index = self.timeline_by_log_index.split_off(&keep_from);
1310 }
1311
1312 let should_flush = match (&self.last_flushed_committed, &committed) {
1313 (None, None) => false,
1314 (None, Some(_)) => true,
1315 (Some(_), None) => true,
1316 (Some(prev), Some(next)) => {
1317 let step_reached =
1318 next.index >= prev.index.saturating_add(Self::COMMITTED_FLUSH_STEP);
1319 step_reached
1320 || self.last_committed_flush_at.elapsed() >= Self::COMMITTED_FLUSH_INTERVAL
1321 }
1322 };
1323
1324 if should_flush {
1325 self.append_records(
1326 vec![DurableRecord::Committed {
1327 committed: committed.clone(),
1328 }],
1329 0,
1330 )
1331 .await?;
1332 self.last_flushed_committed = committed;
1333 self.last_committed_flush_at = Instant::now();
1334 }
1335
1336 Ok(())
1337 }
1338
1339 async fn read_committed(&mut self) -> Result<Option<LogId<u64>>, StorageError<u64>> {
1340 let guard = self.inner.read().await;
1341 Ok(guard.committed.clone())
1342 }
1343
1344 async fn append<I>(
1345 &mut self,
1346 entries: I,
1347 callback: LogFlushed<AstraTypeConfig>,
1348 ) -> Result<(), StorageError<u64>>
1349 where
1350 I: IntoIterator<Item = Entry<AstraTypeConfig>> + openraft::OptionalSend,
1351 I::IntoIter: openraft::OptionalSend,
1352 {
1353 let entries_vec = entries.into_iter().collect::<Vec<_>>();
1354 if entries_vec.is_empty() {
1355 callback.log_io_completed(Ok(()));
1356 return Ok(());
1357 }
1358
1359 let timeline_markers = Self::extract_timeline_markers(&entries_vec);
1360
1361 {
1362 let mut guard = self.inner.write().await;
1363 for ent in &entries_vec {
1364 guard.logs.insert(ent.log_id.index, ent.clone());
1365 }
1366 }
1367 for (log_index, marker) in timeline_markers {
1368 self.timeline_by_log_index.insert(log_index, marker);
1369 }
1370
1371 let append_entries = entries_vec
1372 .iter()
1373 .map(Self::logical_write_ops)
1374 .sum::<usize>()
1375 .max(1);
1376 let (timeline_batch_id, timeline_submit_ts_micros) =
1377 Self::extract_timeline_meta(&entries_vec);
1378 if let Some(batch_id) = timeline_batch_id {
1379 let since_submit_ms = timeline_submit_ts_micros
1380 .and_then(|submit| now_micros().checked_sub(submit))
1381 .map(|delta| delta / 1_000);
1382 debug!(
1383 stage = "raft_log_append_in_memory",
1384 batch_id,
1385 append_entries,
1386 log_entries = entries_vec.len(),
1387 since_submit_ms,
1388 "raft timeline"
1389 );
1390 }
1391 self.append_records_with_callback(
1392 vec![DurableRecord::Append {
1393 entries: entries_vec,
1394 }],
1395 append_entries,
1396 timeline_batch_id,
1397 timeline_submit_ts_micros,
1398 callback,
1399 )
1400 .await
1401 }
1402
1403 async fn truncate(&mut self, log_id: LogId<u64>) -> Result<(), StorageError<u64>> {
1404 {
1405 let mut guard = self.inner.write().await;
1406 let keys = guard
1407 .logs
1408 .range(log_id.index..)
1409 .map(|(k, _)| *k)
1410 .collect::<Vec<_>>();
1411 for k in keys {
1412 guard.logs.remove(&k);
1413 }
1414 }
1415 self.timeline_by_log_index = self.timeline_by_log_index.split_off(&log_id.index);
1416
1417 self.append_records(vec![DurableRecord::Truncate { since: log_id }], 0)
1418 .await
1419 }
1420
1421 async fn purge(&mut self, log_id: LogId<u64>) -> Result<(), StorageError<u64>> {
1422 {
1423 let mut guard = self.inner.write().await;
1424 let keys = guard
1425 .logs
1426 .range(..=log_id.index)
1427 .map(|(k, _)| *k)
1428 .collect::<Vec<_>>();
1429 for k in keys {
1430 guard.logs.remove(&k);
1431 }
1432 guard.last_purged_log_id = Some(log_id.clone());
1433 }
1434 let keep_from = log_id.index.saturating_add(1);
1435 self.timeline_by_log_index = self.timeline_by_log_index.split_off(&keep_from);
1436
1437 self.append_records(vec![DurableRecord::Purge { upto: log_id }], 0)
1438 .await
1439 }
1440}
1441
1442#[derive(Debug, Clone)]
1443struct SnapshotBlob {
1444 meta: SnapshotMeta<u64, BasicNode>,
1445 bytes: Vec<u8>,
1446}
1447
1448#[derive(Debug, Clone, Serialize, Deserialize)]
1449struct SnapshotPayloadV2 {
1450 store: SnapshotState,
1451 token_dict: Vec<(u32, Vec<u8>)>,
1452 token_dict_epoch: u64,
1453}
1454
1455#[derive(Debug)]
1456struct AstraStateMachineShared {
1457 store: Arc<KvStore>,
1458 last_applied_log: Option<LogId<u64>>,
1459 last_membership: StoredMembership<u64, BasicNode>,
1460 current_snapshot: Option<SnapshotBlob>,
1461 token_dict: HashMap<u32, Vec<u8>>,
1462 token_dict_epoch: u64,
1463}
1464
1465#[derive(Debug, Clone)]
1466pub struct AstraStateMachine {
1467 shared: Arc<RwLock<AstraStateMachineShared>>,
1468}
1469
1470impl AstraStateMachine {
1471 pub fn new(store: Arc<KvStore>) -> Self {
1472 Self {
1473 shared: Arc::new(RwLock::new(AstraStateMachineShared {
1474 store,
1475 last_applied_log: None,
1476 last_membership: StoredMembership::default(),
1477 current_snapshot: None,
1478 token_dict: HashMap::new(),
1479 token_dict_epoch: 0,
1480 })),
1481 }
1482 }
1483}
1484
1485fn txn_compare_i64(lhs: i64, rhs: i64, result: &AstraTxnCmpResult) -> bool {
1486 match result {
1487 AstraTxnCmpResult::Equal => lhs == rhs,
1488 AstraTxnCmpResult::Greater => lhs > rhs,
1489 AstraTxnCmpResult::Less => lhs < rhs,
1490 AstraTxnCmpResult::NotEqual => lhs != rhs,
1491 }
1492}
1493
1494fn txn_compare_bytes(lhs: &[u8], rhs: &[u8], result: &AstraTxnCmpResult) -> bool {
1495 match result {
1496 AstraTxnCmpResult::Equal => lhs == rhs,
1497 AstraTxnCmpResult::Greater => lhs > rhs,
1498 AstraTxnCmpResult::Less => lhs < rhs,
1499 AstraTxnCmpResult::NotEqual => lhs != rhs,
1500 }
1501}
1502
1503fn txn_ops_have_write(ops: &[AstraTxnOp]) -> bool {
1504 for op in ops {
1505 match op {
1506 AstraTxnOp::Put { .. } | AstraTxnOp::DeleteRange { .. } => return true,
1507 AstraTxnOp::Txn {
1508 success, failure, ..
1509 } => {
1510 if txn_ops_have_write(success) || txn_ops_have_write(failure) {
1511 return true;
1512 }
1513 }
1514 AstraTxnOp::Range { .. } => {}
1515 }
1516 }
1517 false
1518}
1519
1520fn eval_txn_compare(store: &KvStore, cmp: &AstraTxnCompare) -> Result<bool, StoreError> {
1521 let out = store.range(&cmp.key, &cmp.range_end, 0, 0, false, false)?;
1522
1523 if out.kvs.is_empty() {
1524 return match (&cmp.target, &cmp.target_value) {
1525 (AstraTxnCmpTarget::Value, AstraTxnCmpValue::Bytes(rhs)) => {
1526 Ok(txn_compare_bytes(&[], rhs, &cmp.result))
1527 }
1528 (
1529 AstraTxnCmpTarget::Version
1530 | AstraTxnCmpTarget::CreateRevision
1531 | AstraTxnCmpTarget::ModRevision
1532 | AstraTxnCmpTarget::Lease,
1533 AstraTxnCmpValue::I64(rhs),
1534 ) => Ok(txn_compare_i64(0, *rhs, &cmp.result)),
1535 _ => Err(StoreError::InvalidArgument(
1536 "txn compare target and value type mismatch".to_string(),
1537 )),
1538 };
1539 }
1540
1541 for (_, kv) in out.kvs {
1542 let ok = match (&cmp.target, &cmp.target_value) {
1543 (AstraTxnCmpTarget::Version, AstraTxnCmpValue::I64(rhs)) => {
1544 txn_compare_i64(kv.version, *rhs, &cmp.result)
1545 }
1546 (AstraTxnCmpTarget::CreateRevision, AstraTxnCmpValue::I64(rhs)) => {
1547 txn_compare_i64(kv.create_revision, *rhs, &cmp.result)
1548 }
1549 (AstraTxnCmpTarget::ModRevision, AstraTxnCmpValue::I64(rhs)) => {
1550 txn_compare_i64(kv.mod_revision, *rhs, &cmp.result)
1551 }
1552 (AstraTxnCmpTarget::Lease, AstraTxnCmpValue::I64(rhs)) => {
1553 txn_compare_i64(kv.lease, *rhs, &cmp.result)
1554 }
1555 (AstraTxnCmpTarget::Value, AstraTxnCmpValue::Bytes(rhs)) => {
1556 txn_compare_bytes(&kv.value, rhs, &cmp.result)
1557 }
1558 _ => {
1559 return Err(StoreError::InvalidArgument(
1560 "txn compare target and value type mismatch".to_string(),
1561 ));
1562 }
1563 };
1564 if !ok {
1565 return Ok(false);
1566 }
1567 }
1568
1569 Ok(true)
1570}
1571
1572fn eval_txn_compares(store: &KvStore, compares: &[AstraTxnCompare]) -> Result<bool, StoreError> {
1573 for cmp in compares {
1574 if !eval_txn_compare(store, cmp)? {
1575 return Ok(false);
1576 }
1577 }
1578 Ok(true)
1579}
1580
1581fn apply_txn_ops(
1582 store: &KvStore,
1583 ops: &[AstraTxnOp],
1584 forced_revision: Option<i64>,
1585) -> Result<Vec<AstraTxnOpResponse>, StoreError> {
1586 let mut out = Vec::with_capacity(ops.len());
1587
1588 for op in ops {
1589 match op {
1590 AstraTxnOp::Range {
1591 key,
1592 range_end,
1593 limit,
1594 revision,
1595 keys_only,
1596 count_only,
1597 } => {
1598 let RangeOutput {
1599 revision,
1600 count,
1601 more,
1602 kvs,
1603 } = store.range(key, range_end, *limit, *revision, *keys_only, *count_only)?;
1604 out.push(AstraTxnOpResponse::Range {
1605 revision,
1606 count,
1607 more,
1608 kvs,
1609 });
1610 }
1611 AstraTxnOp::Put {
1612 key,
1613 value,
1614 lease,
1615 ignore_value,
1616 ignore_lease,
1617 prev_kv,
1618 } => {
1619 let PutOutput {
1620 revision,
1621 prev,
1622 current: _,
1623 } = if let Some(rev) = forced_revision {
1624 store.apply_put_at_revision(
1625 key.clone(),
1626 value.clone(),
1627 *lease,
1628 *ignore_value,
1629 *ignore_lease,
1630 rev,
1631 )?
1632 } else {
1633 store.put(
1634 key.clone(),
1635 value.clone(),
1636 *lease,
1637 *ignore_value,
1638 *ignore_lease,
1639 )?
1640 };
1641 out.push(AstraTxnOpResponse::Put {
1642 revision,
1643 prev: if *prev_kv { prev } else { None },
1644 });
1645 }
1646 AstraTxnOp::DeleteRange {
1647 key,
1648 range_end,
1649 prev_kv,
1650 } => {
1651 let DeleteOutput {
1652 revision,
1653 deleted,
1654 prev_kvs,
1655 } = if let Some(rev) = forced_revision {
1656 store.apply_delete_at_revision(key, range_end, *prev_kv, rev)?
1657 } else {
1658 store.delete_range(key, range_end, *prev_kv)?
1659 };
1660 out.push(AstraTxnOpResponse::Delete {
1661 revision,
1662 deleted,
1663 prev_kvs,
1664 });
1665 }
1666 AstraTxnOp::Txn {
1667 compare,
1668 success,
1669 failure,
1670 } => {
1671 let succeeded = eval_txn_compares(store, compare)?;
1672 let branch = if succeeded { success } else { failure };
1673 let nested_forced = if txn_ops_have_write(branch) {
1674 forced_revision
1675 } else {
1676 None
1677 };
1678 let responses = apply_txn_ops(store, branch, nested_forced)?;
1679 out.push(AstraTxnOpResponse::Txn {
1680 revision: forced_revision.unwrap_or_else(|| store.current_revision()),
1681 succeeded,
1682 responses,
1683 });
1684 }
1685 }
1686 }
1687
1688 Ok(out)
1689}
1690
1691#[derive(Debug, Clone)]
1692pub struct AstraSnapshotBuilder {
1693 shared: Arc<RwLock<AstraStateMachineShared>>,
1694}
1695
1696impl RaftSnapshotBuilder<AstraTypeConfig> for AstraSnapshotBuilder {
1697 async fn build_snapshot(&mut self) -> Result<Snapshot<AstraTypeConfig>, StorageError<u64>> {
1698 let (snapshot_state, token_dict, token_dict_epoch, last_log_id, last_membership) = {
1699 let guard = self.shared.read().await;
1700 (
1701 guard.store.snapshot_state(),
1702 guard.token_dict.clone(),
1703 guard.token_dict_epoch,
1704 guard.last_applied_log.clone(),
1705 guard.last_membership.clone(),
1706 )
1707 };
1708
1709 let payload = SnapshotPayloadV2 {
1710 store: snapshot_state,
1711 token_dict: token_dict.into_iter().collect::<Vec<_>>(),
1712 token_dict_epoch,
1713 };
1714
1715 let bytes = bincode::serialize(&payload).map_err(|e| {
1716 StorageError::from_io_error(
1717 openraft::ErrorSubject::Snapshot(None),
1718 openraft::ErrorVerb::Write,
1719 std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()),
1720 )
1721 })?;
1722
1723 let snapshot_id = format!(
1724 "snapshot-{}-{}",
1725 last_log_id.as_ref().map(|v| v.index).unwrap_or_default(),
1726 SystemTime::now()
1727 .duration_since(UNIX_EPOCH)
1728 .map(|d| d.as_millis())
1729 .unwrap_or_default()
1730 );
1731
1732 let meta = SnapshotMeta {
1733 last_log_id: last_log_id.clone(),
1734 last_membership: last_membership.clone(),
1735 snapshot_id,
1736 };
1737
1738 {
1739 let mut guard = self.shared.write().await;
1740 guard.current_snapshot = Some(SnapshotBlob {
1741 meta: meta.clone(),
1742 bytes: bytes.clone(),
1743 });
1744 }
1745
1746 Ok(Snapshot {
1747 meta,
1748 snapshot: Box::new(Cursor::new(bytes)),
1749 })
1750 }
1751}
1752
1753impl RaftStateMachine<AstraTypeConfig> for AstraStateMachine {
1754 type SnapshotBuilder = AstraSnapshotBuilder;
1755
1756 async fn applied_state(
1757 &mut self,
1758 ) -> Result<(Option<LogId<u64>>, StoredMembership<u64, BasicNode>), StorageError<u64>> {
1759 let guard = self.shared.read().await;
1760 Ok((
1761 guard.last_applied_log.clone(),
1762 guard.last_membership.clone(),
1763 ))
1764 }
1765
1766 async fn apply<I>(&mut self, entries: I) -> Result<Vec<AstraWriteResponse>, StorageError<u64>>
1767 where
1768 I: IntoIterator<Item = Entry<AstraTypeConfig>> + openraft::OptionalSend,
1769 I::IntoIter: openraft::OptionalSend,
1770 {
1771 let entries = entries.into_iter().collect::<Vec<_>>();
1772
1773 let (store, mut token_dict, mut token_dict_epoch) = {
1774 let guard = self.shared.read().await;
1775 (
1776 guard.store.clone(),
1777 guard.token_dict.clone(),
1778 guard.token_dict_epoch,
1779 )
1780 };
1781
1782 let mut responses = Vec::with_capacity(entries.len());
1783
1784 let mut last_applied = None;
1785 let mut last_membership = None;
1786
1787 for ent in entries {
1788 let log_id = ent.log_id.clone();
1789
1790 match ent.payload {
1791 EntryPayload::Blank => {
1792 responses.push(AstraWriteResponse::Empty);
1793 }
1794 EntryPayload::Membership(membership) => {
1795 last_membership = Some(StoredMembership::new(Some(log_id.clone()), membership));
1796 responses.push(AstraWriteResponse::Empty);
1797 }
1798 EntryPayload::Normal(req) => match req {
1799 AstraWriteRequest::Put {
1800 key,
1801 value,
1802 lease,
1803 ignore_value,
1804 ignore_lease,
1805 prev_kv,
1806 } => {
1807 let PutOutput {
1808 revision,
1809 prev,
1810 current: _,
1811 } = store
1812 .put(key, value, lease, ignore_value, ignore_lease)
1813 .map_err(|e| {
1814 StorageError::from_io_error(
1815 openraft::ErrorSubject::StateMachine,
1816 openraft::ErrorVerb::Write,
1817 std::io::Error::other(e.to_string()),
1818 )
1819 })?;
1820
1821 responses.push(AstraWriteResponse::Put {
1822 revision,
1823 prev: if prev_kv { prev } else { None },
1824 });
1825 }
1826 AstraWriteRequest::PutBatch { ops } => {
1827 let mut results = Vec::with_capacity(ops.len());
1828 for op in ops {
1829 let PutOutput {
1830 revision,
1831 prev,
1832 current: _,
1833 } = store
1834 .put(op.key, op.value, op.lease, op.ignore_value, op.ignore_lease)
1835 .map_err(|e| {
1836 StorageError::from_io_error(
1837 openraft::ErrorSubject::StateMachine,
1838 openraft::ErrorVerb::Write,
1839 std::io::Error::other(e.to_string()),
1840 )
1841 })?;
1842
1843 results.push(AstraBatchPutResult {
1844 revision,
1845 prev: if op.prev_kv { prev } else { None },
1846 });
1847 }
1848 responses.push(AstraWriteResponse::PutBatch { results });
1849 }
1850 AstraWriteRequest::PutBatchRef {
1851 batch_id,
1852 submit_ts_micros,
1853 values,
1854 ops,
1855 } => {
1856 let op_count = ops.len();
1857 let unique_values = values.len();
1858 let mut results = Vec::with_capacity(ops.len());
1859 for op in ops {
1860 let value =
1861 values.get(op.value_idx as usize).cloned().ok_or_else(|| {
1862 StorageError::from_io_error(
1863 openraft::ErrorSubject::StateMachine,
1864 openraft::ErrorVerb::Write,
1865 std::io::Error::new(
1866 std::io::ErrorKind::InvalidData,
1867 format!(
1868 "put batch reference index out of bounds: {}",
1869 op.value_idx
1870 ),
1871 ),
1872 )
1873 })?;
1874
1875 let PutOutput {
1876 revision,
1877 prev,
1878 current: _,
1879 } = store
1880 .put(op.key, value, op.lease, op.ignore_value, op.ignore_lease)
1881 .map_err(|e| {
1882 StorageError::from_io_error(
1883 openraft::ErrorSubject::StateMachine,
1884 openraft::ErrorVerb::Write,
1885 std::io::Error::other(e.to_string()),
1886 )
1887 })?;
1888
1889 results.push(AstraBatchPutResult {
1890 revision,
1891 prev: if op.prev_kv { prev } else { None },
1892 });
1893 }
1894 if batch_id > 0 {
1895 let since_submit_ms = now_micros()
1896 .checked_sub(submit_ts_micros)
1897 .map(|delta| delta / 1_000);
1898 info!(
1899 stage = "sm_apply_done",
1900 batch_id,
1901 op_count,
1902 unique_values,
1903 tokenized = false,
1904 since_submit_ms,
1905 "raft timeline"
1906 );
1907 }
1908 responses.push(AstraWriteResponse::PutBatch { results });
1909 }
1910 AstraWriteRequest::PutBatchTokenized {
1911 batch_id,
1912 submit_ts_micros,
1913 dict_epoch,
1914 dict_additions,
1915 ops,
1916 } => {
1917 if dict_epoch > token_dict_epoch {
1918 token_dict_epoch = dict_epoch;
1919 }
1920 for add in &dict_additions {
1921 token_dict.insert(add.token_id, add.value.clone());
1922 }
1923
1924 let op_count = ops.len();
1925 let dict_add_count = dict_additions.len();
1926 let mut results = Vec::with_capacity(ops.len());
1927 for op in ops {
1928 let value = token_dict.get(&op.token_id).cloned().ok_or_else(|| {
1929 StorageError::from_io_error(
1930 openraft::ErrorSubject::StateMachine,
1931 openraft::ErrorVerb::Write,
1932 std::io::Error::new(
1933 std::io::ErrorKind::InvalidData,
1934 format!(
1935 "missing token in tokenized batch: {}",
1936 op.token_id
1937 ),
1938 ),
1939 )
1940 })?;
1941
1942 let PutOutput {
1943 revision,
1944 prev,
1945 current: _,
1946 } = store
1947 .put(op.key, value, op.lease, op.ignore_value, op.ignore_lease)
1948 .map_err(|e| {
1949 StorageError::from_io_error(
1950 openraft::ErrorSubject::StateMachine,
1951 openraft::ErrorVerb::Write,
1952 std::io::Error::other(e.to_string()),
1953 )
1954 })?;
1955
1956 results.push(AstraBatchPutResult {
1957 revision,
1958 prev: if op.prev_kv { prev } else { None },
1959 });
1960 }
1961
1962 if batch_id > 0 {
1963 let since_submit_ms = now_micros()
1964 .checked_sub(submit_ts_micros)
1965 .map(|delta| delta / 1_000);
1966 info!(
1967 stage = "sm_apply_done",
1968 batch_id,
1969 op_count,
1970 dict_add_count,
1971 dict_size = token_dict.len(),
1972 tokenized = true,
1973 since_submit_ms,
1974 "raft timeline"
1975 );
1976 }
1977 responses.push(AstraWriteResponse::PutBatch { results });
1978 }
1979 AstraWriteRequest::DeleteRange {
1980 key,
1981 range_end,
1982 prev_kv,
1983 } => {
1984 let DeleteOutput {
1985 revision,
1986 deleted,
1987 prev_kvs,
1988 } = store.delete_range(&key, &range_end, prev_kv).map_err(|e| {
1989 StorageError::from_io_error(
1990 openraft::ErrorSubject::StateMachine,
1991 openraft::ErrorVerb::Write,
1992 std::io::Error::other(e.to_string()),
1993 )
1994 })?;
1995
1996 responses.push(AstraWriteResponse::Delete {
1997 revision,
1998 deleted,
1999 prev_kvs,
2000 });
2001 }
2002 AstraWriteRequest::Txn {
2003 compare,
2004 success,
2005 failure,
2006 } => {
2007 let succeeded = eval_txn_compares(&store, &compare).map_err(|e| {
2008 StorageError::from_io_error(
2009 openraft::ErrorSubject::StateMachine,
2010 openraft::ErrorVerb::Write,
2011 std::io::Error::other(e.to_string()),
2012 )
2013 })?;
2014 let branch = if succeeded { &success } else { &failure };
2015 let forced_revision = if txn_ops_have_write(branch) {
2016 Some(store.reserve_revision())
2017 } else {
2018 None
2019 };
2020 let op_responses =
2021 apply_txn_ops(&store, branch, forced_revision).map_err(|e| {
2022 StorageError::from_io_error(
2023 openraft::ErrorSubject::StateMachine,
2024 openraft::ErrorVerb::Write,
2025 std::io::Error::other(e.to_string()),
2026 )
2027 })?;
2028 responses.push(AstraWriteResponse::Txn {
2029 revision: forced_revision.unwrap_or_else(|| store.current_revision()),
2030 succeeded,
2031 responses: op_responses,
2032 });
2033 }
2034 AstraWriteRequest::Compact { revision } => {
2035 let compact_revision = store.compact_to(revision).map_err(|e| {
2036 StorageError::from_io_error(
2037 openraft::ErrorSubject::StateMachine,
2038 openraft::ErrorVerb::Write,
2039 std::io::Error::other(e.to_string()),
2040 )
2041 })?;
2042 responses.push(AstraWriteResponse::Compact {
2043 revision: store.current_revision(),
2044 compact_revision,
2045 });
2046 }
2047 AstraWriteRequest::LeaseGrant { id, ttl } => {
2048 let LeaseGrantOutput { revision, id, ttl } =
2049 store.lease_grant(id, ttl).map_err(|e| {
2050 StorageError::from_io_error(
2051 openraft::ErrorSubject::StateMachine,
2052 openraft::ErrorVerb::Write,
2053 std::io::Error::other(e.to_string()),
2054 )
2055 })?;
2056 responses.push(AstraWriteResponse::LeaseGrant { revision, id, ttl });
2057 }
2058 AstraWriteRequest::LeaseRevoke { id } => {
2059 let LeaseRevokeOutput { revision, deleted } =
2060 store.lease_revoke(id).map_err(|e| {
2061 StorageError::from_io_error(
2062 openraft::ErrorSubject::StateMachine,
2063 openraft::ErrorVerb::Write,
2064 std::io::Error::other(e.to_string()),
2065 )
2066 })?;
2067 responses.push(AstraWriteResponse::LeaseRevoke { revision, deleted });
2068 }
2069 AstraWriteRequest::LeaseKeepAlive { id } => {
2070 let LeaseGrantOutput { revision, id, ttl } =
2071 store.lease_keep_alive(id).map_err(|e| {
2072 StorageError::from_io_error(
2073 openraft::ErrorSubject::StateMachine,
2074 openraft::ErrorVerb::Write,
2075 std::io::Error::other(e.to_string()),
2076 )
2077 })?;
2078 responses.push(AstraWriteResponse::LeaseKeepAlive { revision, id, ttl });
2079 }
2080 },
2081 }
2082
2083 last_applied = Some(log_id);
2084 }
2085
2086 let mut guard = self.shared.write().await;
2087 if let Some(last) = last_applied {
2088 guard.last_applied_log = Some(last);
2089 }
2090 if let Some(membership) = last_membership {
2091 guard.last_membership = membership;
2092 }
2093 guard.token_dict = token_dict;
2094 guard.token_dict_epoch = token_dict_epoch;
2095
2096 Ok(responses)
2097 }
2098
2099 async fn get_snapshot_builder(&mut self) -> Self::SnapshotBuilder {
2100 AstraSnapshotBuilder {
2101 shared: self.shared.clone(),
2102 }
2103 }
2104
2105 async fn begin_receiving_snapshot(
2106 &mut self,
2107 ) -> Result<Box<Cursor<Vec<u8>>>, StorageError<u64>> {
2108 Ok(Box::new(Cursor::new(Vec::new())))
2109 }
2110
2111 async fn install_snapshot(
2112 &mut self,
2113 meta: &SnapshotMeta<u64, BasicNode>,
2114 snapshot: Box<Cursor<Vec<u8>>>,
2115 ) -> Result<(), StorageError<u64>> {
2116 let cursor = *snapshot;
2117 let bytes = cursor.into_inner();
2118
2119 let (snapshot_state, token_dict, token_dict_epoch) =
2120 match bincode::deserialize::<SnapshotPayloadV2>(&bytes) {
2121 Ok(v2) => (
2122 v2.store,
2123 v2.token_dict.into_iter().collect::<HashMap<_, _>>(),
2124 v2.token_dict_epoch,
2125 ),
2126 Err(_) => {
2127 let snapshot_state: SnapshotState =
2129 bincode::deserialize(&bytes).map_err(|e| {
2130 StorageError::from_io_error(
2131 openraft::ErrorSubject::Snapshot(Some(meta.signature())),
2132 openraft::ErrorVerb::Read,
2133 std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()),
2134 )
2135 })?;
2136 (snapshot_state, HashMap::new(), 0)
2137 }
2138 };
2139
2140 let store = {
2141 let guard = self.shared.read().await;
2142 guard.store.clone()
2143 };
2144
2145 store.load_snapshot_state(snapshot_state).map_err(|e| {
2146 StorageError::from_io_error(
2147 openraft::ErrorSubject::StateMachine,
2148 openraft::ErrorVerb::Write,
2149 std::io::Error::other(e.to_string()),
2150 )
2151 })?;
2152
2153 let mut guard = self.shared.write().await;
2154 guard.last_applied_log = meta.last_log_id.clone();
2155 guard.last_membership = meta.last_membership.clone();
2156 guard.token_dict = token_dict;
2157 guard.token_dict_epoch = token_dict_epoch;
2158 guard.current_snapshot = Some(SnapshotBlob {
2159 meta: meta.clone(),
2160 bytes,
2161 });
2162
2163 Ok(())
2164 }
2165
2166 async fn get_current_snapshot(
2167 &mut self,
2168 ) -> Result<Option<Snapshot<AstraTypeConfig>>, StorageError<u64>> {
2169 let guard = self.shared.read().await;
2170 Ok(guard.current_snapshot.as_ref().map(|blob| Snapshot {
2171 meta: blob.meta.clone(),
2172 snapshot: Box::new(Cursor::new(blob.bytes.clone())),
2173 }))
2174 }
2175}
2176
2177#[derive(Debug, Clone)]
2178pub struct AstraNetworkFactory;
2179
2180#[derive(Debug, Clone)]
2181pub struct AstraNetwork {
2182 addr: String,
2183 client: Option<InternalRaftClient<Channel>>,
2184}
2185
2186impl RaftNetworkFactory<AstraTypeConfig> for AstraNetworkFactory {
2187 type Network = AstraNetwork;
2188
2189 async fn new_client(&mut self, _target: u64, node: &BasicNode) -> Self::Network {
2190 Self::Network {
2191 addr: node.addr.clone(),
2192 client: None,
2193 }
2194 }
2195}
2196
2197impl AstraNetwork {
2198 async fn ensure_client(
2199 &mut self,
2200 ) -> Result<&mut InternalRaftClient<Channel>, tonic::transport::Error> {
2201 let addr = if self.addr.starts_with("http://") || self.addr.starts_with("https://") {
2202 self.addr.clone()
2203 } else {
2204 format!("http://{}", self.addr)
2205 };
2206 if self.client.is_none() {
2207 self.client = Some(InternalRaftClient::connect(addr).await?);
2208 }
2209 Ok(self.client.as_mut().expect("client initialized"))
2210 }
2211
2212 fn invalidate_client(&mut self) {
2213 self.client = None;
2214 }
2215
2216 fn should_retry_status(status: &Status) -> bool {
2217 matches!(
2218 status.code(),
2219 Code::Unavailable | Code::Cancelled | Code::DeadlineExceeded | Code::Unknown
2220 )
2221 }
2222}
2223
2224impl RaftNetwork<AstraTypeConfig> for AstraNetwork {
2225 async fn append_entries(
2226 &mut self,
2227 rpc: AppendEntriesRequest<AstraTypeConfig>,
2228 option: RPCOption,
2229 ) -> Result<AppendEntriesResponse<u64>, RPCError<u64, BasicNode, RaftError<u64>>> {
2230 let entry_count = rpc.entries.len();
2231 let first_log_index = rpc.entries.first().map(|e| e.log_id.index);
2232 let last_log_index = rpc.entries.last().map(|e| e.log_id.index);
2233 let payload = serde_json::to_vec(&rpc).map_err(|e| {
2234 RPCError::Network(NetworkError::new(&std::io::Error::other(e.to_string())))
2235 })?;
2236 let payload_bytes = payload.len();
2237
2238 let timeout = option.hard_ttl();
2239 for attempt in 0..2 {
2240 let started = Instant::now();
2241 let mut req = Request::new(RaftBytes {
2242 payload: payload.clone(),
2243 });
2244 req.set_timeout(timeout);
2245 let resp = self
2246 .ensure_client()
2247 .await
2248 .map_err(|e| RPCError::Network(NetworkError::new(&e)))?
2249 .append_entries(req)
2250 .await;
2251
2252 match resp {
2253 Ok(resp) => {
2254 let elapsed_ms = started.elapsed().as_millis() as u64;
2255 debug!(
2256 stage = "append_entries_rpc_send_done",
2257 entry_count,
2258 first_log_index,
2259 last_log_index,
2260 payload_bytes,
2261 elapsed_ms,
2262 attempt,
2263 "raft timeline"
2264 );
2265 return serde_json::from_slice::<AppendEntriesResponse<u64>>(
2266 &resp.into_inner().payload,
2267 )
2268 .map_err(|e| {
2269 RPCError::Network(NetworkError::new(&std::io::Error::other(e.to_string())))
2270 });
2271 }
2272 Err(status) => {
2273 let elapsed_ms = started.elapsed().as_millis() as u64;
2274 debug!(
2275 stage = "append_entries_rpc_send_done",
2276 entry_count,
2277 first_log_index,
2278 last_log_index,
2279 payload_bytes,
2280 elapsed_ms,
2281 attempt,
2282 code = %status.code(),
2283 success = false,
2284 "raft timeline"
2285 );
2286 self.invalidate_client();
2287 if attempt == 0 && Self::should_retry_status(&status) {
2288 continue;
2289 }
2290 return Err(RPCError::Network(NetworkError::new(&status)));
2291 }
2292 }
2293 }
2294
2295 Err(RPCError::Network(NetworkError::new(
2296 &std::io::Error::other("append_entries exhausted retries"),
2297 )))
2298 }
2299
2300 async fn install_snapshot(
2301 &mut self,
2302 rpc: InstallSnapshotRequest<AstraTypeConfig>,
2303 option: RPCOption,
2304 ) -> Result<
2305 InstallSnapshotResponse<u64>,
2306 RPCError<u64, BasicNode, RaftError<u64, InstallSnapshotError>>,
2307 > {
2308 let payload = serde_json::to_vec(&rpc).map_err(|e| {
2309 RPCError::Network(NetworkError::new(&std::io::Error::other(e.to_string())))
2310 })?;
2311
2312 let timeout = option.hard_ttl();
2313 for attempt in 0..2 {
2314 let mut req = Request::new(RaftBytes {
2315 payload: payload.clone(),
2316 });
2317 req.set_timeout(timeout);
2318 let resp = self
2319 .ensure_client()
2320 .await
2321 .map_err(|e| RPCError::Network(NetworkError::new(&e)))?
2322 .install_snapshot(req)
2323 .await;
2324
2325 match resp {
2326 Ok(resp) => {
2327 return serde_json::from_slice::<InstallSnapshotResponse<u64>>(
2328 &resp.into_inner().payload,
2329 )
2330 .map_err(|e| {
2331 RPCError::Network(NetworkError::new(&std::io::Error::other(e.to_string())))
2332 });
2333 }
2334 Err(status) => {
2335 self.invalidate_client();
2336 if attempt == 0 && Self::should_retry_status(&status) {
2337 continue;
2338 }
2339 return Err(RPCError::Network(NetworkError::new(&status)));
2340 }
2341 }
2342 }
2343
2344 Err(RPCError::Network(NetworkError::new(
2345 &std::io::Error::other("install_snapshot exhausted retries"),
2346 )))
2347 }
2348
2349 async fn vote(
2350 &mut self,
2351 rpc: VoteRequest<u64>,
2352 option: RPCOption,
2353 ) -> Result<VoteResponse<u64>, RPCError<u64, BasicNode, RaftError<u64>>> {
2354 let payload = serde_json::to_vec(&rpc).map_err(|e| {
2355 RPCError::Network(NetworkError::new(&std::io::Error::other(e.to_string())))
2356 })?;
2357
2358 let timeout = option.hard_ttl();
2359 for attempt in 0..2 {
2360 let mut req = Request::new(RaftBytes {
2361 payload: payload.clone(),
2362 });
2363 req.set_timeout(timeout);
2364 let resp = self
2365 .ensure_client()
2366 .await
2367 .map_err(|e| RPCError::Network(NetworkError::new(&e)))?
2368 .vote(req)
2369 .await;
2370
2371 match resp {
2372 Ok(resp) => {
2373 return serde_json::from_slice::<VoteResponse<u64>>(&resp.into_inner().payload)
2374 .map_err(|e| {
2375 RPCError::Network(NetworkError::new(&std::io::Error::other(
2376 e.to_string(),
2377 )))
2378 });
2379 }
2380 Err(status) => {
2381 self.invalidate_client();
2382 if attempt == 0 && Self::should_retry_status(&status) {
2383 continue;
2384 }
2385 return Err(RPCError::Network(NetworkError::new(&status)));
2386 }
2387 }
2388 }
2389
2390 Err(RPCError::Network(NetworkError::new(
2391 &std::io::Error::other("vote exhausted retries"),
2392 )))
2393 }
2394}
2395
2396#[derive(Clone)]
2397pub struct AstraRaftService {
2398 pub raft: Raft<AstraTypeConfig>,
2399 pub node_id: u64,
2400 pub chaos_append_ack_delay_enabled: bool,
2401 pub chaos_append_ack_delay_min: Duration,
2402 pub chaos_append_ack_delay_max: Duration,
2403 pub chaos_append_ack_delay_node_id: u64,
2404}
2405
2406impl AstraRaftService {
2407 fn append_ack_chaos_delay(&self, entry_count: usize) -> Option<Duration> {
2408 if !self.chaos_append_ack_delay_enabled || entry_count == 0 {
2409 return None;
2410 }
2411 if self.chaos_append_ack_delay_node_id != 0
2412 && self.chaos_append_ack_delay_node_id != self.node_id
2413 {
2414 return None;
2415 }
2416
2417 let min_ms = self.chaos_append_ack_delay_min.as_millis() as u64;
2418 let max_ms = self
2419 .chaos_append_ack_delay_max
2420 .as_millis()
2421 .max(min_ms as u128) as u64;
2422 let span = max_ms.saturating_sub(min_ms);
2423 let jitter = if span == 0 {
2424 0
2425 } else {
2426 let now_ns = SystemTime::now()
2427 .duration_since(UNIX_EPOCH)
2428 .map(|d| d.as_nanos() as u64)
2429 .unwrap_or_default();
2430 now_ns % (span + 1)
2431 };
2432 Some(Duration::from_millis(min_ms.saturating_add(jitter)))
2433 }
2434}
2435
2436#[tonic::async_trait]
2437impl InternalRaft for AstraRaftService {
2438 async fn append_entries(
2439 &self,
2440 request: Request<RaftBytes>,
2441 ) -> Result<Response<RaftBytes>, Status> {
2442 let started = Instant::now();
2443 let rpc: AppendEntriesRequest<AstraTypeConfig> =
2444 serde_json::from_slice(&request.into_inner().payload)
2445 .map_err(|e| Status::invalid_argument(e.to_string()))?;
2446 let entry_count = rpc.entries.len();
2447 let first_log_index = rpc.entries.first().map(|e| e.log_id.index);
2448 let last_log_index = rpc.entries.last().map(|e| e.log_id.index);
2449
2450 if let Some(delay) = self.append_ack_chaos_delay(entry_count) {
2451 tokio::time::sleep(delay).await;
2452 debug!(
2453 stage = "append_entries_chaos_delay",
2454 node_id = self.node_id,
2455 delay_ms = delay.as_millis() as u64,
2456 entry_count,
2457 first_log_index,
2458 last_log_index,
2459 "raft timeline"
2460 );
2461 }
2462
2463 let resp = self
2464 .raft
2465 .append_entries(rpc)
2466 .await
2467 .map_err(|e| Status::internal(e.to_string()))?;
2468 let elapsed_ms = started.elapsed().as_millis() as u64;
2469 debug!(
2470 stage = "append_entries_rpc_recv_done",
2471 entry_count, first_log_index, last_log_index, elapsed_ms, "raft timeline"
2472 );
2473
2474 let payload = serde_json::to_vec(&resp).map_err(|e| Status::internal(e.to_string()))?;
2475 Ok(Response::new(RaftBytes { payload }))
2476 }
2477
2478 async fn vote(&self, request: Request<RaftBytes>) -> Result<Response<RaftBytes>, Status> {
2479 let rpc: VoteRequest<u64> = serde_json::from_slice(&request.into_inner().payload)
2480 .map_err(|e| Status::invalid_argument(e.to_string()))?;
2481
2482 let resp = self
2483 .raft
2484 .vote(rpc)
2485 .await
2486 .map_err(|e| Status::internal(e.to_string()))?;
2487
2488 let payload = serde_json::to_vec(&resp).map_err(|e| Status::internal(e.to_string()))?;
2489 Ok(Response::new(RaftBytes { payload }))
2490 }
2491
2492 async fn install_snapshot(
2493 &self,
2494 request: Request<RaftBytes>,
2495 ) -> Result<Response<RaftBytes>, Status> {
2496 let rpc: InstallSnapshotRequest<AstraTypeConfig> =
2497 serde_json::from_slice(&request.into_inner().payload)
2498 .map_err(|e| Status::invalid_argument(e.to_string()))?;
2499
2500 let resp = self
2501 .raft
2502 .install_snapshot(rpc)
2503 .await
2504 .map_err(|e| Status::internal(e.to_string()))?;
2505
2506 let payload = serde_json::to_vec(&resp).map_err(|e| Status::internal(e.to_string()))?;
2507 Ok(Response::new(RaftBytes { payload }))
2508 }
2509}
2510
2511pub fn parse_raft_nodes(peers: &[String]) -> HashMap<u64, BasicNode> {
2512 peers
2513 .iter()
2514 .enumerate()
2515 .map(|(idx, p)| {
2516 let addr = p
2517 .strip_prefix("http://")
2518 .or_else(|| p.strip_prefix("https://"))
2519 .unwrap_or(p)
2520 .to_string();
2521 ((idx as u64) + 1, BasicNode { addr })
2522 })
2523 .collect::<HashMap<_, _>>()
2524}
2525
2526pub async fn maybe_initialize(
2527 raft: &Raft<AstraTypeConfig>,
2528 nodes: HashMap<u64, BasicNode>,
2529) -> Result<(), RaftError<u64, openraft::error::InitializeError<u64, BasicNode>>> {
2530 if raft.is_initialized().await.unwrap_or(false) {
2531 return Ok(());
2532 }
2533
2534 let members = nodes.into_iter().collect::<BTreeMap<_, _>>();
2535 raft.initialize(members).await
2536}