Skip to main content

rocketmq_common/common/message/
message_enum.rs

1// Copyright 2023 The RocketMQ Rust Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt;
16
17use serde::de;
18use serde::de::Visitor;
19use serde::Deserialize;
20use serde::Deserializer;
21use serde::Serialize;
22use serde::Serializer;
23
24#[derive(Debug, PartialEq, Copy, Clone, Default)]
25pub enum MessageType {
26    #[default]
27    NormalMsg,
28    TransMsgHalf,
29    TransMsgCommit,
30    DelayMsg,
31    OrderMsg,
32}
33
34impl MessageType {
35    pub fn get_short_name(&self) -> &'static str {
36        match self {
37            MessageType::NormalMsg => "Normal",
38            MessageType::TransMsgHalf => "Trans",
39            MessageType::TransMsgCommit => "TransCommit",
40            MessageType::DelayMsg => "Delay",
41            MessageType::OrderMsg => "Order",
42        }
43    }
44
45    pub fn get_by_short_name(short_name: &str) -> MessageType {
46        match short_name {
47            "Normal" => MessageType::NormalMsg,
48            "Trans" => MessageType::TransMsgHalf,
49            "TransCommit" => MessageType::TransMsgCommit,
50            "Delay" => MessageType::DelayMsg,
51            "Order" => MessageType::OrderMsg,
52            _ => MessageType::NormalMsg,
53        }
54    }
55}
56
57#[derive(Debug, PartialEq, Copy, Clone, Hash, Eq)]
58pub enum MessageRequestMode {
59    Pull,
60    Pop,
61}
62
63impl MessageRequestMode {
64    pub fn get_name(&self) -> &'static str {
65        match self {
66            MessageRequestMode::Pull => "PULL",
67            MessageRequestMode::Pop => "POP",
68        }
69    }
70}
71
72impl Serialize for MessageRequestMode {
73    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
74    where
75        S: Serializer,
76    {
77        serializer.serialize_str(match *self {
78            MessageRequestMode::Pull => "PULL",
79            MessageRequestMode::Pop => "POP",
80        })
81    }
82}
83
84impl<'de> Deserialize<'de> for MessageRequestMode {
85    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
86    where
87        D: Deserializer<'de>,
88    {
89        struct MessageRequestModeVisitor;
90
91        impl Visitor<'_> for MessageRequestModeVisitor {
92            type Value = MessageRequestMode;
93
94            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
95                formatter.write_str("a string representing a MessageRequestMode")
96            }
97
98            fn visit_str<E>(self, value: &str) -> Result<MessageRequestMode, E>
99            where
100                E: de::Error,
101            {
102                match value {
103                    "PULL" | "Pull" => Ok(MessageRequestMode::Pull),
104                    "POP" | "Pop" => Ok(MessageRequestMode::Pop),
105                    _ => Err(de::Error::unknown_variant(value, &["PULL/Pull", "POP/Pop"])),
106                }
107            }
108        }
109
110        deserializer.deserialize_str(MessageRequestModeVisitor)
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    #[test]
119    fn test_get_short_name() {
120        assert_eq!(MessageType::NormalMsg.get_short_name(), "Normal");
121        assert_eq!(MessageType::TransMsgHalf.get_short_name(), "Trans");
122        assert_eq!(MessageType::TransMsgCommit.get_short_name(), "TransCommit");
123        assert_eq!(MessageType::DelayMsg.get_short_name(), "Delay");
124        assert_eq!(MessageType::OrderMsg.get_short_name(), "Order");
125    }
126
127    #[test]
128    fn test_get_by_short_name() {
129        assert_eq!(MessageType::get_by_short_name("Normal"), MessageType::NormalMsg);
130        assert_eq!(MessageType::get_by_short_name("Trans"), MessageType::TransMsgHalf);
131        assert_eq!(
132            MessageType::get_by_short_name("TransCommit"),
133            MessageType::TransMsgCommit
134        );
135        assert_eq!(MessageType::get_by_short_name("Delay"), MessageType::DelayMsg);
136        assert_eq!(MessageType::get_by_short_name("Order"), MessageType::OrderMsg);
137        assert_eq!(MessageType::get_by_short_name("Invalid"), MessageType::NormalMsg);
138    }
139
140    #[test]
141    fn test_get_name() {
142        assert_eq!(MessageRequestMode::Pull.get_name(), "PULL");
143        assert_eq!(MessageRequestMode::Pop.get_name(), "POP");
144    }
145
146    #[test]
147    fn serialize_message_request_mode_pull() {
148        let mode = MessageRequestMode::Pull;
149        let serialized = serde_json::to_string(&mode).unwrap();
150        assert_eq!(serialized, "\"PULL\"");
151    }
152
153    #[test]
154    fn serialize_message_request_mode_pop() {
155        let mode = MessageRequestMode::Pop;
156        let serialized = serde_json::to_string(&mode).unwrap();
157        assert_eq!(serialized, "\"POP\"");
158    }
159
160    #[test]
161    fn deserialize_message_request_mode_pull() {
162        let json = "\"PULL\"";
163        let deserialized: MessageRequestMode = serde_json::from_str(json).unwrap();
164        assert_eq!(deserialized, MessageRequestMode::Pull);
165
166        let json = "\"Pull\"";
167        let deserialized: MessageRequestMode = serde_json::from_str(json).unwrap();
168        assert_eq!(deserialized, MessageRequestMode::Pull);
169    }
170
171    #[test]
172    fn deserialize_message_request_mode_pop() {
173        let json = "\"POP\"";
174        let deserialized: MessageRequestMode = serde_json::from_str(json).unwrap();
175        assert_eq!(deserialized, MessageRequestMode::Pop);
176
177        let json = "\"Pop\"";
178        let deserialized: MessageRequestMode = serde_json::from_str(json).unwrap();
179        assert_eq!(deserialized, MessageRequestMode::Pop);
180    }
181
182    #[test]
183    fn deserialize_message_request_mode_invalid() {
184        let json = "\"INVALID\"";
185        let deserialized: Result<MessageRequestMode, _> = serde_json::from_str(json);
186        assert!(deserialized.is_err());
187    }
188}