rocketmq_remoting/protocol/body/
subscription_group_wrapper.rs1use std::sync::Arc;
16
17use cheetah_string::CheetahString;
18use dashmap::DashMap;
19use serde::de;
20use serde::Deserialize;
21use serde::Deserializer;
22use serde::Serialize;
23use serde::Serializer;
24
25use crate::protocol::subscription::subscription_group_config::SubscriptionGroupConfig;
26use crate::protocol::DataVersion;
27
28#[derive(Debug, Clone)]
29pub struct SubscriptionGroupWrapper {
30 pub subscription_group_table: DashMap<CheetahString, Arc<SubscriptionGroupConfig>>,
31 pub forbidden_table: DashMap<CheetahString, std::collections::HashMap<CheetahString, i32>>,
32 pub data_version: DataVersion,
33}
34
35impl Serialize for SubscriptionGroupWrapper {
37 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
38 where
39 S: Serializer,
40 {
41 use serde::ser::SerializeStruct;
42
43 let mut state = serializer.serialize_struct("SubscriptionGroupWrapper", 3)?;
44
45 let table: std::collections::HashMap<CheetahString, SubscriptionGroupConfig> = self
47 .subscription_group_table
48 .iter()
49 .map(|entry| (entry.key().clone(), (**entry.value()).clone()))
50 .collect();
51 let forbidden_table: std::collections::HashMap<CheetahString, std::collections::HashMap<CheetahString, i32>> =
52 self.forbidden_table
53 .iter()
54 .map(|entry| (entry.key().clone(), entry.value().clone()))
55 .collect();
56 state.serialize_field("subscriptionGroupTable", &table)?;
57 state.serialize_field("forbiddenTable", &forbidden_table)?;
58 state.serialize_field("dataVersion", &self.data_version)?;
59 state.end()
60 }
61}
62
63impl<'de> Deserialize<'de> for SubscriptionGroupWrapper {
65 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
66 where
67 D: Deserializer<'de>,
68 {
69 use std::fmt;
70
71 use serde::de::MapAccess;
72 use serde::de::Visitor;
73
74 #[derive(Deserialize)]
75 #[serde(field_identifier, rename_all = "camelCase")]
76 enum Field {
77 SubscriptionGroupTable,
78 ForbiddenTable,
79 DataVersion,
80 #[serde(other)]
81 Ignore,
82 }
83
84 struct SubscriptionGroupWrapperVisitor;
85
86 impl<'de> Visitor<'de> for SubscriptionGroupWrapperVisitor {
87 type Value = SubscriptionGroupWrapper;
88
89 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
90 formatter.write_str("struct SubscriptionGroupWrapper")
91 }
92
93 fn visit_map<V>(self, mut map: V) -> Result<SubscriptionGroupWrapper, V::Error>
94 where
95 V: MapAccess<'de>,
96 {
97 let mut subscription_group_table: Option<
98 std::collections::HashMap<CheetahString, SubscriptionGroupConfig>,
99 > = None;
100 let mut forbidden_table: Option<
101 std::collections::HashMap<CheetahString, std::collections::HashMap<CheetahString, i32>>,
102 > = None;
103 let mut data_version: Option<DataVersion> = None;
104
105 while let Some(key) = map.next_key()? {
106 match key {
107 Field::SubscriptionGroupTable => {
108 if subscription_group_table.is_some() {
109 return Err(de::Error::duplicate_field("subscriptionGroupTable"));
110 }
111 subscription_group_table = Some(map.next_value()?);
112 }
113 Field::ForbiddenTable => {
114 if forbidden_table.is_some() {
115 return Err(de::Error::duplicate_field("forbiddenTable"));
116 }
117 forbidden_table = Some(map.next_value()?);
118 }
119 Field::DataVersion => {
120 if data_version.is_some() {
121 return Err(de::Error::duplicate_field("dataVersion"));
122 }
123 data_version = Some(map.next_value()?);
124 }
125 Field::Ignore => {
126 let _: de::IgnoredAny = map.next_value()?;
127 }
128 }
129 }
130
131 let subscription_group_table =
132 subscription_group_table.ok_or_else(|| de::Error::missing_field("subscriptionGroupTable"))?;
133 let forbidden_table = forbidden_table.unwrap_or_default();
134 let data_version = data_version.ok_or_else(|| de::Error::missing_field("dataVersion"))?;
135
136 let dash_map = DashMap::new();
138 for (key, value) in subscription_group_table {
139 dash_map.insert(key, Arc::new(value));
140 }
141 let forbidden_dash_map = DashMap::new();
142 for (key, value) in forbidden_table {
143 forbidden_dash_map.insert(key, value);
144 }
145
146 Ok(SubscriptionGroupWrapper {
147 subscription_group_table: dash_map,
148 forbidden_table: forbidden_dash_map,
149 data_version,
150 })
151 }
152 }
153
154 const FIELDS: &[&str] = &["subscriptionGroupTable", "forbiddenTable", "dataVersion"];
155 deserializer.deserialize_struct("SubscriptionGroupWrapper", FIELDS, SubscriptionGroupWrapperVisitor)
156 }
157}
158
159impl Default for SubscriptionGroupWrapper {
160 fn default() -> Self {
161 Self::new()
162 }
163}
164
165impl SubscriptionGroupWrapper {
166 pub fn new() -> Self {
167 SubscriptionGroupWrapper {
168 subscription_group_table: DashMap::with_capacity(1024),
169 forbidden_table: DashMap::with_capacity(1024),
170 data_version: DataVersion::default(),
171 }
172 }
173
174 pub fn get_subscription_group_table(&self) -> &DashMap<CheetahString, Arc<SubscriptionGroupConfig>> {
175 &self.subscription_group_table
176 }
177
178 pub fn set_subscription_group_table(&mut self, table: DashMap<CheetahString, Arc<SubscriptionGroupConfig>>) {
179 self.subscription_group_table = table;
180 }
181
182 pub fn forbidden_table(&self) -> &DashMap<CheetahString, std::collections::HashMap<CheetahString, i32>> {
183 &self.forbidden_table
184 }
185
186 pub fn set_forbidden_table(
187 &mut self,
188 table: DashMap<CheetahString, std::collections::HashMap<CheetahString, i32>>,
189 ) {
190 self.forbidden_table = table;
191 }
192
193 pub fn data_version(&self) -> &DataVersion {
194 &self.data_version
195 }
196
197 pub fn set_data_version(&mut self, version: DataVersion) {
198 self.data_version = version;
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use crate::protocol::subscription::subscription_group_config::SubscriptionGroupConfig;
206
207 #[test]
208 fn new_creates_wrapper_with_default_values() {
209 let wrapper = SubscriptionGroupWrapper::new();
210
211 assert_eq!(wrapper.subscription_group_table.len(), 0);
212 assert_eq!(wrapper.forbidden_table.len(), 0);
213 assert!(wrapper.data_version.timestamp <= DataVersion::default().timestamp);
214 }
215
216 #[test]
217 fn get_subscription_group_table_returns_reference() {
218 let wrapper = SubscriptionGroupWrapper::new();
219 wrapper
220 .subscription_group_table
221 .insert("test_group".into(), Arc::new(SubscriptionGroupConfig::default()));
222
223 let table = wrapper.get_subscription_group_table();
224 assert_eq!(table.len(), 1);
225 assert!(table.contains_key("test_group"));
226 }
227
228 #[test]
229 fn deserialize_wrapper_accepts_forbidden_table() {
230 let json = r#"{
231 "subscriptionGroupTable": {
232 "group-a": {
233 "groupName": "group-a"
234 }
235 },
236 "forbiddenTable": {
237 "group-a": {
238 "topic-a": 1
239 }
240 },
241 "dataVersion": {
242 "timestamp": 1,
243 "counter": 1,
244 "stateVersion": 0
245 }
246 }"#;
247
248 let wrapper: SubscriptionGroupWrapper =
249 serde_json::from_str(json).expect("subscription group wrapper should deserialize");
250
251 assert!(wrapper.get_subscription_group_table().contains_key("group-a"));
252 assert!(wrapper.forbidden_table().contains_key("group-a"));
253 }
254}