rocketmq_client/
connection.rs1use crate::error::{self, ClientError};
7use crate::frame::{self, Frame};
8use bytes::{self, Buf, BytesMut};
9use std::collections::HashMap;
10use std::io::Cursor;
11use std::net::SocketAddr;
12use std::sync::{Arc, Mutex};
13use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter};
14use tokio::net::TcpStream;
15
16pub struct Connection {
17 stream: BufWriter<TcpStream>,
18 buffer: BytesMut,
19}
20
21impl Connection {
22 pub async fn new(endpoint: &SocketAddr) -> Result<Self, error::ClientError> {
42 let tcp_stream = TcpStream::connect(endpoint)
43 .await
44 .map_err(|e| error::ClientError::ConnectTimeout(e))?;
45
46 Ok(Connection {
47 stream: BufWriter::new(tcp_stream),
48 buffer: BytesMut::with_capacity(1024 * 1024),
49 })
50 }
51
52 pub async fn read_frame(&mut self) -> Result<Option<frame::Frame>, ClientError> {
53 loop {
54 if let Some(frame) = self.parse_frame()? {
55 return Ok(Some(frame));
56 }
57
58 if 0 == self.stream.read_buf(&mut self.buffer).await? {
59 if self.buffer.is_empty() {
60 return Ok(None);
61 } else {
62 return Err(ClientError::ConnectionReset);
63 }
64 }
65 }
66 }
67
68 pub async fn write_frame(&mut self, frame: &Frame) -> Result<(), ClientError> {
69 if let Some(buf) = frame.encode()? {
70 self.stream.write_all(&buf.slice(..)).await?;
71 self.stream.flush().await?;
72 }
73 Ok(())
74 }
75
76 fn parse_frame(&mut self) -> Result<Option<frame::Frame>, ClientError> {
77 let mut buf = Cursor::new(&self.buffer[..]);
78 match Frame::check(&mut buf) {
79 Ok(_) => {
80 let len = buf.position() as usize;
81 buf.set_position(0);
82 let frame = Frame::parse(&mut buf)?;
83 self.buffer.advance(len);
84 return Ok(frame);
85 }
86
87 Err(frame::Error::Incomplete) => {
88 return Ok(None);
89 }
90
91 Err(frame::Error::Other(e)) => {
92 return Err(e);
93 }
94 }
95 }
96}
97
98pub(crate) struct ConnectionManager {
99 connections: Arc<Mutex<HashMap<String, Connection>>>,
100}
101
102impl ConnectionManager {
103 pub(crate) fn new() -> Self {
104 Self {
105 connections: Arc::new(Mutex::new(HashMap::new())),
106 }
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use crate::protocol::{SendMessageRequestHeader, TopicRouteData};
113
114 use super::*;
115 use std::net::SocketAddr;
116
117 #[tokio::test]
118 async fn test_connection_new() -> Result<(), error::ClientError> {
119 let addr = "127.0.0.1:9876";
120 let endpoint: SocketAddr = addr
121 .parse()
122 .map_err(|_e| error::ClientError::BadAddress(addr.to_string()))?;
123 let _connection = Connection::new(&endpoint).await?;
124 Ok(())
125 }
126
127 #[tokio::test]
128 async fn test_read_write_frame() -> Result<(), ClientError> {
129 let mut frame = Frame::new();
130 frame.code = frame::RequestCode::GetRouteInfoByTopic as i32;
131 frame.language = crate::frame::Language::CPP;
132 frame.put_ext_field("topic", "T1");
133 let addr = "127.0.0.1:9876";
134 let endpoint: SocketAddr = addr
135 .parse()
136 .map_err(|_e| error::ClientError::BadAddress(addr.to_string()))?;
137 let mut connection = Connection::new(&endpoint).await?;
138 connection.write_frame(&frame).await?;
139 if let Some(response) = connection.read_frame().await? {
140 assert_eq!(response.frame_type(), frame::Type::Response);
141 if 0 == response.code {
142 let body = response.body();
143 let topic_route_data: TopicRouteData = serde_json::from_reader(body.reader())
144 .map_err(|_e| {
145 return crate::error::ClientError::InvalidFrame(
146 "Response body is invalid JSON".to_owned(),
147 );
148 })?;
149 topic_route_data.broker_datas.iter().for_each(|item| {
150 println!("{:#?}", item);
151 });
152 topic_route_data.queue_datas.iter().for_each(|item| {
153 println!("{:#?}", item);
154 });
155 }
156 println!("Remark: {}", response.remark());
157 }
158
159 Ok(())
160 }
161
162 #[tokio::test]
163 async fn test_send_message() -> Result<(), Box<dyn std::error::Error>> {
164 let mut frame = Frame::new();
165 frame.code = frame::RequestCode::SendMessage as i32;
166 frame.language = crate::frame::Language::CPP;
167 let send_message_header = SendMessageRequestHeader {
168 producer_group: String::from("Default"),
169 topic: String::from("T1"),
170 default_topic: String::from("TBW102"),
171 default_topic_queue_nums: 8,
172 queue_id: 0,
173 sys_flag: 0,
174 born_timestamp: std::time::SystemTime::now().elapsed().unwrap().as_millis() as i64,
175 flag: 0,
176 properties: None,
177 reconsume_times: None,
178 unit_mode: None,
179 batch: Some(false),
180 max_reconsume_times: None,
181 };
182 frame.add_ext_headers(send_message_header);
183 frame.body = bytes::Bytes::from("Test Body");
184 let addr = "127.0.0.1:10911";
185 let endpoint: SocketAddr = addr
186 .parse()
187 .map_err(|_e| error::ClientError::BadAddress(addr.to_string()))?;
188 let mut connection = Connection::new(&endpoint).await?;
189 connection.write_frame(&frame).await?;
190 if let Some(response) = connection.read_frame().await? {
191 assert_eq!(response.frame_type(), frame::Type::Response);
192 response.ext_fields.iter().for_each(|(k, v)| {
193 println!("{} ==> {}", k, v);
194 });
195 }
196 Ok(())
197 }
198}