nsq_async_rs/
consumer.rs

1use async_trait::async_trait;
2use log::{error, info};
3use std::collections::HashMap;
4use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU64, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7use thiserror::Error;
8use tokio::sync::mpsc;
9use tokio::sync::Mutex;
10
11use crate::connection::Connection;
12use crate::error::{Error, Result};
13use crate::protocol::{Command, Frame, Message as ProtocolMessage, ProtocolError};
14
15#[derive(Debug, Error)]
16pub enum ConsumerError {
17    #[error("Invalid topic name: {0}")]
18    InvalidTopic(String),
19    #[error("Invalid channel name: {0}")]
20    InvalidChannel(String),
21    #[error("Connection error: {0}")]
22    ConnectionError(String),
23    #[error("Protocol error: {0}")]
24    ProtocolError(String),
25}
26
27#[derive(Debug, Clone)]
28pub struct Message {
29    pub id: Vec<u8>,
30    pub body: Vec<u8>,
31    pub attempts: u16,
32    pub timestamp: u64,
33}
34
35#[async_trait]
36pub trait Handler: Send + Sync + 'static {
37    async fn handle_message(&self, message: ProtocolMessage) -> Result<()>;
38}
39
40pub struct ConsumerStats {
41    pub messages_received: u64,
42    pub messages_finished: u64,
43    pub messages_requeued: u64,
44    pub connections: i32,
45}
46
47#[derive(Debug, Clone)]
48pub struct ConsumerConfig {
49    pub max_in_flight: i32,
50    pub max_attempts: u16,
51    pub dial_timeout: Duration,
52    pub read_timeout: Duration,
53    pub write_timeout: Duration,
54    pub lookup_poll_interval: Duration,
55    pub lookup_poll_jitter: f64,
56    pub max_requeue_delay: Duration,
57    pub default_requeue_delay: Duration,
58    pub shutdown_timeout: Duration,
59    /// 是否使用指数退避策略进行重连
60    pub backoff_strategy: bool,
61}
62
63impl Default for ConsumerConfig {
64    fn default() -> Self {
65        ConsumerConfig {
66            max_in_flight: 1,
67            max_attempts: 5,
68            dial_timeout: Duration::from_secs(1),
69            read_timeout: Duration::from_secs(60),
70            write_timeout: Duration::from_secs(1),
71            lookup_poll_interval: Duration::from_secs(60),
72            lookup_poll_jitter: 0.3,
73            max_requeue_delay: Duration::from_secs(15 * 60),
74            default_requeue_delay: Duration::from_secs(90),
75            shutdown_timeout: Duration::from_secs(30),
76            backoff_strategy: true,
77        }
78    }
79}
80
81pub struct Consumer {
82    topic: String,
83    channel: String,
84    config: ConsumerConfig,
85    handler: Arc<dyn Handler + Send + Sync + 'static>,
86
87    // Stats
88    messages_received: AtomicU64,
89    messages_finished: AtomicU64,
90    messages_requeued: AtomicU64,
91
92    // Connection management
93    connections: Arc<Mutex<HashMap<String, Arc<Connection>>>>,
94    total_rdy_count: AtomicI32,
95    max_in_flight: AtomicI32,
96
97    // Control
98    is_running: AtomicBool,
99    stop_chan: mpsc::Sender<()>,
100}
101
102struct ConnectionHandler {
103    topic: String,
104    channel: String,
105    handler: Arc<dyn Handler + Send + Sync + 'static>,
106    messages_received: Arc<AtomicU64>,
107    messages_finished: Arc<AtomicU64>,
108    messages_requeued: Arc<AtomicU64>,
109    total_rdy_count: Arc<AtomicI32>,
110    max_in_flight: Arc<AtomicI32>,
111}
112
113impl ConnectionHandler {
114    fn new(consumer: &Consumer) -> Self {
115        Self {
116            topic: consumer.topic.clone(),
117            channel: consumer.channel.clone(),
118            handler: consumer.handler.clone(),
119            messages_received: Arc::new(AtomicU64::new(0)),
120            messages_finished: Arc::new(AtomicU64::new(0)),
121            messages_requeued: Arc::new(AtomicU64::new(0)),
122            total_rdy_count: Arc::new(AtomicI32::new(0)),
123            max_in_flight: Arc::new(AtomicI32::new(consumer.config.max_in_flight)),
124        }
125    }
126
127    async fn handle_connection(&self, conn: Arc<Connection>) -> Result<()> {
128        // 发送订阅命令
129        let sub_cmd = Command::Subscribe(self.topic.clone(), self.channel.clone());
130        conn.send_command(sub_cmd).await?;
131
132        // 发送就绪命令
133        let rdy_count = self.max_in_flight.load(Ordering::Relaxed);
134        let rdy_cmd = Command::Ready(rdy_count as u32);
135        conn.send_command(rdy_cmd).await?;
136
137        // 创建心跳间隔
138        let mut heartbeat_interval = tokio::time::interval(Duration::from_secs(30));
139
140        loop {
141            tokio::select! {
142                // 主动心跳检测
143                _ = heartbeat_interval.tick() => {
144                    if let Err(e) = conn.handle_heartbeat().await {
145                        error!("心跳检测失败: {}", e);
146                        return Err(e);
147                    }
148                }
149                // 接收并处理消息
150                frame = conn.read_frame() =>
151                    match frame {
152                        Ok(Frame::Response(data)) => {
153                            // 检查是否是心跳消息
154                            if data == b"_heartbeat_" {
155                                if let Err(e) = conn.send_command(Command::Nop).await {
156                                    error!("发送心跳响应失败: {}", e);
157                                    return Err(e);
158                                }
159                            }
160                        }
161                        Ok(Frame::Error(data)) => {
162                            error!("NSQ错误: {:?}", String::from_utf8_lossy(&data));
163                            // 如果是致命错误,需要重新连接
164                            if String::from_utf8_lossy(&data).contains("E_INVALID") {
165                                return Err(Error::Protocol(ProtocolError::Other(
166                                    String::from_utf8_lossy(&data).to_string()
167                                )));
168                            }
169                        }
170                        Ok(Frame::Message(msg)) => {
171                            self.messages_received.fetch_add(1, Ordering::Relaxed);
172
173                            // 处理消息
174                            match self.handler.handle_message(msg.clone()).await {
175                                Ok(_) => {
176                                    let msg_id = String::from_utf8_lossy(&msg.id).to_string();
177                                    let fin_cmd = Command::Finish(msg_id);
178                                    if let Err(e) = conn.send_command(fin_cmd).await {
179                                        error!("发送 FIN 命令失败: {}", e);
180                                        return Err(e);
181                                    } else {
182                                        self.messages_finished.fetch_add(1, Ordering::Relaxed);
183                                    }
184                                }
185                                Err(e) => {
186                                    error!("消息处理失败: {}", e);
187                                    let msg_id = String::from_utf8_lossy(&msg.id).to_string();
188                                    let req_cmd = Command::Requeue(msg_id, 0);
189                                    if let Err(e) = conn.send_command(req_cmd).await {
190                                        error!("发送 REQ 命令失败: {}", e);
191                                        return Err(e);
192                                    } else {
193                                        self.messages_requeued.fetch_add(1, Ordering::Relaxed);
194                                    }
195                                }
196                            }
197
198                            // 更新 RDY 计数
199                            let current_rdy = self.total_rdy_count.fetch_sub(1, Ordering::Relaxed);
200                            if current_rdy <= self.max_in_flight.load(Ordering::Relaxed) / 2 {
201                                let new_rdy = self.max_in_flight.load(Ordering::Relaxed);
202                                let rdy_cmd = Command::Ready(new_rdy as u32);
203                                if let Err(e) = conn.send_command(rdy_cmd).await {
204                                    error!("发送 RDY 命令失败: {}", e);
205                                    return Err(e);
206                                } else {
207                                    self.total_rdy_count.store(new_rdy, Ordering::Relaxed);
208                                }
209                            }
210                        }
211                        Err(e) => {
212                            error!("读取帧失败: {}", e);
213                            return Err(e);
214                        }
215                    }
216            }
217        }
218    }
219}
220
221impl Consumer {
222    pub fn new(
223        topic: String,
224        channel: String,
225        config: ConsumerConfig,
226        handler: impl Handler,
227    ) -> Result<Self> {
228        if !Self::is_valid_topic_name(&topic) {
229            return Err(Error::Other(format!("Invalid topic name: {}", topic)));
230        }
231        if !Self::is_valid_channel_name(&channel) {
232            return Err(Error::Other(format!("Invalid channel name: {}", channel)));
233        }
234
235        let (stop_tx, _) = mpsc::channel(1);
236
237        Ok(Consumer {
238            topic,
239            channel,
240            config: config.clone(),
241            handler: Arc::new(handler),
242            messages_received: AtomicU64::new(0),
243            messages_finished: AtomicU64::new(0),
244            messages_requeued: AtomicU64::new(0),
245            connections: Arc::new(Mutex::new(HashMap::new())),
246            total_rdy_count: AtomicI32::new(0),
247            max_in_flight: AtomicI32::new(config.max_in_flight),
248            is_running: AtomicBool::new(false),
249            stop_chan: stop_tx,
250        })
251    }
252
253    fn is_valid_topic_name(topic: &str) -> bool {
254        if topic.is_empty() || topic.len() > 64 {
255            return false;
256        }
257        topic
258            .chars()
259            .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-' || c == '.')
260    }
261
262    fn is_valid_channel_name(channel: &str) -> bool {
263        if channel.is_empty() || channel.len() > 64 {
264            return false;
265        }
266        channel.chars().all(|c| {
267            c.is_ascii_alphanumeric() || c == '_' || c == '-' || c == '.' || c == '#' || c == '*'
268        })
269    }
270
271    pub fn stats(&self) -> ConsumerStats {
272        ConsumerStats {
273            messages_received: self.messages_received.load(Ordering::Relaxed),
274            messages_finished: self.messages_finished.load(Ordering::Relaxed),
275            messages_requeued: self.messages_requeued.load(Ordering::Relaxed),
276            connections: self.connections.blocking_lock().len() as i32,
277        }
278    }
279
280    pub async fn connect_to_nsqd(&self, addr: String) -> Result<()> {
281        let mut conns = self.connections.lock().await;
282        if conns.contains_key(&addr) {
283            return Ok(());
284        }
285
286        let conn = Arc::new(
287            Connection::new(
288                &addr,
289                None,
290                None,
291                self.config.read_timeout,
292                self.config.write_timeout,
293            )
294            .await?,
295        );
296
297        let conn_clone = Arc::clone(&conn);
298        let handler = Arc::new(ConnectionHandler::new(self));
299        let addr_clone = addr.clone();
300        let config_clone = self.config.clone();
301
302        // 启动消息处理循环
303        tokio::spawn(async move {
304            // 初始重试延迟(秒)
305            let mut retry_delay = 1;
306            // 最大重试延迟(秒)
307            let max_retry_delay = 60;
308            // 重试计数
309            let mut retry_count = 0;
310
311            loop {
312                match handler.handle_connection(Arc::clone(&conn_clone)).await {
313                    Ok(_) => {
314                        info!("连接循环正常结束");
315                        break;
316                    }
317                    Err(e) => {
318                        retry_count += 1;
319                        let is_connection_error = matches!(e,
320                            Error::Io(ref io_err) if io_err.kind() == std::io::ErrorKind::BrokenPipe
321                            || io_err.kind() == std::io::ErrorKind::ConnectionReset
322                            || io_err.kind() == std::io::ErrorKind::ConnectionAborted
323                            || io_err.kind() == std::io::ErrorKind::UnexpectedEof
324                        ) || e.to_string().contains("early eof");
325
326                        // 根据错误类型决定是否需要重连
327                        if is_connection_error || matches!(e, Error::Timeout(_)) {
328                            error!("连接错误 (尝试 #{}) 到 {}: {}", retry_count, addr_clone, e);
329
330                            // 指数退避策略
331                            let sleep_duration = if config_clone.backoff_strategy {
332                                let jitter = rand::random::<f32>() * 0.3;
333                                let delay = (retry_delay as f32 * (1.0 + jitter)) as u64;
334                                retry_delay = std::cmp::min(retry_delay * 2, max_retry_delay);
335                                delay
336                            } else {
337                                retry_delay
338                            };
339
340                            info!("将在 {}秒 后尝试重新连接到 {}", sleep_duration, addr_clone);
341                            tokio::time::sleep(Duration::from_secs(sleep_duration)).await;
342
343                            // 尝试重新建立连接
344                            match conn_clone.reconnect().await {
345                                Ok(_) => {
346                                    // 重置重试计数和延迟
347                                    info!("成功重新连接到 {}", addr_clone);
348                                    retry_delay = 1;
349                                    retry_count = 0;
350                                }
351                                Err(conn_err) => {
352                                    error!("重新连接失败: {}", conn_err);
353                                    continue;
354                                }
355                            }
356                        } else {
357                            // 对于其他类型的错误,记录并中断
358                            error!("非连接错误,停止重试: {}", e);
359                            break;
360                        }
361                    }
362                }
363            }
364        });
365
366        conns.insert(addr, conn);
367        Ok(())
368    }
369
370    pub async fn disconnect_from_nsqd(&self, addr: String) -> Result<()> {
371        let mut conns = self.connections.lock().await;
372        if let Some(conn) = conns.remove(&addr) {
373            conn.close().await?;
374        }
375        Ok(())
376    }
377
378    pub async fn start(&self) -> Result<()> {
379        self.is_running.store(true, Ordering::Relaxed);
380        Ok(())
381    }
382
383    pub async fn stop(&self) -> Result<()> {
384        info!("开始优雅关闭消费者...");
385        self.is_running.store(false, Ordering::Relaxed);
386
387        // 发送停止信号
388        let _ = self.stop_chan.send(()).await;
389
390        // 等待所有连接关闭或超时
391        let shutdown_deadline = tokio::time::sleep(self.config.shutdown_timeout);
392        tokio::pin!(shutdown_deadline);
393
394        let mut conns = self.connections.lock().await;
395        for (addr, conn) in conns.drain() {
396            info!("正在关闭到 {} 的连接", addr);
397
398            tokio::select! {
399                _ = &mut shutdown_deadline => {
400                    error!("关闭连接超时");
401                    break;
402                }
403                result = conn.close() => {
404                    if let Err(e) = result {
405                        error!("关闭到 {} 的连接时出错: {}", addr, e);
406                    } else {
407                        info!("成功关闭到 {} 的连接", addr);
408                    }
409                }
410            }
411        }
412
413        info!("消费者已关闭");
414        Ok(())
415    }
416
417    pub async fn connect_to_nsqlookupd(&self, lookupd_url: String) -> Result<()> {
418        info!("正在从 nsqlookupd 获取 nsqd 节点列表...");
419        let nodes = crate::lookup::lookup_nodes(&lookupd_url, &self.topic).await?;
420
421        for node in nodes {
422            info!("发现 nsqd 节点: {}", node);
423            if let Err(e) = self.connect_to_nsqd(node.clone()).await {
424                error!("连接到 nsqd 节点 {} 失败: {}", node, e);
425            }
426        }
427
428        // 启动定期更新节点的任务
429        let consumer = self.clone();
430        let lookupd_url = lookupd_url.clone();
431        tokio::spawn(async move {
432            let mut interval = tokio::time::interval(consumer.config.lookup_poll_interval);
433            loop {
434                interval.tick().await;
435                match crate::lookup::lookup_nodes(&lookupd_url, &consumer.topic).await {
436                    Ok(nodes) => {
437                        for node in nodes {
438                            if let Err(e) = consumer.connect_to_nsqd(node.clone()).await {
439                                error!("连接到 nsqd 节点 {} 失败: {}", node, e);
440                            }
441                        }
442                    }
443                    Err(e) => {
444                        error!("从 nsqlookupd 获取节点列表失败: {}", e);
445                    }
446                }
447            }
448        });
449
450        Ok(())
451    }
452}
453
454impl Clone for Consumer {
455    fn clone(&self) -> Self {
456        Consumer {
457            topic: self.topic.clone(),
458            channel: self.channel.clone(),
459            config: self.config.clone(),
460            handler: self.handler.clone(),
461            messages_received: AtomicU64::new(self.messages_received.load(Ordering::Relaxed)),
462            messages_finished: AtomicU64::new(self.messages_finished.load(Ordering::Relaxed)),
463            messages_requeued: AtomicU64::new(self.messages_requeued.load(Ordering::Relaxed)),
464            connections: self.connections.clone(),
465            total_rdy_count: AtomicI32::new(self.total_rdy_count.load(Ordering::Relaxed)),
466            max_in_flight: AtomicI32::new(self.max_in_flight.load(Ordering::Relaxed)),
467            is_running: AtomicBool::new(self.is_running.load(Ordering::Relaxed)),
468            stop_chan: self.stop_chan.clone(),
469        }
470    }
471}