nsq_async_rs/
consumer.rs

1use async_trait::async_trait;
2use log::{error, info};
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU64, Ordering};
6use std::time::Duration;
7use thiserror::Error;
8use tokio::sync::Mutex;
9use tokio::sync::mpsc;
10
11use crate::connection::Connection;
12use crate::error::{Error, Result};
13use crate::protocol::{Command, Frame, Message as ProtocolMessage, ProtocolError};
14
15#[derive(Debug, Error)]
16pub enum ConsumerError {
17    #[error("Invalid topic name: {0}")]
18    InvalidTopic(String),
19    #[error("Invalid channel name: {0}")]
20    InvalidChannel(String),
21    #[error("Connection error: {0}")]
22    ConnectionError(String),
23    #[error("Protocol error: {0}")]
24    ProtocolError(String),
25}
26
27#[derive(Debug, Clone)]
28pub struct Message {
29    pub id: Vec<u8>,
30    pub body: Vec<u8>,
31    pub attempts: u16,
32    pub timestamp: u64,
33}
34
35/// 消息处理器 trait
36///
37/// 实现此 trait 来处理从 NSQ 接收的消息。
38///
39/// # 自动响应模式(默认)
40///
41/// 当 `ConsumerConfig::disable_auto_response` 为 `false` 时(默认):
42/// - 如果 `handle_message` 返回 `Ok(())`,消息会自动发送 FIN 命令
43/// - 如果 `handle_message` 返回 `Err(_)`,消息会自动发送 REQ 命令
44///
45/// # 手动响应模式
46///
47/// 当 `ConsumerConfig::disable_auto_response` 为 `true` 时:
48/// - 消息不会自动发送 FIN/REQ
49/// - 需要在 handler 中手动调用:
50///   - `message.finish()` 完成消息处理
51///   - `message.requeue(delay)` 重新入队消息
52///   - `message.touch()` 重置消息超时
53///
54/// 或者,在自动响应模式下,也可以在单个消息上调用 `message.disable_auto_response()` 来禁用该消息的自动响应。
55///
56/// # 示例
57///
58/// ## 自动响应模式
59///
60/// ```rust,ignore
61/// struct MyHandler;
62///
63/// #[async_trait]
64/// impl Handler for MyHandler {
65///     async fn handle_message(&self, message: Message) -> Result<()> {
66///         // 处理消息
67///         println!("收到消息: {:?}", String::from_utf8_lossy(&message.body));
68///         
69///         // 返回 Ok 会自动发送 FIN
70///         Ok(())
71///     }
72/// }
73/// ```
74///
75/// ## 手动响应模式
76///
77/// ```rust,ignore
78/// struct MyHandler;
79///
80/// #[async_trait]
81/// impl Handler for MyHandler {
82///     async fn handle_message(&self, message: Message) -> Result<()> {
83///         // 处理消息
84///         match process(&message.body) {
85///             Ok(_) => {
86///                 // 手动发送 FIN
87///                 message.finish().await?;
88///             }
89///             Err(_) => {
90///                 // 手动重新入队,延迟 5 秒
91///                 message.requeue(5000).await?;
92///             }
93///         }
94///         
95///         Ok(())
96///     }
97/// }
98/// ```
99#[async_trait]
100pub trait Handler: Send + Sync + 'static {
101    async fn handle_message(&self, message: ProtocolMessage) -> Result<()>;
102}
103
104pub struct ConsumerStats {
105    pub messages_received: u64,
106    pub messages_finished: u64,
107    pub messages_requeued: u64,
108    pub connections: i32,
109}
110
111#[derive(Debug, Clone)]
112pub struct ConsumerConfig {
113    pub max_in_flight: i32,
114    pub max_attempts: u16,
115    pub dial_timeout: Duration,
116    pub read_timeout: Duration,
117    pub write_timeout: Duration,
118    pub lookup_poll_interval: Duration,
119    pub lookup_poll_jitter: f64,
120    pub max_requeue_delay: Duration,
121    pub default_requeue_delay: Duration,
122    pub shutdown_timeout: Duration,
123    /// 是否使用指数退避策略进行重连
124    pub backoff_strategy: bool,
125    /// 是否禁用自动响应
126    ///
127    /// 当设置为 true 时,消息不会根据 Handler 的返回值自动发送 FIN/REQ
128    /// 需要在 Handler 中手动调用 message.finish() 或 message.requeue()
129    ///
130    /// 这对于以下场景很有用:
131    /// - 并发处理消息时需要异步确认
132    /// - 批量处理消息
133    /// - 需要精确控制消息确认时机
134    pub disable_auto_response: bool,
135}
136
137impl Default for ConsumerConfig {
138    fn default() -> Self {
139        ConsumerConfig {
140            max_in_flight: 1,
141            max_attempts: 5,
142            dial_timeout: Duration::from_secs(1),
143            read_timeout: Duration::from_secs(60),
144            write_timeout: Duration::from_secs(1),
145            lookup_poll_interval: Duration::from_secs(60),
146            lookup_poll_jitter: 0.3,
147            max_requeue_delay: Duration::from_secs(15 * 60),
148            default_requeue_delay: Duration::from_secs(90),
149            shutdown_timeout: Duration::from_secs(30),
150            backoff_strategy: true,
151            disable_auto_response: false,
152        }
153    }
154}
155
156pub struct Consumer {
157    topic: String,
158    channel: String,
159    config: ConsumerConfig,
160    handler: Arc<dyn Handler + Send + Sync + 'static>,
161
162    // Stats
163    messages_received: AtomicU64,
164    messages_finished: AtomicU64,
165    messages_requeued: AtomicU64,
166
167    // Connection management
168    connections: Arc<Mutex<HashMap<String, Arc<Connection>>>>,
169    total_rdy_count: AtomicI32,
170    max_in_flight: AtomicI32,
171
172    // Control
173    is_running: AtomicBool,
174    stop_chan: mpsc::Sender<()>,
175}
176
177struct ConnectionHandler {
178    topic: String,
179    channel: String,
180    handler: Arc<dyn Handler + Send + Sync + 'static>,
181    messages_received: Arc<AtomicU64>,
182    messages_finished: Arc<AtomicU64>,
183    messages_requeued: Arc<AtomicU64>,
184    total_rdy_count: Arc<AtomicI32>,
185    max_in_flight: Arc<AtomicI32>,
186    disable_auto_response: bool,
187}
188
189impl ConnectionHandler {
190    fn new(consumer: &Consumer) -> Self {
191        Self {
192            topic: consumer.topic.clone(),
193            channel: consumer.channel.clone(),
194            handler: consumer.handler.clone(),
195            messages_received: Arc::new(AtomicU64::new(0)),
196            messages_finished: Arc::new(AtomicU64::new(0)),
197            messages_requeued: Arc::new(AtomicU64::new(0)),
198            total_rdy_count: Arc::new(AtomicI32::new(0)),
199            max_in_flight: Arc::new(AtomicI32::new(consumer.config.max_in_flight)),
200            disable_auto_response: consumer.config.disable_auto_response,
201        }
202    }
203
204    async fn handle_connection(&self, conn: Arc<Connection>) -> Result<()> {
205        // 发送订阅命令
206        let sub_cmd = Command::Subscribe(self.topic.clone(), self.channel.clone());
207        conn.send_command(sub_cmd).await?;
208
209        // 发送就绪命令
210        let rdy_count = self.max_in_flight.load(Ordering::Relaxed);
211        let rdy_cmd = Command::Ready(rdy_count as u32);
212        conn.send_command(rdy_cmd).await?;
213
214        // 创建心跳间隔
215        let mut heartbeat_interval = tokio::time::interval(Duration::from_secs(30));
216
217        loop {
218            tokio::select! {
219                // 主动心跳检测
220                _ = heartbeat_interval.tick() => {
221                    if let Err(e) = conn.handle_heartbeat().await {
222                        error!("心跳检测失败: {}", e);
223                        return Err(e);
224                    }
225                }
226                // 接收并处理消息
227                frame = conn.read_frame() =>
228                    match frame {
229                        Ok(Frame::Response(data)) => {
230                            // 检查是否是心跳消息
231                            if data == b"_heartbeat_"
232                                && let Err(e) = conn.send_command(Command::Nop).await {
233                                    error!("发送心跳响应失败: {}", e);
234                                    return Err(e);
235                                }
236                        }
237                        Ok(Frame::Error(data)) => {
238                            error!("NSQ错误: {:?}", String::from_utf8_lossy(&data));
239                            // 如果是致命错误,需要重新连接
240                            if String::from_utf8_lossy(&data).contains("E_INVALID") {
241                                return Err(Error::Protocol(ProtocolError::Other(
242                                    String::from_utf8_lossy(&data).to_string()
243                                )));
244                            }
245                        }
246                        Ok(Frame::Message(msg)) => {
247                            self.messages_received.fetch_add(1, Ordering::Relaxed);
248
249                            // 为消息附加连接引用(用于手动确认)
250                            let msg_with_conn = msg.with_responder(Arc::clone(&conn));
251
252                            // 处理消息
253                            match self.handler.handle_message(msg_with_conn.clone()).await {
254                                Ok(_) => {
255                                    // 检查是否需要自动响应
256                                    if !self.disable_auto_response && !msg_with_conn.is_auto_response_disabled() && !msg_with_conn.has_responded() {
257                                        // 自动发送 FIN
258                                        let msg_id = msg_with_conn.id_string();
259                                        let fin_cmd = Command::Finish(msg_id);
260                                        if let Err(e) = conn.send_command(fin_cmd).await {
261                                            error!("发送 FIN 命令失败: {}", e);
262                                            return Err(e);
263                                        } else {
264                                            self.messages_finished.fetch_add(1, Ordering::Relaxed);
265                                        }
266                                    } else if msg_with_conn.has_responded() {
267                                        // 消息已经手动响应过了,更新统计
268                                        self.messages_finished.fetch_add(1, Ordering::Relaxed);
269                                    }
270                                }
271                                Err(e) => {
272                                    error!("消息处理失败: {}", e);
273
274                                    // 检查是否需要自动响应
275                                    if !self.disable_auto_response && !msg_with_conn.is_auto_response_disabled() && !msg_with_conn.has_responded() {
276                                        // 自动发送 REQ
277                                        let msg_id = msg_with_conn.id_string();
278                                        let req_cmd = Command::Requeue(msg_id, 0);
279                                        if let Err(e) = conn.send_command(req_cmd).await {
280                                            error!("发送 REQ 命令失败: {}", e);
281                                            return Err(e);
282                                        } else {
283                                            self.messages_requeued.fetch_add(1, Ordering::Relaxed);
284                                        }
285                                    } else if msg_with_conn.has_responded() {
286                                        // 消息已经手动响应过了,更新统计
287                                        self.messages_requeued.fetch_add(1, Ordering::Relaxed);
288                                    }
289                                }
290                            }
291
292                            // 更新 RDY 计数
293                            let current_rdy = self.total_rdy_count.fetch_sub(1, Ordering::Relaxed);
294                            if current_rdy <= self.max_in_flight.load(Ordering::Relaxed) / 2 {
295                                let new_rdy = self.max_in_flight.load(Ordering::Relaxed);
296                                let rdy_cmd = Command::Ready(new_rdy as u32);
297                                if let Err(e) = conn.send_command(rdy_cmd).await {
298                                    error!("发送 RDY 命令失败: {}", e);
299                                    return Err(e);
300                                } else {
301                                    self.total_rdy_count.store(new_rdy, Ordering::Relaxed);
302                                }
303                            }
304                        }
305                        Err(e) => {
306                            error!("读取帧失败: {}", e);
307                            return Err(e);
308                        }
309                    }
310            }
311        }
312    }
313}
314
315impl Consumer {
316    pub fn new(
317        topic: String,
318        channel: String,
319        config: ConsumerConfig,
320        handler: impl Handler,
321    ) -> Result<Self> {
322        if !Self::is_valid_topic_name(&topic) {
323            return Err(Error::Other(format!("Invalid topic name: {}", topic)));
324        }
325        if !Self::is_valid_channel_name(&channel) {
326            return Err(Error::Other(format!("Invalid channel name: {}", channel)));
327        }
328
329        let (stop_tx, _) = mpsc::channel(1);
330
331        Ok(Consumer {
332            topic,
333            channel,
334            config: config.clone(),
335            handler: Arc::new(handler),
336            messages_received: AtomicU64::new(0),
337            messages_finished: AtomicU64::new(0),
338            messages_requeued: AtomicU64::new(0),
339            connections: Arc::new(Mutex::new(HashMap::new())),
340            total_rdy_count: AtomicI32::new(0),
341            max_in_flight: AtomicI32::new(config.max_in_flight),
342            is_running: AtomicBool::new(false),
343            stop_chan: stop_tx,
344        })
345    }
346
347    fn is_valid_topic_name(topic: &str) -> bool {
348        if topic.is_empty() || topic.len() > 64 {
349            return false;
350        }
351        topic
352            .chars()
353            .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-' || c == '.')
354    }
355
356    fn is_valid_channel_name(channel: &str) -> bool {
357        if channel.is_empty() || channel.len() > 64 {
358            return false;
359        }
360        channel.chars().all(|c| {
361            c.is_ascii_alphanumeric() || c == '_' || c == '-' || c == '.' || c == '#' || c == '*'
362        })
363    }
364
365    pub fn stats(&self) -> ConsumerStats {
366        ConsumerStats {
367            messages_received: self.messages_received.load(Ordering::Relaxed),
368            messages_finished: self.messages_finished.load(Ordering::Relaxed),
369            messages_requeued: self.messages_requeued.load(Ordering::Relaxed),
370            connections: self.connections.blocking_lock().len() as i32,
371        }
372    }
373
374    pub async fn connect_to_nsqd(&self, addr: String) -> Result<()> {
375        let mut conns = self.connections.lock().await;
376        if conns.contains_key(&addr) {
377            return Ok(());
378        }
379
380        let conn = Arc::new(
381            Connection::new(
382                &addr,
383                None,
384                None,
385                self.config.read_timeout,
386                self.config.write_timeout,
387            )
388            .await?,
389        );
390
391        let conn_clone = Arc::clone(&conn);
392        let handler = Arc::new(ConnectionHandler::new(self));
393        let addr_clone = addr.clone();
394        let config_clone = self.config.clone();
395
396        // 启动消息处理循环
397        tokio::spawn(async move {
398            // 初始重试延迟(秒)
399            let mut retry_delay = 1;
400            // 最大重试延迟(秒)
401            let max_retry_delay = 60;
402            // 重试计数
403            let mut retry_count = 0;
404
405            loop {
406                match handler.handle_connection(Arc::clone(&conn_clone)).await {
407                    Ok(_) => {
408                        info!("连接循环正常结束");
409                        break;
410                    }
411                    Err(e) => {
412                        retry_count += 1;
413                        let is_connection_error = matches!(e,
414                            Error::Io(ref io_err) if io_err.kind() == std::io::ErrorKind::BrokenPipe
415                            || io_err.kind() == std::io::ErrorKind::ConnectionReset
416                            || io_err.kind() == std::io::ErrorKind::ConnectionAborted
417                            || io_err.kind() == std::io::ErrorKind::UnexpectedEof
418                        ) || e.to_string().contains("early eof");
419
420                        // 根据错误类型决定是否需要重连
421                        if is_connection_error || matches!(e, Error::Timeout(_)) {
422                            error!("连接错误 (尝试 #{}) 到 {}: {}", retry_count, addr_clone, e);
423
424                            // 指数退避策略
425                            let sleep_duration = if config_clone.backoff_strategy {
426                                let jitter = rand::random::<f32>() * 0.3;
427                                let delay = (retry_delay as f32 * (1.0 + jitter)) as u64;
428                                retry_delay = std::cmp::min(retry_delay * 2, max_retry_delay);
429                                delay
430                            } else {
431                                retry_delay
432                            };
433
434                            info!("将在 {}秒 后尝试重新连接到 {}", sleep_duration, addr_clone);
435                            tokio::time::sleep(Duration::from_secs(sleep_duration)).await;
436
437                            // 尝试重新建立连接
438                            match conn_clone.reconnect().await {
439                                Ok(_) => {
440                                    // 重置重试计数和延迟
441                                    info!("成功重新连接到 {}", addr_clone);
442                                    retry_delay = 1;
443                                    retry_count = 0;
444                                }
445                                Err(conn_err) => {
446                                    error!("重新连接失败: {}", conn_err);
447                                    continue;
448                                }
449                            }
450                        } else {
451                            // 对于其他类型的错误,记录并中断
452                            error!("非连接错误,停止重试: {}", e);
453                            break;
454                        }
455                    }
456                }
457            }
458        });
459
460        conns.insert(addr, conn);
461        Ok(())
462    }
463
464    pub async fn disconnect_from_nsqd(&self, addr: String) -> Result<()> {
465        let mut conns = self.connections.lock().await;
466        if let Some(conn) = conns.remove(&addr) {
467            conn.close().await?;
468        }
469        Ok(())
470    }
471
472    pub async fn start(&self) -> Result<()> {
473        self.is_running.store(true, Ordering::Relaxed);
474        Ok(())
475    }
476
477    pub async fn stop(&self) -> Result<()> {
478        info!("开始优雅关闭消费者...");
479        self.is_running.store(false, Ordering::Relaxed);
480
481        // 发送停止信号
482        let _ = self.stop_chan.send(()).await;
483
484        // 等待所有连接关闭或超时
485        let shutdown_deadline = tokio::time::sleep(self.config.shutdown_timeout);
486        tokio::pin!(shutdown_deadline);
487
488        let mut conns = self.connections.lock().await;
489        for (addr, conn) in conns.drain() {
490            info!("正在关闭到 {} 的连接", addr);
491
492            tokio::select! {
493                _ = &mut shutdown_deadline => {
494                    error!("关闭连接超时");
495                    break;
496                }
497                result = conn.close() => {
498                    if let Err(e) = result {
499                        error!("关闭到 {} 的连接时出错: {}", addr, e);
500                    } else {
501                        info!("成功关闭到 {} 的连接", addr);
502                    }
503                }
504            }
505        }
506
507        info!("消费者已关闭");
508        Ok(())
509    }
510
511    pub async fn connect_to_nsqlookupd(&self, lookupd_url: String) -> Result<()> {
512        info!("正在从 nsqlookupd 获取 nsqd 节点列表...");
513        let nodes = crate::lookup::lookup_nodes(&lookupd_url, &self.topic).await?;
514
515        for node in nodes {
516            info!("发现 nsqd 节点: {}", node);
517            if let Err(e) = self.connect_to_nsqd(node.clone()).await {
518                error!("连接到 nsqd 节点 {} 失败: {}", node, e);
519            }
520        }
521
522        // 启动定期更新节点的任务
523        let consumer = self.clone();
524        let lookupd_url = lookupd_url.clone();
525        tokio::spawn(async move {
526            let mut interval = tokio::time::interval(consumer.config.lookup_poll_interval);
527            loop {
528                interval.tick().await;
529                match crate::lookup::lookup_nodes(&lookupd_url, &consumer.topic).await {
530                    Ok(nodes) => {
531                        for node in nodes {
532                            if let Err(e) = consumer.connect_to_nsqd(node.clone()).await {
533                                error!("连接到 nsqd 节点 {} 失败: {}", node, e);
534                            }
535                        }
536                    }
537                    Err(e) => {
538                        error!("从 nsqlookupd 获取节点列表失败: {}", e);
539                    }
540                }
541            }
542        });
543
544        Ok(())
545    }
546}
547
548impl Clone for Consumer {
549    fn clone(&self) -> Self {
550        Consumer {
551            topic: self.topic.clone(),
552            channel: self.channel.clone(),
553            config: self.config.clone(),
554            handler: self.handler.clone(),
555            messages_received: AtomicU64::new(self.messages_received.load(Ordering::Relaxed)),
556            messages_finished: AtomicU64::new(self.messages_finished.load(Ordering::Relaxed)),
557            messages_requeued: AtomicU64::new(self.messages_requeued.load(Ordering::Relaxed)),
558            connections: self.connections.clone(),
559            total_rdy_count: AtomicI32::new(self.total_rdy_count.load(Ordering::Relaxed)),
560            max_in_flight: AtomicI32::new(self.max_in_flight.load(Ordering::Relaxed)),
561            is_running: AtomicBool::new(self.is_running.load(Ordering::Relaxed)),
562            stop_chan: self.stop_chan.clone(),
563        }
564    }
565}