nsq_async_rs/
connection.rs

1use std::net::ToSocketAddrs;
2use std::sync::Arc;
3use std::time::Duration;
4
5use backoff::ExponentialBackoffBuilder;
6use log::{error, info, warn};
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8use tokio::net::TcpStream;
9use tokio::sync::Mutex;
10use tokio::time::timeout;
11
12use crate::error::{Error, Result};
13use crate::protocol::{Command, Frame, IdentifyConfig, Message, Protocol, ProtocolError, MAGIC_V2};
14
15/// TCP连接管理器
16#[derive(Debug)]
17pub struct Connection {
18    /// TCP连接
19    stream: Mutex<TcpStream>,
20    /// NSQ服务器地址
21    addr: String,
22    /// 身份配置
23    identify_config: IdentifyConfig,
24    /// 是否已认证
25    auth_secret: Option<String>,
26    read_timeout: Duration,
27    write_timeout: Duration,
28}
29
30impl Connection {
31    /// 重新建立连接
32    pub async fn reconnect(&self) -> Result<()> {
33        let stream = Self::connect_with_retry(
34            &self.addr,
35            Duration::from_secs(5),
36            self.read_timeout,
37            self.write_timeout,
38        )
39        .await?;
40
41        // 替换现有的流
42        let mut current_stream = self.stream.lock().await;
43        *current_stream = stream;
44
45        // 重新初始化连接
46        drop(current_stream); // 释放锁,避免死锁
47        self.initialize().await?;
48
49        Ok(())
50    }
51
52    /// 创建新的连接
53    pub async fn new<A: ToSocketAddrs + std::fmt::Display>(
54        addr: A,
55        identify_config: Option<IdentifyConfig>,
56        auth_secret: Option<String>,
57        read_timeout: Duration,
58        write_timeout: Duration,
59    ) -> Result<Self> {
60        let addr_str = addr.to_string();
61        let stream = Self::connect_with_retry(
62            &addr_str,
63            Duration::from_secs(5),
64            read_timeout,
65            write_timeout,
66        )
67        .await?;
68
69        let connection = Self {
70            stream: Mutex::new(stream),
71            addr: addr_str,
72            identify_config: identify_config.unwrap_or_default(),
73            auth_secret,
74            read_timeout,
75            write_timeout,
76        };
77
78        // 初始化连接
79        connection.initialize().await?;
80
81        Ok(connection)
82    }
83
84    /// 使用重试机制连接到NSQ服务器
85    pub async fn connect_with_retry(
86        addr: &str,
87        timeout_duration: Duration,
88        _read_timeout: Duration,
89        _write_timeout: Duration,
90    ) -> Result<TcpStream> {
91        let backoff = ExponentialBackoffBuilder::new()
92            .with_initial_interval(Duration::from_millis(100))
93            .with_max_interval(Duration::from_secs(1))
94            .with_multiplier(2.0)
95            .with_max_elapsed_time(Some(timeout_duration))
96            .build();
97
98        let addr_clone = addr.to_string();
99        let result = backoff::future::retry_notify(
100            backoff,
101            || async {
102                match TcpStream::connect(&addr_clone).await {
103                    Ok(stream) => Ok(stream),
104                    Err(e) => Err(backoff::Error::transient(Error::Io(e))),
105                }
106            },
107            |err, duration| {
108                warn!(
109                    "连接到 {} 失败: {:?}, 将在 {:?} 后重试",
110                    addr_clone, err, duration
111                );
112            },
113        )
114        .await;
115
116        match result {
117            Ok(stream) => {
118                // info!("成功连接到 NSQ 服务器: {}", addr);
119                Ok(stream)
120            }
121            Err(e) => Err(Error::Connection(format!("无法连接到 {}: {:?}", addr, e))),
122        }
123    }
124
125    /// 初始化连接
126    async fn initialize(&self) -> Result<()> {
127        let mut stream = self.stream.lock().await;
128
129        // 发送魔术字
130        stream.write_all(MAGIC_V2).await?;
131
132        // 发送识别信息
133        let identify_cmd = Command::Identify(self.identify_config.clone());
134        let identify_bytes = identify_cmd.to_bytes()?;
135        stream.write_all(&identify_bytes).await?;
136        stream.flush().await?;
137
138        // 读取和处理响应
139        let mut buf = [0u8; 4];
140        stream.read_exact(&mut buf).await?;
141        let size = u32::from_be_bytes(buf);
142
143        if size == 0 {
144            return Err(Error::Protocol(ProtocolError::InvalidFrameSize));
145        }
146
147        // 读取帧类型
148        stream.read_exact(&mut buf).await?;
149        let frame_type = u32::from_be_bytes(buf);
150
151        if frame_type != 0 {
152            return Err(Error::Protocol(ProtocolError::InvalidFrameType(
153                frame_type as i32,
154            )));
155        }
156
157        // 读取响应体
158        let mut response_data = vec![0u8; (size - 4) as usize];
159        stream.read_exact(&mut response_data).await?;
160
161        // 如果需要认证,发送认证命令
162        if let Some(ref secret) = self.auth_secret {
163            let auth_cmd = Command::Auth(Some(secret.clone()));
164            let auth_bytes = auth_cmd.to_bytes()?;
165            stream.write_all(&auth_bytes).await?;
166            stream.flush().await?;
167
168            // 读取认证响应 (简化版)
169            let mut buf = [0u8; 4];
170            stream.read_exact(&mut buf).await?;
171            let size = u32::from_be_bytes(buf);
172
173            if size == 0 {
174                return Err(Error::Auth("认证响应大小为0".to_string()));
175            }
176
177            // 读取帧类型
178            stream.read_exact(&mut buf).await?;
179            let frame_type = u32::from_be_bytes(buf);
180
181            if frame_type != 0 {
182                return Err(Error::Auth(format!("认证失败,帧类型 {}", frame_type)));
183            }
184
185            // 读取响应体
186            let mut response_data = vec![0u8; (size - 4) as usize];
187            stream.read_exact(&mut response_data).await?;
188        }
189
190        Ok(())
191    }
192
193    /// 发送命令到NSQ服务器
194    pub async fn send_command(&self, command: Command) -> Result<()> {
195        let mut stream = self.stream.lock().await;
196        let bytes = command.to_bytes()?;
197        stream.write_all(&bytes).await?;
198        stream.flush().await?;
199        Ok(())
200    }
201
202    /// 读取下一个NSQ帧
203    pub async fn read_frame(&self) -> Result<Frame> {
204        let mut stream_guard = self.stream.lock().await;
205
206        // 读取帧大小 (4字节)
207        let mut size_buf = [0u8; 4];
208        timeout(self.read_timeout, stream_guard.read_exact(&mut size_buf)).await??;
209        let size = u32::from_be_bytes(size_buf);
210
211        if size < 4 {
212            return Err(Error::Protocol(ProtocolError::InvalidFrameSize));
213        }
214
215        // 读取帧类型 (4字节)
216        let mut frame_type_buf = [0u8; 4];
217        timeout(
218            self.read_timeout,
219            stream_guard.read_exact(&mut frame_type_buf),
220        )
221        .await??;
222        let frame_type = i32::from_be_bytes(frame_type_buf);
223
224        // 根据 NSQ 协议,帧类型应该是以下之一:
225        // FrameTypeResponse = 0
226        // FrameTypeError = 1
227        // FrameTypeMessage = 2
228        match frame_type {
229            0..=2 => {
230                // 读取帧数据
231                let data_size = size - 4; // 减去帧类型的4字节
232                let mut data = vec![0u8; data_size as usize];
233                timeout(self.read_timeout, stream_guard.read_exact(&mut data)).await??;
234
235                // 构造完整的帧数据(包括帧类型)
236                let mut frame_data = Vec::with_capacity(size as usize);
237                frame_data.extend_from_slice(&frame_type_buf);
238                frame_data.extend_from_slice(&data);
239
240                Protocol::decode_frame(&frame_data).map_err(Error::from)
241            }
242            _ => Err(Error::Protocol(ProtocolError::InvalidFrameType(frame_type))),
243        }
244    }
245
246    /// 处理心跳帧
247    pub async fn handle_heartbeat(&self) -> Result<()> {
248        self.send_command(Command::Nop).await
249    }
250
251    /// 发送 ping 命令并等待响应,用于检测连接是否活跃
252    ///
253    /// 使用 NOP 命令实现,并添加超时机制
254    ///
255    /// # 参数
256    /// * `timeout_duration` - 超时时间,默认为 5 秒
257    ///
258    /// # 返回
259    /// * `Ok(())` - 如果连接正常
260    /// * `Err(Error)` - 如果连接异常或超时
261    pub async fn ping(&self, timeout_duration: Option<Duration>) -> Result<()> {
262        let timeout_dur = timeout_duration.unwrap_or(Duration::from_secs(5));
263
264        match timeout(timeout_dur, self.send_command(Command::Nop)).await {
265            Ok(result) => result,
266            Err(_) => Err(Error::Timeout(format!(
267                "Ping 操作超时 ({}秒)",
268                timeout_dur.as_secs()
269            ))),
270        }
271    }
272
273    /// 读取消息 - 参考Go客户端中的readLoop实现
274    pub async fn read_message(&self) -> Result<Option<Message>> {
275        match self.read_frame().await {
276            Ok(Frame::Message(msg)) => {
277                info!(
278                    "收到消息 [ID: {:?}, 尝试次数: {}, 时间戳: {}]",
279                    &msg.id, msg.attempts, msg.timestamp
280                );
281                Ok(Some(msg))
282            }
283            Ok(Frame::Response(_)) => Ok(None),
284            Ok(Frame::Error(data)) => {
285                error!("NSQ错误响应: {:?}", String::from_utf8_lossy(&data));
286                Ok(None)
287            }
288            Err(e) => {
289                error!("读取消息错误: {:?}", e);
290                Err(e)
291            }
292        }
293    }
294
295    /// 获取连接的地址
296    pub fn addr(&self) -> &str {
297        &self.addr
298    }
299
300    pub async fn write_all(&self, buf: &[u8]) -> Result<()> {
301        let mut stream = self.stream.lock().await;
302        timeout(self.write_timeout, stream.write_all(buf)).await??;
303        Ok(())
304    }
305
306    pub async fn read_exact(&self, buf: &mut [u8]) -> Result<()> {
307        let mut stream = self.stream.lock().await;
308        timeout(self.read_timeout, stream.read_exact(buf)).await??;
309        Ok(())
310    }
311
312    pub async fn write_command(
313        &self,
314        name: &str,
315        body: Option<&[u8]>,
316        params: &[&str],
317    ) -> Result<()> {
318        let cmd = Protocol::encode_command(name, body, params);
319        self.write_all(&cmd).await
320    }
321
322    pub async fn close(&self) -> Result<()> {
323        self.stream
324            .lock()
325            .await
326            .shutdown()
327            .await
328            .map_err(Error::from)
329    }
330}
331
332impl Drop for Connection {
333    fn drop(&mut self) {
334        // 在析构函数中发送CLS命令是不可能的,因为它需要异步上下文
335        // 实际应用中应确保在丢弃连接前调用显式的关闭方法
336    }
337}
338
339/// 在异步上下文中安全关闭连接
340pub async fn close_connection(connection: &Arc<Connection>) -> Result<()> {
341    connection.send_command(Command::Cls).await
342}