rocketmq_common/common/message/
message_enum.rs

1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17use std::fmt;
18
19use serde::de;
20use serde::de::Visitor;
21use serde::Deserialize;
22use serde::Deserializer;
23use serde::Serialize;
24use serde::Serializer;
25
26#[derive(Debug, PartialEq, Copy, Clone, Default)]
27pub enum MessageType {
28    #[default]
29    NormalMsg,
30    TransMsgHalf,
31    TransMsgCommit,
32    DelayMsg,
33    OrderMsg,
34}
35
36impl MessageType {
37    pub fn get_short_name(&self) -> &'static str {
38        match self {
39            MessageType::NormalMsg => "Normal",
40            MessageType::TransMsgHalf => "Trans",
41            MessageType::TransMsgCommit => "TransCommit",
42            MessageType::DelayMsg => "Delay",
43            MessageType::OrderMsg => "Order",
44        }
45    }
46
47    pub fn get_by_short_name(short_name: &str) -> MessageType {
48        match short_name {
49            "Normal" => MessageType::NormalMsg,
50            "Trans" => MessageType::TransMsgHalf,
51            "TransCommit" => MessageType::TransMsgCommit,
52            "Delay" => MessageType::DelayMsg,
53            "Order" => MessageType::OrderMsg,
54            _ => MessageType::NormalMsg,
55        }
56    }
57}
58
59#[derive(Debug, PartialEq, Copy, Clone, Hash, Eq)]
60pub enum MessageRequestMode {
61    Pull,
62    Pop,
63}
64
65impl MessageRequestMode {
66    pub fn get_name(&self) -> &'static str {
67        match self {
68            MessageRequestMode::Pull => "PULL",
69            MessageRequestMode::Pop => "POP",
70        }
71    }
72}
73
74impl Serialize for MessageRequestMode {
75    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
76    where
77        S: Serializer,
78    {
79        serializer.serialize_str(match *self {
80            MessageRequestMode::Pull => "PULL",
81            MessageRequestMode::Pop => "POP",
82        })
83    }
84}
85
86impl<'de> Deserialize<'de> for MessageRequestMode {
87    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
88    where
89        D: Deserializer<'de>,
90    {
91        struct MessageRequestModeVisitor;
92
93        impl Visitor<'_> for MessageRequestModeVisitor {
94            type Value = MessageRequestMode;
95
96            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
97                formatter.write_str("a string representing a MessageRequestMode")
98            }
99
100            fn visit_str<E>(self, value: &str) -> Result<MessageRequestMode, E>
101            where
102                E: de::Error,
103            {
104                match value {
105                    "PULL" | "Pull" => Ok(MessageRequestMode::Pull),
106                    "POP" | "Pop" => Ok(MessageRequestMode::Pop),
107                    _ => Err(de::Error::unknown_variant(value, &["PULL/Pull", "POP/Pop"])),
108                }
109            }
110        }
111
112        deserializer.deserialize_str(MessageRequestModeVisitor)
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    #[test]
121    fn test_get_short_name() {
122        assert_eq!(MessageType::NormalMsg.get_short_name(), "Normal");
123        assert_eq!(MessageType::TransMsgHalf.get_short_name(), "Trans");
124        assert_eq!(MessageType::TransMsgCommit.get_short_name(), "TransCommit");
125        assert_eq!(MessageType::DelayMsg.get_short_name(), "Delay");
126        assert_eq!(MessageType::OrderMsg.get_short_name(), "Order");
127    }
128
129    #[test]
130    fn test_get_by_short_name() {
131        assert_eq!(
132            MessageType::get_by_short_name("Normal"),
133            MessageType::NormalMsg
134        );
135        assert_eq!(
136            MessageType::get_by_short_name("Trans"),
137            MessageType::TransMsgHalf
138        );
139        assert_eq!(
140            MessageType::get_by_short_name("TransCommit"),
141            MessageType::TransMsgCommit
142        );
143        assert_eq!(
144            MessageType::get_by_short_name("Delay"),
145            MessageType::DelayMsg
146        );
147        assert_eq!(
148            MessageType::get_by_short_name("Order"),
149            MessageType::OrderMsg
150        );
151        assert_eq!(
152            MessageType::get_by_short_name("Invalid"),
153            MessageType::NormalMsg
154        );
155    }
156
157    #[test]
158    fn test_get_name() {
159        assert_eq!(MessageRequestMode::Pull.get_name(), "PULL");
160        assert_eq!(MessageRequestMode::Pop.get_name(), "POP");
161    }
162
163    #[test]
164    fn serialize_message_request_mode_pull() {
165        let mode = MessageRequestMode::Pull;
166        let serialized = serde_json::to_string(&mode).unwrap();
167        assert_eq!(serialized, "\"PULL\"");
168    }
169
170    #[test]
171    fn serialize_message_request_mode_pop() {
172        let mode = MessageRequestMode::Pop;
173        let serialized = serde_json::to_string(&mode).unwrap();
174        assert_eq!(serialized, "\"POP\"");
175    }
176
177    #[test]
178    fn deserialize_message_request_mode_pull() {
179        let json = "\"PULL\"";
180        let deserialized: MessageRequestMode = serde_json::from_str(json).unwrap();
181        assert_eq!(deserialized, MessageRequestMode::Pull);
182
183        let json = "\"Pull\"";
184        let deserialized: MessageRequestMode = serde_json::from_str(json).unwrap();
185        assert_eq!(deserialized, MessageRequestMode::Pull);
186    }
187
188    #[test]
189    fn deserialize_message_request_mode_pop() {
190        let json = "\"POP\"";
191        let deserialized: MessageRequestMode = serde_json::from_str(json).unwrap();
192        assert_eq!(deserialized, MessageRequestMode::Pop);
193
194        let json = "\"Pop\"";
195        let deserialized: MessageRequestMode = serde_json::from_str(json).unwrap();
196        assert_eq!(deserialized, MessageRequestMode::Pop);
197    }
198
199    #[test]
200    fn deserialize_message_request_mode_invalid() {
201        let json = "\"INVALID\"";
202        let deserialized: Result<MessageRequestMode, _> = serde_json::from_str(json);
203        assert!(deserialized.is_err());
204    }
205}