rocketmq_controller/raft/
transport.rs1use std::collections::HashMap;
19use std::net::SocketAddr;
20use std::sync::Arc;
21
22use bytes::Bytes;
23use bytes::BytesMut;
24use protobuf::Message as ProtobufMessage;
25use raft::eraftpb;
26use raft::prelude::Message;
27use tokio::io::AsyncReadExt;
28use tokio::io::AsyncWriteExt;
29use tokio::net::TcpListener;
30use tokio::net::TcpStream;
31use tokio::sync::mpsc;
32use tokio::sync::RwLock;
33use tracing::debug;
34use tracing::error;
35use tracing::info;
36use tracing::warn;
37
38use crate::error::ControllerError;
39use crate::error::Result;
40
41pub struct MessageCodec;
43
44impl MessageCodec {
45 pub fn encode(msg: &Message) -> Result<Bytes> {
47 let proto_msg: eraftpb::Message = msg.clone();
49
50 let encoded = proto_msg
52 .write_to_bytes()
53 .map_err(|e| ControllerError::SerializationError(e.to_string()))?;
54
55 let len = encoded.len() as u32;
57 let mut result = BytesMut::with_capacity(4 + encoded.len());
58 result.extend_from_slice(&len.to_be_bytes());
59 result.extend_from_slice(&encoded);
60
61 Ok(result.freeze())
62 }
63
64 pub async fn decode(stream: &mut TcpStream) -> Result<Message> {
66 let mut len_buf = [0u8; 4];
68 stream
69 .read_exact(&mut len_buf)
70 .await
71 .map_err(|e| ControllerError::NetworkError(e.to_string()))?;
72
73 let len = u32::from_be_bytes(len_buf) as usize;
74
75 if len > 10 * 1024 * 1024 {
77 return Err(ControllerError::InvalidRequest(format!(
78 "Message too large: {} bytes",
79 len
80 )));
81 }
82
83 let mut buf = vec![0u8; len];
85 stream
86 .read_exact(&mut buf)
87 .await
88 .map_err(|e| ControllerError::NetworkError(e.to_string()))?;
89
90 let proto_msg = eraftpb::Message::parse_from_bytes(&buf)
92 .map_err(|e| ControllerError::SerializationError(e.to_string()))?;
93
94 Ok(proto_msg)
96 }
97}
98
99pub struct PeerConnection {
101 peer_id: u64,
103
104 addr: SocketAddr,
106
107 stream: Option<TcpStream>,
109
110 tx: mpsc::UnboundedSender<Message>,
112
113 rx: mpsc::UnboundedReceiver<Message>,
115}
116
117impl PeerConnection {
118 pub fn new(peer_id: u64, addr: SocketAddr) -> Self {
120 let (tx, rx) = mpsc::unbounded_channel();
121
122 Self {
123 peer_id,
124 addr,
125 stream: None,
126 tx,
127 rx,
128 }
129 }
130
131 pub async fn connect(&mut self) -> Result<()> {
133 debug!("Connecting to peer {} at {}", self.peer_id, self.addr);
134
135 match TcpStream::connect(self.addr).await {
136 Ok(stream) => {
137 info!(
138 "Successfully connected to peer {} at {}",
139 self.peer_id, self.addr
140 );
141 self.stream = Some(stream);
142 Ok(())
143 }
144 Err(e) => {
145 warn!(
146 "Failed to connect to peer {} at {}: {}",
147 self.peer_id, self.addr, e
148 );
149 Err(ControllerError::NetworkError(e.to_string()))
150 }
151 }
152 }
153
154 pub async fn send(&mut self, msg: Message) -> Result<()> {
156 if self.stream.is_none() {
158 self.connect().await?;
159 }
160
161 let stream = self
162 .stream
163 .as_mut()
164 .ok_or_else(|| ControllerError::NetworkError("Not connected".to_string()))?;
165
166 let bytes = MessageCodec::encode(&msg)?;
168
169 stream.write_all(&bytes).await.map_err(|e| {
171 error!("Failed to send message to peer {}: {}", self.peer_id, e);
172 self.stream = None; ControllerError::NetworkError(e.to_string())
174 })?;
175
176 debug!(
177 "Sent message to peer {}, type: {:?}",
178 self.peer_id,
179 msg.get_msg_type()
180 );
181 Ok(())
182 }
183
184 pub async fn receive(&mut self) -> Result<Message> {
186 let stream = self
187 .stream
188 .as_mut()
189 .ok_or_else(|| ControllerError::NetworkError("Not connected".to_string()))?;
190
191 MessageCodec::decode(stream).await
192 }
193
194 pub fn sender(&self) -> mpsc::UnboundedSender<Message> {
196 self.tx.clone()
197 }
198}
199
200pub struct RaftTransport {
202 node_id: u64,
204
205 listen_addr: SocketAddr,
207
208 peers: Arc<RwLock<HashMap<u64, Arc<RwLock<PeerConnection>>>>>,
210
211 message_tx: mpsc::UnboundedSender<Message>,
213
214 incoming_tx: mpsc::UnboundedSender<Message>,
216}
217
218impl RaftTransport {
219 pub fn new(
221 node_id: u64,
222 listen_addr: SocketAddr,
223 peer_addrs: HashMap<u64, SocketAddr>,
224 ) -> (
225 Self,
226 mpsc::UnboundedReceiver<Message>,
227 mpsc::UnboundedReceiver<Message>,
228 ) {
229 let (message_tx, message_rx) = mpsc::unbounded_channel();
230 let (incoming_tx, incoming_rx) = mpsc::unbounded_channel();
231
232 let mut peers = HashMap::new();
233 for (peer_id, addr) in peer_addrs {
234 if peer_id != node_id {
235 let conn = PeerConnection::new(peer_id, addr);
236 peers.insert(peer_id, Arc::new(RwLock::new(conn)));
237 }
238 }
239
240 let transport = Self {
241 node_id,
242 listen_addr,
243 peers: Arc::new(RwLock::new(peers)),
244 message_tx,
245 incoming_tx,
246 };
247
248 (transport, message_rx, incoming_rx)
249 }
250
251 pub async fn start(self: Arc<Self>) -> Result<()> {
253 info!("Starting Raft transport on {}", self.listen_addr);
254
255 let self_clone = self.clone();
257 tokio::spawn(async move {
258 if let Err(e) = self_clone.accept_loop().await {
259 error!("Accept loop error: {}", e);
260 }
261 });
262
263 let self_clone = self.clone();
265 tokio::spawn(async move {
266 if let Err(e) = self_clone.send_loop().await {
267 error!("Send loop error: {}", e);
268 }
269 });
270
271 info!("Raft transport started successfully");
272 Ok(())
273 }
274
275 async fn accept_loop(&self) -> Result<()> {
277 let listener = TcpListener::bind(self.listen_addr)
278 .await
279 .map_err(|e| ControllerError::NetworkError(e.to_string()))?;
280
281 info!("Listening for Raft connections on {}", self.listen_addr);
282
283 loop {
284 match listener.accept().await {
285 Ok((mut stream, addr)) => {
286 debug!("Accepted connection from {}", addr);
287
288 let incoming_tx = self.incoming_tx.clone();
289 tokio::spawn(async move {
290 loop {
291 match MessageCodec::decode(&mut stream).await {
292 Ok(msg) => {
293 debug!(
294 "Received message from {}: {:?}",
295 addr,
296 msg.get_msg_type()
297 );
298 if incoming_tx.send(msg).is_err() {
299 warn!("Failed to forward incoming message");
300 break;
301 }
302 }
303 Err(e) => {
304 error!("Failed to decode message from {}: {}", addr, e);
305 break;
306 }
307 }
308 }
309 });
310 }
311 Err(e) => {
312 error!("Failed to accept connection: {}", e);
313 }
314 }
315 }
316 }
317
318 async fn send_loop(&self) -> Result<()> {
320 Ok(())
323 }
324
325 pub async fn send_to_peer(&self, peer_id: u64, msg: Message) -> Result<()> {
327 debug!("Sending message to peer {}", peer_id);
328
329 let peers = self.peers.read().await;
330 let peer = peers
331 .get(&peer_id)
332 .ok_or_else(|| ControllerError::NetworkError(format!("Unknown peer: {}", peer_id)))?;
333
334 let mut conn = peer.write().await;
335 conn.send(msg).await
336 }
337
338 pub async fn broadcast(&self, msg: Message) -> Result<()> {
340 debug!("Broadcasting message to all peers");
341
342 let peers = self.peers.read().await;
343 for (peer_id, peer) in peers.iter() {
344 let mut conn = peer.write().await;
345 if let Err(e) = conn.send(msg.clone()).await {
346 warn!("Failed to send message to peer {}: {}", peer_id, e);
347 }
348 }
349
350 Ok(())
351 }
352
353 pub fn message_sender(&self) -> mpsc::UnboundedSender<Message> {
355 self.message_tx.clone()
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 #[tokio::test]
364 async fn test_peer_connection_creation() {
365 let addr: SocketAddr = "127.0.0.1:9876".parse().unwrap();
366 let conn = PeerConnection::new(1, addr);
367 assert_eq!(conn.peer_id, 1);
368 assert_eq!(conn.addr, addr);
369 }
370}