1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use async_trait::async_trait;
6use backoff::ExponentialBackoffBuilder;
7use log::debug;
8use tokio::sync::RwLock;
9
10use crate::commands::lookup_nsqd_nodes;
11use crate::connection::Connection;
12use crate::error::{Error, Result};
13use crate::protocol::{Command, Frame, IdentifyConfig, ProtocolError};
14
15#[derive(Debug, Clone)]
17pub struct ProducerConfig {
18    pub nsqd_addresses: Vec<String>,
20    pub nsqlookupd_addresses: Vec<String>,
22    pub connection_timeout: Duration,
24    pub auth_secret: Option<String>,
26    pub identify_config: Option<IdentifyConfig>,
28    pub backoff_config: BackoffConfig,
30}
31
32impl Default for ProducerConfig {
33    fn default() -> Self {
34        Self {
35            nsqd_addresses: vec![],
36            nsqlookupd_addresses: vec![],
37            connection_timeout: Duration::from_secs(5),
38            auth_secret: None,
39            identify_config: None,
40            backoff_config: BackoffConfig::default(),
41        }
42    }
43}
44
45#[derive(Debug, Clone)]
47pub struct BackoffConfig {
48    pub initial_interval: Duration,
50    pub max_interval: Duration,
52    pub multiplier: f64,
54    pub max_elapsed_time: Option<Duration>,
56}
57
58impl Default for BackoffConfig {
59    fn default() -> Self {
60        Self {
61            initial_interval: Duration::from_millis(100),
62            max_interval: Duration::from_secs(10),
63            multiplier: 2.0,
64            max_elapsed_time: Some(Duration::from_secs(60)),
65        }
66    }
67}
68
69#[async_trait]
71pub trait Producer: Send + Sync {
72    async fn publish<T: AsRef<[u8]> + Send + Sync>(&self, topic: &str, message: T) -> Result<()>;
74
75    async fn publish_delayed<T: AsRef<[u8]> + Send + Sync>(
77        &self,
78        topic: &str,
79        message: T,
80        delay: Duration,
81    ) -> Result<()>;
82
83    async fn publish_multi<T: AsRef<[u8]> + Send + Sync>(
85        &self,
86        topic: &str,
87        messages: Vec<T>,
88    ) -> Result<()>;
89
90    async fn ping(&self, addr: Option<&str>, timeout: Option<Duration>) -> Result<()>;
100}
101
102pub struct NsqProducer {
104    config: ProducerConfig,
106    connections: RwLock<HashMap<String, Arc<Connection>>>,
108}
109
110impl NsqProducer {
111    pub fn new(config: ProducerConfig) -> Self {
113        Self {
114            config,
115            connections: RwLock::new(HashMap::new()),
116        }
117    }
118
119    async fn get_or_create_connection(&self, addr: &str) -> Result<Arc<Connection>> {
121        let mut connections = self.connections.write().await;
122
123        if let Some(conn) = connections.get(addr) {
125            match conn.ping(None).await {
127                Ok(_) => {
128                    return Ok(conn.clone());
130                }
131                Err(_) => {
132                    connections.remove(addr);
134                }
135            }
136        }
137
138        let conn = Connection::new(
140            addr.to_string(),
141            self.config.identify_config.clone(),
142            self.config.auth_secret.clone(),
143            self.config.connection_timeout,
144            self.config.connection_timeout,
145        )
146        .await?;
147
148        let conn = Arc::new(conn);
149        connections.insert(addr.to_string(), conn.clone());
150        Ok(conn)
151    }
152
153    async fn get_publish_connection(&self, topic: &str) -> Result<Arc<Connection>> {
155        if !self.config.nsqd_addresses.is_empty() {
157            return self
158                .get_or_create_connection(&self.config.nsqd_addresses[0])
159                .await;
160        }
161
162        if !self.config.nsqlookupd_addresses.is_empty() {
164            let addr = &self.config.nsqlookupd_addresses[0];
165            let nodes = lookup_nsqd_nodes(addr, topic).await?;
166
167            if nodes.is_empty() {
168                return Err(Error::Connection(format!(
169                    "nsqlookupd未找到主题 {} 的生产者",
170                    topic
171                )));
172            }
173
174            return self.get_or_create_connection(&nodes[0]).await;
175        }
176
177        Err(Error::Config("未配置nsqd或nsqlookupd地址".to_string()))
178    }
179}
180
181#[async_trait]
182impl Producer for NsqProducer {
183    async fn ping(&self, addr: Option<&str>, timeout: Option<Duration>) -> Result<()> {
184        let target_addr = match addr {
185            Some(a) => a.to_string(),
186            None => {
187                if !self.config.nsqd_addresses.is_empty() {
188                    self.config.nsqd_addresses[0].clone()
189                } else if !self.config.nsqlookupd_addresses.is_empty() {
190                    let nsqd_nodes =
193                        lookup_nsqd_nodes(&self.config.nsqlookupd_addresses[0], "_ping_topic")
194                            .await?;
195                    if nsqd_nodes.is_empty() {
196                        return Err(Error::Connection("没有可用的 NSQ 服务器地址".to_string()));
197                    }
198                    nsqd_nodes[0].clone()
199                } else {
200                    return Err(Error::Config("没有配置 NSQ 服务器地址".to_string()));
201                }
202            }
203        };
204
205        let connection = self.get_or_create_connection(&target_addr).await?;
207        connection.ping(timeout).await
208    }
209
210    async fn publish<T: AsRef<[u8]> + Send + Sync>(&self, topic: &str, message: T) -> Result<()> {
211        let backoff = ExponentialBackoffBuilder::new()
212            .with_initial_interval(self.config.backoff_config.initial_interval)
213            .with_max_interval(self.config.backoff_config.max_interval)
214            .with_multiplier(self.config.backoff_config.multiplier)
215            .with_max_elapsed_time(self.config.backoff_config.max_elapsed_time)
216            .build();
217
218        let topic_owned = topic.to_string();
219        let message_bytes = message.as_ref().to_vec();
220
221        let result = backoff::future::retry(backoff, || async {
222            let connection = match self.get_publish_connection(&topic_owned).await {
223                Ok(conn) => conn,
224                Err(e) => return Err(backoff::Error::permanent(e)),
225            };
226
227            let cmd = Command::Publish(topic_owned.clone(), message_bytes.clone());
228            match connection.send_command(cmd).await {
229                Ok(_) => {
230                    match connection.read_frame().await {
232                        Ok(Frame::Response(_)) => Ok(()),
233                        Ok(Frame::Error(data)) => {
234                            let error_msg = String::from_utf8_lossy(&data);
235                            Err(backoff::Error::transient(Error::Protocol(
236                                ProtocolError::Other(error_msg.to_string()),
237                            )))
238                        }
239                        Ok(_) => Err(backoff::Error::transient(Error::Protocol(
240                            ProtocolError::Other("收到意外响应".to_string()),
241                        ))),
242                        Err(e) => Err(backoff::Error::transient(e)),
243                    }
244                }
245                Err(e) => Err(backoff::Error::transient(e)),
246            }
247        })
248        .await;
249
250        match result {
251            Ok(_) => Ok(()),
252            Err(e) => Err(e),
253        }
254    }
255
256    async fn publish_delayed<T: AsRef<[u8]> + Send + Sync>(
257        &self,
258        topic: &str,
259        message: T,
260        delay: Duration,
261    ) -> Result<()> {
262        let backoff = ExponentialBackoffBuilder::new()
264            .with_initial_interval(self.config.backoff_config.initial_interval)
265            .with_max_interval(self.config.backoff_config.max_interval)
266            .with_multiplier(self.config.backoff_config.multiplier)
267            .with_max_elapsed_time(self.config.backoff_config.max_elapsed_time)
268            .build();
269
270        let topic_owned = topic.to_string();
271        let message_bytes = message.as_ref().to_vec();
272
273        let result = backoff::future::retry(backoff, || async {
274            let connection = match self.get_publish_connection(&topic_owned).await {
275                Ok(conn) => conn,
276                Err(e) => return Err(backoff::Error::permanent(e)),
277            };
278
279            let cmd = Command::DelayedPublish(
280                topic_owned.clone(),
281                message_bytes.clone(),
282                delay.as_millis() as u32,
283            );
284            match connection.send_command(cmd).await {
285                Ok(_) => {
286                    match connection.read_frame().await {
288                        Ok(Frame::Response(_)) => Ok(()),
289                        Ok(Frame::Error(data)) => {
290                            let error_msg = String::from_utf8_lossy(&data);
291                            Err(backoff::Error::transient(Error::Protocol(
292                                ProtocolError::Other(error_msg.to_string()),
293                            )))
294                        }
295                        Ok(_) => Err(backoff::Error::transient(Error::Protocol(
296                            ProtocolError::Other("收到意外响应".to_string()),
297                        ))),
298                        Err(e) => Err(backoff::Error::transient(e)),
299                    }
300                }
301                Err(e) => Err(backoff::Error::transient(e)),
302            }
303        })
304        .await;
305
306        match result {
307            Ok(_) => Ok(()),
308            Err(e) => Err(e),
309        }
310    }
311
312    async fn publish_multi<T: AsRef<[u8]> + Send + Sync>(
313        &self,
314        topic: &str,
315        messages: Vec<T>,
316    ) -> Result<()> {
317        if messages.is_empty() {
318            debug!("忽略空消息列表");
319            return Ok(());
320        }
321
322        let byte_messages: Vec<Vec<u8>> =
324            messages.iter().map(|msg| msg.as_ref().to_vec()).collect();
325
326        let backoff = ExponentialBackoffBuilder::new()
328            .with_initial_interval(self.config.backoff_config.initial_interval)
329            .with_max_interval(self.config.backoff_config.max_interval)
330            .with_multiplier(self.config.backoff_config.multiplier)
331            .with_max_elapsed_time(self.config.backoff_config.max_elapsed_time)
332            .build();
333
334        let topic_owned = topic.to_string();
335
336        let result = backoff::future::retry(backoff, || async {
337            let connection = match self.get_publish_connection(&topic_owned).await {
338                Ok(conn) => conn,
339                Err(e) => return Err(backoff::Error::permanent(e)),
340            };
341
342            let cmd = Command::Mpublish(topic_owned.clone(), byte_messages.clone());
343            match connection.send_command(cmd).await {
344                Ok(_) => {
345                    match connection.read_frame().await {
347                        Ok(Frame::Response(_)) => Ok(()),
348                        Ok(Frame::Error(data)) => {
349                            let error_msg = String::from_utf8_lossy(&data);
350                            Err(backoff::Error::transient(Error::Protocol(
351                                ProtocolError::Other(error_msg.to_string()),
352                            )))
353                        }
354                        Ok(_) => Err(backoff::Error::transient(Error::Protocol(
355                            ProtocolError::Other("收到意外响应".to_string()),
356                        ))),
357                        Err(e) => Err(backoff::Error::transient(e)),
358                    }
359                }
360                Err(e) => Err(backoff::Error::transient(e)),
361            }
362        })
363        .await;
364
365        match result {
366            Ok(_) => Ok(()),
367            Err(e) => Err(e),
368        }
369    }
370}
371
372impl NsqProducer {
373    pub fn config(&self) -> &ProducerConfig {
375        &self.config
376    }
377
378    pub async fn get_connection_pool_size(&self) -> usize {
380        self.connections.read().await.len()
381    }
382}
383
384pub fn new_producer(config: ProducerConfig) -> NsqProducer {
386    NsqProducer::new(config)
387}