1use crate::raft::metrics::{RaftMetricsCollector, RaftMetricsUpdate};
2use crate::raft::transport::{ChirpsRaftTransport, RaftFramePayload};
3use crate::raft::{
4 AppendEntriesRequest, AppendEntriesResponse, BasicNode, ChirpsNodeId, ChirpsTypeConfig,
5 GroupId, InstallSnapshotRequest, InstallSnapshotResponse, RaftConfig, RaftError, RaftResult,
6 VoteRequest, VoteResponse,
7};
8use anyhow::anyhow;
9use openraft::Raft;
10use openraft::error::{ClientWriteError, RaftError as OpenRaftError};
11use openraft::metrics::RaftMetrics as OpenRaftMetrics;
12use openraft::network::RaftNetworkFactory;
13use openraft::raft::ClientWriteResponse;
14use openraft::storage::{RaftLogStorage, RaftStateMachine};
15use openraft::{Config, ConfigError, LogId, MessageSummary, ServerState, SnapshotPolicy};
16use serde::{Deserialize, Serialize};
17use std::collections::BTreeSet;
18use std::sync::{Arc, Mutex};
19use tokio::sync::watch::Receiver;
20use tokio::task::JoinHandle;
21use tracing::info;
22
23#[derive(Debug, Serialize, Deserialize)]
41pub enum RaftMessage {
42 AppendEntries {
43 group_id: GroupId,
44 request: AppendEntriesRequest<ChirpsTypeConfig>,
45 },
46 AppendEntriesResponse {
47 group_id: GroupId,
48 response: AppendEntriesResponse<ChirpsNodeId>,
49 },
50 Vote {
51 group_id: GroupId,
52 request: VoteRequest<ChirpsNodeId>,
53 },
54 VoteResponse {
55 group_id: GroupId,
56 response: VoteResponse<ChirpsNodeId>,
57 },
58 InstallSnapshot {
59 group_id: GroupId,
60 request: InstallSnapshotRequest<ChirpsTypeConfig>,
61 },
62 InstallSnapshotResponse {
63 group_id: GroupId,
64 response: InstallSnapshotResponse<ChirpsNodeId>,
65 },
66}
67
68impl RaftMessage {
69 pub fn group_id(&self) -> GroupId {
70 match self {
71 RaftMessage::AppendEntries { group_id, .. }
72 | RaftMessage::AppendEntriesResponse { group_id, .. }
73 | RaftMessage::Vote { group_id, .. }
74 | RaftMessage::VoteResponse { group_id, .. }
75 | RaftMessage::InstallSnapshot { group_id, .. }
76 | RaftMessage::InstallSnapshotResponse { group_id, .. } => *group_id,
77 }
78 }
79}
80
81pub struct RaftNode {
107 pub(crate) config: RaftConfig,
108 pub(crate) raft: Raft<ChirpsTypeConfig>,
109 #[allow(dead_code)]
110 pub(crate) transport: Arc<ChirpsRaftTransport>,
111 metrics_collector: Arc<Mutex<Option<Arc<RaftMetricsCollector>>>>,
112 #[allow(dead_code)]
113 observer_handle: JoinHandle<()>,
114}
115
116impl RaftNode {
117 pub async fn new<NF, LS, SM>(
119 config: RaftConfig,
120 network: NF,
121 log_store: LS,
122 state_machine: SM,
123 transport: Arc<ChirpsRaftTransport>,
124 ) -> RaftResult<Self>
125 where
126 NF: RaftNetworkFactory<ChirpsTypeConfig> + Clone + Send + Sync + 'static,
127 NF::Network: Send + Sync,
128 LS: RaftLogStorage<ChirpsTypeConfig> + Send + Sync + 'static,
129 SM: RaftStateMachine<ChirpsTypeConfig> + Send + Sync + 'static,
130 {
131 let cfg = build_openraft_config(&config)
132 .map_err(|e| RaftError::Internal(anyhow!("config error: {e}")))?;
133 let raft = Raft::new(config.node_id, cfg, network, log_store, state_machine)
134 .await
135 .map_err(RaftError::from)?;
136
137 let collector = Arc::new(Mutex::new(None));
138 let observer_handle =
139 spawn_metrics_observer(config.group_id, raft.metrics(), Arc::clone(&collector));
140 info!(
141 target: "raft",
142 event = "raft_initialized",
143 group_id = %config.group_id.0,
144 node_id = %config.node_id,
145 term = %raft.metrics().borrow().current_term,
146 "Raft node initialized"
147 );
148
149 Ok(Self {
150 config,
151 raft,
152 transport,
153 metrics_collector: collector,
154 observer_handle,
155 })
156 }
157
158 pub async fn start(&mut self) -> RaftResult<()> {
160 Ok(())
161 }
162
163 pub async fn initialize(&self, members: BTreeSet<ChirpsNodeId>) -> RaftResult<()> {
165 self.raft.initialize(members).await.map_err(RaftError::from)
166 }
167
168 pub fn metrics(&self) -> OpenRaftMetrics<ChirpsNodeId, BasicNode> {
170 self.raft.metrics().borrow().clone()
171 }
172
173 pub fn last_applied_log(&self) -> Option<LogId<ChirpsNodeId>> {
175 self.raft.metrics().borrow().last_applied
176 }
177
178 pub fn set_metrics_collector(&self, collector: Arc<RaftMetricsCollector>) {
180 if let Ok(mut slot) = self.metrics_collector.lock() {
181 *slot = Some(collector);
182 }
183 }
184
185 pub async fn propose(&self, command: Vec<u8>) -> RaftResult<Vec<u8>> {
187 match self.raft.client_write(command).await {
188 Ok(ClientWriteResponse { data, .. }) => {
189 self.push_metrics_update(RaftMetricsUpdate {
190 proposals_total: 1,
191 ..Default::default()
192 });
193 Ok(data)
194 }
195 Err(OpenRaftError::APIError(ClientWriteError::ForwardToLeader(fwd))) => {
196 Err(RaftError::NotLeader(fwd.leader_id))
197 }
198 Err(other) => {
199 let reason = other.to_string();
200 tracing::warn!(
201 target: "raft",
202 event = "raft_propose_failed",
203 group_id = %self.config.group_id.0,
204 node_id = %self.config.node_id,
205 term = %self.raft.metrics().borrow().current_term,
206 reason = %reason,
207 "Proposal failed"
208 );
209 self.push_metrics_update(RaftMetricsUpdate {
210 proposals_failed_total: 1,
211 proposals_failed_reason: Some(reason.clone()),
212 ..Default::default()
213 });
214 Err(RaftError::Internal(anyhow!(reason)))
215 }
216 }
217 }
218
219 pub fn leader_id(&self) -> Option<ChirpsNodeId> {
221 self.raft.metrics().borrow().current_leader
222 }
223
224 pub fn is_leader(&self) -> bool {
226 self.leader_id() == Some(self.config.node_id)
227 }
228
229 pub async fn change_membership(&self, members: BTreeSet<ChirpsNodeId>) -> RaftResult<()> {
231 self.raft
232 .change_membership(members, false)
233 .await
234 .map(|_| ())
235 .map_err(RaftError::from)
236 }
237
238 pub async fn add_learner(&self, node_id: ChirpsNodeId, node: BasicNode) -> RaftResult<()> {
240 self.raft
241 .add_learner(node_id, node, true)
242 .await
243 .map(|_| ())
244 .map_err(RaftError::from)
245 }
246
247 pub async fn handle_message(&self, payload: RaftFramePayload) -> RaftResult<RaftMessage> {
249 if payload.message.group_id() != self.config.group_id {
250 return Err(RaftError::InvalidMessage(format!(
251 "group mismatch: expected {}, got {:?}",
252 self.config.group_id.0,
253 payload.message.group_id()
254 )));
255 }
256 match payload.message {
257 RaftMessage::AppendEntries { request, .. } => {
258 let resp = self
259 .raft
260 .append_entries(request)
261 .await
262 .map_err(RaftError::from)?;
263 Ok(RaftMessage::AppendEntriesResponse {
264 group_id: self.config.group_id,
265 response: resp,
266 })
267 }
268 RaftMessage::Vote { request, .. } => {
269 let resp = self.raft.vote(request).await.map_err(RaftError::from)?;
270 Ok(RaftMessage::VoteResponse {
271 group_id: self.config.group_id,
272 response: resp,
273 })
274 }
275 RaftMessage::InstallSnapshot { request, .. } => {
276 let resp = self
277 .raft
278 .install_snapshot(request)
279 .await
280 .map_err(RaftError::from)?;
281 Ok(RaftMessage::InstallSnapshotResponse {
282 group_id: self.config.group_id,
283 response: resp,
284 })
285 }
286 RaftMessage::AppendEntriesResponse { response, .. } => {
287 Ok(RaftMessage::AppendEntriesResponse {
288 group_id: self.config.group_id,
289 response,
290 })
291 }
292 RaftMessage::VoteResponse { response, .. } => Ok(RaftMessage::VoteResponse {
293 group_id: self.config.group_id,
294 response,
295 }),
296 RaftMessage::InstallSnapshotResponse { response, .. } => {
297 Ok(RaftMessage::InstallSnapshotResponse {
298 group_id: self.config.group_id,
299 response,
300 })
301 }
302 }
303 }
304
305 pub async fn tick(&self) -> RaftResult<()> {
307 self.raft
308 .trigger()
309 .heartbeat()
310 .await
311 .map_err(RaftError::from)
312 }
313
314 pub async fn trigger_snapshot(&self) -> RaftResult<()> {
316 self.raft
317 .trigger()
318 .snapshot()
319 .await
320 .map_err(RaftError::from)?;
321
322 let last_log = self.raft.metrics().borrow().last_log_index;
323 tracing::info!(
324 target: "raft",
325 event = "raft_snapshot_created",
326 group_id = %self.config.group_id.0,
327 node_id = %self.config.node_id,
328 log_id = ?last_log,
329 "Snapshot triggered"
330 );
331 self.push_metrics_update(RaftMetricsUpdate {
332 snapshot_total: 1,
333 ..Default::default()
334 });
335 Ok(())
336 }
337}
338
339fn build_openraft_config(src: &RaftConfig) -> Result<Arc<Config>, Box<ConfigError>> {
340 let cfg = Config {
341 cluster_name: format!("chirps-raft-{}", src.group_id.0),
342 election_timeout_min: src.election_timeout_ms,
343 election_timeout_max: src.election_timeout_ms * 2,
344 heartbeat_interval: src.heartbeat_interval_ms,
345 max_payload_entries: src.max_batch_size as u64,
346 snapshot_policy: SnapshotPolicy::LogsSinceLast(src.snapshot_threshold),
347 max_in_snapshot_log_to_keep: src.max_in_snapshot_log_to_keep,
348 ..Default::default()
349 };
350 Ok(Arc::new(cfg.validate().map_err(Box::new)?))
351}
352
353fn spawn_metrics_observer(
354 group_id: GroupId,
355 mut rx: Receiver<OpenRaftMetrics<ChirpsNodeId, BasicNode>>,
356 collector: Arc<Mutex<Option<Arc<RaftMetricsCollector>>>>,
357) -> JoinHandle<()> {
358 tokio::spawn(async move {
359 let mut obs_state = ObservationState::default();
360
361 loop {
362 {
363 let metrics = rx.borrow().clone();
364 if let Ok(slot) = collector.lock()
365 && let Some(col) = slot.as_ref()
366 {
367 let update = RaftMetricsUpdate::from((group_id, metrics.clone()));
368 col.update(&update);
369 }
370 obs_state.handle(group_id, &metrics);
371 }
372
373 if rx.changed().await.is_err() {
374 break;
375 }
376 }
377 })
378}
379
380#[derive(Default)]
381struct ObservationState {
382 last_state: Option<ServerState>,
383 last_leader: Option<ChirpsNodeId>,
384 last_membership: String,
385 last_snapshot: Option<LogId<ChirpsNodeId>>,
386 last_purged: Option<LogId<ChirpsNodeId>>,
387}
388
389impl ObservationState {
390 fn handle(&mut self, group_id: GroupId, metrics: &OpenRaftMetrics<ChirpsNodeId, BasicNode>) {
391 if self.last_state != Some(metrics.state) {
392 tracing::info!(
393 target: "raft",
394 event = "raft_state_changed",
395 group_id = %group_id.0,
396 node_id = %metrics.id,
397 term = %metrics.current_term,
398 old_state = ?self.last_state,
399 new_state = ?metrics.state,
400 "Raft state changed"
401 );
402 self.last_state = Some(metrics.state);
403 }
404
405 if metrics.current_leader != self.last_leader {
406 if let Some(leader_id) = metrics.current_leader {
407 tracing::info!(
408 target: "raft",
409 event = "raft_leader_elected",
410 group_id = %group_id.0,
411 node_id = %metrics.id,
412 term = %metrics.current_term,
413 leader_id = %leader_id,
414 "Leader elected"
415 );
416 }
417 self.last_leader = metrics.current_leader;
418 }
419
420 let membership_summary = metrics.membership_config.summary();
421 if membership_summary != self.last_membership {
422 let membership = metrics.membership_config.membership();
423 let voter_ids = membership
424 .get_joint_config()
425 .iter()
426 .flatten()
427 .cloned()
428 .collect::<BTreeSet<_>>();
429 let learners = membership
430 .nodes()
431 .filter(|(id, _)| !voter_ids.contains(id))
432 .map(|(id, _)| *id)
433 .collect::<Vec<_>>();
434 tracing::info!(
435 target: "raft",
436 event = "raft_membership_changed",
437 group_id = %group_id.0,
438 node_id = %metrics.id,
439 term = %metrics.current_term,
440 voters = ?membership.get_joint_config(),
441 learners = ?learners,
442 "Membership changed"
443 );
444 self.last_membership = membership_summary;
445 }
446
447 if metrics.snapshot != self.last_snapshot {
448 if let Some(log_id) = metrics.snapshot {
449 tracing::info!(
450 target: "raft",
451 event = "raft_snapshot_installed",
452 group_id = %group_id.0,
453 node_id = %metrics.id,
454 term = %metrics.current_term,
455 log_id = ?log_id,
456 "Snapshot installed"
457 );
458 }
459 self.last_snapshot = metrics.snapshot;
460 }
461
462 if metrics.purged != self.last_purged {
463 if let Some(log_id) = metrics.purged {
464 tracing::info!(
465 target: "raft",
466 event = "raft_log_compacted",
467 group_id = %group_id.0,
468 node_id = %metrics.id,
469 term = %metrics.current_term,
470 up_to_log_id = ?log_id,
471 "Log compacted"
472 );
473 }
474 self.last_purged = metrics.purged;
475 }
476 }
477}
478
479impl RaftNode {
480 fn push_metrics_update(&self, update: RaftMetricsUpdate) {
481 if let Ok(slot) = self.metrics_collector.lock()
482 && let Some(col) = slot.as_ref()
483 {
484 let mut base = RaftMetricsUpdate::from((
485 self.config.group_id,
486 self.raft.metrics().borrow().clone(),
487 ));
488 base.snapshot_total = update.snapshot_total;
489 base.proposals_total = update.proposals_total;
490 base.proposals_failed_total = update.proposals_failed_total;
491 base.proposals_failed_reason = update.proposals_failed_reason;
492 col.update(&base);
493 }
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500 use alopex_chirps_wire::frame::{Frame, RaftFrame};
501 use bincode;
502 use openraft::CommittedLeaderId;
503 use openraft::ServerState;
504 use openraft::metrics::RaftMetrics as OpenRaftMetrics;
505 use serde_json::Value;
506 use std::io;
507 use tracing_subscriber::FmtSubscriber;
508 use tracing_subscriber::fmt::writer::MakeWriter;
509
510 #[test]
511 fn config_defaults_match_design() {
512 let cfg = RaftConfig::default();
513 assert_eq!(cfg.election_timeout_ms, 150);
514 assert_eq!(cfg.heartbeat_interval_ms, 50);
515 assert_eq!(cfg.max_batch_size, 1_000);
516 assert_eq!(cfg.snapshot_threshold, 10_000);
517 assert_eq!(cfg.max_in_snapshot_log_to_keep, 1_000);
518 }
519
520 #[test]
521 fn raft_message_reports_group() {
522 let msg = RaftMessage::Vote {
523 group_id: GroupId(42),
524 request: VoteRequest {
525 vote: alopex_chirps_raft_storage::types::Vote::new(0, 0),
526 last_log_id: None,
527 },
528 };
529 assert_eq!(msg.group_id(), GroupId(42));
530 }
531
532 #[test]
533 fn decode_frame_roundtrip() {
534 let payload = RaftFramePayload {
535 correlation_id: 7,
536 message: RaftMessage::AppendEntries {
537 group_id: GroupId(1),
538 request: AppendEntriesRequest {
539 vote: alopex_chirps_raft_storage::types::Vote::new(0, 0),
540 prev_log_id: None,
541 entries: Vec::new(),
542 leader_commit: None,
543 },
544 },
545 };
546 let bytes = bincode::serialize(&payload).expect("serialize");
547 let frame = Frame::Raft(RaftFrame {
548 group_id: 1,
549 payload: bytes,
550 });
551 let decoded = ChirpsRaftTransport::decode_frame(frame).expect("decode");
552 assert_eq!(decoded.correlation_id, 7);
553 assert_eq!(decoded.message.group_id(), GroupId(1));
554 }
555
556 #[test]
557 fn observation_state_emits_structured_logs() {
558 #[derive(Clone)]
559 struct MemoryMakeWriter(Arc<Mutex<Vec<u8>>>);
560 struct MemoryWriter(Arc<Mutex<Vec<u8>>>);
561
562 impl<'a> MakeWriter<'a> for MemoryMakeWriter {
563 type Writer = MemoryWriter;
564
565 fn make_writer(&'a self) -> Self::Writer {
566 MemoryWriter(Arc::clone(&self.0))
567 }
568 }
569
570 impl io::Write for MemoryWriter {
571 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
572 let mut lock = self.0.lock().unwrap();
573 lock.extend_from_slice(buf);
574 Ok(buf.len())
575 }
576
577 fn flush(&mut self) -> io::Result<()> {
578 Ok(())
579 }
580 }
581
582 let buffer = Arc::new(Mutex::new(Vec::new()));
583 let subscriber = FmtSubscriber::builder()
584 .json()
585 .with_writer(MemoryMakeWriter(Arc::clone(&buffer)))
586 .finish();
587
588 tracing::subscriber::with_default(subscriber, || {
589 let mut obs = ObservationState::default();
590 let mut metrics = OpenRaftMetrics::new_initial(1);
591 metrics.state = ServerState::Leader;
592 metrics.current_term = 3;
593 metrics.current_leader = Some(1);
594 metrics.snapshot = Some(LogId::new(CommittedLeaderId::new(3, 1), 2));
595 metrics.purged = Some(LogId::new(CommittedLeaderId::new(2, 1), 1));
596
597 obs.handle(GroupId(9), &metrics);
598 });
599
600 let logs = String::from_utf8(buffer.lock().unwrap().clone()).expect("utf8");
601 let mut events = Vec::new();
602 for line in logs.lines() {
603 let v: Value = serde_json::from_str(line).expect("json");
604 if let Some(ev) = v
605 .get("fields")
606 .and_then(|fields| fields.get("event"))
607 .and_then(|e| e.as_str())
608 {
609 events.push(ev.to_string());
610 }
611 if let Some(target) = v.get("target").and_then(|t| t.as_str()) {
612 assert_eq!(target, "raft", "log target should be raft");
613 }
614 }
615
616 assert!(
617 events.contains(&"raft_state_changed".to_string()),
618 "state change event expected"
619 );
620 assert!(
621 events.contains(&"raft_leader_elected".to_string()),
622 "leader election event expected"
623 );
624 assert!(
625 events.contains(&"raft_snapshot_installed".to_string()),
626 "snapshot installed event expected"
627 );
628 assert!(
629 events.contains(&"raft_log_compacted".to_string()),
630 "log compacted event expected"
631 );
632 }
633}