1use async_trait::async_trait;
2use log::{error, info};
3use std::collections::HashMap;
4use std::sync::atomic::{AtomicBool, AtomicI32, AtomicU64, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7use thiserror::Error;
8use tokio::sync::mpsc;
9use tokio::sync::Mutex;
10
11use crate::connection::Connection;
12use crate::error::{Error, Result};
13use crate::protocol::{Command, Frame, Message as ProtocolMessage, ProtocolError};
14
15#[derive(Debug, Error)]
16pub enum ConsumerError {
17 #[error("Invalid topic name: {0}")]
18 InvalidTopic(String),
19 #[error("Invalid channel name: {0}")]
20 InvalidChannel(String),
21 #[error("Connection error: {0}")]
22 ConnectionError(String),
23 #[error("Protocol error: {0}")]
24 ProtocolError(String),
25}
26
27#[derive(Debug, Clone)]
28pub struct Message {
29 pub id: Vec<u8>,
30 pub body: Vec<u8>,
31 pub attempts: u16,
32 pub timestamp: u64,
33}
34
35#[async_trait]
36pub trait Handler: Send + Sync + 'static {
37 async fn handle_message(&self, message: ProtocolMessage) -> Result<()>;
38}
39
40pub struct ConsumerStats {
41 pub messages_received: u64,
42 pub messages_finished: u64,
43 pub messages_requeued: u64,
44 pub connections: i32,
45}
46
47#[derive(Debug, Clone)]
48pub struct ConsumerConfig {
49 pub max_in_flight: i32,
50 pub max_attempts: u16,
51 pub dial_timeout: Duration,
52 pub read_timeout: Duration,
53 pub write_timeout: Duration,
54 pub lookup_poll_interval: Duration,
55 pub lookup_poll_jitter: f64,
56 pub max_requeue_delay: Duration,
57 pub default_requeue_delay: Duration,
58 pub shutdown_timeout: Duration,
59 pub backoff_strategy: bool,
61}
62
63impl Default for ConsumerConfig {
64 fn default() -> Self {
65 ConsumerConfig {
66 max_in_flight: 1,
67 max_attempts: 5,
68 dial_timeout: Duration::from_secs(1),
69 read_timeout: Duration::from_secs(60),
70 write_timeout: Duration::from_secs(1),
71 lookup_poll_interval: Duration::from_secs(60),
72 lookup_poll_jitter: 0.3,
73 max_requeue_delay: Duration::from_secs(15 * 60),
74 default_requeue_delay: Duration::from_secs(90),
75 shutdown_timeout: Duration::from_secs(30),
76 backoff_strategy: true,
77 }
78 }
79}
80
81pub struct Consumer {
82 topic: String,
83 channel: String,
84 config: ConsumerConfig,
85 handler: Arc<dyn Handler + Send + Sync + 'static>,
86
87 messages_received: AtomicU64,
89 messages_finished: AtomicU64,
90 messages_requeued: AtomicU64,
91
92 connections: Arc<Mutex<HashMap<String, Arc<Connection>>>>,
94 total_rdy_count: AtomicI32,
95 max_in_flight: AtomicI32,
96
97 is_running: AtomicBool,
99 stop_chan: mpsc::Sender<()>,
100}
101
102struct ConnectionHandler {
103 topic: String,
104 channel: String,
105 handler: Arc<dyn Handler + Send + Sync + 'static>,
106 messages_received: Arc<AtomicU64>,
107 messages_finished: Arc<AtomicU64>,
108 messages_requeued: Arc<AtomicU64>,
109 total_rdy_count: Arc<AtomicI32>,
110 max_in_flight: Arc<AtomicI32>,
111}
112
113impl ConnectionHandler {
114 fn new(consumer: &Consumer) -> Self {
115 Self {
116 topic: consumer.topic.clone(),
117 channel: consumer.channel.clone(),
118 handler: consumer.handler.clone(),
119 messages_received: Arc::new(AtomicU64::new(0)),
120 messages_finished: Arc::new(AtomicU64::new(0)),
121 messages_requeued: Arc::new(AtomicU64::new(0)),
122 total_rdy_count: Arc::new(AtomicI32::new(0)),
123 max_in_flight: Arc::new(AtomicI32::new(consumer.config.max_in_flight)),
124 }
125 }
126
127 async fn handle_connection(&self, conn: Arc<Connection>) -> Result<()> {
128 let sub_cmd = Command::Subscribe(self.topic.clone(), self.channel.clone());
130 conn.send_command(sub_cmd).await?;
131
132 let rdy_count = self.max_in_flight.load(Ordering::Relaxed);
134 let rdy_cmd = Command::Ready(rdy_count as u32);
135 conn.send_command(rdy_cmd).await?;
136
137 let mut heartbeat_interval = tokio::time::interval(Duration::from_secs(30));
139
140 loop {
141 tokio::select! {
142 _ = heartbeat_interval.tick() => {
144 if let Err(e) = conn.handle_heartbeat().await {
145 error!("心跳检测失败: {}", e);
146 return Err(e);
147 }
148 }
149 frame = conn.read_frame() =>
151 match frame {
152 Ok(Frame::Response(data)) => {
153 if data == b"_heartbeat_" {
155 if let Err(e) = conn.send_command(Command::Nop).await {
156 error!("发送心跳响应失败: {}", e);
157 return Err(e);
158 }
159 }
160 }
161 Ok(Frame::Error(data)) => {
162 error!("NSQ错误: {:?}", String::from_utf8_lossy(&data));
163 if String::from_utf8_lossy(&data).contains("E_INVALID") {
165 return Err(Error::Protocol(ProtocolError::Other(
166 String::from_utf8_lossy(&data).to_string()
167 )));
168 }
169 }
170 Ok(Frame::Message(msg)) => {
171 self.messages_received.fetch_add(1, Ordering::Relaxed);
172
173 match self.handler.handle_message(msg.clone()).await {
175 Ok(_) => {
176 let msg_id = String::from_utf8_lossy(&msg.id).to_string();
177 let fin_cmd = Command::Finish(msg_id);
178 if let Err(e) = conn.send_command(fin_cmd).await {
179 error!("发送 FIN 命令失败: {}", e);
180 return Err(e);
181 } else {
182 self.messages_finished.fetch_add(1, Ordering::Relaxed);
183 }
184 }
185 Err(e) => {
186 error!("消息处理失败: {}", e);
187 let msg_id = String::from_utf8_lossy(&msg.id).to_string();
188 let req_cmd = Command::Requeue(msg_id, 0);
189 if let Err(e) = conn.send_command(req_cmd).await {
190 error!("发送 REQ 命令失败: {}", e);
191 return Err(e);
192 } else {
193 self.messages_requeued.fetch_add(1, Ordering::Relaxed);
194 }
195 }
196 }
197
198 let current_rdy = self.total_rdy_count.fetch_sub(1, Ordering::Relaxed);
200 if current_rdy <= self.max_in_flight.load(Ordering::Relaxed) / 2 {
201 let new_rdy = self.max_in_flight.load(Ordering::Relaxed);
202 let rdy_cmd = Command::Ready(new_rdy as u32);
203 if let Err(e) = conn.send_command(rdy_cmd).await {
204 error!("发送 RDY 命令失败: {}", e);
205 return Err(e);
206 } else {
207 self.total_rdy_count.store(new_rdy, Ordering::Relaxed);
208 }
209 }
210 }
211 Err(e) => {
212 error!("读取帧失败: {}", e);
213 return Err(e);
214 }
215 }
216 }
217 }
218 }
219}
220
221impl Consumer {
222 pub fn new(
223 topic: String,
224 channel: String,
225 config: ConsumerConfig,
226 handler: impl Handler,
227 ) -> Result<Self> {
228 if !Self::is_valid_topic_name(&topic) {
229 return Err(Error::Other(format!("Invalid topic name: {}", topic)));
230 }
231 if !Self::is_valid_channel_name(&channel) {
232 return Err(Error::Other(format!("Invalid channel name: {}", channel)));
233 }
234
235 let (stop_tx, _) = mpsc::channel(1);
236
237 Ok(Consumer {
238 topic,
239 channel,
240 config: config.clone(),
241 handler: Arc::new(handler),
242 messages_received: AtomicU64::new(0),
243 messages_finished: AtomicU64::new(0),
244 messages_requeued: AtomicU64::new(0),
245 connections: Arc::new(Mutex::new(HashMap::new())),
246 total_rdy_count: AtomicI32::new(0),
247 max_in_flight: AtomicI32::new(config.max_in_flight),
248 is_running: AtomicBool::new(false),
249 stop_chan: stop_tx,
250 })
251 }
252
253 fn is_valid_topic_name(topic: &str) -> bool {
254 if topic.is_empty() || topic.len() > 64 {
255 return false;
256 }
257 topic
258 .chars()
259 .all(|c| c.is_ascii_alphanumeric() || c == '_' || c == '-' || c == '.')
260 }
261
262 fn is_valid_channel_name(channel: &str) -> bool {
263 if channel.is_empty() || channel.len() > 64 {
264 return false;
265 }
266 channel.chars().all(|c| {
267 c.is_ascii_alphanumeric() || c == '_' || c == '-' || c == '.' || c == '#' || c == '*'
268 })
269 }
270
271 pub fn stats(&self) -> ConsumerStats {
272 ConsumerStats {
273 messages_received: self.messages_received.load(Ordering::Relaxed),
274 messages_finished: self.messages_finished.load(Ordering::Relaxed),
275 messages_requeued: self.messages_requeued.load(Ordering::Relaxed),
276 connections: self.connections.blocking_lock().len() as i32,
277 }
278 }
279
280 pub async fn connect_to_nsqd(&self, addr: String) -> Result<()> {
281 let mut conns = self.connections.lock().await;
282 if conns.contains_key(&addr) {
283 return Ok(());
284 }
285
286 let conn = Arc::new(
287 Connection::new(
288 &addr,
289 None,
290 None,
291 self.config.read_timeout,
292 self.config.write_timeout,
293 )
294 .await?,
295 );
296
297 let conn_clone = Arc::clone(&conn);
298 let handler = Arc::new(ConnectionHandler::new(self));
299 let addr_clone = addr.clone();
300 let config_clone = self.config.clone();
301
302 tokio::spawn(async move {
304 let mut retry_delay = 1;
306 let max_retry_delay = 60;
308 let mut retry_count = 0;
310
311 loop {
312 match handler.handle_connection(Arc::clone(&conn_clone)).await {
313 Ok(_) => {
314 info!("连接循环正常结束");
315 break;
316 }
317 Err(e) => {
318 retry_count += 1;
319 let is_connection_error = matches!(e,
320 Error::Io(ref io_err) if io_err.kind() == std::io::ErrorKind::BrokenPipe
321 || io_err.kind() == std::io::ErrorKind::ConnectionReset
322 || io_err.kind() == std::io::ErrorKind::ConnectionAborted
323 || io_err.kind() == std::io::ErrorKind::UnexpectedEof
324 ) || e.to_string().contains("early eof");
325
326 if is_connection_error || matches!(e, Error::Timeout(_)) {
328 error!("连接错误 (尝试 #{}) 到 {}: {}", retry_count, addr_clone, e);
329
330 let sleep_duration = if config_clone.backoff_strategy {
332 let jitter = rand::random::<f32>() * 0.3;
333 let delay = (retry_delay as f32 * (1.0 + jitter)) as u64;
334 retry_delay = std::cmp::min(retry_delay * 2, max_retry_delay);
335 delay
336 } else {
337 retry_delay
338 };
339
340 info!("将在 {}秒 后尝试重新连接到 {}", sleep_duration, addr_clone);
341 tokio::time::sleep(Duration::from_secs(sleep_duration)).await;
342
343 match conn_clone.reconnect().await {
345 Ok(_) => {
346 info!("成功重新连接到 {}", addr_clone);
348 retry_delay = 1;
349 retry_count = 0;
350 }
351 Err(conn_err) => {
352 error!("重新连接失败: {}", conn_err);
353 continue;
354 }
355 }
356 } else {
357 error!("非连接错误,停止重试: {}", e);
359 break;
360 }
361 }
362 }
363 }
364 });
365
366 conns.insert(addr, conn);
367 Ok(())
368 }
369
370 pub async fn disconnect_from_nsqd(&self, addr: String) -> Result<()> {
371 let mut conns = self.connections.lock().await;
372 if let Some(conn) = conns.remove(&addr) {
373 conn.close().await?;
374 }
375 Ok(())
376 }
377
378 pub async fn start(&self) -> Result<()> {
379 self.is_running.store(true, Ordering::Relaxed);
380 Ok(())
381 }
382
383 pub async fn stop(&self) -> Result<()> {
384 info!("开始优雅关闭消费者...");
385 self.is_running.store(false, Ordering::Relaxed);
386
387 let _ = self.stop_chan.send(()).await;
389
390 let shutdown_deadline = tokio::time::sleep(self.config.shutdown_timeout);
392 tokio::pin!(shutdown_deadline);
393
394 let mut conns = self.connections.lock().await;
395 for (addr, conn) in conns.drain() {
396 info!("正在关闭到 {} 的连接", addr);
397
398 tokio::select! {
399 _ = &mut shutdown_deadline => {
400 error!("关闭连接超时");
401 break;
402 }
403 result = conn.close() => {
404 if let Err(e) = result {
405 error!("关闭到 {} 的连接时出错: {}", addr, e);
406 } else {
407 info!("成功关闭到 {} 的连接", addr);
408 }
409 }
410 }
411 }
412
413 info!("消费者已关闭");
414 Ok(())
415 }
416
417 pub async fn connect_to_nsqlookupd(&self, lookupd_url: String) -> Result<()> {
418 info!("正在从 nsqlookupd 获取 nsqd 节点列表...");
419 let nodes = crate::lookup::lookup_nodes(&lookupd_url, &self.topic).await?;
420
421 for node in nodes {
422 info!("发现 nsqd 节点: {}", node);
423 if let Err(e) = self.connect_to_nsqd(node.clone()).await {
424 error!("连接到 nsqd 节点 {} 失败: {}", node, e);
425 }
426 }
427
428 let consumer = self.clone();
430 let lookupd_url = lookupd_url.clone();
431 tokio::spawn(async move {
432 let mut interval = tokio::time::interval(consumer.config.lookup_poll_interval);
433 loop {
434 interval.tick().await;
435 match crate::lookup::lookup_nodes(&lookupd_url, &consumer.topic).await {
436 Ok(nodes) => {
437 for node in nodes {
438 if let Err(e) = consumer.connect_to_nsqd(node.clone()).await {
439 error!("连接到 nsqd 节点 {} 失败: {}", node, e);
440 }
441 }
442 }
443 Err(e) => {
444 error!("从 nsqlookupd 获取节点列表失败: {}", e);
445 }
446 }
447 }
448 });
449
450 Ok(())
451 }
452}
453
454impl Clone for Consumer {
455 fn clone(&self) -> Self {
456 Consumer {
457 topic: self.topic.clone(),
458 channel: self.channel.clone(),
459 config: self.config.clone(),
460 handler: self.handler.clone(),
461 messages_received: AtomicU64::new(self.messages_received.load(Ordering::Relaxed)),
462 messages_finished: AtomicU64::new(self.messages_finished.load(Ordering::Relaxed)),
463 messages_requeued: AtomicU64::new(self.messages_requeued.load(Ordering::Relaxed)),
464 connections: self.connections.clone(),
465 total_rdy_count: AtomicI32::new(self.total_rdy_count.load(Ordering::Relaxed)),
466 max_in_flight: AtomicI32::new(self.max_in_flight.load(Ordering::Relaxed)),
467 is_running: AtomicBool::new(self.is_running.load(Ordering::Relaxed)),
468 stop_chan: self.stop_chan.clone(),
469 }
470 }
471}