1use async_trait::async_trait;
2use log::{error, info};
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU64, Ordering};
6use std::time::Duration;
7use thiserror::Error;
8use tokio::sync::Mutex;
9use tokio::sync::mpsc;
10
11use crate::connection::Connection;
12use crate::error::{Error, Result};
13use crate::protocol::{Command, Frame, Message as ProtocolMessage, ProtocolError};
14
15#[derive(Debug, Error)]
16pub enum ConsumerError {
17 #[error("Invalid topic name: {0}")]
18 InvalidTopic(String),
19 #[error("Invalid channel name: {0}")]
20 InvalidChannel(String),
21 #[error("Connection error: {0}")]
22 ConnectionError(String),
23 #[error("Protocol error: {0}")]
24 ProtocolError(String),
25}
26
27#[derive(Debug, Clone)]
28pub struct Message {
29 pub id: Vec<u8>,
30 pub body: Vec<u8>,
31 pub attempts: u16,
32 pub timestamp: u64,
33}
34
35#[async_trait]
100pub trait Handler: Send + Sync + 'static {
101 async fn handle_message(&self, message: ProtocolMessage) -> Result<()>;
102}
103
104pub struct ConsumerStats {
105 pub messages_received: u64,
106 pub messages_finished: u64,
107 pub messages_requeued: u64,
108 pub connections: i32,
109}
110
111#[derive(Debug, Clone)]
112pub struct ConsumerConfig {
113 pub max_in_flight: i32,
114 pub max_attempts: u16,
115 pub dial_timeout: Duration,
116 pub read_timeout: Duration,
117 pub write_timeout: Duration,
118 pub lookup_poll_interval: Duration,
119 pub lookup_poll_jitter: f64,
120 pub max_requeue_delay: Duration,
121 pub default_requeue_delay: Duration,
122 pub shutdown_timeout: Duration,
123 pub backoff_strategy: bool,
125 pub disable_auto_response: bool,
135}
136
137impl Default for ConsumerConfig {
138 fn default() -> Self {
139 ConsumerConfig {
140 max_in_flight: 1,
141 max_attempts: 5,
142 dial_timeout: Duration::from_secs(1),
143 read_timeout: Duration::from_secs(60),
144 write_timeout: Duration::from_secs(1),
145 lookup_poll_interval: Duration::from_secs(60),
146 lookup_poll_jitter: 0.3,
147 max_requeue_delay: Duration::from_secs(15 * 60),
148 default_requeue_delay: Duration::from_secs(90),
149 shutdown_timeout: Duration::from_secs(30),
150 backoff_strategy: true,
151 disable_auto_response: false,
152 }
153 }
154}
155
156pub struct Consumer {
157 topic: String,
158 channel: String,
159 config: ConsumerConfig,
160 handler: Arc<dyn Handler + Send + Sync + 'static>,
161
162 messages_received: AtomicU64,
164 messages_finished: AtomicU64,
165 messages_requeued: AtomicU64,
166
167 connections: Arc<Mutex<HashMap<String, Arc<Connection>>>>,
169 total_rdy_count: AtomicI32,
170 max_in_flight: AtomicI32,
171
172 is_running: AtomicBool,
174 stop_chan: mpsc::Sender<()>,
175}
176
177struct ConnectionHandler {
178 topic: String,
179 channel: String,
180 handler: Arc<dyn Handler + Send + Sync + 'static>,
181 messages_received: Arc<AtomicU64>,
182 messages_finished: Arc<AtomicU64>,
183 messages_requeued: Arc<AtomicU64>,
184 total_rdy_count: Arc<AtomicI32>,
185 max_in_flight: Arc<AtomicI32>,
186 disable_auto_response: bool,
187}
188
189impl ConnectionHandler {
190 fn new(consumer: &Consumer) -> Self {
191 Self {
192 topic: consumer.topic.clone(),
193 channel: consumer.channel.clone(),
194 handler: consumer.handler.clone(),
195 messages_received: Arc::new(AtomicU64::new(0)),
196 messages_finished: Arc::new(AtomicU64::new(0)),
197 messages_requeued: Arc::new(AtomicU64::new(0)),
198 total_rdy_count: Arc::new(AtomicI32::new(0)),
199 max_in_flight: Arc::new(AtomicI32::new(consumer.config.max_in_flight)),
200 disable_auto_response: consumer.config.disable_auto_response,
201 }
202 }
203
204 async fn handle_connection(&self, conn: Arc<Connection>) -> Result<()> {
205 let sub_cmd = Command::Subscribe(self.topic.clone(), self.channel.clone());
207 conn.send_command(sub_cmd).await?;
208
209 let rdy_count = self.max_in_flight.load(Ordering::Relaxed);
211 let rdy_cmd = Command::Ready(rdy_count as u32);
212 conn.send_command(rdy_cmd).await?;
213
214 let mut heartbeat_interval = tokio::time::interval(Duration::from_secs(30));
216
217 loop {
218 tokio::select! {
219 _ = heartbeat_interval.tick() => {
221 if let Err(e) = conn.handle_heartbeat().await {
222 error!("心跳检测失败: {}", e);
223 return Err(e);
224 }
225 }
226 frame = conn.read_frame() =>
228 match frame {
229 Ok(Frame::Response(data)) => {
230 if data == b"_heartbeat_"
232 && let Err(e) = conn.send_command(Command::Nop).await {
233 error!("发送心跳响应失败: {}", e);
234 return Err(e);
235 }
236 }
237 Ok(Frame::Error(data)) => {
238 error!("NSQ错误: {:?}", String::from_utf8_lossy(&data));
239 if String::from_utf8_lossy(&data).contains("E_INVALID") {
241 return Err(Error::Protocol(ProtocolError::Other(
242 String::from_utf8_lossy(&data).to_string()
243 )));
244 }
245 }
246 Ok(Frame::Message(msg)) => {
247 self.messages_received.fetch_add(1, Ordering::Relaxed);
248
249 let msg_with_conn = msg.with_responder(Arc::clone(&conn));
251
252 match self.handler.handle_message(msg_with_conn.clone()).await {
254 Ok(_) => {
255 if !self.disable_auto_response && !msg_with_conn.is_auto_response_disabled() && !msg_with_conn.has_responded() {
257 let msg_id = msg_with_conn.id_string();
259 let fin_cmd = Command::Finish(msg_id);
260 if let Err(e) = conn.send_command(fin_cmd).await {
261 error!("发送 FIN 命令失败: {}", e);
262 return Err(e);
263 } else {
264 self.messages_finished.fetch_add(1, Ordering::Relaxed);
265 }
266 } else if msg_with_conn.has_responded() {
267 self.messages_finished.fetch_add(1, Ordering::Relaxed);
269 }
270 }
271 Err(e) => {
272 error!("消息处理失败: {}", e);
273
274 if !self.disable_auto_response && !msg_with_conn.is_auto_response_disabled() && !msg_with_conn.has_responded() {
276 let msg_id = msg_with_conn.id_string();
278 let req_cmd = Command::Requeue(msg_id, 0);
279 if let Err(e) = conn.send_command(req_cmd).await {
280 error!("发送 REQ 命令失败: {}", e);
281 return Err(e);
282 } else {
283 self.messages_requeued.fetch_add(1, Ordering::Relaxed);
284 }
285 } else if msg_with_conn.has_responded() {
286 self.messages_requeued.fetch_add(1, Ordering::Relaxed);
288 }
289 }
290 }
291
292 let current_rdy = self.total_rdy_count.fetch_sub(1, Ordering::Relaxed);
294 if current_rdy <= self.max_in_flight.load(Ordering::Relaxed) / 2 {
295 let new_rdy = self.max_in_flight.load(Ordering::Relaxed);
296 let rdy_cmd = Command::Ready(new_rdy as u32);
297 if let Err(e) = conn.send_command(rdy_cmd).await {
298 error!("发送 RDY 命令失败: {}", e);
299 return Err(e);
300 } else {
301 self.total_rdy_count.store(new_rdy, Ordering::Relaxed);
302 }
303 }
304 }
305 Err(e) => {
306 error!("读取帧失败: {}", e);
307 return Err(e);
308 }
309 }
310 }
311 }
312 }
313}
314
315impl Consumer {
316 pub fn new(
317 topic: String,
318 channel: String,
319 config: ConsumerConfig,
320 handler: impl Handler,
321 ) -> Result<Self> {
322 if !Self::is_valid_topic_name(&topic) {
323 return Err(Error::Other(format!("Invalid topic name: {}", topic)));
324 }
325 if !Self::is_valid_channel_name(&channel) {
326 return Err(Error::Other(format!("Invalid channel name: {}", channel)));
327 }
328
329 let (stop_tx, _) = mpsc::channel(1);
330
331 Ok(Consumer {
332 topic,
333 channel,
334 config: config.clone(),
335 handler: Arc::new(handler),
336 messages_received: AtomicU64::new(0),
337 messages_finished: AtomicU64::new(0),
338 messages_requeued: AtomicU64::new(0),
339 connections: Arc::new(Mutex::new(HashMap::new())),
340 total_rdy_count: AtomicI32::new(0),
341 max_in_flight: AtomicI32::new(config.max_in_flight),
342 is_running: AtomicBool::new(false),
343 stop_chan: stop_tx,
344 })
345 }
346
347 fn is_valid_topic_name(topic: &str) -> bool {
348 if topic.is_empty() || topic.len() > 64 {
349 return false;
350 }
351 topic
352 .chars()
353 .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-' || c == '.')
354 }
355
356 fn is_valid_channel_name(channel: &str) -> bool {
357 if channel.is_empty() || channel.len() > 64 {
358 return false;
359 }
360 channel.chars().all(|c| {
361 c.is_ascii_alphanumeric() || c == '_' || c == '-' || c == '.' || c == '#' || c == '*'
362 })
363 }
364
365 pub fn stats(&self) -> ConsumerStats {
366 ConsumerStats {
367 messages_received: self.messages_received.load(Ordering::Relaxed),
368 messages_finished: self.messages_finished.load(Ordering::Relaxed),
369 messages_requeued: self.messages_requeued.load(Ordering::Relaxed),
370 connections: self.connections.blocking_lock().len() as i32,
371 }
372 }
373
374 pub async fn connect_to_nsqd(&self, addr: String) -> Result<()> {
375 let mut conns = self.connections.lock().await;
376 if conns.contains_key(&addr) {
377 return Ok(());
378 }
379
380 let conn = Arc::new(
381 Connection::new(
382 &addr,
383 None,
384 None,
385 self.config.read_timeout,
386 self.config.write_timeout,
387 )
388 .await?,
389 );
390
391 let conn_clone = Arc::clone(&conn);
392 let handler = Arc::new(ConnectionHandler::new(self));
393 let addr_clone = addr.clone();
394 let config_clone = self.config.clone();
395
396 tokio::spawn(async move {
398 let mut retry_delay = 1;
400 let max_retry_delay = 60;
402 let mut retry_count = 0;
404
405 loop {
406 match handler.handle_connection(Arc::clone(&conn_clone)).await {
407 Ok(_) => {
408 info!("连接循环正常结束");
409 break;
410 }
411 Err(e) => {
412 retry_count += 1;
413 let is_connection_error = matches!(e,
414 Error::Io(ref io_err) if io_err.kind() == std::io::ErrorKind::BrokenPipe
415 || io_err.kind() == std::io::ErrorKind::ConnectionReset
416 || io_err.kind() == std::io::ErrorKind::ConnectionAborted
417 || io_err.kind() == std::io::ErrorKind::UnexpectedEof
418 ) || e.to_string().contains("early eof");
419
420 if is_connection_error || matches!(e, Error::Timeout(_)) {
422 error!("连接错误 (尝试 #{}) 到 {}: {}", retry_count, addr_clone, e);
423
424 let sleep_duration = if config_clone.backoff_strategy {
426 let jitter = rand::random::<f32>() * 0.3;
427 let delay = (retry_delay as f32 * (1.0 + jitter)) as u64;
428 retry_delay = std::cmp::min(retry_delay * 2, max_retry_delay);
429 delay
430 } else {
431 retry_delay
432 };
433
434 info!("将在 {}秒 后尝试重新连接到 {}", sleep_duration, addr_clone);
435 tokio::time::sleep(Duration::from_secs(sleep_duration)).await;
436
437 match conn_clone.reconnect().await {
439 Ok(_) => {
440 info!("成功重新连接到 {}", addr_clone);
442 retry_delay = 1;
443 retry_count = 0;
444 }
445 Err(conn_err) => {
446 error!("重新连接失败: {}", conn_err);
447 continue;
448 }
449 }
450 } else {
451 error!("非连接错误,停止重试: {}", e);
453 break;
454 }
455 }
456 }
457 }
458 });
459
460 conns.insert(addr, conn);
461 Ok(())
462 }
463
464 pub async fn disconnect_from_nsqd(&self, addr: String) -> Result<()> {
465 let mut conns = self.connections.lock().await;
466 if let Some(conn) = conns.remove(&addr) {
467 conn.close().await?;
468 }
469 Ok(())
470 }
471
472 pub async fn start(&self) -> Result<()> {
473 self.is_running.store(true, Ordering::Relaxed);
474 Ok(())
475 }
476
477 pub async fn stop(&self) -> Result<()> {
478 info!("开始优雅关闭消费者...");
479 self.is_running.store(false, Ordering::Relaxed);
480
481 let _ = self.stop_chan.send(()).await;
483
484 let shutdown_deadline = tokio::time::sleep(self.config.shutdown_timeout);
486 tokio::pin!(shutdown_deadline);
487
488 let mut conns = self.connections.lock().await;
489 for (addr, conn) in conns.drain() {
490 info!("正在关闭到 {} 的连接", addr);
491
492 tokio::select! {
493 _ = &mut shutdown_deadline => {
494 error!("关闭连接超时");
495 break;
496 }
497 result = conn.close() => {
498 if let Err(e) = result {
499 error!("关闭到 {} 的连接时出错: {}", addr, e);
500 } else {
501 info!("成功关闭到 {} 的连接", addr);
502 }
503 }
504 }
505 }
506
507 info!("消费者已关闭");
508 Ok(())
509 }
510
511 pub async fn connect_to_nsqlookupd(&self, lookupd_url: String) -> Result<()> {
512 info!("正在从 nsqlookupd 获取 nsqd 节点列表...");
513 let nodes = crate::lookup::lookup_nodes(&lookupd_url, &self.topic).await?;
514
515 for node in nodes {
516 info!("发现 nsqd 节点: {}", node);
517 if let Err(e) = self.connect_to_nsqd(node.clone()).await {
518 error!("连接到 nsqd 节点 {} 失败: {}", node, e);
519 }
520 }
521
522 let consumer = self.clone();
524 let lookupd_url = lookupd_url.clone();
525 tokio::spawn(async move {
526 let mut interval = tokio::time::interval(consumer.config.lookup_poll_interval);
527 loop {
528 interval.tick().await;
529 match crate::lookup::lookup_nodes(&lookupd_url, &consumer.topic).await {
530 Ok(nodes) => {
531 for node in nodes {
532 if let Err(e) = consumer.connect_to_nsqd(node.clone()).await {
533 error!("连接到 nsqd 节点 {} 失败: {}", node, e);
534 }
535 }
536 }
537 Err(e) => {
538 error!("从 nsqlookupd 获取节点列表失败: {}", e);
539 }
540 }
541 }
542 });
543
544 Ok(())
545 }
546}
547
548impl Clone for Consumer {
549 fn clone(&self) -> Self {
550 Consumer {
551 topic: self.topic.clone(),
552 channel: self.channel.clone(),
553 config: self.config.clone(),
554 handler: self.handler.clone(),
555 messages_received: AtomicU64::new(self.messages_received.load(Ordering::Relaxed)),
556 messages_finished: AtomicU64::new(self.messages_finished.load(Ordering::Relaxed)),
557 messages_requeued: AtomicU64::new(self.messages_requeued.load(Ordering::Relaxed)),
558 connections: self.connections.clone(),
559 total_rdy_count: AtomicI32::new(self.total_rdy_count.load(Ordering::Relaxed)),
560 max_in_flight: AtomicI32::new(self.max_in_flight.load(Ordering::Relaxed)),
561 is_running: AtomicBool::new(self.is_running.load(Ordering::Relaxed)),
562 stop_chan: self.stop_chan.clone(),
563 }
564 }
565}