nsq_async_rs/
connection.rs1use std::net::ToSocketAddrs;
2use std::sync::Arc;
3use std::time::Duration;
4
5use backoff::ExponentialBackoffBuilder;
6use log::{error, info, warn};
7use tokio::io::{AsyncReadExt, AsyncWriteExt};
8use tokio::net::TcpStream;
9use tokio::sync::Mutex;
10use tokio::time::timeout;
11
12use crate::error::{Error, Result};
13use crate::protocol::{Command, Frame, IdentifyConfig, Message, Protocol, ProtocolError, MAGIC_V2};
14
15#[derive(Debug)]
17pub struct Connection {
18    stream: Mutex<TcpStream>,
20    addr: String,
22    identify_config: IdentifyConfig,
24    auth_secret: Option<String>,
26    read_timeout: Duration,
27    write_timeout: Duration,
28}
29
30impl Connection {
31    pub async fn reconnect(&self) -> Result<()> {
33        let stream = Self::connect_with_retry(
34            &self.addr,
35            Duration::from_secs(5),
36            self.read_timeout,
37            self.write_timeout,
38        )
39        .await?;
40
41        let mut current_stream = self.stream.lock().await;
43        *current_stream = stream;
44
45        drop(current_stream); self.initialize().await?;
48
49        Ok(())
50    }
51
52    pub async fn new<A: ToSocketAddrs + std::fmt::Display>(
54        addr: A,
55        identify_config: Option<IdentifyConfig>,
56        auth_secret: Option<String>,
57        read_timeout: Duration,
58        write_timeout: Duration,
59    ) -> Result<Self> {
60        let addr_str = addr.to_string();
61        let stream = Self::connect_with_retry(
62            &addr_str,
63            Duration::from_secs(5),
64            read_timeout,
65            write_timeout,
66        )
67        .await?;
68
69        let connection = Self {
70            stream: Mutex::new(stream),
71            addr: addr_str,
72            identify_config: identify_config.unwrap_or_default(),
73            auth_secret,
74            read_timeout,
75            write_timeout,
76        };
77
78        connection.initialize().await?;
80
81        Ok(connection)
82    }
83
84    pub async fn connect_with_retry(
86        addr: &str,
87        timeout_duration: Duration,
88        _read_timeout: Duration,
89        _write_timeout: Duration,
90    ) -> Result<TcpStream> {
91        let backoff = ExponentialBackoffBuilder::new()
92            .with_initial_interval(Duration::from_millis(100))
93            .with_max_interval(Duration::from_secs(1))
94            .with_multiplier(2.0)
95            .with_max_elapsed_time(Some(timeout_duration))
96            .build();
97
98        let addr_clone = addr.to_string();
99        let result = backoff::future::retry_notify(
100            backoff,
101            || async {
102                match TcpStream::connect(&addr_clone).await {
103                    Ok(stream) => Ok(stream),
104                    Err(e) => Err(backoff::Error::transient(Error::Io(e))),
105                }
106            },
107            |err, duration| {
108                warn!(
109                    "连接到 {} 失败: {:?}, 将在 {:?} 后重试",
110                    addr_clone, err, duration
111                );
112            },
113        )
114        .await;
115
116        match result {
117            Ok(stream) => {
118                Ok(stream)
120            }
121            Err(e) => Err(Error::Connection(format!("无法连接到 {}: {:?}", addr, e))),
122        }
123    }
124
125    async fn initialize(&self) -> Result<()> {
127        let mut stream = self.stream.lock().await;
128
129        stream.write_all(MAGIC_V2).await?;
131
132        let identify_cmd = Command::Identify(self.identify_config.clone());
134        let identify_bytes = identify_cmd.to_bytes()?;
135        stream.write_all(&identify_bytes).await?;
136        stream.flush().await?;
137
138        let mut buf = [0u8; 4];
140        stream.read_exact(&mut buf).await?;
141        let size = u32::from_be_bytes(buf);
142
143        if size == 0 {
144            return Err(Error::Protocol(ProtocolError::InvalidFrameSize));
145        }
146
147        stream.read_exact(&mut buf).await?;
149        let frame_type = u32::from_be_bytes(buf);
150
151        if frame_type != 0 {
152            return Err(Error::Protocol(ProtocolError::InvalidFrameType(
153                frame_type as i32,
154            )));
155        }
156
157        let mut response_data = vec![0u8; (size - 4) as usize];
159        stream.read_exact(&mut response_data).await?;
160
161        if let Some(ref secret) = self.auth_secret {
163            let auth_cmd = Command::Auth(Some(secret.clone()));
164            let auth_bytes = auth_cmd.to_bytes()?;
165            stream.write_all(&auth_bytes).await?;
166            stream.flush().await?;
167
168            let mut buf = [0u8; 4];
170            stream.read_exact(&mut buf).await?;
171            let size = u32::from_be_bytes(buf);
172
173            if size == 0 {
174                return Err(Error::Auth("认证响应大小为0".to_string()));
175            }
176
177            stream.read_exact(&mut buf).await?;
179            let frame_type = u32::from_be_bytes(buf);
180
181            if frame_type != 0 {
182                return Err(Error::Auth(format!("认证失败,帧类型 {}", frame_type)));
183            }
184
185            let mut response_data = vec![0u8; (size - 4) as usize];
187            stream.read_exact(&mut response_data).await?;
188        }
189
190        Ok(())
191    }
192
193    pub async fn send_command(&self, command: Command) -> Result<()> {
195        let mut stream = self.stream.lock().await;
196        let bytes = command.to_bytes()?;
197        stream.write_all(&bytes).await?;
198        stream.flush().await?;
199        Ok(())
200    }
201
202    pub async fn read_frame(&self) -> Result<Frame> {
204        let mut stream_guard = self.stream.lock().await;
205
206        let mut size_buf = [0u8; 4];
208        timeout(self.read_timeout, stream_guard.read_exact(&mut size_buf)).await??;
209        let size = u32::from_be_bytes(size_buf);
210
211        if size < 4 {
212            return Err(Error::Protocol(ProtocolError::InvalidFrameSize));
213        }
214
215        let mut frame_type_buf = [0u8; 4];
217        timeout(
218            self.read_timeout,
219            stream_guard.read_exact(&mut frame_type_buf),
220        )
221        .await??;
222        let frame_type = i32::from_be_bytes(frame_type_buf);
223
224        match frame_type {
229            0..=2 => {
230                let data_size = size - 4; let mut data = vec![0u8; data_size as usize];
233                timeout(self.read_timeout, stream_guard.read_exact(&mut data)).await??;
234
235                let mut frame_data = Vec::with_capacity(size as usize);
237                frame_data.extend_from_slice(&frame_type_buf);
238                frame_data.extend_from_slice(&data);
239
240                Protocol::decode_frame(&frame_data).map_err(Error::from)
241            }
242            _ => Err(Error::Protocol(ProtocolError::InvalidFrameType(frame_type))),
243        }
244    }
245
246    pub async fn handle_heartbeat(&self) -> Result<()> {
248        self.send_command(Command::Nop).await
249    }
250
251    pub async fn read_message(&self) -> Result<Option<Message>> {
253        match self.read_frame().await {
254            Ok(Frame::Message(msg)) => {
255                info!(
256                    "收到消息 [ID: {:?}, 尝试次数: {}, 时间戳: {}]",
257                    &msg.id, msg.attempts, msg.timestamp
258                );
259                Ok(Some(msg))
260            }
261            Ok(Frame::Response(_)) => Ok(None),
262            Ok(Frame::Error(data)) => {
263                error!("NSQ错误响应: {:?}", String::from_utf8_lossy(&data));
264                Ok(None)
265            }
266            Err(e) => {
267                error!("读取消息错误: {:?}", e);
268                Err(e)
269            }
270        }
271    }
272
273    pub fn addr(&self) -> &str {
275        &self.addr
276    }
277
278    pub async fn write_all(&self, buf: &[u8]) -> Result<()> {
279        let mut stream = self.stream.lock().await;
280        timeout(self.write_timeout, stream.write_all(buf)).await??;
281        Ok(())
282    }
283
284    pub async fn read_exact(&self, buf: &mut [u8]) -> Result<()> {
285        let mut stream = self.stream.lock().await;
286        timeout(self.read_timeout, stream.read_exact(buf)).await??;
287        Ok(())
288    }
289
290    pub async fn write_command(
291        &self,
292        name: &str,
293        body: Option<&[u8]>,
294        params: &[&str],
295    ) -> Result<()> {
296        let cmd = Protocol::encode_command(name, body, params);
297        self.write_all(&cmd).await
298    }
299
300    pub async fn close(&self) -> Result<()> {
301        self.stream
302            .lock()
303            .await
304            .shutdown()
305            .await
306            .map_err(Error::from)
307    }
308}
309
310impl Drop for Connection {
311    fn drop(&mut self) {
312        }
315}
316
317pub async fn close_connection(connection: &Arc<Connection>) -> Result<()> {
319    connection.send_command(Command::Cls).await
320}