1use byteorder::{BigEndian, ByteOrder, ReadBytesExt, WriteBytesExt};
2use serde::{Deserialize, Serialize};
3use std::io::{Cursor, Read};
4use thiserror::Error;
5
6use crate::error::{Error, Result};
7
8pub const MAGIC_V2: &[u8] = b"  V2";
10pub const HEARTBEAT: &[u8] = b"_heartbeat_";
11pub const OK: &[u8] = b"OK";
12pub const FRAME_TYPE_RESPONSE: i32 = 0;
13pub const FRAME_TYPE_ERROR: i32 = 1;
14pub const FRAME_TYPE_MESSAGE: i32 = 2;
15
16#[derive(Debug, Clone, PartialEq)]
18pub enum Command {
19    Identify(IdentifyConfig),
21    Subscribe(String, String),
23    Publish(String, Vec<u8>),
25    DelayedPublish(String, Vec<u8>, u32),
27    Mpublish(String, Vec<Vec<u8>>),
29    Ready(u32),
31    Finish(String),
33    Requeue(String, u32),
35    Touch(String),
37    Nop,
39    Cls,
41    Auth(Option<String>),
43}
44
45impl Command {
46    pub fn to_bytes(&self) -> Result<Vec<u8>> {
48        let mut buf = Vec::new();
49        match self {
50            Command::Identify(config) => {
51                buf.extend_from_slice(b"IDENTIFY\n");
52                let json = serde_json::to_string(config)?;
53                buf.write_u32::<BigEndian>(json.len() as u32)?;
54                buf.extend_from_slice(json.as_bytes());
55            }
56            Command::Subscribe(topic, channel) => {
57                let cmd = format!("SUB {} {}\n", topic, channel);
58                buf.extend_from_slice(cmd.as_bytes());
59            }
60            Command::Publish(topic, body) => {
61                let cmd = format!("PUB {}\n", topic);
62                buf.extend_from_slice(cmd.as_bytes());
63                buf.write_u32::<BigEndian>(body.len() as u32)?;
64                buf.extend_from_slice(body.as_slice());
65            }
66            Command::DelayedPublish(topic, body, delay) => {
67                let cmd = format!("DPUB {} {}\n", topic, delay);
68                buf.extend_from_slice(cmd.as_bytes());
69                buf.write_u32::<BigEndian>(body.len() as u32)?;
70                buf.extend_from_slice(body.as_slice());
71            }
72            Command::Mpublish(topic, bodies) => {
73                let cmd = format!("MPUB {}\n", topic);
74                buf.extend_from_slice(cmd.as_bytes());
75
76                let mut total_size = 4;
78                for body in bodies {
79                    total_size += 4 + body.len();
80                }
81
82                buf.write_u32::<BigEndian>(total_size as u32)?;
83                buf.write_u32::<BigEndian>(bodies.len() as u32)?;
84
85                for body in bodies {
86                    buf.write_u32::<BigEndian>(body.len() as u32)?;
87                    buf.extend_from_slice(body);
88                }
89            }
90            Command::Ready(count) => {
91                let cmd = format!("RDY {}\n", count);
92                buf.extend_from_slice(cmd.as_bytes());
93            }
94            Command::Finish(id) => {
95                let cmd = format!("FIN {}\n", id);
96                buf.extend_from_slice(cmd.as_bytes());
97            }
98            Command::Requeue(id, delay) => {
99                let cmd = format!("REQ {} {}\n", id, delay);
100                buf.extend_from_slice(cmd.as_bytes());
101            }
102            Command::Touch(id) => {
103                let cmd = format!("TOUCH {}\n", id);
104                buf.extend_from_slice(cmd.as_bytes());
105            }
106            Command::Nop => {
107                buf.extend_from_slice(b"NOP\n");
108            }
109            Command::Cls => {
110                buf.extend_from_slice(b"CLS\n");
111            }
112            Command::Auth(secret) => {
113                buf.extend_from_slice(b"AUTH\n");
114                if let Some(s) = secret {
115                    buf.write_u32::<BigEndian>(s.len() as u32)?;
116                    buf.extend_from_slice(s.as_bytes());
117                } else {
118                    buf.write_u32::<BigEndian>(0)?;
119                }
120            }
121        }
122
123        Ok(buf)
124    }
125}
126
127#[derive(Debug, Clone)]
129pub struct Message {
130    pub id: Vec<u8>,
132    pub timestamp: u64,
134    pub attempts: u16,
136    pub body: Vec<u8>,
138}
139
140impl Message {
141    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
143        if bytes.len() < 26 {
144            return Err(Error::Protocol(ProtocolError::Other(
145                "消息大小不足".to_string(),
146            )));
147        }
148
149        let mut cursor = Cursor::new(bytes);
150
151        cursor.set_position(4);
153
154        let frame_type = cursor.read_u32::<BigEndian>()?;
156        if frame_type != 2 {
157            return Err(Error::Protocol(ProtocolError::Other(format!(
158                "无效的帧类型: {}",
159                frame_type
160            ))));
161        }
162
163        let timestamp = cursor.read_u64::<BigEndian>()?;
165
166        let attempts = cursor.read_u16::<BigEndian>()?;
168
169        let mut id_bytes = [0u8; 16];
171        cursor.read_exact(&mut id_bytes)?;
172        let id = id_bytes.to_vec();
173
174        let mut body = Vec::new();
176        cursor.read_to_end(&mut body)?;
177
178        Ok(Self {
179            id,
180            timestamp,
181            attempts,
182            body,
183        })
184    }
185}
186
187#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
189pub struct IdentifyConfig {
190    #[serde(skip_serializing_if = "Option::is_none")]
192    pub client_id: Option<String>,
193
194    #[serde(skip_serializing_if = "Option::is_none")]
196    pub hostname: Option<String>,
197
198    #[serde(skip_serializing_if = "Option::is_none")]
200    pub feature_negotiation: Option<bool>,
201
202    #[serde(skip_serializing_if = "Option::is_none")]
204    pub heartbeat_interval: Option<i32>,
205
206    #[serde(skip_serializing_if = "Option::is_none")]
208    pub output_buffer_size: Option<i32>,
209
210    #[serde(skip_serializing_if = "Option::is_none")]
212    pub output_buffer_timeout: Option<i32>,
213
214    #[serde(skip_serializing_if = "Option::is_none")]
216    pub tls_v1: Option<bool>,
217
218    #[serde(skip_serializing_if = "Option::is_none")]
220    pub snappy: Option<bool>,
221
222    #[serde(skip_serializing_if = "Option::is_none")]
224    pub sample_rate: Option<i32>,
225
226    #[serde(skip_serializing_if = "Option::is_none")]
228    pub user_agent: Option<String>,
229
230    #[serde(skip_serializing_if = "Option::is_none")]
232    pub msg_timeout: Option<i32>,
233}
234
235impl Default for IdentifyConfig {
236    fn default() -> Self {
237        let hostname = hostname::get()
238            .ok()
239            .and_then(|h| h.into_string().ok())
240            .unwrap_or_else(|| "unknown".to_string());
241
242        Self {
243            client_id: Some(hostname.clone()),
244            hostname: Some(hostname),
245            feature_negotiation: Some(true),
246            heartbeat_interval: Some(30000),
247            output_buffer_size: Some(16384),
248            output_buffer_timeout: Some(250),
249            tls_v1: None,
250            snappy: None,
251            sample_rate: None,
252            user_agent: Some(format!("rust-nsq/{}", env!("CARGO_PKG_VERSION"))),
253            msg_timeout: Some(60000),
254        }
255    }
256}
257
258#[derive(Debug, Clone, PartialEq)]
260pub enum FrameType {
261    Response,
263    Error,
265    Message,
267}
268
269impl TryFrom<u32> for FrameType {
270    type Error = Error;
271
272    fn try_from(value: u32) -> Result<Self> {
273        match value {
274            0 => Ok(FrameType::Response),
275            1 => Ok(FrameType::Error),
276            2 => Ok(FrameType::Message),
277            _ => Err(Error::Protocol(ProtocolError::Other(format!(
278                "未知帧类型: {}",
279                value
280            )))),
281        }
282    }
283}
284
285pub fn read_frame(data: &[u8]) -> Result<(FrameType, &[u8])> {
287    if data.len() < 8 {
288        return Err(Error::Protocol(ProtocolError::Other(
289            "帧数据不完整".to_string(),
290        )));
291    }
292
293    let mut cursor = Cursor::new(data);
294    let size = cursor.read_u32::<BigEndian>()?;
295
296    if data.len() < (size as usize + 4) {
297        return Err(Error::Protocol(ProtocolError::Other(
298            "帧数据不完整".to_string(),
299        )));
300    }
301
302    let frame_type_raw = cursor.read_u32::<BigEndian>()?;
303    let frame_type = FrameType::try_from(frame_type_raw)?;
304
305    Ok((frame_type, &data[8..(size as usize + 4)]))
307}
308
309#[derive(Debug, Error)]
310pub enum ProtocolError {
311    #[error("IO error: {0}")]
312    Io(#[from] std::io::Error),
313    #[error("Invalid frame size")]
314    InvalidFrameSize,
315    #[error("Invalid magic version")]
316    InvalidMagicVersion,
317    #[error("Invalid frame type: {0}")]
318    InvalidFrameType(i32),
319    #[error("Protocol error: {0}")]
320    Other(String),
321}
322
323#[derive(Debug)]
324pub enum Frame {
325    Response(Vec<u8>),
326    Error(Vec<u8>),
327    Message(Message),
328}
329
330pub struct Protocol;
331
332impl Protocol {
333    pub fn write_command(cmd: &[u8], params: &[&[u8]]) -> Vec<u8> {
334        let mut buf = Vec::new();
335        buf.extend_from_slice(cmd);
336
337        if !params.is_empty() {
338            buf.push(b' ');
339            for (i, param) in params.iter().enumerate() {
340                if i > 0 {
341                    buf.push(b' ');
342                }
343                buf.extend_from_slice(param);
344            }
345        }
346
347        buf.extend_from_slice(b"\n");
348        buf
349    }
350
351    pub fn decode_message(data: &[u8]) -> Result<Message> {
352        if data.len() < 4 {
353            return Err(Error::Protocol(ProtocolError::InvalidFrameSize));
354        }
355
356        let timestamp = BigEndian::read_u64(&data[0..8]);
357        let attempts = BigEndian::read_u16(&data[8..10]);
358        let id = data[10..26].to_vec();
359        let body = data[26..].to_vec();
360
361        Ok(Message {
362            timestamp,
363            attempts,
364            id,
365            body,
366        })
367    }
368
369    pub fn decode_frame(data: &[u8]) -> Result<Frame> {
370        if data.len() < 4 {
371            return Err(Error::Protocol(ProtocolError::InvalidFrameSize));
372        }
373
374        let frame_type = BigEndian::read_i32(&data[0..4]);
375        let frame_data = &data[4..];
376
377        match frame_type {
378            FRAME_TYPE_RESPONSE => Ok(Frame::Response(frame_data.to_vec())),
379            FRAME_TYPE_ERROR => Ok(Frame::Error(frame_data.to_vec())),
380            FRAME_TYPE_MESSAGE => {
381                let msg = Self::decode_message(frame_data)?;
382                Ok(Frame::Message(msg))
383            }
384            _ => Err(Error::Protocol(ProtocolError::InvalidFrameType(frame_type))),
385        }
386    }
387
388    pub fn encode_command(name: &str, body: Option<&[u8]>, params: &[&str]) -> Vec<u8> {
389        let mut cmd = Vec::new();
390
391        cmd.extend_from_slice(&[0; 4]);
393
394        cmd.extend_from_slice(name.as_bytes());
396
397        for param in params {
399            cmd.push(b' ');
400            cmd.extend_from_slice(param.as_bytes());
401        }
402
403        cmd.push(b'\n');
404
405        if let Some(body) = body {
407            cmd.extend_from_slice(body);
408        }
409
410        let size = (cmd.len() - 4) as u32;
412        let mut size_bytes = [0; 4];
413        BigEndian::write_u32(&mut size_bytes, size);
414        cmd[0..4].copy_from_slice(&size_bytes);
415
416        cmd
417    }
418}
419
420pub const IDENTIFY: &str = "IDENTIFY";
422pub const SUB: &str = "SUB";
423pub const PUB: &str = "PUB";
424pub const MPUB: &str = "MPUB";
425pub const RDY: &str = "RDY";
426pub const FIN: &str = "FIN";
427pub const REQ: &str = "REQ";
428pub const TOUCH: &str = "TOUCH";
429pub const CLS: &str = "CLS";
430pub const NOP: &str = "NOP";
431pub const AUTH: &str = "AUTH";
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    #[test]
438    fn test_identify_command() {
439        let config = IdentifyConfig {
440            client_id: Some("test_client".to_string()),
441            hostname: Some("test_host".to_string()),
442            feature_negotiation: Some(true),
443            ..Default::default()
444        };
445
446        let cmd = Command::Identify(config);
447        let bytes = cmd.to_bytes().unwrap();
448
449        assert!(bytes.starts_with(b"IDENTIFY\n"));
451    }
452
453    #[test]
454    fn test_publish_command() {
455        let topic = "test_topic".to_string();
456        let message = b"test message".to_vec();
457
458        let cmd = Command::Publish(topic, message.clone());
459        let bytes = cmd.to_bytes().unwrap();
460
461        assert!(bytes.starts_with(b"PUB test_topic\n"));
463
464        let message_size_bytes = &bytes[15..19];
466        let mut cursor = Cursor::new(message_size_bytes);
467        let message_size = cursor.read_u32::<BigEndian>().unwrap();
468        assert_eq!(message_size as usize, message.len());
469
470        let actual_message = &bytes[19..];
471        assert_eq!(actual_message, message.as_slice());
472    }
473}