rocketmq_controller/raft/
mod.rs1mod network;
19mod node;
20mod storage;
21mod transport;
22
23use std::sync::Arc;
24use std::time::Duration;
25
26pub use network::NetworkManager;
27pub use node::RaftNode;
28use raft::prelude::*;
29pub use storage::MemStorage;
30use tokio::sync::mpsc;
31use tokio::sync::RwLock;
32use tokio::time;
33use tracing::debug;
34use tracing::error;
35use tracing::info;
36pub use transport::MessageCodec;
37pub use transport::PeerConnection;
38pub use transport::RaftTransport;
39
40use crate::config::ControllerConfig;
41use crate::error::ControllerError;
42use crate::error::Result;
43
44#[derive(Debug)]
46pub enum RaftMessage {
47 Propose {
49 data: Vec<u8>,
50 response: tokio::sync::oneshot::Sender<Result<Vec<u8>>>,
51 },
52 Step { message: Message },
54 Tick,
56 Query {
58 data: Vec<u8>,
59 response: tokio::sync::oneshot::Sender<Result<Vec<u8>>>,
60 },
61 Shutdown,
63}
64
65pub struct RaftController {
75 node_id: u64,
77
78 node: Arc<RwLock<Option<RaftNode>>>,
80
81 network: Arc<RwLock<Option<NetworkManager>>>,
83
84 tx: mpsc::UnboundedSender<RaftMessage>,
86
87 config: Arc<ControllerConfig>,
89}
90
91impl RaftController {
92 pub async fn new(config: Arc<ControllerConfig>) -> Result<Self> {
94 let node_id = config.node_id;
95 let (tx, rx) = mpsc::unbounded_channel();
96
97 let controller = Self {
98 node_id,
99 node: Arc::new(RwLock::new(None)),
100 network: Arc::new(RwLock::new(None)),
101 tx,
102 config: config.clone(),
103 };
104
105 let node = RaftNode::new(node_id, config.clone()).await?;
107 *controller.node.write().await = Some(node);
108
109 let (network_manager, incoming_rx) = NetworkManager::new(config.clone());
111 *controller.network.write().await = Some(network_manager);
112
113 let node_clone = controller.node.clone();
115 tokio::spawn(async move {
116 Self::message_loop(node_clone, rx).await;
117 });
118
119 let tx_clone = controller.tx.clone();
121 tokio::spawn(async move {
122 Self::incoming_message_loop(incoming_rx, tx_clone).await;
123 });
124
125 Ok(controller)
126 }
127
128 pub async fn start(&self) -> Result<()> {
130 info!("Starting Raft controller for node {}", self.node_id);
131
132 if let Some(network) = self.network.write().await.as_mut() {
134 network.start().await?;
135 }
136
137 let tx = self.tx.clone();
139 tokio::spawn(async move {
140 let mut interval = time::interval(Duration::from_millis(100));
141 loop {
142 interval.tick().await;
143 if tx.send(RaftMessage::Tick).is_err() {
144 break;
145 }
146 }
147 });
148
149 Ok(())
150 }
151
152 pub async fn shutdown(&self) -> Result<()> {
154 info!("Shutting down Raft controller for node {}", self.node_id);
155
156 if let Some(network) = self.network.read().await.as_ref() {
158 network.shutdown().await?;
159 }
160
161 self.tx
162 .send(RaftMessage::Shutdown)
163 .map_err(|_| ControllerError::Shutdown)?;
164 Ok(())
165 }
166
167 pub async fn propose(&self, data: Vec<u8>) -> Result<Vec<u8>> {
169 let (tx, rx) = tokio::sync::oneshot::channel();
170 self.tx
171 .send(RaftMessage::Propose { data, response: tx })
172 .map_err(|_| ControllerError::Shutdown)?;
173
174 rx.await
175 .map_err(|_| ControllerError::Internal("Response channel closed".to_string()))?
176 }
177
178 pub async fn query(&self, data: Vec<u8>) -> Result<Vec<u8>> {
180 let (tx, rx) = tokio::sync::oneshot::channel();
181 self.tx
182 .send(RaftMessage::Query { data, response: tx })
183 .map_err(|_| ControllerError::Shutdown)?;
184
185 rx.await
186 .map_err(|_| ControllerError::Internal("Response channel closed".to_string()))?
187 }
188
189 pub async fn is_leader(&self) -> bool {
191 if let Some(node) = self.node.read().await.as_ref() {
192 node.is_leader().await
193 } else {
194 false
195 }
196 }
197
198 pub async fn get_leader(&self) -> Option<u64> {
200 if let Some(node) = self.node.read().await.as_ref() {
201 node.get_leader().await
202 } else {
203 None
204 }
205 }
206
207 pub async fn step(&self, message: Message) -> Result<()> {
209 self.tx
210 .send(RaftMessage::Step { message })
211 .map_err(|_| ControllerError::Shutdown)?;
212 Ok(())
213 }
214
215 async fn message_loop(
217 node: Arc<RwLock<Option<RaftNode>>>,
218 mut rx: mpsc::UnboundedReceiver<RaftMessage>,
219 ) {
220 while let Some(msg) = rx.recv().await {
221 match msg {
222 RaftMessage::Propose { data, response } => {
223 let result = if let Some(n) = node.read().await.as_ref() {
224 n.propose(data).await
225 } else {
226 Err(ControllerError::Internal(
227 "Node not initialized".to_string(),
228 ))
229 };
230 let _ = response.send(result);
231 }
232 RaftMessage::Step { message } => {
233 if let Some(n) = node.read().await.as_ref() {
234 if let Err(e) = n.step(message).await {
235 error!("Failed to step Raft: {}", e);
236 }
237 }
238 }
239 RaftMessage::Tick => {
240 if let Some(n) = node.read().await.as_ref() {
241 if let Err(e) = n.tick().await {
242 error!("Failed to tick Raft: {}", e);
243 }
244 }
245 }
246 RaftMessage::Query { data, response } => {
247 let result = if let Some(n) = node.read().await.as_ref() {
248 n.query(data).await
249 } else {
250 Err(ControllerError::Internal(
251 "Node not initialized".to_string(),
252 ))
253 };
254 let _ = response.send(result);
255 }
256 RaftMessage::Shutdown => {
257 info!("Raft controller shutting down");
258 break;
259 }
260 }
261 }
262 }
263
264 async fn incoming_message_loop(
266 mut incoming_rx: mpsc::UnboundedReceiver<Message>,
267 tx: mpsc::UnboundedSender<RaftMessage>,
268 ) {
269 info!("Starting incoming message loop");
270
271 while let Some(message) = incoming_rx.recv().await {
272 debug!(
273 "Received Raft message from network: {:?}",
274 message.get_msg_type()
275 );
276
277 if tx.send(RaftMessage::Step { message }).is_err() {
278 error!("Failed to forward incoming message to Raft");
279 break;
280 }
281 }
282
283 info!("Incoming message loop stopped");
284 }
285}
286
287