nsq_async_rs/
connection.rs1use std::net::ToSocketAddrs;
2use std::sync::Arc;
3use std::time::Duration;
4
5use backoff::ExponentialBackoffBuilder;
6use log::{error, info, warn};
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8use tokio::net::TcpStream;
9use tokio::sync::Mutex;
10use tokio::time::timeout;
11
12use crate::error::{Error, Result};
13use crate::protocol::{Command, Frame, IdentifyConfig, Message, Protocol, ProtocolError, MAGIC_V2};
14
15#[derive(Debug)]
17pub struct Connection {
18 stream: Mutex<TcpStream>,
20 addr: String,
22 identify_config: IdentifyConfig,
24 auth_secret: Option<String>,
26 read_timeout: Duration,
27 write_timeout: Duration,
28}
29
30impl Connection {
31 pub async fn reconnect(&self) -> Result<()> {
33 let stream = Self::connect_with_retry(
34 &self.addr,
35 Duration::from_secs(5),
36 self.read_timeout,
37 self.write_timeout,
38 )
39 .await?;
40
41 let mut current_stream = self.stream.lock().await;
43 *current_stream = stream;
44
45 drop(current_stream); self.initialize().await?;
48
49 Ok(())
50 }
51
52 pub async fn new<A: ToSocketAddrs + std::fmt::Display>(
54 addr: A,
55 identify_config: Option<IdentifyConfig>,
56 auth_secret: Option<String>,
57 read_timeout: Duration,
58 write_timeout: Duration,
59 ) -> Result<Self> {
60 let addr_str = addr.to_string();
61 let stream = Self::connect_with_retry(
62 &addr_str,
63 Duration::from_secs(5),
64 read_timeout,
65 write_timeout,
66 )
67 .await?;
68
69 let connection = Self {
70 stream: Mutex::new(stream),
71 addr: addr_str,
72 identify_config: identify_config.unwrap_or_default(),
73 auth_secret,
74 read_timeout,
75 write_timeout,
76 };
77
78 connection.initialize().await?;
80
81 Ok(connection)
82 }
83
84 pub async fn connect_with_retry(
86 addr: &str,
87 timeout_duration: Duration,
88 _read_timeout: Duration,
89 _write_timeout: Duration,
90 ) -> Result<TcpStream> {
91 let backoff = ExponentialBackoffBuilder::new()
92 .with_initial_interval(Duration::from_millis(100))
93 .with_max_interval(Duration::from_secs(1))
94 .with_multiplier(2.0)
95 .with_max_elapsed_time(Some(timeout_duration))
96 .build();
97
98 let addr_clone = addr.to_string();
99 let result = backoff::future::retry_notify(
100 backoff,
101 || async {
102 match TcpStream::connect(&addr_clone).await {
103 Ok(stream) => Ok(stream),
104 Err(e) => Err(backoff::Error::transient(Error::Io(e))),
105 }
106 },
107 |err, duration| {
108 warn!(
109 "连接到 {} 失败: {:?}, 将在 {:?} 后重试",
110 addr_clone, err, duration
111 );
112 },
113 )
114 .await;
115
116 match result {
117 Ok(stream) => {
118 Ok(stream)
120 }
121 Err(e) => Err(Error::Connection(format!("无法连接到 {}: {:?}", addr, e))),
122 }
123 }
124
125 async fn initialize(&self) -> Result<()> {
127 let mut stream = self.stream.lock().await;
128
129 stream.write_all(MAGIC_V2).await?;
131
132 let identify_cmd = Command::Identify(self.identify_config.clone());
134 let identify_bytes = identify_cmd.to_bytes()?;
135 stream.write_all(&identify_bytes).await?;
136 stream.flush().await?;
137
138 let mut buf = [0u8; 4];
140 stream.read_exact(&mut buf).await?;
141 let size = u32::from_be_bytes(buf);
142
143 if size == 0 {
144 return Err(Error::Protocol(ProtocolError::InvalidFrameSize));
145 }
146
147 stream.read_exact(&mut buf).await?;
149 let frame_type = u32::from_be_bytes(buf);
150
151 if frame_type != 0 {
152 return Err(Error::Protocol(ProtocolError::InvalidFrameType(
153 frame_type as i32,
154 )));
155 }
156
157 let mut response_data = vec![0u8; (size - 4) as usize];
159 stream.read_exact(&mut response_data).await?;
160
161 if let Some(ref secret) = self.auth_secret {
163 let auth_cmd = Command::Auth(Some(secret.clone()));
164 let auth_bytes = auth_cmd.to_bytes()?;
165 stream.write_all(&auth_bytes).await?;
166 stream.flush().await?;
167
168 let mut buf = [0u8; 4];
170 stream.read_exact(&mut buf).await?;
171 let size = u32::from_be_bytes(buf);
172
173 if size == 0 {
174 return Err(Error::Auth("认证响应大小为0".to_string()));
175 }
176
177 stream.read_exact(&mut buf).await?;
179 let frame_type = u32::from_be_bytes(buf);
180
181 if frame_type != 0 {
182 return Err(Error::Auth(format!("认证失败,帧类型 {}", frame_type)));
183 }
184
185 let mut response_data = vec![0u8; (size - 4) as usize];
187 stream.read_exact(&mut response_data).await?;
188 }
189
190 Ok(())
191 }
192
193 pub async fn send_command(&self, command: Command) -> Result<()> {
195 let mut stream = self.stream.lock().await;
196 let bytes = command.to_bytes()?;
197 stream.write_all(&bytes).await?;
198 stream.flush().await?;
199 Ok(())
200 }
201
202 pub async fn read_frame(&self) -> Result<Frame> {
204 let mut stream_guard = self.stream.lock().await;
205
206 let mut size_buf = [0u8; 4];
208 timeout(self.read_timeout, stream_guard.read_exact(&mut size_buf)).await??;
209 let size = u32::from_be_bytes(size_buf);
210
211 if size < 4 {
212 return Err(Error::Protocol(ProtocolError::InvalidFrameSize));
213 }
214
215 let mut frame_type_buf = [0u8; 4];
217 timeout(
218 self.read_timeout,
219 stream_guard.read_exact(&mut frame_type_buf),
220 )
221 .await??;
222 let frame_type = i32::from_be_bytes(frame_type_buf);
223
224 match frame_type {
229 0..=2 => {
230 let data_size = size - 4; let mut data = vec![0u8; data_size as usize];
233 timeout(self.read_timeout, stream_guard.read_exact(&mut data)).await??;
234
235 let mut frame_data = Vec::with_capacity(size as usize);
237 frame_data.extend_from_slice(&frame_type_buf);
238 frame_data.extend_from_slice(&data);
239
240 Protocol::decode_frame(&frame_data).map_err(Error::from)
241 }
242 _ => Err(Error::Protocol(ProtocolError::InvalidFrameType(frame_type))),
243 }
244 }
245
246 pub async fn handle_heartbeat(&self) -> Result<()> {
248 self.send_command(Command::Nop).await
249 }
250
251 pub async fn ping(&self, timeout_duration: Option<Duration>) -> Result<()> {
262 let timeout_dur = timeout_duration.unwrap_or(Duration::from_secs(5));
263
264 match timeout(timeout_dur, self.send_command(Command::Nop)).await {
265 Ok(result) => result,
266 Err(_) => Err(Error::Timeout(format!(
267 "Ping 操作超时 ({}秒)",
268 timeout_dur.as_secs()
269 ))),
270 }
271 }
272
273 pub async fn read_message(&self) -> Result<Option<Message>> {
275 match self.read_frame().await {
276 Ok(Frame::Message(msg)) => {
277 info!(
278 "收到消息 [ID: {:?}, 尝试次数: {}, 时间戳: {}]",
279 &msg.id, msg.attempts, msg.timestamp
280 );
281 Ok(Some(msg))
282 }
283 Ok(Frame::Response(_)) => Ok(None),
284 Ok(Frame::Error(data)) => {
285 error!("NSQ错误响应: {:?}", String::from_utf8_lossy(&data));
286 Ok(None)
287 }
288 Err(e) => {
289 error!("读取消息错误: {:?}", e);
290 Err(e)
291 }
292 }
293 }
294
295 pub fn addr(&self) -> &str {
297 &self.addr
298 }
299
300 pub async fn write_all(&self, buf: &[u8]) -> Result<()> {
301 let mut stream = self.stream.lock().await;
302 timeout(self.write_timeout, stream.write_all(buf)).await??;
303 Ok(())
304 }
305
306 pub async fn read_exact(&self, buf: &mut [u8]) -> Result<()> {
307 let mut stream = self.stream.lock().await;
308 timeout(self.read_timeout, stream.read_exact(buf)).await??;
309 Ok(())
310 }
311
312 pub async fn write_command(
313 &self,
314 name: &str,
315 body: Option<&[u8]>,
316 params: &[&str],
317 ) -> Result<()> {
318 let cmd = Protocol::encode_command(name, body, params);
319 self.write_all(&cmd).await
320 }
321
322 pub async fn close(&self) -> Result<()> {
323 self.stream
324 .lock()
325 .await
326 .shutdown()
327 .await
328 .map_err(Error::from)
329 }
330}
331
332impl Drop for Connection {
333 fn drop(&mut self) {
334 }
337}
338
339pub async fn close_connection(connection: &Arc<Connection>) -> Result<()> {
341 connection.send_command(Command::Cls).await
342}