nsq_async_rs/
producer.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use async_trait::async_trait;
6use backoff::ExponentialBackoffBuilder;
7use log::debug;
8use tokio::sync::RwLock;
9
10use crate::commands::{create_nsqd_connection, lookup_nsqd_nodes};
11use crate::connection::Connection;
12use crate::connection_pool::ConnectionPool;
13use crate::error::{Error, Result};
14use crate::protocol::{Command, Frame, IdentifyConfig, ProtocolError};
15
16/// 生产者配置
17#[derive(Debug, Clone)]
18pub struct ProducerConfig {
19    /// NSQ服务器地址
20    pub nsqd_addresses: Vec<String>,
21    /// NSQ查询服务地址
22    pub nsqlookupd_addresses: Vec<String>,
23    /// 连接超时时间
24    pub connection_timeout: Duration,
25    /// 认证密钥
26    pub auth_secret: Option<String>,
27    /// 身份配置
28    pub identify_config: Option<IdentifyConfig>,
29    /// 重连策略
30    pub backoff_config: BackoffConfig,
31}
32
33impl Default for ProducerConfig {
34    fn default() -> Self {
35        Self {
36            nsqd_addresses: vec![],
37            nsqlookupd_addresses: vec![],
38            connection_timeout: Duration::from_secs(5),
39            auth_secret: None,
40            identify_config: None,
41            backoff_config: BackoffConfig::default(),
42        }
43    }
44}
45
46/// 重连策略配置
47#[derive(Debug, Clone)]
48pub struct BackoffConfig {
49    /// 初始间隔
50    pub initial_interval: Duration,
51    /// 最大间隔
52    pub max_interval: Duration,
53    /// 倍数
54    pub multiplier: f64,
55    /// 最大重试时间
56    pub max_elapsed_time: Option<Duration>,
57}
58
59impl Default for BackoffConfig {
60    fn default() -> Self {
61        Self {
62            initial_interval: Duration::from_millis(100),
63            max_interval: Duration::from_secs(10),
64            multiplier: 2.0,
65            max_elapsed_time: Some(Duration::from_secs(60)),
66        }
67    }
68}
69
70/// NSQ生产者特性
71#[async_trait]
72pub trait Producer: Send + Sync {
73    /// 向NSQ发布消息
74    async fn publish<T: AsRef<[u8]> + Send + Sync>(&self, topic: &str, message: T) -> Result<()>;
75
76    /// 向NSQ发布延迟消息
77    async fn publish_delayed<T: AsRef<[u8]> + Send + Sync>(
78        &self,
79        topic: &str,
80        message: T,
81        delay: Duration,
82    ) -> Result<()>;
83
84    /// 批量发布消息
85    async fn publish_multi<T: AsRef<[u8]> + Send + Sync>(
86        &self,
87        topic: &str,
88        messages: Vec<T>,
89    ) -> Result<()>;
90
91    /// 设置外部连接池 - 返回原始类型,不能直接使用
92    fn with_connection_pool(self, _pool: Arc<ConnectionPool>) -> Self
93    where
94        Self: Sized,
95    {
96        self
97    }
98}
99
100/// NSQ生产者实现
101pub struct NsqProducer {
102    /// 生产者配置
103    config: ProducerConfig,
104    /// 内部连接池
105    connections: RwLock<HashMap<String, Arc<Connection>>>,
106    /// 外部连接池(可选)
107    external_pool: Option<Arc<ConnectionPool>>,
108}
109
110impl NsqProducer {
111    /// 创建新的NSQ生产者
112    pub fn new(config: ProducerConfig) -> Self {
113        Self {
114            config,
115            connections: RwLock::new(HashMap::new()),
116            external_pool: None,
117        }
118    }
119
120    /// 设置外部连接池
121    pub fn with_connection_pool(mut self, pool: Arc<ConnectionPool>) -> Self {
122        self.external_pool = Some(pool);
123        self
124    }
125
126    /// 获取或创建到NSQ服务器的连接
127    async fn get_or_create_connection(&self, addr: &str) -> Result<Arc<Connection>> {
128        // 优先使用外部连接池(如果提供了)
129        if let Some(pool) = &self.external_pool {
130            debug!("使用外部连接池获取连接: {}", addr);
131            return pool
132                .get_connection(
133                    addr,
134                    self.config.identify_config.clone(),
135                    self.config.auth_secret.clone(),
136                )
137                .await;
138        }
139
140        // 尝试从内部连接池获取连接
141        let connections = self.connections.read().await;
142        if let Some(connection) = connections.get(addr) {
143            return Ok(connection.clone());
144        }
145        drop(connections);
146
147        // 创建新连接
148        debug!("为地址 {} 创建新连接", addr);
149        let connection = create_nsqd_connection(
150            addr,
151            self.config.identify_config.clone(),
152            self.config.auth_secret.clone(),
153        )
154        .await?;
155
156        // 添加连接到内部连接池
157        let mut connections = self.connections.write().await;
158        connections.insert(addr.to_string(), connection.clone());
159
160        Ok(connection)
161    }
162
163    /// 获取用于发布消息的连接
164    async fn get_publish_connection(&self, topic: &str) -> Result<Arc<Connection>> {
165        // 如果直接配置了nsqd地址,使用第一个
166        if !self.config.nsqd_addresses.is_empty() {
167            return self
168                .get_or_create_connection(&self.config.nsqd_addresses[0])
169                .await;
170        }
171
172        // 如果配置了nsqlookupd,使用它查找nsqd
173        if !self.config.nsqlookupd_addresses.is_empty() {
174            let addr = &self.config.nsqlookupd_addresses[0];
175            let nodes = lookup_nsqd_nodes(addr, topic).await?;
176
177            if nodes.is_empty() {
178                return Err(Error::Connection(format!(
179                    "nsqlookupd未找到主题 {} 的生产者",
180                    topic
181                )));
182            }
183
184            return self.get_or_create_connection(&nodes[0]).await;
185        }
186
187        Err(Error::Config("未配置nsqd或nsqlookupd地址".to_string()))
188    }
189}
190
191#[async_trait]
192impl Producer for NsqProducer {
193    async fn publish<T: AsRef<[u8]> + Send + Sync>(&self, topic: &str, message: T) -> Result<()> {
194        let backoff = ExponentialBackoffBuilder::new()
195            .with_initial_interval(self.config.backoff_config.initial_interval)
196            .with_max_interval(self.config.backoff_config.max_interval)
197            .with_multiplier(self.config.backoff_config.multiplier)
198            .with_max_elapsed_time(self.config.backoff_config.max_elapsed_time)
199            .build();
200
201        let topic_owned = topic.to_string();
202        let message_bytes = message.as_ref().to_vec();
203
204        let result = backoff::future::retry(backoff, || async {
205            let connection = match self.get_publish_connection(&topic_owned).await {
206                Ok(conn) => conn,
207                Err(e) => return Err(backoff::Error::permanent(e)),
208            };
209
210            let cmd = Command::Publish(topic_owned.clone(), message_bytes.clone());
211            match connection.send_command(cmd).await {
212                Ok(_) => {
213                    // 读取响应
214                    match connection.read_frame().await {
215                        Ok(Frame::Response(_)) => Ok(()),
216                        Ok(Frame::Error(data)) => {
217                            let error_msg = String::from_utf8_lossy(&data);
218                            Err(backoff::Error::transient(Error::Protocol(
219                                ProtocolError::Other(error_msg.to_string()),
220                            )))
221                        }
222                        Ok(_) => Err(backoff::Error::transient(Error::Protocol(
223                            ProtocolError::Other("收到意外响应".to_string()),
224                        ))),
225                        Err(e) => Err(backoff::Error::transient(e)),
226                    }
227                }
228                Err(e) => Err(backoff::Error::transient(e)),
229            }
230        })
231        .await;
232
233        match result {
234            Ok(_) => Ok(()),
235            Err(e) => Err(e),
236        }
237    }
238
239    async fn publish_delayed<T: AsRef<[u8]> + Send + Sync>(
240        &self,
241        topic: &str,
242        message: T,
243        delay: Duration,
244    ) -> Result<()> {
245        // 初始化退避策略
246        let backoff = ExponentialBackoffBuilder::new()
247            .with_initial_interval(self.config.backoff_config.initial_interval)
248            .with_max_interval(self.config.backoff_config.max_interval)
249            .with_multiplier(self.config.backoff_config.multiplier)
250            .with_max_elapsed_time(self.config.backoff_config.max_elapsed_time)
251            .build();
252
253        let topic_owned = topic.to_string();
254        let message_bytes = message.as_ref().to_vec();
255
256        let result = backoff::future::retry(backoff, || async {
257            let connection = match self.get_publish_connection(&topic_owned).await {
258                Ok(conn) => conn,
259                Err(e) => return Err(backoff::Error::permanent(e)),
260            };
261
262            let cmd = Command::DelayedPublish(
263                topic_owned.clone(),
264                message_bytes.clone(),
265                delay.as_millis() as u32,
266            );
267            match connection.send_command(cmd).await {
268                Ok(_) => {
269                    // 读取响应
270                    match connection.read_frame().await {
271                        Ok(Frame::Response(_)) => Ok(()),
272                        Ok(Frame::Error(data)) => {
273                            let error_msg = String::from_utf8_lossy(&data);
274                            Err(backoff::Error::transient(Error::Protocol(
275                                ProtocolError::Other(error_msg.to_string()),
276                            )))
277                        }
278                        Ok(_) => Err(backoff::Error::transient(Error::Protocol(
279                            ProtocolError::Other("收到意外响应".to_string()),
280                        ))),
281                        Err(e) => Err(backoff::Error::transient(e)),
282                    }
283                }
284                Err(e) => Err(backoff::Error::transient(e)),
285            }
286        })
287        .await;
288
289        match result {
290            Ok(_) => Ok(()),
291            Err(e) => Err(e),
292        }
293    }
294
295    async fn publish_multi<T: AsRef<[u8]> + Send + Sync>(
296        &self,
297        topic: &str,
298        messages: Vec<T>,
299    ) -> Result<()> {
300        if messages.is_empty() {
301            debug!("忽略空消息列表");
302            return Ok(());
303        }
304
305        // 将消息转换为字节向量
306        let byte_messages: Vec<Vec<u8>> =
307            messages.iter().map(|msg| msg.as_ref().to_vec()).collect();
308
309        // 使用与批量发送相同的逻辑处理
310        let backoff = ExponentialBackoffBuilder::new()
311            .with_initial_interval(self.config.backoff_config.initial_interval)
312            .with_max_interval(self.config.backoff_config.max_interval)
313            .with_multiplier(self.config.backoff_config.multiplier)
314            .with_max_elapsed_time(self.config.backoff_config.max_elapsed_time)
315            .build();
316
317        let topic_owned = topic.to_string();
318
319        let result = backoff::future::retry(backoff, || async {
320            let connection = match self.get_publish_connection(&topic_owned).await {
321                Ok(conn) => conn,
322                Err(e) => return Err(backoff::Error::permanent(e)),
323            };
324
325            let cmd = Command::Mpublish(topic_owned.clone(), byte_messages.clone());
326            match connection.send_command(cmd).await {
327                Ok(_) => {
328                    // 读取响应
329                    match connection.read_frame().await {
330                        Ok(Frame::Response(_)) => Ok(()),
331                        Ok(Frame::Error(data)) => {
332                            let error_msg = String::from_utf8_lossy(&data);
333                            Err(backoff::Error::transient(Error::Protocol(
334                                ProtocolError::Other(error_msg.to_string()),
335                            )))
336                        }
337                        Ok(_) => Err(backoff::Error::transient(Error::Protocol(
338                            ProtocolError::Other("收到意外响应".to_string()),
339                        ))),
340                        Err(e) => Err(backoff::Error::transient(e)),
341                    }
342                }
343                Err(e) => Err(backoff::Error::transient(e)),
344            }
345        })
346        .await;
347
348        match result {
349            Ok(_) => Ok(()),
350            Err(e) => Err(e),
351        }
352    }
353}
354
355impl NsqProducer {
356    /// 获取生产者配置
357    pub fn config(&self) -> &ProducerConfig {
358        &self.config
359    }
360}
361
362/// 创建一个新的NSQ生产者
363pub fn new_producer(config: ProducerConfig) -> NsqProducer {
364    NsqProducer::new(config)
365}