1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use async_trait::async_trait;
6use backoff::ExponentialBackoffBuilder;
7use log::debug;
8use tokio::sync::RwLock;
9
10use crate::commands::lookup_nsqd_nodes;
11use crate::connection::Connection;
12use crate::error::{Error, Result};
13use crate::protocol::{Command, Frame, IdentifyConfig, ProtocolError};
14
15#[derive(Debug, Clone)]
17pub struct ProducerConfig {
18 pub nsqd_addresses: Vec<String>,
20 pub nsqlookupd_addresses: Vec<String>,
22 pub connection_timeout: Duration,
24 pub auth_secret: Option<String>,
26 pub identify_config: Option<IdentifyConfig>,
28 pub backoff_config: BackoffConfig,
30}
31
32impl Default for ProducerConfig {
33 fn default() -> Self {
34 Self {
35 nsqd_addresses: vec![],
36 nsqlookupd_addresses: vec![],
37 connection_timeout: Duration::from_secs(5),
38 auth_secret: None,
39 identify_config: None,
40 backoff_config: BackoffConfig::default(),
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct BackoffConfig {
48 pub initial_interval: Duration,
50 pub max_interval: Duration,
52 pub multiplier: f64,
54 pub max_elapsed_time: Option<Duration>,
56}
57
58impl Default for BackoffConfig {
59 fn default() -> Self {
60 Self {
61 initial_interval: Duration::from_millis(100),
62 max_interval: Duration::from_secs(10),
63 multiplier: 2.0,
64 max_elapsed_time: Some(Duration::from_secs(60)),
65 }
66 }
67}
68
69#[async_trait]
71pub trait Producer: Send + Sync {
72 async fn publish<T: AsRef<[u8]> + Send + Sync>(&self, topic: &str, message: T) -> Result<()>;
74
75 async fn publish_delayed<T: AsRef<[u8]> + Send + Sync>(
77 &self,
78 topic: &str,
79 message: T,
80 delay: Duration,
81 ) -> Result<()>;
82
83 async fn publish_multi<T: AsRef<[u8]> + Send + Sync>(
85 &self,
86 topic: &str,
87 messages: Vec<T>,
88 ) -> Result<()>;
89
90 async fn ping(&self, addr: Option<&str>, timeout: Option<Duration>) -> Result<()>;
100}
101
102pub struct NsqProducer {
104 config: ProducerConfig,
106 connections: RwLock<HashMap<String, Arc<Connection>>>,
108}
109
110impl NsqProducer {
111 pub fn new(config: ProducerConfig) -> Self {
113 Self {
114 config,
115 connections: RwLock::new(HashMap::new()),
116 }
117 }
118
119 async fn get_or_create_connection(&self, addr: &str) -> Result<Arc<Connection>> {
121 let mut connections = self.connections.write().await;
122
123 if let Some(conn) = connections.get(addr) {
125 match conn.ping(None).await {
127 Ok(_) => {
128 return Ok(conn.clone());
130 }
131 Err(_) => {
132 connections.remove(addr);
134 }
135 }
136 }
137
138 let conn = Connection::new(
140 addr.to_string(),
141 self.config.identify_config.clone(),
142 self.config.auth_secret.clone(),
143 self.config.connection_timeout,
144 self.config.connection_timeout,
145 )
146 .await?;
147
148 let conn = Arc::new(conn);
149 connections.insert(addr.to_string(), conn.clone());
150 Ok(conn)
151 }
152
153 async fn get_publish_connection(&self, topic: &str) -> Result<Arc<Connection>> {
155 if !self.config.nsqd_addresses.is_empty() {
157 return self
158 .get_or_create_connection(&self.config.nsqd_addresses[0])
159 .await;
160 }
161
162 if !self.config.nsqlookupd_addresses.is_empty() {
164 let addr = &self.config.nsqlookupd_addresses[0];
165 let nodes = lookup_nsqd_nodes(addr, topic).await?;
166
167 if nodes.is_empty() {
168 return Err(Error::Connection(format!(
169 "nsqlookupd未找到主题 {} 的生产者",
170 topic
171 )));
172 }
173
174 return self.get_or_create_connection(&nodes[0]).await;
175 }
176
177 Err(Error::Config("未配置nsqd或nsqlookupd地址".to_string()))
178 }
179}
180
181#[async_trait]
182impl Producer for NsqProducer {
183 async fn ping(&self, addr: Option<&str>, timeout: Option<Duration>) -> Result<()> {
184 let target_addr = match addr {
185 Some(a) => a.to_string(),
186 None => {
187 if !self.config.nsqd_addresses.is_empty() {
188 self.config.nsqd_addresses[0].clone()
189 } else if !self.config.nsqlookupd_addresses.is_empty() {
190 let nsqd_nodes =
193 lookup_nsqd_nodes(&self.config.nsqlookupd_addresses[0], "_ping_topic")
194 .await?;
195 if nsqd_nodes.is_empty() {
196 return Err(Error::Connection("没有可用的 NSQ 服务器地址".to_string()));
197 }
198 nsqd_nodes[0].clone()
199 } else {
200 return Err(Error::Config("没有配置 NSQ 服务器地址".to_string()));
201 }
202 }
203 };
204
205 let connection = self.get_or_create_connection(&target_addr).await?;
207 connection.ping(timeout).await
208 }
209
210 async fn publish<T: AsRef<[u8]> + Send + Sync>(&self, topic: &str, message: T) -> Result<()> {
211 let backoff = ExponentialBackoffBuilder::new()
212 .with_initial_interval(self.config.backoff_config.initial_interval)
213 .with_max_interval(self.config.backoff_config.max_interval)
214 .with_multiplier(self.config.backoff_config.multiplier)
215 .with_max_elapsed_time(self.config.backoff_config.max_elapsed_time)
216 .build();
217
218 let topic_owned = topic.to_string();
219 let message_bytes = message.as_ref().to_vec();
220
221 let result = backoff::future::retry(backoff, || async {
222 let connection = match self.get_publish_connection(&topic_owned).await {
223 Ok(conn) => conn,
224 Err(e) => return Err(backoff::Error::permanent(e)),
225 };
226
227 let cmd = Command::Publish(topic_owned.clone(), message_bytes.clone());
228 match connection.send_command(cmd).await {
229 Ok(_) => {
230 match connection.read_frame().await {
232 Ok(Frame::Response(_)) => Ok(()),
233 Ok(Frame::Error(data)) => {
234 let error_msg = String::from_utf8_lossy(&data);
235 Err(backoff::Error::transient(Error::Protocol(
236 ProtocolError::Other(error_msg.to_string()),
237 )))
238 }
239 Ok(_) => Err(backoff::Error::transient(Error::Protocol(
240 ProtocolError::Other("收到意外响应".to_string()),
241 ))),
242 Err(e) => Err(backoff::Error::transient(e)),
243 }
244 }
245 Err(e) => Err(backoff::Error::transient(e)),
246 }
247 })
248 .await;
249
250 match result {
251 Ok(_) => Ok(()),
252 Err(e) => Err(e),
253 }
254 }
255
256 async fn publish_delayed<T: AsRef<[u8]> + Send + Sync>(
257 &self,
258 topic: &str,
259 message: T,
260 delay: Duration,
261 ) -> Result<()> {
262 let backoff = ExponentialBackoffBuilder::new()
264 .with_initial_interval(self.config.backoff_config.initial_interval)
265 .with_max_interval(self.config.backoff_config.max_interval)
266 .with_multiplier(self.config.backoff_config.multiplier)
267 .with_max_elapsed_time(self.config.backoff_config.max_elapsed_time)
268 .build();
269
270 let topic_owned = topic.to_string();
271 let message_bytes = message.as_ref().to_vec();
272
273 let result = backoff::future::retry(backoff, || async {
274 let connection = match self.get_publish_connection(&topic_owned).await {
275 Ok(conn) => conn,
276 Err(e) => return Err(backoff::Error::permanent(e)),
277 };
278
279 let cmd = Command::DelayedPublish(
280 topic_owned.clone(),
281 message_bytes.clone(),
282 delay.as_millis() as u32,
283 );
284 match connection.send_command(cmd).await {
285 Ok(_) => {
286 match connection.read_frame().await {
288 Ok(Frame::Response(_)) => Ok(()),
289 Ok(Frame::Error(data)) => {
290 let error_msg = String::from_utf8_lossy(&data);
291 Err(backoff::Error::transient(Error::Protocol(
292 ProtocolError::Other(error_msg.to_string()),
293 )))
294 }
295 Ok(_) => Err(backoff::Error::transient(Error::Protocol(
296 ProtocolError::Other("收到意外响应".to_string()),
297 ))),
298 Err(e) => Err(backoff::Error::transient(e)),
299 }
300 }
301 Err(e) => Err(backoff::Error::transient(e)),
302 }
303 })
304 .await;
305
306 match result {
307 Ok(_) => Ok(()),
308 Err(e) => Err(e),
309 }
310 }
311
312 async fn publish_multi<T: AsRef<[u8]> + Send + Sync>(
313 &self,
314 topic: &str,
315 messages: Vec<T>,
316 ) -> Result<()> {
317 if messages.is_empty() {
318 debug!("忽略空消息列表");
319 return Ok(());
320 }
321
322 let byte_messages: Vec<Vec<u8>> =
324 messages.iter().map(|msg| msg.as_ref().to_vec()).collect();
325
326 let backoff = ExponentialBackoffBuilder::new()
328 .with_initial_interval(self.config.backoff_config.initial_interval)
329 .with_max_interval(self.config.backoff_config.max_interval)
330 .with_multiplier(self.config.backoff_config.multiplier)
331 .with_max_elapsed_time(self.config.backoff_config.max_elapsed_time)
332 .build();
333
334 let topic_owned = topic.to_string();
335
336 let result = backoff::future::retry(backoff, || async {
337 let connection = match self.get_publish_connection(&topic_owned).await {
338 Ok(conn) => conn,
339 Err(e) => return Err(backoff::Error::permanent(e)),
340 };
341
342 let cmd = Command::Mpublish(topic_owned.clone(), byte_messages.clone());
343 match connection.send_command(cmd).await {
344 Ok(_) => {
345 match connection.read_frame().await {
347 Ok(Frame::Response(_)) => Ok(()),
348 Ok(Frame::Error(data)) => {
349 let error_msg = String::from_utf8_lossy(&data);
350 Err(backoff::Error::transient(Error::Protocol(
351 ProtocolError::Other(error_msg.to_string()),
352 )))
353 }
354 Ok(_) => Err(backoff::Error::transient(Error::Protocol(
355 ProtocolError::Other("收到意外响应".to_string()),
356 ))),
357 Err(e) => Err(backoff::Error::transient(e)),
358 }
359 }
360 Err(e) => Err(backoff::Error::transient(e)),
361 }
362 })
363 .await;
364
365 match result {
366 Ok(_) => Ok(()),
367 Err(e) => Err(e),
368 }
369 }
370}
371
372impl NsqProducer {
373 pub fn config(&self) -> &ProducerConfig {
375 &self.config
376 }
377
378 pub async fn get_connection_pool_size(&self) -> usize {
380 self.connections.read().await.len()
381 }
382}
383
384pub fn new_producer(config: ProducerConfig) -> NsqProducer {
386 NsqProducer::new(config)
387}