rocketmq_client/
connection.rs

1//!
2//! This module defines connection related structs.
3//!
4//!
5
6use crate::error::{self, ClientError};
7use crate::frame::{self, Frame};
8use bytes::{self, Buf, BytesMut};
9use std::collections::HashMap;
10use std::io::Cursor;
11use std::net::SocketAddr;
12use std::sync::{Arc, Mutex};
13use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter};
14use tokio::net::TcpStream;
15
16pub struct Connection {
17    stream: BufWriter<TcpStream>,
18    buffer: BytesMut,
19}
20
21impl Connection {
22    /// Establish a connection to the given socket address.
23    ///
24    /// # Examples
25    ///
26    /// ```no_run
27    /// use rocketmq_client::connection::Connection;
28    /// use std::net::SocketAddr;
29    ///
30    /// #[tokio::main]
31    /// fn main() {
32    ///    let endpoint = "127.0.0.1:80";
33    ///    let socket_addr = endpoint.parse::<std::net::SocketAddr>().unwrap();
34    ///    let connection = rocketmq_client::connection::Connection::new(&socket_addr).await.unwrap();
35    /// }
36    ///
37    /// ```
38    ///
39    /// # Errors
40    /// Raise ClientError::ConnectTimeout if connection may not be established within reasonable amount of time.
41    pub async fn new(endpoint: &SocketAddr) -> Result<Self, error::ClientError> {
42        let tcp_stream = TcpStream::connect(endpoint)
43            .await
44            .map_err(|e| error::ClientError::ConnectTimeout(e))?;
45
46        Ok(Connection {
47            stream: BufWriter::new(tcp_stream),
48            buffer: BytesMut::with_capacity(1024 * 1024),
49        })
50    }
51
52    pub async fn read_frame(&mut self) -> Result<Option<frame::Frame>, ClientError> {
53        loop {
54            if let Some(frame) = self.parse_frame()? {
55                return Ok(Some(frame));
56            }
57
58            if 0 == self.stream.read_buf(&mut self.buffer).await? {
59                if self.buffer.is_empty() {
60                    return Ok(None);
61                } else {
62                    return Err(ClientError::ConnectionReset);
63                }
64            }
65        }
66    }
67
68    pub async fn write_frame(&mut self, frame: &Frame) -> Result<(), ClientError> {
69        if let Some(buf) = frame.encode()? {
70            self.stream.write_all(&buf.slice(..)).await?;
71            self.stream.flush().await?;
72        }
73        Ok(())
74    }
75
76    fn parse_frame(&mut self) -> Result<Option<frame::Frame>, ClientError> {
77        let mut buf = Cursor::new(&self.buffer[..]);
78        match Frame::check(&mut buf) {
79            Ok(_) => {
80                let len = buf.position() as usize;
81                buf.set_position(0);
82                let frame = Frame::parse(&mut buf)?;
83                self.buffer.advance(len);
84                return Ok(frame);
85            }
86
87            Err(frame::Error::Incomplete) => {
88                return Ok(None);
89            }
90
91            Err(frame::Error::Other(e)) => {
92                return Err(e);
93            }
94        }
95    }
96}
97
98pub(crate) struct ConnectionManager {
99    connections: Arc<Mutex<HashMap<String, Connection>>>,
100}
101
102impl ConnectionManager {
103    pub(crate) fn new() -> Self {
104        Self {
105            connections: Arc::new(Mutex::new(HashMap::new())),
106        }
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use crate::protocol::{SendMessageRequestHeader, TopicRouteData};
113
114    use super::*;
115    use std::net::SocketAddr;
116
117    #[tokio::test]
118    async fn test_connection_new() -> Result<(), error::ClientError> {
119        let addr = "127.0.0.1:9876";
120        let endpoint: SocketAddr = addr
121            .parse()
122            .map_err(|_e| error::ClientError::BadAddress(addr.to_string()))?;
123        let _connection = Connection::new(&endpoint).await?;
124        Ok(())
125    }
126
127    #[tokio::test]
128    async fn test_read_write_frame() -> Result<(), ClientError> {
129        let mut frame = Frame::new();
130        frame.code = frame::RequestCode::GetRouteInfoByTopic as i32;
131        frame.language = crate::frame::Language::CPP;
132        frame.put_ext_field("topic", "T1");
133        let addr = "127.0.0.1:9876";
134        let endpoint: SocketAddr = addr
135            .parse()
136            .map_err(|_e| error::ClientError::BadAddress(addr.to_string()))?;
137        let mut connection = Connection::new(&endpoint).await?;
138        connection.write_frame(&frame).await?;
139        if let Some(response) = connection.read_frame().await? {
140            assert_eq!(response.frame_type(), frame::Type::Response);
141            if 0 == response.code {
142                let body = response.body();
143                let topic_route_data: TopicRouteData = serde_json::from_reader(body.reader())
144                    .map_err(|_e| {
145                        return crate::error::ClientError::InvalidFrame(
146                            "Response body is invalid JSON".to_owned(),
147                        );
148                    })?;
149                topic_route_data.broker_datas.iter().for_each(|item| {
150                    println!("{:#?}", item);
151                });
152                topic_route_data.queue_datas.iter().for_each(|item| {
153                    println!("{:#?}", item);
154                });
155            }
156            println!("Remark: {}", response.remark());
157        }
158
159        Ok(())
160    }
161
162    #[tokio::test]
163    async fn test_send_message() -> Result<(), Box<dyn std::error::Error>> {
164        let mut frame = Frame::new();
165        frame.code = frame::RequestCode::SendMessage as i32;
166        frame.language = crate::frame::Language::CPP;
167        let send_message_header = SendMessageRequestHeader {
168            producer_group: String::from("Default"),
169            topic: String::from("T1"),
170            default_topic: String::from("TBW102"),
171            default_topic_queue_nums: 8,
172            queue_id: 0,
173            sys_flag: 0,
174            born_timestamp: std::time::SystemTime::now().elapsed().unwrap().as_millis() as i64,
175            flag: 0,
176            properties: None,
177            reconsume_times: None,
178            unit_mode: None,
179            batch: Some(false),
180            max_reconsume_times: None,
181        };
182        frame.add_ext_headers(send_message_header);
183        frame.body = bytes::Bytes::from("Test Body");
184        let addr = "127.0.0.1:10911";
185        let endpoint: SocketAddr = addr
186            .parse()
187            .map_err(|_e| error::ClientError::BadAddress(addr.to_string()))?;
188        let mut connection = Connection::new(&endpoint).await?;
189        connection.write_frame(&frame).await?;
190        if let Some(response) = connection.read_frame().await? {
191            assert_eq!(response.frame_type(), frame::Type::Response);
192            response.ext_fields.iter().for_each(|(k, v)| {
193                println!("{} ==> {}", k, v);
194            });
195        }
196        Ok(())
197    }
198}