nsq_async_rs/
producer.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use async_trait::async_trait;
6use backoff::ExponentialBackoffBuilder;
7use log::debug;
8use tokio::sync::RwLock;
9
10use crate::commands::lookup_nsqd_nodes;
11use crate::connection::Connection;
12use crate::error::{Error, Result};
13use crate::protocol::{Command, Frame, IdentifyConfig, ProtocolError};
14
15/// 生产者配置
16#[derive(Debug, Clone)]
17pub struct ProducerConfig {
18    /// NSQ服务器地址
19    pub nsqd_addresses: Vec<String>,
20    /// NSQ查询服务地址
21    pub nsqlookupd_addresses: Vec<String>,
22    /// 连接超时时间
23    pub connection_timeout: Duration,
24    /// 认证密钥
25    pub auth_secret: Option<String>,
26    /// 身份配置
27    pub identify_config: Option<IdentifyConfig>,
28    /// 重连策略
29    pub backoff_config: BackoffConfig,
30}
31
32impl Default for ProducerConfig {
33    fn default() -> Self {
34        Self {
35            nsqd_addresses: vec![],
36            nsqlookupd_addresses: vec![],
37            connection_timeout: Duration::from_secs(5),
38            auth_secret: None,
39            identify_config: None,
40            backoff_config: BackoffConfig::default(),
41        }
42    }
43}
44
45/// 重连策略配置
46#[derive(Debug, Clone)]
47pub struct BackoffConfig {
48    /// 初始间隔
49    pub initial_interval: Duration,
50    /// 最大间隔
51    pub max_interval: Duration,
52    /// 倍数
53    pub multiplier: f64,
54    /// 最大重试时间
55    pub max_elapsed_time: Option<Duration>,
56}
57
58impl Default for BackoffConfig {
59    fn default() -> Self {
60        Self {
61            initial_interval: Duration::from_millis(100),
62            max_interval: Duration::from_secs(10),
63            multiplier: 2.0,
64            max_elapsed_time: Some(Duration::from_secs(60)),
65        }
66    }
67}
68
69/// NSQ生产者特性
70#[async_trait]
71pub trait Producer: Send + Sync {
72    /// 向NSQ发布消息
73    async fn publish<T: AsRef<[u8]> + Send + Sync>(&self, topic: &str, message: T) -> Result<()>;
74
75    /// 向NSQ发布延迟消息
76    async fn publish_delayed<T: AsRef<[u8]> + Send + Sync>(
77        &self,
78        topic: &str,
79        message: T,
80        delay: Duration,
81    ) -> Result<()>;
82
83    /// 批量发布消息
84    async fn publish_multi<T: AsRef<[u8]> + Send + Sync>(
85        &self,
86        topic: &str,
87        messages: Vec<T>,
88    ) -> Result<()>;
89
90    /// 发送 ping 命令检测连接状态
91    ///
92    /// # 参数
93    /// * `addr` - 要 ping 的 NSQ 服务器地址,如果为 None,则使用配置中的第一个地址
94    /// * `timeout` - 超时时间,默认为 5 秒
95    ///
96    /// # 返回
97    /// * `Ok(())` - 如果连接正常
98    /// * `Err(Error)` - 如果连接异常或超时
99    async fn ping(&self, addr: Option<&str>, timeout: Option<Duration>) -> Result<()>;
100}
101
102/// NSQ生产者实现
103pub struct NsqProducer {
104    /// 生产者配置
105    config: ProducerConfig,
106    /// 内部连接池
107    connections: RwLock<HashMap<String, Arc<Connection>>>,
108}
109
110impl NsqProducer {
111    /// 创建新的NSQ生产者
112    pub fn new(config: ProducerConfig) -> Self {
113        Self {
114            config,
115            connections: RwLock::new(HashMap::new()),
116        }
117    }
118
119    /// 获取或创建到NSQ服务器的连接
120    async fn get_or_create_connection(&self, addr: &str) -> Result<Arc<Connection>> {
121        let mut connections = self.connections.write().await;
122
123        // 检查现有连接
124        if let Some(conn) = connections.get(addr) {
125            // 简单的 ping 检测
126            match conn.ping(None).await {
127                Ok(_) => {
128                    // 连接有效,直接复用
129                    return Ok(conn.clone());
130                }
131                Err(_) => {
132                    // 连接无效,移除并重新创建
133                    connections.remove(addr);
134                }
135            }
136        }
137
138        // 创建新连接
139        let conn = Connection::new(
140            addr.to_string(),
141            self.config.identify_config.clone(),
142            self.config.auth_secret.clone(),
143            self.config.connection_timeout,
144            self.config.connection_timeout,
145        )
146        .await?;
147
148        let conn = Arc::new(conn);
149        connections.insert(addr.to_string(), conn.clone());
150        Ok(conn)
151    }
152
153    /// 获取用于发布消息的连接
154    async fn get_publish_connection(&self, topic: &str) -> Result<Arc<Connection>> {
155        // 如果直接配置了nsqd地址,使用第一个
156        if !self.config.nsqd_addresses.is_empty() {
157            return self
158                .get_or_create_connection(&self.config.nsqd_addresses[0])
159                .await;
160        }
161
162        // 如果配置了nsqlookupd,使用它查找nsqd
163        if !self.config.nsqlookupd_addresses.is_empty() {
164            let addr = &self.config.nsqlookupd_addresses[0];
165            let nodes = lookup_nsqd_nodes(addr, topic).await?;
166
167            if nodes.is_empty() {
168                return Err(Error::Connection(format!(
169                    "nsqlookupd未找到主题 {} 的生产者",
170                    topic
171                )));
172            }
173
174            return self.get_or_create_connection(&nodes[0]).await;
175        }
176
177        Err(Error::Config("未配置nsqd或nsqlookupd地址".to_string()))
178    }
179}
180
181#[async_trait]
182impl Producer for NsqProducer {
183    async fn ping(&self, addr: Option<&str>, timeout: Option<Duration>) -> Result<()> {
184        let target_addr = match addr {
185            Some(a) => a.to_string(),
186            None => {
187                if !self.config.nsqd_addresses.is_empty() {
188                    self.config.nsqd_addresses[0].clone()
189                } else if !self.config.nsqlookupd_addresses.is_empty() {
190                    // 如果没有直接的 nsqd 地址,尝试从 nsqlookupd 获取
191                    // 使用一个特殊的主题名仅用于 ping 目的
192                    let nsqd_nodes =
193                        lookup_nsqd_nodes(&self.config.nsqlookupd_addresses[0], "_ping_topic")
194                            .await?;
195                    if nsqd_nodes.is_empty() {
196                        return Err(Error::Connection("没有可用的 NSQ 服务器地址".to_string()));
197                    }
198                    nsqd_nodes[0].clone()
199                } else {
200                    return Err(Error::Config("没有配置 NSQ 服务器地址".to_string()));
201                }
202            }
203        };
204
205        // 获取连接并发送 ping
206        let connection = self.get_or_create_connection(&target_addr).await?;
207        connection.ping(timeout).await
208    }
209
210    async fn publish<T: AsRef<[u8]> + Send + Sync>(&self, topic: &str, message: T) -> Result<()> {
211        let backoff = ExponentialBackoffBuilder::new()
212            .with_initial_interval(self.config.backoff_config.initial_interval)
213            .with_max_interval(self.config.backoff_config.max_interval)
214            .with_multiplier(self.config.backoff_config.multiplier)
215            .with_max_elapsed_time(self.config.backoff_config.max_elapsed_time)
216            .build();
217
218        let topic_owned = topic.to_string();
219        let message_bytes = message.as_ref().to_vec();
220
221        let result = backoff::future::retry(backoff, || async {
222            let connection = match self.get_publish_connection(&topic_owned).await {
223                Ok(conn) => conn,
224                Err(e) => return Err(backoff::Error::permanent(e)),
225            };
226
227            let cmd = Command::Publish(topic_owned.clone(), message_bytes.clone());
228            match connection.send_command(cmd).await {
229                Ok(_) => {
230                    // 读取响应
231                    match connection.read_frame().await {
232                        Ok(Frame::Response(_)) => Ok(()),
233                        Ok(Frame::Error(data)) => {
234                            let error_msg = String::from_utf8_lossy(&data);
235                            Err(backoff::Error::transient(Error::Protocol(
236                                ProtocolError::Other(error_msg.to_string()),
237                            )))
238                        }
239                        Ok(_) => Err(backoff::Error::transient(Error::Protocol(
240                            ProtocolError::Other("收到意外响应".to_string()),
241                        ))),
242                        Err(e) => Err(backoff::Error::transient(e)),
243                    }
244                }
245                Err(e) => Err(backoff::Error::transient(e)),
246            }
247        })
248        .await;
249
250        match result {
251            Ok(_) => Ok(()),
252            Err(e) => Err(e),
253        }
254    }
255
256    async fn publish_delayed<T: AsRef<[u8]> + Send + Sync>(
257        &self,
258        topic: &str,
259        message: T,
260        delay: Duration,
261    ) -> Result<()> {
262        // 初始化退避策略
263        let backoff = ExponentialBackoffBuilder::new()
264            .with_initial_interval(self.config.backoff_config.initial_interval)
265            .with_max_interval(self.config.backoff_config.max_interval)
266            .with_multiplier(self.config.backoff_config.multiplier)
267            .with_max_elapsed_time(self.config.backoff_config.max_elapsed_time)
268            .build();
269
270        let topic_owned = topic.to_string();
271        let message_bytes = message.as_ref().to_vec();
272
273        let result = backoff::future::retry(backoff, || async {
274            let connection = match self.get_publish_connection(&topic_owned).await {
275                Ok(conn) => conn,
276                Err(e) => return Err(backoff::Error::permanent(e)),
277            };
278
279            let cmd = Command::DelayedPublish(
280                topic_owned.clone(),
281                message_bytes.clone(),
282                delay.as_millis() as u32,
283            );
284            match connection.send_command(cmd).await {
285                Ok(_) => {
286                    // 读取响应
287                    match connection.read_frame().await {
288                        Ok(Frame::Response(_)) => Ok(()),
289                        Ok(Frame::Error(data)) => {
290                            let error_msg = String::from_utf8_lossy(&data);
291                            Err(backoff::Error::transient(Error::Protocol(
292                                ProtocolError::Other(error_msg.to_string()),
293                            )))
294                        }
295                        Ok(_) => Err(backoff::Error::transient(Error::Protocol(
296                            ProtocolError::Other("收到意外响应".to_string()),
297                        ))),
298                        Err(e) => Err(backoff::Error::transient(e)),
299                    }
300                }
301                Err(e) => Err(backoff::Error::transient(e)),
302            }
303        })
304        .await;
305
306        match result {
307            Ok(_) => Ok(()),
308            Err(e) => Err(e),
309        }
310    }
311
312    async fn publish_multi<T: AsRef<[u8]> + Send + Sync>(
313        &self,
314        topic: &str,
315        messages: Vec<T>,
316    ) -> Result<()> {
317        if messages.is_empty() {
318            debug!("忽略空消息列表");
319            return Ok(());
320        }
321
322        // 将消息转换为字节向量
323        let byte_messages: Vec<Vec<u8>> =
324            messages.iter().map(|msg| msg.as_ref().to_vec()).collect();
325
326        // 使用与批量发送相同的逻辑处理
327        let backoff = ExponentialBackoffBuilder::new()
328            .with_initial_interval(self.config.backoff_config.initial_interval)
329            .with_max_interval(self.config.backoff_config.max_interval)
330            .with_multiplier(self.config.backoff_config.multiplier)
331            .with_max_elapsed_time(self.config.backoff_config.max_elapsed_time)
332            .build();
333
334        let topic_owned = topic.to_string();
335
336        let result = backoff::future::retry(backoff, || async {
337            let connection = match self.get_publish_connection(&topic_owned).await {
338                Ok(conn) => conn,
339                Err(e) => return Err(backoff::Error::permanent(e)),
340            };
341
342            let cmd = Command::Mpublish(topic_owned.clone(), byte_messages.clone());
343            match connection.send_command(cmd).await {
344                Ok(_) => {
345                    // 读取响应
346                    match connection.read_frame().await {
347                        Ok(Frame::Response(_)) => Ok(()),
348                        Ok(Frame::Error(data)) => {
349                            let error_msg = String::from_utf8_lossy(&data);
350                            Err(backoff::Error::transient(Error::Protocol(
351                                ProtocolError::Other(error_msg.to_string()),
352                            )))
353                        }
354                        Ok(_) => Err(backoff::Error::transient(Error::Protocol(
355                            ProtocolError::Other("收到意外响应".to_string()),
356                        ))),
357                        Err(e) => Err(backoff::Error::transient(e)),
358                    }
359                }
360                Err(e) => Err(backoff::Error::transient(e)),
361            }
362        })
363        .await;
364
365        match result {
366            Ok(_) => Ok(()),
367            Err(e) => Err(e),
368        }
369    }
370}
371
372impl NsqProducer {
373    /// 获取生产者配置
374    pub fn config(&self) -> &ProducerConfig {
375        &self.config
376    }
377
378    /// 获取连接池大小
379    pub async fn get_connection_pool_size(&self) -> usize {
380        self.connections.read().await.len()
381    }
382}
383
384/// 创建一个新的NSQ生产者
385pub fn new_producer(config: ProducerConfig) -> NsqProducer {
386    NsqProducer::new(config)
387}