nsq_async_rs/
protocol.rs

1use byteorder::{BigEndian, ByteOrder, ReadBytesExt, WriteBytesExt};
2use serde::{Deserialize, Serialize};
3use std::io::{Cursor, Read};
4use std::sync::Arc;
5use std::sync::atomic::{AtomicBool, Ordering};
6use thiserror::Error;
7
8use crate::error::{Error, Result};
9
10/// NSQ协议常量
11pub const MAGIC_V2: &[u8] = b"  V2";
12pub const HEARTBEAT: &[u8] = b"_heartbeat_";
13pub const OK: &[u8] = b"OK";
14pub const FRAME_TYPE_RESPONSE: i32 = 0;
15pub const FRAME_TYPE_ERROR: i32 = 1;
16pub const FRAME_TYPE_MESSAGE: i32 = 2;
17
18/// 命令类型枚举
19#[derive(Debug, Clone, PartialEq)]
20pub enum Command {
21    /// 标识服务器身份和特性
22    Identify(IdentifyConfig),
23    /// 订阅主题和频道
24    Subscribe(String, String),
25    /// 发布消息到主题
26    Publish(String, Vec<u8>),
27    /// 延迟发布消息到主题
28    DelayedPublish(String, Vec<u8>, u32),
29    /// 批量发布消息到主题
30    Mpublish(String, Vec<Vec<u8>>),
31    /// 准备接收更多消息
32    Ready(u32),
33    /// 完成处理消息
34    Finish(String),
35    /// 重新入队消息
36    Requeue(String, u32),
37    /// 标记消息需要延迟处理
38    Touch(String),
39    /// 处理不同的响应类型
40    Nop,
41    /// 清理和关闭连接
42    Cls,
43    /// 认证
44    Auth(Option<String>),
45}
46
47impl Command {
48    /// 将命令转换为字节以便发送
49    pub fn to_bytes(&self) -> Result<Vec<u8>> {
50        let mut buf = Vec::new();
51        match self {
52            Command::Identify(config) => {
53                buf.extend_from_slice(b"IDENTIFY\n");
54                let json = serde_json::to_string(config)?;
55                buf.write_u32::<BigEndian>(json.len() as u32)?;
56                buf.extend_from_slice(json.as_bytes());
57            }
58            Command::Subscribe(topic, channel) => {
59                let cmd = format!("SUB {} {}\n", topic, channel);
60                buf.extend_from_slice(cmd.as_bytes());
61            }
62            Command::Publish(topic, body) => {
63                let cmd = format!("PUB {}\n", topic);
64                buf.extend_from_slice(cmd.as_bytes());
65                buf.write_u32::<BigEndian>(body.len() as u32)?;
66                buf.extend_from_slice(body.as_slice());
67            }
68            Command::DelayedPublish(topic, body, delay) => {
69                let cmd = format!("DPUB {} {}\n", topic, delay);
70                buf.extend_from_slice(cmd.as_bytes());
71                buf.write_u32::<BigEndian>(body.len() as u32)?;
72                buf.extend_from_slice(body.as_slice());
73            }
74            Command::Mpublish(topic, bodies) => {
75                let cmd = format!("MPUB {}\n", topic);
76                buf.extend_from_slice(cmd.as_bytes());
77
78                // 计算总大小: 4字节(消息数量) + 每个消息的(4字节大小 + 内容)
79                let mut total_size = 4;
80                for body in bodies {
81                    total_size += 4 + body.len();
82                }
83
84                buf.write_u32::<BigEndian>(total_size as u32)?;
85                buf.write_u32::<BigEndian>(bodies.len() as u32)?;
86
87                for body in bodies {
88                    buf.write_u32::<BigEndian>(body.len() as u32)?;
89                    buf.extend_from_slice(body);
90                }
91            }
92            Command::Ready(count) => {
93                let cmd = format!("RDY {}\n", count);
94                buf.extend_from_slice(cmd.as_bytes());
95            }
96            Command::Finish(id) => {
97                let cmd = format!("FIN {}\n", id);
98                buf.extend_from_slice(cmd.as_bytes());
99            }
100            Command::Requeue(id, delay) => {
101                let cmd = format!("REQ {} {}\n", id, delay);
102                buf.extend_from_slice(cmd.as_bytes());
103            }
104            Command::Touch(id) => {
105                let cmd = format!("TOUCH {}\n", id);
106                buf.extend_from_slice(cmd.as_bytes());
107            }
108            Command::Nop => {
109                buf.extend_from_slice(b"NOP\n");
110            }
111            Command::Cls => {
112                buf.extend_from_slice(b"CLS\n");
113            }
114            Command::Auth(secret) => {
115                buf.extend_from_slice(b"AUTH\n");
116                if let Some(s) = secret {
117                    buf.write_u32::<BigEndian>(s.len() as u32)?;
118                    buf.extend_from_slice(s.as_bytes());
119                } else {
120                    buf.write_u32::<BigEndian>(0)?;
121                }
122            }
123        }
124
125        Ok(buf)
126    }
127}
128
129/// NSQ消息格式
130#[derive(Debug, Clone)]
131pub struct Message {
132    /// 唯一消息ID
133    pub id: Vec<u8>,
134    /// 消息时间戳
135    pub timestamp: u64,
136    /// 消息尝试次数
137    pub attempts: u16,
138    /// 消息体
139    pub body: Vec<u8>,
140    /// 连接引用(用于手动确认)
141    #[allow(dead_code)]
142    connection: Option<Arc<MessageResponder>>,
143    /// 是否禁用自动响应
144    auto_response_disabled: bool,
145    /// 是否已经响应
146    responded: Arc<AtomicBool>,
147}
148
149/// 消息响应器 - 用于手动确认消息
150#[derive(Debug)]
151pub struct MessageResponder {
152    connection: Arc<crate::connection::Connection>,
153    msg_id: String,
154}
155
156impl MessageResponder {
157    /// 创建新的消息响应器
158    pub fn new(connection: Arc<crate::connection::Connection>, msg_id: String) -> Self {
159        Self { connection, msg_id }
160    }
161
162    /// 发送 FIN 命令,标记消息处理完成
163    pub async fn finish(&self) -> Result<()> {
164        let cmd = Command::Finish(self.msg_id.clone());
165        self.connection.send_command(cmd).await
166    }
167
168    /// 发送 REQ 命令,重新入队消息
169    pub async fn requeue(&self, delay: u32) -> Result<()> {
170        let cmd = Command::Requeue(self.msg_id.clone(), delay);
171        self.connection.send_command(cmd).await
172    }
173
174    /// 发送 TOUCH 命令,重置消息超时
175    pub async fn touch(&self) -> Result<()> {
176        let cmd = Command::Touch(self.msg_id.clone());
177        self.connection.send_command(cmd).await
178    }
179}
180
181impl Message {
182    /// 创建新消息(用于测试)
183    pub fn new(id: Vec<u8>, body: Vec<u8>, timestamp: u64, attempts: u16) -> Self {
184        Self {
185            id,
186            timestamp,
187            attempts,
188            body,
189            connection: None,
190            auto_response_disabled: false,
191            responded: Arc::new(AtomicBool::new(false)),
192        }
193    }
194
195    /// 设置消息响应器
196    pub fn with_responder(mut self, connection: Arc<crate::connection::Connection>) -> Self {
197        let msg_id = String::from_utf8_lossy(&self.id).to_string();
198        self.connection = Some(Arc::new(MessageResponder::new(connection, msg_id)));
199        self
200    }
201
202    /// 禁用自动响应
203    ///
204    /// 调用此方法后,消息不会自动发送 FIN/REQ,需要手动调用 finish() 或 requeue()
205    pub fn disable_auto_response(&mut self) {
206        self.auto_response_disabled = true;
207    }
208
209    /// 检查是否禁用了自动响应
210    pub fn is_auto_response_disabled(&self) -> bool {
211        self.auto_response_disabled
212    }
213
214    /// 检查消息是否已经响应
215    pub fn has_responded(&self) -> bool {
216        self.responded.load(Ordering::SeqCst)
217    }
218
219    /// 手动发送 FIN 命令
220    pub async fn finish(&self) -> Result<()> {
221        // 使用 CAS 确保只响应一次
222        if self
223            .responded
224            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst).is_err()
225        {
226            return Ok(()); // 已经响应过了
227        }
228
229        if let Some(responder) = &self.connection {
230            responder.finish().await
231        } else {
232            Err(Error::Other("消息没有关联的连接".to_string()))
233        }
234    }
235
236    /// 手动发送 REQ 命令
237    pub async fn requeue(&self, delay: u32) -> Result<()> {
238        // 使用 CAS 确保只响应一次
239        if self
240            .responded
241            .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst).is_err()
242        {
243            return Ok(()); // 已经响应过了
244        }
245
246        if let Some(responder) = &self.connection {
247            responder.requeue(delay).await
248        } else {
249            Err(Error::Other("消息没有关联的连接".to_string()))
250        }
251    }
252
253    /// 手动发送 TOUCH 命令
254    pub async fn touch(&self) -> Result<()> {
255        if self.has_responded() {
256            return Ok(()); // 已经响应过了,不能再 touch
257        }
258
259        if let Some(responder) = &self.connection {
260            responder.touch().await
261        } else {
262            Err(Error::Other("消息没有关联的连接".to_string()))
263        }
264    }
265
266    /// 从字节流解析消息
267    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
268        if bytes.len() < 26 {
269            return Err(Error::Protocol(ProtocolError::Other(
270                "消息大小不足".to_string(),
271            )));
272        }
273
274        let mut cursor = Cursor::new(bytes);
275
276        // 跳过消息大小
277        cursor.set_position(4);
278
279        // 消息帧类型,2表示消息
280        let frame_type = cursor.read_u32::<BigEndian>()?;
281        if frame_type != 2 {
282            return Err(Error::Protocol(ProtocolError::Other(format!(
283                "无效的帧类型: {}",
284                frame_type
285            ))));
286        }
287
288        // 读取时间戳 (8字节)
289        let timestamp = cursor.read_u64::<BigEndian>()?;
290
291        // 读取尝试次数 (2字节)
292        let attempts = cursor.read_u16::<BigEndian>()?;
293
294        // 读取消息ID (16字节)
295        let mut id_bytes = [0u8; 16];
296        cursor.read_exact(&mut id_bytes)?;
297        let id = id_bytes.to_vec();
298
299        // 读取消息体
300        let mut body = Vec::new();
301        cursor.read_to_end(&mut body)?;
302
303        Ok(Self {
304            id,
305            timestamp,
306            attempts,
307            body,
308            connection: None,
309            auto_response_disabled: false,
310            responded: Arc::new(AtomicBool::new(false)),
311        })
312    }
313
314    /// 获取消息ID的字符串表示
315    pub fn id_string(&self) -> String {
316        String::from_utf8_lossy(&self.id).to_string()
317    }
318}
319
320/// IDENTIFY命令的配置
321#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
322pub struct IdentifyConfig {
323    /// 客户端标识,默认为hostname
324    #[serde(skip_serializing_if = "Option::is_none")]
325    pub client_id: Option<String>,
326
327    /// 客户端主机名
328    #[serde(skip_serializing_if = "Option::is_none")]
329    pub hostname: Option<String>,
330
331    /// 客户端功能特性
332    #[serde(skip_serializing_if = "Option::is_none")]
333    pub feature_negotiation: Option<bool>,
334
335    /// 心跳间隔(毫秒)
336    #[serde(skip_serializing_if = "Option::is_none")]
337    pub heartbeat_interval: Option<i32>,
338
339    /// 输出缓冲大小
340    #[serde(skip_serializing_if = "Option::is_none")]
341    pub output_buffer_size: Option<i32>,
342
343    /// 输出缓冲超时(毫秒)
344    #[serde(skip_serializing_if = "Option::is_none")]
345    pub output_buffer_timeout: Option<i32>,
346
347    /// TLS设置
348    #[serde(skip_serializing_if = "Option::is_none")]
349    pub tls_v1: Option<bool>,
350
351    /// 压缩设置
352    #[serde(skip_serializing_if = "Option::is_none")]
353    pub snappy: Option<bool>,
354
355    /// 延迟采样率
356    #[serde(skip_serializing_if = "Option::is_none")]
357    pub sample_rate: Option<i32>,
358
359    /// 用户代理
360    #[serde(skip_serializing_if = "Option::is_none")]
361    pub user_agent: Option<String>,
362
363    /// 消息超时(毫秒)
364    #[serde(skip_serializing_if = "Option::is_none")]
365    pub msg_timeout: Option<i32>,
366}
367
368impl Default for IdentifyConfig {
369    fn default() -> Self {
370        let hostname = hostname::get()
371            .ok()
372            .and_then(|h| h.into_string().ok())
373            .unwrap_or_else(|| "unknown".to_string());
374
375        Self {
376            client_id: Some(hostname.clone()),
377            hostname: Some(hostname),
378            feature_negotiation: Some(true),
379            heartbeat_interval: Some(30000),
380            output_buffer_size: Some(16384),
381            output_buffer_timeout: Some(250),
382            tls_v1: None,
383            snappy: None,
384            sample_rate: None,
385            user_agent: Some(format!("rust-nsq/{}", env!("CARGO_PKG_VERSION"))),
386            msg_timeout: Some(60000),
387        }
388    }
389}
390
391/// 帧类型
392#[derive(Debug, Clone, PartialEq)]
393pub enum FrameType {
394    /// 响应
395    Response,
396    /// 错误
397    Error,
398    /// 消息
399    Message,
400}
401
402impl TryFrom<u32> for FrameType {
403    type Error = Error;
404
405    fn try_from(value: u32) -> Result<Self> {
406        match value {
407            0 => Ok(FrameType::Response),
408            1 => Ok(FrameType::Error),
409            2 => Ok(FrameType::Message),
410            _ => Err(Error::Protocol(ProtocolError::Other(format!(
411                "未知帧类型: {}",
412                value
413            )))),
414        }
415    }
416}
417
418/// 读取NSQ帧
419pub fn read_frame(data: &[u8]) -> Result<(FrameType, &[u8])> {
420    if data.len() < 8 {
421        return Err(Error::Protocol(ProtocolError::Other(
422            "帧数据不完整".to_string(),
423        )));
424    }
425
426    let mut cursor = Cursor::new(data);
427    let size = cursor.read_u32::<BigEndian>()?;
428
429    if data.len() < (size as usize + 4) {
430        return Err(Error::Protocol(ProtocolError::Other(
431            "帧数据不完整".to_string(),
432        )));
433    }
434
435    let frame_type_raw = cursor.read_u32::<BigEndian>()?;
436    let frame_type = FrameType::try_from(frame_type_raw)?;
437
438    // 返回帧类型和帧数据(不包括大小和类型)
439    Ok((frame_type, &data[8..(size as usize + 4)]))
440}
441
442#[derive(Debug, Error)]
443pub enum ProtocolError {
444    #[error("IO error: {0}")]
445    Io(#[from] std::io::Error),
446    #[error("Invalid frame size")]
447    InvalidFrameSize,
448    #[error("Invalid magic version")]
449    InvalidMagicVersion,
450    #[error("Invalid frame type: {0}")]
451    InvalidFrameType(i32),
452    #[error("Protocol error: {0}")]
453    Other(String),
454}
455
456#[derive(Debug)]
457pub enum Frame {
458    Response(Vec<u8>),
459    Error(Vec<u8>),
460    Message(Message),
461}
462
463pub struct Protocol;
464
465impl Protocol {
466    pub fn write_command(cmd: &[u8], params: &[&[u8]]) -> Vec<u8> {
467        let mut buf = Vec::new();
468        buf.extend_from_slice(cmd);
469
470        if !params.is_empty() {
471            buf.push(b' ');
472            for (i, param) in params.iter().enumerate() {
473                if i > 0 {
474                    buf.push(b' ');
475                }
476                buf.extend_from_slice(param);
477            }
478        }
479
480        buf.extend_from_slice(b"\n");
481        buf
482    }
483
484    pub fn decode_message(data: &[u8]) -> Result<Message> {
485        if data.len() < 4 {
486            return Err(Error::Protocol(ProtocolError::InvalidFrameSize));
487        }
488
489        let timestamp = BigEndian::read_u64(&data[0..8]);
490        let attempts = BigEndian::read_u16(&data[8..10]);
491        let id = data[10..26].to_vec();
492        let body = data[26..].to_vec();
493
494        Ok(Message {
495            timestamp,
496            attempts,
497            id,
498            body,
499            connection: None,
500            auto_response_disabled: false,
501            responded: Arc::new(AtomicBool::new(false)),
502        })
503    }
504
505    pub fn decode_frame(data: &[u8]) -> Result<Frame> {
506        if data.len() < 4 {
507            return Err(Error::Protocol(ProtocolError::InvalidFrameSize));
508        }
509
510        let frame_type = BigEndian::read_i32(&data[0..4]);
511        let frame_data = &data[4..];
512
513        match frame_type {
514            FRAME_TYPE_RESPONSE => Ok(Frame::Response(frame_data.to_vec())),
515            FRAME_TYPE_ERROR => Ok(Frame::Error(frame_data.to_vec())),
516            FRAME_TYPE_MESSAGE => {
517                let msg = Self::decode_message(frame_data)?;
518                Ok(Frame::Message(msg))
519            }
520            _ => Err(Error::Protocol(ProtocolError::InvalidFrameType(frame_type))),
521        }
522    }
523
524    pub fn encode_command(name: &str, body: Option<&[u8]>, params: &[&str]) -> Vec<u8> {
525        let mut cmd = Vec::new();
526
527        // Write size placeholder
528        cmd.extend_from_slice(&[0; 4]);
529
530        // Write command name
531        cmd.extend_from_slice(name.as_bytes());
532
533        // Write parameters
534        for param in params {
535            cmd.push(b' ');
536            cmd.extend_from_slice(param.as_bytes());
537        }
538
539        cmd.push(b'\n');
540
541        // Write body if present
542        if let Some(body) = body {
543            cmd.extend_from_slice(body);
544        }
545
546        // Update size
547        let size = (cmd.len() - 4) as u32;
548        let mut size_bytes = [0; 4];
549        BigEndian::write_u32(&mut size_bytes, size);
550        cmd[0..4].copy_from_slice(&size_bytes);
551
552        cmd
553    }
554}
555
556// Common NSQ commands
557pub const IDENTIFY: &str = "IDENTIFY";
558pub const SUB: &str = "SUB";
559pub const PUB: &str = "PUB";
560pub const MPUB: &str = "MPUB";
561pub const RDY: &str = "RDY";
562pub const FIN: &str = "FIN";
563pub const REQ: &str = "REQ";
564pub const TOUCH: &str = "TOUCH";
565pub const CLS: &str = "CLS";
566pub const NOP: &str = "NOP";
567pub const AUTH: &str = "AUTH";
568
569#[cfg(test)]
570mod tests {
571    use super::*;
572
573    #[test]
574    fn test_identify_command() {
575        let config = IdentifyConfig {
576            client_id: Some("test_client".to_string()),
577            hostname: Some("test_host".to_string()),
578            feature_negotiation: Some(true),
579            ..Default::default()
580        };
581
582        let cmd = Command::Identify(config);
583        let bytes = cmd.to_bytes().unwrap();
584
585        // 验证命令前缀
586        assert!(bytes.starts_with(b"IDENTIFY\n"));
587    }
588
589    #[test]
590    fn test_publish_command() {
591        let topic = "test_topic".to_string();
592        let msg_body = b"test message".to_vec();
593
594        let cmd = Command::Publish(topic, msg_body.clone());
595        let bytes = cmd.to_bytes().unwrap();
596
597        // 验证命令前缀
598        assert!(bytes.starts_with(b"PUB test_topic\n"));
599
600        // 验证消息内容
601        let message_size_bytes = &bytes[15..19];
602        let mut cursor = Cursor::new(message_size_bytes);
603        let message_size = cursor.read_u32::<BigEndian>().unwrap();
604        assert_eq!(message_size as usize, msg_body.len());
605
606        let actual_message = &bytes[19..];
607        assert_eq!(actual_message, msg_body.as_slice());
608    }
609
610    #[test]
611    fn test_message_creation() {
612        let msg = Message::new(
613            vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
614            b"test body".to_vec(),
615            123456789,
616            1,
617        );
618
619        assert_eq!(msg.attempts, 1);
620        assert_eq!(msg.timestamp, 123456789);
621        assert_eq!(msg.body, b"test body");
622        assert!(!msg.is_auto_response_disabled());
623        assert!(!msg.has_responded());
624    }
625}