1use byteorder::{BigEndian, ByteOrder, ReadBytesExt, WriteBytesExt};
2use serde::{Deserialize, Serialize};
3use std::io::{Cursor, Read};
4use std::sync::Arc;
5use std::sync::atomic::{AtomicBool, Ordering};
6use thiserror::Error;
7
8use crate::error::{Error, Result};
9
10pub const MAGIC_V2: &[u8] = b" V2";
12pub const HEARTBEAT: &[u8] = b"_heartbeat_";
13pub const OK: &[u8] = b"OK";
14pub const FRAME_TYPE_RESPONSE: i32 = 0;
15pub const FRAME_TYPE_ERROR: i32 = 1;
16pub const FRAME_TYPE_MESSAGE: i32 = 2;
17
18#[derive(Debug, Clone, PartialEq)]
20pub enum Command {
21 Identify(IdentifyConfig),
23 Subscribe(String, String),
25 Publish(String, Vec<u8>),
27 DelayedPublish(String, Vec<u8>, u32),
29 Mpublish(String, Vec<Vec<u8>>),
31 Ready(u32),
33 Finish(String),
35 Requeue(String, u32),
37 Touch(String),
39 Nop,
41 Cls,
43 Auth(Option<String>),
45}
46
47impl Command {
48 pub fn to_bytes(&self) -> Result<Vec<u8>> {
50 let mut buf = Vec::new();
51 match self {
52 Command::Identify(config) => {
53 buf.extend_from_slice(b"IDENTIFY\n");
54 let json = serde_json::to_string(config)?;
55 buf.write_u32::<BigEndian>(json.len() as u32)?;
56 buf.extend_from_slice(json.as_bytes());
57 }
58 Command::Subscribe(topic, channel) => {
59 let cmd = format!("SUB {} {}\n", topic, channel);
60 buf.extend_from_slice(cmd.as_bytes());
61 }
62 Command::Publish(topic, body) => {
63 let cmd = format!("PUB {}\n", topic);
64 buf.extend_from_slice(cmd.as_bytes());
65 buf.write_u32::<BigEndian>(body.len() as u32)?;
66 buf.extend_from_slice(body.as_slice());
67 }
68 Command::DelayedPublish(topic, body, delay) => {
69 let cmd = format!("DPUB {} {}\n", topic, delay);
70 buf.extend_from_slice(cmd.as_bytes());
71 buf.write_u32::<BigEndian>(body.len() as u32)?;
72 buf.extend_from_slice(body.as_slice());
73 }
74 Command::Mpublish(topic, bodies) => {
75 let cmd = format!("MPUB {}\n", topic);
76 buf.extend_from_slice(cmd.as_bytes());
77
78 let mut total_size = 4;
80 for body in bodies {
81 total_size += 4 + body.len();
82 }
83
84 buf.write_u32::<BigEndian>(total_size as u32)?;
85 buf.write_u32::<BigEndian>(bodies.len() as u32)?;
86
87 for body in bodies {
88 buf.write_u32::<BigEndian>(body.len() as u32)?;
89 buf.extend_from_slice(body);
90 }
91 }
92 Command::Ready(count) => {
93 let cmd = format!("RDY {}\n", count);
94 buf.extend_from_slice(cmd.as_bytes());
95 }
96 Command::Finish(id) => {
97 let cmd = format!("FIN {}\n", id);
98 buf.extend_from_slice(cmd.as_bytes());
99 }
100 Command::Requeue(id, delay) => {
101 let cmd = format!("REQ {} {}\n", id, delay);
102 buf.extend_from_slice(cmd.as_bytes());
103 }
104 Command::Touch(id) => {
105 let cmd = format!("TOUCH {}\n", id);
106 buf.extend_from_slice(cmd.as_bytes());
107 }
108 Command::Nop => {
109 buf.extend_from_slice(b"NOP\n");
110 }
111 Command::Cls => {
112 buf.extend_from_slice(b"CLS\n");
113 }
114 Command::Auth(secret) => {
115 buf.extend_from_slice(b"AUTH\n");
116 if let Some(s) = secret {
117 buf.write_u32::<BigEndian>(s.len() as u32)?;
118 buf.extend_from_slice(s.as_bytes());
119 } else {
120 buf.write_u32::<BigEndian>(0)?;
121 }
122 }
123 }
124
125 Ok(buf)
126 }
127}
128
129#[derive(Debug, Clone)]
131pub struct Message {
132 pub id: Vec<u8>,
134 pub timestamp: u64,
136 pub attempts: u16,
138 pub body: Vec<u8>,
140 #[allow(dead_code)]
142 connection: Option<Arc<MessageResponder>>,
143 auto_response_disabled: bool,
145 responded: Arc<AtomicBool>,
147}
148
149#[derive(Debug)]
151pub struct MessageResponder {
152 connection: Arc<crate::connection::Connection>,
153 msg_id: String,
154}
155
156impl MessageResponder {
157 pub fn new(connection: Arc<crate::connection::Connection>, msg_id: String) -> Self {
159 Self { connection, msg_id }
160 }
161
162 pub async fn finish(&self) -> Result<()> {
164 let cmd = Command::Finish(self.msg_id.clone());
165 self.connection.send_command(cmd).await
166 }
167
168 pub async fn requeue(&self, delay: u32) -> Result<()> {
170 let cmd = Command::Requeue(self.msg_id.clone(), delay);
171 self.connection.send_command(cmd).await
172 }
173
174 pub async fn touch(&self) -> Result<()> {
176 let cmd = Command::Touch(self.msg_id.clone());
177 self.connection.send_command(cmd).await
178 }
179}
180
181impl Message {
182 pub fn new(id: Vec<u8>, body: Vec<u8>, timestamp: u64, attempts: u16) -> Self {
184 Self {
185 id,
186 timestamp,
187 attempts,
188 body,
189 connection: None,
190 auto_response_disabled: false,
191 responded: Arc::new(AtomicBool::new(false)),
192 }
193 }
194
195 pub fn with_responder(mut self, connection: Arc<crate::connection::Connection>) -> Self {
197 let msg_id = String::from_utf8_lossy(&self.id).to_string();
198 self.connection = Some(Arc::new(MessageResponder::new(connection, msg_id)));
199 self
200 }
201
202 pub fn disable_auto_response(&mut self) {
206 self.auto_response_disabled = true;
207 }
208
209 pub fn is_auto_response_disabled(&self) -> bool {
211 self.auto_response_disabled
212 }
213
214 pub fn has_responded(&self) -> bool {
216 self.responded.load(Ordering::SeqCst)
217 }
218
219 pub async fn finish(&self) -> Result<()> {
221 if self
223 .responded
224 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst).is_err()
225 {
226 return Ok(()); }
228
229 if let Some(responder) = &self.connection {
230 responder.finish().await
231 } else {
232 Err(Error::Other("消息没有关联的连接".to_string()))
233 }
234 }
235
236 pub async fn requeue(&self, delay: u32) -> Result<()> {
238 if self
240 .responded
241 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst).is_err()
242 {
243 return Ok(()); }
245
246 if let Some(responder) = &self.connection {
247 responder.requeue(delay).await
248 } else {
249 Err(Error::Other("消息没有关联的连接".to_string()))
250 }
251 }
252
253 pub async fn touch(&self) -> Result<()> {
255 if self.has_responded() {
256 return Ok(()); }
258
259 if let Some(responder) = &self.connection {
260 responder.touch().await
261 } else {
262 Err(Error::Other("消息没有关联的连接".to_string()))
263 }
264 }
265
266 pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
268 if bytes.len() < 26 {
269 return Err(Error::Protocol(ProtocolError::Other(
270 "消息大小不足".to_string(),
271 )));
272 }
273
274 let mut cursor = Cursor::new(bytes);
275
276 cursor.set_position(4);
278
279 let frame_type = cursor.read_u32::<BigEndian>()?;
281 if frame_type != 2 {
282 return Err(Error::Protocol(ProtocolError::Other(format!(
283 "无效的帧类型: {}",
284 frame_type
285 ))));
286 }
287
288 let timestamp = cursor.read_u64::<BigEndian>()?;
290
291 let attempts = cursor.read_u16::<BigEndian>()?;
293
294 let mut id_bytes = [0u8; 16];
296 cursor.read_exact(&mut id_bytes)?;
297 let id = id_bytes.to_vec();
298
299 let mut body = Vec::new();
301 cursor.read_to_end(&mut body)?;
302
303 Ok(Self {
304 id,
305 timestamp,
306 attempts,
307 body,
308 connection: None,
309 auto_response_disabled: false,
310 responded: Arc::new(AtomicBool::new(false)),
311 })
312 }
313
314 pub fn id_string(&self) -> String {
316 String::from_utf8_lossy(&self.id).to_string()
317 }
318}
319
320#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
322pub struct IdentifyConfig {
323 #[serde(skip_serializing_if = "Option::is_none")]
325 pub client_id: Option<String>,
326
327 #[serde(skip_serializing_if = "Option::is_none")]
329 pub hostname: Option<String>,
330
331 #[serde(skip_serializing_if = "Option::is_none")]
333 pub feature_negotiation: Option<bool>,
334
335 #[serde(skip_serializing_if = "Option::is_none")]
337 pub heartbeat_interval: Option<i32>,
338
339 #[serde(skip_serializing_if = "Option::is_none")]
341 pub output_buffer_size: Option<i32>,
342
343 #[serde(skip_serializing_if = "Option::is_none")]
345 pub output_buffer_timeout: Option<i32>,
346
347 #[serde(skip_serializing_if = "Option::is_none")]
349 pub tls_v1: Option<bool>,
350
351 #[serde(skip_serializing_if = "Option::is_none")]
353 pub snappy: Option<bool>,
354
355 #[serde(skip_serializing_if = "Option::is_none")]
357 pub sample_rate: Option<i32>,
358
359 #[serde(skip_serializing_if = "Option::is_none")]
361 pub user_agent: Option<String>,
362
363 #[serde(skip_serializing_if = "Option::is_none")]
365 pub msg_timeout: Option<i32>,
366}
367
368impl Default for IdentifyConfig {
369 fn default() -> Self {
370 let hostname = hostname::get()
371 .ok()
372 .and_then(|h| h.into_string().ok())
373 .unwrap_or_else(|| "unknown".to_string());
374
375 Self {
376 client_id: Some(hostname.clone()),
377 hostname: Some(hostname),
378 feature_negotiation: Some(true),
379 heartbeat_interval: Some(30000),
380 output_buffer_size: Some(16384),
381 output_buffer_timeout: Some(250),
382 tls_v1: None,
383 snappy: None,
384 sample_rate: None,
385 user_agent: Some(format!("rust-nsq/{}", env!("CARGO_PKG_VERSION"))),
386 msg_timeout: Some(60000),
387 }
388 }
389}
390
391#[derive(Debug, Clone, PartialEq)]
393pub enum FrameType {
394 Response,
396 Error,
398 Message,
400}
401
402impl TryFrom<u32> for FrameType {
403 type Error = Error;
404
405 fn try_from(value: u32) -> Result<Self> {
406 match value {
407 0 => Ok(FrameType::Response),
408 1 => Ok(FrameType::Error),
409 2 => Ok(FrameType::Message),
410 _ => Err(Error::Protocol(ProtocolError::Other(format!(
411 "未知帧类型: {}",
412 value
413 )))),
414 }
415 }
416}
417
418pub fn read_frame(data: &[u8]) -> Result<(FrameType, &[u8])> {
420 if data.len() < 8 {
421 return Err(Error::Protocol(ProtocolError::Other(
422 "帧数据不完整".to_string(),
423 )));
424 }
425
426 let mut cursor = Cursor::new(data);
427 let size = cursor.read_u32::<BigEndian>()?;
428
429 if data.len() < (size as usize + 4) {
430 return Err(Error::Protocol(ProtocolError::Other(
431 "帧数据不完整".to_string(),
432 )));
433 }
434
435 let frame_type_raw = cursor.read_u32::<BigEndian>()?;
436 let frame_type = FrameType::try_from(frame_type_raw)?;
437
438 Ok((frame_type, &data[8..(size as usize + 4)]))
440}
441
442#[derive(Debug, Error)]
443pub enum ProtocolError {
444 #[error("IO error: {0}")]
445 Io(#[from] std::io::Error),
446 #[error("Invalid frame size")]
447 InvalidFrameSize,
448 #[error("Invalid magic version")]
449 InvalidMagicVersion,
450 #[error("Invalid frame type: {0}")]
451 InvalidFrameType(i32),
452 #[error("Protocol error: {0}")]
453 Other(String),
454}
455
456#[derive(Debug)]
457pub enum Frame {
458 Response(Vec<u8>),
459 Error(Vec<u8>),
460 Message(Message),
461}
462
463pub struct Protocol;
464
465impl Protocol {
466 pub fn write_command(cmd: &[u8], params: &[&[u8]]) -> Vec<u8> {
467 let mut buf = Vec::new();
468 buf.extend_from_slice(cmd);
469
470 if !params.is_empty() {
471 buf.push(b' ');
472 for (i, param) in params.iter().enumerate() {
473 if i > 0 {
474 buf.push(b' ');
475 }
476 buf.extend_from_slice(param);
477 }
478 }
479
480 buf.extend_from_slice(b"\n");
481 buf
482 }
483
484 pub fn decode_message(data: &[u8]) -> Result<Message> {
485 if data.len() < 4 {
486 return Err(Error::Protocol(ProtocolError::InvalidFrameSize));
487 }
488
489 let timestamp = BigEndian::read_u64(&data[0..8]);
490 let attempts = BigEndian::read_u16(&data[8..10]);
491 let id = data[10..26].to_vec();
492 let body = data[26..].to_vec();
493
494 Ok(Message {
495 timestamp,
496 attempts,
497 id,
498 body,
499 connection: None,
500 auto_response_disabled: false,
501 responded: Arc::new(AtomicBool::new(false)),
502 })
503 }
504
505 pub fn decode_frame(data: &[u8]) -> Result<Frame> {
506 if data.len() < 4 {
507 return Err(Error::Protocol(ProtocolError::InvalidFrameSize));
508 }
509
510 let frame_type = BigEndian::read_i32(&data[0..4]);
511 let frame_data = &data[4..];
512
513 match frame_type {
514 FRAME_TYPE_RESPONSE => Ok(Frame::Response(frame_data.to_vec())),
515 FRAME_TYPE_ERROR => Ok(Frame::Error(frame_data.to_vec())),
516 FRAME_TYPE_MESSAGE => {
517 let msg = Self::decode_message(frame_data)?;
518 Ok(Frame::Message(msg))
519 }
520 _ => Err(Error::Protocol(ProtocolError::InvalidFrameType(frame_type))),
521 }
522 }
523
524 pub fn encode_command(name: &str, body: Option<&[u8]>, params: &[&str]) -> Vec<u8> {
525 let mut cmd = Vec::new();
526
527 cmd.extend_from_slice(&[0; 4]);
529
530 cmd.extend_from_slice(name.as_bytes());
532
533 for param in params {
535 cmd.push(b' ');
536 cmd.extend_from_slice(param.as_bytes());
537 }
538
539 cmd.push(b'\n');
540
541 if let Some(body) = body {
543 cmd.extend_from_slice(body);
544 }
545
546 let size = (cmd.len() - 4) as u32;
548 let mut size_bytes = [0; 4];
549 BigEndian::write_u32(&mut size_bytes, size);
550 cmd[0..4].copy_from_slice(&size_bytes);
551
552 cmd
553 }
554}
555
556pub const IDENTIFY: &str = "IDENTIFY";
558pub const SUB: &str = "SUB";
559pub const PUB: &str = "PUB";
560pub const MPUB: &str = "MPUB";
561pub const RDY: &str = "RDY";
562pub const FIN: &str = "FIN";
563pub const REQ: &str = "REQ";
564pub const TOUCH: &str = "TOUCH";
565pub const CLS: &str = "CLS";
566pub const NOP: &str = "NOP";
567pub const AUTH: &str = "AUTH";
568
569#[cfg(test)]
570mod tests {
571 use super::*;
572
573 #[test]
574 fn test_identify_command() {
575 let config = IdentifyConfig {
576 client_id: Some("test_client".to_string()),
577 hostname: Some("test_host".to_string()),
578 feature_negotiation: Some(true),
579 ..Default::default()
580 };
581
582 let cmd = Command::Identify(config);
583 let bytes = cmd.to_bytes().unwrap();
584
585 assert!(bytes.starts_with(b"IDENTIFY\n"));
587 }
588
589 #[test]
590 fn test_publish_command() {
591 let topic = "test_topic".to_string();
592 let msg_body = b"test message".to_vec();
593
594 let cmd = Command::Publish(topic, msg_body.clone());
595 let bytes = cmd.to_bytes().unwrap();
596
597 assert!(bytes.starts_with(b"PUB test_topic\n"));
599
600 let message_size_bytes = &bytes[15..19];
602 let mut cursor = Cursor::new(message_size_bytes);
603 let message_size = cursor.read_u32::<BigEndian>().unwrap();
604 assert_eq!(message_size as usize, msg_body.len());
605
606 let actual_message = &bytes[19..];
607 assert_eq!(actual_message, msg_body.as_slice());
608 }
609
610 #[test]
611 fn test_message_creation() {
612 let msg = Message::new(
613 vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
614 b"test body".to_vec(),
615 123456789,
616 1,
617 );
618
619 assert_eq!(msg.attempts, 1);
620 assert_eq!(msg.timestamp, 123456789);
621 assert_eq!(msg.body, b"test body");
622 assert!(!msg.is_auto_response_disabled());
623 assert!(!msg.has_responded());
624 }
625}