alopex_chirps/raft/
node.rs

1use crate::raft::metrics::{RaftMetricsCollector, RaftMetricsUpdate};
2use crate::raft::transport::{ChirpsRaftTransport, RaftFramePayload};
3use crate::raft::{
4    AppendEntriesRequest, AppendEntriesResponse, BasicNode, ChirpsNodeId, ChirpsTypeConfig,
5    GroupId, InstallSnapshotRequest, InstallSnapshotResponse, RaftConfig, RaftError, RaftResult,
6    VoteRequest, VoteResponse,
7};
8use anyhow::anyhow;
9use openraft::Raft;
10use openraft::error::{ClientWriteError, RaftError as OpenRaftError};
11use openraft::metrics::RaftMetrics as OpenRaftMetrics;
12use openraft::network::RaftNetworkFactory;
13use openraft::raft::ClientWriteResponse;
14use openraft::storage::{RaftLogStorage, RaftStateMachine};
15use openraft::{Config, ConfigError, LogId, MessageSummary, ServerState, SnapshotPolicy};
16use serde::{Deserialize, Serialize};
17use std::collections::BTreeSet;
18use std::sync::{Arc, Mutex};
19use tokio::sync::watch::Receiver;
20use tokio::task::JoinHandle;
21use tracing::info;
22
23/// Chirps内部でやり取りするRaft RPC。リクエストとレスポンス両方を保持する。
24///
25/// # 例
26///
27/// ```rust,ignore
28/// use alopex_chirps::raft::{RaftMessage, GroupId};
29/// use alopex_chirps_raft_storage::types::{VoteRequest, Vote};
30///
31/// let msg = RaftMessage::Vote {
32///     group_id: GroupId(1),
33///     request: VoteRequest {
34///         vote: Vote::new(1, 1),
35///         last_log_id: None,
36///     },
37/// };
38/// assert_eq!(msg.group_id(), GroupId(1));
39/// ```
40#[derive(Debug, Serialize, Deserialize)]
41pub enum RaftMessage {
42    AppendEntries {
43        group_id: GroupId,
44        request: AppendEntriesRequest<ChirpsTypeConfig>,
45    },
46    AppendEntriesResponse {
47        group_id: GroupId,
48        response: AppendEntriesResponse<ChirpsNodeId>,
49    },
50    Vote {
51        group_id: GroupId,
52        request: VoteRequest<ChirpsNodeId>,
53    },
54    VoteResponse {
55        group_id: GroupId,
56        response: VoteResponse<ChirpsNodeId>,
57    },
58    InstallSnapshot {
59        group_id: GroupId,
60        request: InstallSnapshotRequest<ChirpsTypeConfig>,
61    },
62    InstallSnapshotResponse {
63        group_id: GroupId,
64        response: InstallSnapshotResponse<ChirpsNodeId>,
65    },
66}
67
68impl RaftMessage {
69    pub fn group_id(&self) -> GroupId {
70        match self {
71            RaftMessage::AppendEntries { group_id, .. }
72            | RaftMessage::AppendEntriesResponse { group_id, .. }
73            | RaftMessage::Vote { group_id, .. }
74            | RaftMessage::VoteResponse { group_id, .. }
75            | RaftMessage::InstallSnapshot { group_id, .. }
76            | RaftMessage::InstallSnapshotResponse { group_id, .. } => *group_id,
77        }
78    }
79}
80
81/// openraft Raftをラップし、Chirps固有のエラー/設定型を提供する。
82///
83/// # 例
84///
85/// ```rust,ignore
86/// use alopex_chirps::raft::{RaftConfig, RaftNode};
87/// use alopex_chirps::raft::transport::ChirpsRaftTransport;
88/// use alopex_chirps_raft_storage::types::GroupId;
89/// use std::sync::Arc;
90///
91/// # async fn build() -> anyhow::Result<()> {
92/// let transport = Arc::new(ChirpsRaftTransport::new(mock_backend(), GroupId(1), 1));
93/// let network = ChirpsRaftTransport::factory(transport.clone());
94/// let log_store = build_log_store();      // RaftLogStorageを実装した型を使う
95/// let state_machine = build_state_machine(); // RaftStateMachineを実装した型を使う
96/// let mut node = RaftNode::new(
97///     RaftConfig { group_id: GroupId(1), node_id: 1, ..Default::default() },
98///     network,
99///     log_store,
100///     state_machine,
101///     transport,
102/// ).await?;
103/// node.start().await?;
104/// # Ok(()) }
105/// ```
106pub struct RaftNode {
107    pub(crate) config: RaftConfig,
108    pub(crate) raft: Raft<ChirpsTypeConfig>,
109    #[allow(dead_code)]
110    pub(crate) transport: Arc<ChirpsRaftTransport>,
111    metrics_collector: Arc<Mutex<Option<Arc<RaftMetricsCollector>>>>,
112    #[allow(dead_code)]
113    observer_handle: JoinHandle<()>,
114}
115
116impl RaftNode {
117    /// Raftノードを初期化する。openraft::Raftの生成に必要なConfigを組み立てる。
118    pub async fn new<NF, LS, SM>(
119        config: RaftConfig,
120        network: NF,
121        log_store: LS,
122        state_machine: SM,
123        transport: Arc<ChirpsRaftTransport>,
124    ) -> RaftResult<Self>
125    where
126        NF: RaftNetworkFactory<ChirpsTypeConfig> + Clone + Send + Sync + 'static,
127        NF::Network: Send + Sync,
128        LS: RaftLogStorage<ChirpsTypeConfig> + Send + Sync + 'static,
129        SM: RaftStateMachine<ChirpsTypeConfig> + Send + Sync + 'static,
130    {
131        let cfg = build_openraft_config(&config)
132            .map_err(|e| RaftError::Internal(anyhow!("config error: {e}")))?;
133        let raft = Raft::new(config.node_id, cfg, network, log_store, state_machine)
134            .await
135            .map_err(RaftError::from)?;
136
137        let collector = Arc::new(Mutex::new(None));
138        let observer_handle =
139            spawn_metrics_observer(config.group_id, raft.metrics(), Arc::clone(&collector));
140        info!(
141            target: "raft",
142            event = "raft_initialized",
143            group_id = %config.group_id.0,
144            node_id = %config.node_id,
145            term = %raft.metrics().borrow().current_term,
146            "Raft node initialized"
147        );
148
149        Ok(Self {
150            config,
151            raft,
152            transport,
153            metrics_collector: collector,
154            observer_handle,
155        })
156    }
157
158    /// Raft起動を行う。openraftでは生成時に起動するため、ここではNOP。
159    pub async fn start(&mut self) -> RaftResult<()> {
160        Ok(())
161    }
162
163    /// クラスターを初期化する。初回のみ呼び出すこと。
164    pub async fn initialize(&self, members: BTreeSet<ChirpsNodeId>) -> RaftResult<()> {
165        self.raft.initialize(members).await.map_err(RaftError::from)
166    }
167
168    /// 最新のメトリクススナップショットを取得する。
169    pub fn metrics(&self) -> OpenRaftMetrics<ChirpsNodeId, BasicNode> {
170        self.raft.metrics().borrow().clone()
171    }
172
173    /// 最終適用ログIDを返す。
174    pub fn last_applied_log(&self) -> Option<LogId<ChirpsNodeId>> {
175        self.raft.metrics().borrow().last_applied
176    }
177
178    /// メトリクスコレクタを登録する。登録後は状態変化に応じて自動更新される。
179    pub fn set_metrics_collector(&self, collector: Arc<RaftMetricsCollector>) {
180        if let Ok(mut slot) = self.metrics_collector.lock() {
181            *slot = Some(collector);
182        }
183    }
184
185    /// クライアントコマンドを提案する。NotLeaderの場合はリーダーIDを返す。
186    pub async fn propose(&self, command: Vec<u8>) -> RaftResult<Vec<u8>> {
187        match self.raft.client_write(command).await {
188            Ok(ClientWriteResponse { data, .. }) => {
189                self.push_metrics_update(RaftMetricsUpdate {
190                    proposals_total: 1,
191                    ..Default::default()
192                });
193                Ok(data)
194            }
195            Err(OpenRaftError::APIError(ClientWriteError::ForwardToLeader(fwd))) => {
196                Err(RaftError::NotLeader(fwd.leader_id))
197            }
198            Err(other) => {
199                let reason = other.to_string();
200                tracing::warn!(
201                    target: "raft",
202                    event = "raft_propose_failed",
203                    group_id = %self.config.group_id.0,
204                    node_id = %self.config.node_id,
205                    term = %self.raft.metrics().borrow().current_term,
206                    reason = %reason,
207                    "Proposal failed"
208                );
209                self.push_metrics_update(RaftMetricsUpdate {
210                    proposals_failed_total: 1,
211                    proposals_failed_reason: Some(reason.clone()),
212                    ..Default::default()
213                });
214                Err(RaftError::Internal(anyhow!(reason)))
215            }
216        }
217    }
218
219    /// 現在のリーダーIDを返す。
220    pub fn leader_id(&self) -> Option<ChirpsNodeId> {
221        self.raft.metrics().borrow().current_leader
222    }
223
224    /// 自ノードがリーダーか判定する。
225    pub fn is_leader(&self) -> bool {
226        self.leader_id() == Some(self.config.node_id)
227    }
228
229    /// メンバーシップ変更(Joint Consensus対応)。
230    pub async fn change_membership(&self, members: BTreeSet<ChirpsNodeId>) -> RaftResult<()> {
231        self.raft
232            .change_membership(members, false)
233            .await
234            .map(|_| ())
235            .map_err(RaftError::from)
236    }
237
238    /// Learner追加。
239    pub async fn add_learner(&self, node_id: ChirpsNodeId, node: BasicNode) -> RaftResult<()> {
240        self.raft
241            .add_learner(node_id, node, true)
242            .await
243            .map(|_| ())
244            .map_err(RaftError::from)
245    }
246
247    /// 受信メッセージをopenraftへ橋渡しし、レスポンスを返す。
248    pub async fn handle_message(&self, payload: RaftFramePayload) -> RaftResult<RaftMessage> {
249        if payload.message.group_id() != self.config.group_id {
250            return Err(RaftError::InvalidMessage(format!(
251                "group mismatch: expected {}, got {:?}",
252                self.config.group_id.0,
253                payload.message.group_id()
254            )));
255        }
256        match payload.message {
257            RaftMessage::AppendEntries { request, .. } => {
258                let resp = self
259                    .raft
260                    .append_entries(request)
261                    .await
262                    .map_err(RaftError::from)?;
263                Ok(RaftMessage::AppendEntriesResponse {
264                    group_id: self.config.group_id,
265                    response: resp,
266                })
267            }
268            RaftMessage::Vote { request, .. } => {
269                let resp = self.raft.vote(request).await.map_err(RaftError::from)?;
270                Ok(RaftMessage::VoteResponse {
271                    group_id: self.config.group_id,
272                    response: resp,
273                })
274            }
275            RaftMessage::InstallSnapshot { request, .. } => {
276                let resp = self
277                    .raft
278                    .install_snapshot(request)
279                    .await
280                    .map_err(RaftError::from)?;
281                Ok(RaftMessage::InstallSnapshotResponse {
282                    group_id: self.config.group_id,
283                    response: resp,
284                })
285            }
286            RaftMessage::AppendEntriesResponse { response, .. } => {
287                Ok(RaftMessage::AppendEntriesResponse {
288                    group_id: self.config.group_id,
289                    response,
290                })
291            }
292            RaftMessage::VoteResponse { response, .. } => Ok(RaftMessage::VoteResponse {
293                group_id: self.config.group_id,
294                response,
295            }),
296            RaftMessage::InstallSnapshotResponse { response, .. } => {
297                Ok(RaftMessage::InstallSnapshotResponse {
298                    group_id: self.config.group_id,
299                    response,
300                })
301            }
302        }
303    }
304
305    /// openraftトリガーをそのまま公開。現在はハートビートのみ。
306    pub async fn tick(&self) -> RaftResult<()> {
307        self.raft
308            .trigger()
309            .heartbeat()
310            .await
311            .map_err(RaftError::from)
312    }
313
314    /// スナップショット生成を手動でトリガーする。
315    pub async fn trigger_snapshot(&self) -> RaftResult<()> {
316        self.raft
317            .trigger()
318            .snapshot()
319            .await
320            .map_err(RaftError::from)?;
321
322        let last_log = self.raft.metrics().borrow().last_log_index;
323        tracing::info!(
324            target: "raft",
325            event = "raft_snapshot_created",
326            group_id = %self.config.group_id.0,
327            node_id = %self.config.node_id,
328            log_id = ?last_log,
329            "Snapshot triggered"
330        );
331        self.push_metrics_update(RaftMetricsUpdate {
332            snapshot_total: 1,
333            ..Default::default()
334        });
335        Ok(())
336    }
337}
338
339fn build_openraft_config(src: &RaftConfig) -> Result<Arc<Config>, Box<ConfigError>> {
340    let cfg = Config {
341        cluster_name: format!("chirps-raft-{}", src.group_id.0),
342        election_timeout_min: src.election_timeout_ms,
343        election_timeout_max: src.election_timeout_ms * 2,
344        heartbeat_interval: src.heartbeat_interval_ms,
345        max_payload_entries: src.max_batch_size as u64,
346        snapshot_policy: SnapshotPolicy::LogsSinceLast(src.snapshot_threshold),
347        max_in_snapshot_log_to_keep: src.max_in_snapshot_log_to_keep,
348        ..Default::default()
349    };
350    Ok(Arc::new(cfg.validate().map_err(Box::new)?))
351}
352
353fn spawn_metrics_observer(
354    group_id: GroupId,
355    mut rx: Receiver<OpenRaftMetrics<ChirpsNodeId, BasicNode>>,
356    collector: Arc<Mutex<Option<Arc<RaftMetricsCollector>>>>,
357) -> JoinHandle<()> {
358    tokio::spawn(async move {
359        let mut obs_state = ObservationState::default();
360
361        loop {
362            {
363                let metrics = rx.borrow().clone();
364                if let Ok(slot) = collector.lock()
365                    && let Some(col) = slot.as_ref()
366                {
367                    let update = RaftMetricsUpdate::from((group_id, metrics.clone()));
368                    col.update(&update);
369                }
370                obs_state.handle(group_id, &metrics);
371            }
372
373            if rx.changed().await.is_err() {
374                break;
375            }
376        }
377    })
378}
379
380#[derive(Default)]
381struct ObservationState {
382    last_state: Option<ServerState>,
383    last_leader: Option<ChirpsNodeId>,
384    last_membership: String,
385    last_snapshot: Option<LogId<ChirpsNodeId>>,
386    last_purged: Option<LogId<ChirpsNodeId>>,
387}
388
389impl ObservationState {
390    fn handle(&mut self, group_id: GroupId, metrics: &OpenRaftMetrics<ChirpsNodeId, BasicNode>) {
391        if self.last_state != Some(metrics.state) {
392            tracing::info!(
393                target: "raft",
394                event = "raft_state_changed",
395                group_id = %group_id.0,
396                node_id = %metrics.id,
397                term = %metrics.current_term,
398                old_state = ?self.last_state,
399                new_state = ?metrics.state,
400                "Raft state changed"
401            );
402            self.last_state = Some(metrics.state);
403        }
404
405        if metrics.current_leader != self.last_leader {
406            if let Some(leader_id) = metrics.current_leader {
407                tracing::info!(
408                    target: "raft",
409                    event = "raft_leader_elected",
410                    group_id = %group_id.0,
411                    node_id = %metrics.id,
412                    term = %metrics.current_term,
413                    leader_id = %leader_id,
414                    "Leader elected"
415                );
416            }
417            self.last_leader = metrics.current_leader;
418        }
419
420        let membership_summary = metrics.membership_config.summary();
421        if membership_summary != self.last_membership {
422            let membership = metrics.membership_config.membership();
423            let voter_ids = membership
424                .get_joint_config()
425                .iter()
426                .flatten()
427                .cloned()
428                .collect::<BTreeSet<_>>();
429            let learners = membership
430                .nodes()
431                .filter(|(id, _)| !voter_ids.contains(id))
432                .map(|(id, _)| *id)
433                .collect::<Vec<_>>();
434            tracing::info!(
435                target: "raft",
436                event = "raft_membership_changed",
437                group_id = %group_id.0,
438                node_id = %metrics.id,
439                term = %metrics.current_term,
440                voters = ?membership.get_joint_config(),
441                learners = ?learners,
442                "Membership changed"
443            );
444            self.last_membership = membership_summary;
445        }
446
447        if metrics.snapshot != self.last_snapshot {
448            if let Some(log_id) = metrics.snapshot {
449                tracing::info!(
450                    target: "raft",
451                    event = "raft_snapshot_installed",
452                    group_id = %group_id.0,
453                    node_id = %metrics.id,
454                    term = %metrics.current_term,
455                    log_id = ?log_id,
456                    "Snapshot installed"
457                );
458            }
459            self.last_snapshot = metrics.snapshot;
460        }
461
462        if metrics.purged != self.last_purged {
463            if let Some(log_id) = metrics.purged {
464                tracing::info!(
465                    target: "raft",
466                    event = "raft_log_compacted",
467                    group_id = %group_id.0,
468                    node_id = %metrics.id,
469                    term = %metrics.current_term,
470                    up_to_log_id = ?log_id,
471                    "Log compacted"
472                );
473            }
474            self.last_purged = metrics.purged;
475        }
476    }
477}
478
479impl RaftNode {
480    fn push_metrics_update(&self, update: RaftMetricsUpdate) {
481        if let Ok(slot) = self.metrics_collector.lock()
482            && let Some(col) = slot.as_ref()
483        {
484            let mut base = RaftMetricsUpdate::from((
485                self.config.group_id,
486                self.raft.metrics().borrow().clone(),
487            ));
488            base.snapshot_total = update.snapshot_total;
489            base.proposals_total = update.proposals_total;
490            base.proposals_failed_total = update.proposals_failed_total;
491            base.proposals_failed_reason = update.proposals_failed_reason;
492            col.update(&base);
493        }
494    }
495}
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500    use alopex_chirps_wire::frame::{Frame, RaftFrame};
501    use bincode;
502    use openraft::CommittedLeaderId;
503    use openraft::ServerState;
504    use openraft::metrics::RaftMetrics as OpenRaftMetrics;
505    use serde_json::Value;
506    use std::io;
507    use tracing_subscriber::FmtSubscriber;
508    use tracing_subscriber::fmt::writer::MakeWriter;
509
510    #[test]
511    fn config_defaults_match_design() {
512        let cfg = RaftConfig::default();
513        assert_eq!(cfg.election_timeout_ms, 150);
514        assert_eq!(cfg.heartbeat_interval_ms, 50);
515        assert_eq!(cfg.max_batch_size, 1_000);
516        assert_eq!(cfg.snapshot_threshold, 10_000);
517        assert_eq!(cfg.max_in_snapshot_log_to_keep, 1_000);
518    }
519
520    #[test]
521    fn raft_message_reports_group() {
522        let msg = RaftMessage::Vote {
523            group_id: GroupId(42),
524            request: VoteRequest {
525                vote: alopex_chirps_raft_storage::types::Vote::new(0, 0),
526                last_log_id: None,
527            },
528        };
529        assert_eq!(msg.group_id(), GroupId(42));
530    }
531
532    #[test]
533    fn decode_frame_roundtrip() {
534        let payload = RaftFramePayload {
535            correlation_id: 7,
536            message: RaftMessage::AppendEntries {
537                group_id: GroupId(1),
538                request: AppendEntriesRequest {
539                    vote: alopex_chirps_raft_storage::types::Vote::new(0, 0),
540                    prev_log_id: None,
541                    entries: Vec::new(),
542                    leader_commit: None,
543                },
544            },
545        };
546        let bytes = bincode::serialize(&payload).expect("serialize");
547        let frame = Frame::Raft(RaftFrame {
548            group_id: 1,
549            payload: bytes,
550        });
551        let decoded = ChirpsRaftTransport::decode_frame(frame).expect("decode");
552        assert_eq!(decoded.correlation_id, 7);
553        assert_eq!(decoded.message.group_id(), GroupId(1));
554    }
555
556    #[test]
557    fn observation_state_emits_structured_logs() {
558        #[derive(Clone)]
559        struct MemoryMakeWriter(Arc<Mutex<Vec<u8>>>);
560        struct MemoryWriter(Arc<Mutex<Vec<u8>>>);
561
562        impl<'a> MakeWriter<'a> for MemoryMakeWriter {
563            type Writer = MemoryWriter;
564
565            fn make_writer(&'a self) -> Self::Writer {
566                MemoryWriter(Arc::clone(&self.0))
567            }
568        }
569
570        impl io::Write for MemoryWriter {
571            fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
572                let mut lock = self.0.lock().unwrap();
573                lock.extend_from_slice(buf);
574                Ok(buf.len())
575            }
576
577            fn flush(&mut self) -> io::Result<()> {
578                Ok(())
579            }
580        }
581
582        let buffer = Arc::new(Mutex::new(Vec::new()));
583        let subscriber = FmtSubscriber::builder()
584            .json()
585            .with_writer(MemoryMakeWriter(Arc::clone(&buffer)))
586            .finish();
587
588        tracing::subscriber::with_default(subscriber, || {
589            let mut obs = ObservationState::default();
590            let mut metrics = OpenRaftMetrics::new_initial(1);
591            metrics.state = ServerState::Leader;
592            metrics.current_term = 3;
593            metrics.current_leader = Some(1);
594            metrics.snapshot = Some(LogId::new(CommittedLeaderId::new(3, 1), 2));
595            metrics.purged = Some(LogId::new(CommittedLeaderId::new(2, 1), 1));
596
597            obs.handle(GroupId(9), &metrics);
598        });
599
600        let logs = String::from_utf8(buffer.lock().unwrap().clone()).expect("utf8");
601        let mut events = Vec::new();
602        for line in logs.lines() {
603            let v: Value = serde_json::from_str(line).expect("json");
604            if let Some(ev) = v
605                .get("fields")
606                .and_then(|fields| fields.get("event"))
607                .and_then(|e| e.as_str())
608            {
609                events.push(ev.to_string());
610            }
611            if let Some(target) = v.get("target").and_then(|t| t.as_str()) {
612                assert_eq!(target, "raft", "log target should be raft");
613            }
614        }
615
616        assert!(
617            events.contains(&"raft_state_changed".to_string()),
618            "state change event expected"
619        );
620        assert!(
621            events.contains(&"raft_leader_elected".to_string()),
622            "leader election event expected"
623        );
624        assert!(
625            events.contains(&"raft_snapshot_installed".to_string()),
626            "snapshot installed event expected"
627        );
628        assert!(
629            events.contains(&"raft_log_compacted".to_string()),
630            "log compacted event expected"
631        );
632    }
633}