rocketmq_controller/raft/
transport.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18use std::collections::HashMap;
19use std::net::SocketAddr;
20use std::sync::Arc;
21
22use bytes::Bytes;
23use bytes::BytesMut;
24use protobuf::Message as ProtobufMessage;
25use raft::eraftpb;
26use raft::prelude::Message;
27use tokio::io::AsyncReadExt;
28use tokio::io::AsyncWriteExt;
29use tokio::net::TcpListener;
30use tokio::net::TcpStream;
31use tokio::sync::mpsc;
32use tokio::sync::RwLock;
33use tracing::debug;
34use tracing::error;
35use tracing::info;
36use tracing::warn;
37
38use crate::error::ControllerError;
39use crate::error::Result;
40
41/// Message codec for Raft messages
42pub struct MessageCodec;
43
44impl MessageCodec {
45    /// Encode a Raft message to bytes using protobuf
46    pub fn encode(msg: &Message) -> Result<Bytes> {
47        // Convert to protobuf message
48        let proto_msg: eraftpb::Message = msg.clone();
49
50        // Encode using protobuf v2
51        let encoded = proto_msg
52            .write_to_bytes()
53            .map_err(|e| ControllerError::SerializationError(e.to_string()))?;
54
55        // Length prefix (4 bytes) + message data
56        let len = encoded.len() as u32;
57        let mut result = BytesMut::with_capacity(4 + encoded.len());
58        result.extend_from_slice(&len.to_be_bytes());
59        result.extend_from_slice(&encoded);
60
61        Ok(result.freeze())
62    }
63
64    /// Decode bytes to a Raft message using protobuf
65    pub async fn decode(stream: &mut TcpStream) -> Result<Message> {
66        // Read length prefix
67        let mut len_buf = [0u8; 4];
68        stream
69            .read_exact(&mut len_buf)
70            .await
71            .map_err(|e| ControllerError::NetworkError(e.to_string()))?;
72
73        let len = u32::from_be_bytes(len_buf) as usize;
74
75        // Validate length
76        if len > 10 * 1024 * 1024 {
77            return Err(ControllerError::InvalidRequest(format!(
78                "Message too large: {} bytes",
79                len
80            )));
81        }
82
83        // Read message data
84        let mut buf = vec![0u8; len];
85        stream
86            .read_exact(&mut buf)
87            .await
88            .map_err(|e| ControllerError::NetworkError(e.to_string()))?;
89
90        // Deserialize using protobuf v2
91        let proto_msg = eraftpb::Message::parse_from_bytes(&buf)
92            .map_err(|e| ControllerError::SerializationError(e.to_string()))?;
93
94        // Convert to raft Message
95        Ok(proto_msg)
96    }
97}
98
99/// Connection to a peer
100pub struct PeerConnection {
101    /// Peer ID
102    peer_id: u64,
103
104    /// Peer address
105    addr: SocketAddr,
106
107    /// TCP stream
108    stream: Option<TcpStream>,
109
110    /// Send queue
111    tx: mpsc::UnboundedSender<Message>,
112
113    /// Receive handler
114    rx: mpsc::UnboundedReceiver<Message>,
115}
116
117impl PeerConnection {
118    /// Create a new peer connection
119    pub fn new(peer_id: u64, addr: SocketAddr) -> Self {
120        let (tx, rx) = mpsc::unbounded_channel();
121
122        Self {
123            peer_id,
124            addr,
125            stream: None,
126            tx,
127            rx,
128        }
129    }
130
131    /// Connect to the peer
132    pub async fn connect(&mut self) -> Result<()> {
133        debug!("Connecting to peer {} at {}", self.peer_id, self.addr);
134
135        match TcpStream::connect(self.addr).await {
136            Ok(stream) => {
137                info!(
138                    "Successfully connected to peer {} at {}",
139                    self.peer_id, self.addr
140                );
141                self.stream = Some(stream);
142                Ok(())
143            }
144            Err(e) => {
145                warn!(
146                    "Failed to connect to peer {} at {}: {}",
147                    self.peer_id, self.addr, e
148                );
149                Err(ControllerError::NetworkError(e.to_string()))
150            }
151        }
152    }
153
154    /// Send a message to the peer
155    pub async fn send(&mut self, msg: Message) -> Result<()> {
156        // Ensure we're connected
157        if self.stream.is_none() {
158            self.connect().await?;
159        }
160
161        let stream = self
162            .stream
163            .as_mut()
164            .ok_or_else(|| ControllerError::NetworkError("Not connected".to_string()))?;
165
166        // Encode message
167        let bytes = MessageCodec::encode(&msg)?;
168
169        // Send
170        stream.write_all(&bytes).await.map_err(|e| {
171            error!("Failed to send message to peer {}: {}", self.peer_id, e);
172            self.stream = None; // Reset connection on error
173            ControllerError::NetworkError(e.to_string())
174        })?;
175
176        debug!(
177            "Sent message to peer {}, type: {:?}",
178            self.peer_id,
179            msg.get_msg_type()
180        );
181        Ok(())
182    }
183
184    /// Receive a message from the peer
185    pub async fn receive(&mut self) -> Result<Message> {
186        let stream = self
187            .stream
188            .as_mut()
189            .ok_or_else(|| ControllerError::NetworkError("Not connected".to_string()))?;
190
191        MessageCodec::decode(stream).await
192    }
193
194    /// Get the sender channel
195    pub fn sender(&self) -> mpsc::UnboundedSender<Message> {
196        self.tx.clone()
197    }
198}
199
200/// Network transport for Raft
201pub struct RaftTransport {
202    /// Node ID
203    node_id: u64,
204
205    /// Listen address
206    listen_addr: SocketAddr,
207
208    /// Peer connections
209    peers: Arc<RwLock<HashMap<u64, Arc<RwLock<PeerConnection>>>>>,
210
211    /// Message receiver from Raft
212    message_tx: mpsc::UnboundedSender<Message>,
213
214    /// Incoming message sender to Raft
215    incoming_tx: mpsc::UnboundedSender<Message>,
216}
217
218impl RaftTransport {
219    /// Create a new transport
220    pub fn new(
221        node_id: u64,
222        listen_addr: SocketAddr,
223        peer_addrs: HashMap<u64, SocketAddr>,
224    ) -> (
225        Self,
226        mpsc::UnboundedReceiver<Message>,
227        mpsc::UnboundedReceiver<Message>,
228    ) {
229        let (message_tx, message_rx) = mpsc::unbounded_channel();
230        let (incoming_tx, incoming_rx) = mpsc::unbounded_channel();
231
232        let mut peers = HashMap::new();
233        for (peer_id, addr) in peer_addrs {
234            if peer_id != node_id {
235                let conn = PeerConnection::new(peer_id, addr);
236                peers.insert(peer_id, Arc::new(RwLock::new(conn)));
237            }
238        }
239
240        let transport = Self {
241            node_id,
242            listen_addr,
243            peers: Arc::new(RwLock::new(peers)),
244            message_tx,
245            incoming_tx,
246        };
247
248        (transport, message_rx, incoming_rx)
249    }
250
251    /// Start the transport
252    pub async fn start(self: Arc<Self>) -> Result<()> {
253        info!("Starting Raft transport on {}", self.listen_addr);
254
255        // Start listening for incoming connections
256        let self_clone = self.clone();
257        tokio::spawn(async move {
258            if let Err(e) = self_clone.accept_loop().await {
259                error!("Accept loop error: {}", e);
260            }
261        });
262
263        // Start message sending loop
264        let self_clone = self.clone();
265        tokio::spawn(async move {
266            if let Err(e) = self_clone.send_loop().await {
267                error!("Send loop error: {}", e);
268            }
269        });
270
271        info!("Raft transport started successfully");
272        Ok(())
273    }
274
275    /// Accept incoming connections
276    async fn accept_loop(&self) -> Result<()> {
277        let listener = TcpListener::bind(self.listen_addr)
278            .await
279            .map_err(|e| ControllerError::NetworkError(e.to_string()))?;
280
281        info!("Listening for Raft connections on {}", self.listen_addr);
282
283        loop {
284            match listener.accept().await {
285                Ok((mut stream, addr)) => {
286                    debug!("Accepted connection from {}", addr);
287
288                    let incoming_tx = self.incoming_tx.clone();
289                    tokio::spawn(async move {
290                        loop {
291                            match MessageCodec::decode(&mut stream).await {
292                                Ok(msg) => {
293                                    debug!(
294                                        "Received message from {}: {:?}",
295                                        addr,
296                                        msg.get_msg_type()
297                                    );
298                                    if incoming_tx.send(msg).is_err() {
299                                        warn!("Failed to forward incoming message");
300                                        break;
301                                    }
302                                }
303                                Err(e) => {
304                                    error!("Failed to decode message from {}: {}", addr, e);
305                                    break;
306                                }
307                            }
308                        }
309                    });
310                }
311                Err(e) => {
312                    error!("Failed to accept connection: {}", e);
313                }
314            }
315        }
316    }
317
318    /// Send messages to peers
319    async fn send_loop(&self) -> Result<()> {
320        // This will be implemented to actually send messages
321        // For now, it's a placeholder
322        Ok(())
323    }
324
325    /// Send a message to a specific peer
326    pub async fn send_to_peer(&self, peer_id: u64, msg: Message) -> Result<()> {
327        debug!("Sending message to peer {}", peer_id);
328
329        let peers = self.peers.read().await;
330        let peer = peers
331            .get(&peer_id)
332            .ok_or_else(|| ControllerError::NetworkError(format!("Unknown peer: {}", peer_id)))?;
333
334        let mut conn = peer.write().await;
335        conn.send(msg).await
336    }
337
338    /// Broadcast a message to all peers
339    pub async fn broadcast(&self, msg: Message) -> Result<()> {
340        debug!("Broadcasting message to all peers");
341
342        let peers = self.peers.read().await;
343        for (peer_id, peer) in peers.iter() {
344            let mut conn = peer.write().await;
345            if let Err(e) = conn.send(msg.clone()).await {
346                warn!("Failed to send message to peer {}: {}", peer_id, e);
347            }
348        }
349
350        Ok(())
351    }
352
353    /// Get the message sender
354    pub fn message_sender(&self) -> mpsc::UnboundedSender<Message> {
355        self.message_tx.clone()
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362
363    #[tokio::test]
364    async fn test_peer_connection_creation() {
365        let addr: SocketAddr = "127.0.0.1:9876".parse().unwrap();
366        let conn = PeerConnection::new(1, addr);
367        assert_eq!(conn.peer_id, 1);
368        assert_eq!(conn.addr, addr);
369    }
370}