Skip to main content

rocketmq_remoting/protocol/body/
subscription_group_wrapper.rs

1// Copyright 2023 The RocketMQ Rust Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::sync::Arc;
16
17use cheetah_string::CheetahString;
18use dashmap::DashMap;
19use serde::de;
20use serde::Deserialize;
21use serde::Deserializer;
22use serde::Serialize;
23use serde::Serializer;
24
25use crate::protocol::subscription::subscription_group_config::SubscriptionGroupConfig;
26use crate::protocol::DataVersion;
27
28#[derive(Debug, Clone)]
29pub struct SubscriptionGroupWrapper {
30    pub subscription_group_table: DashMap<CheetahString, Arc<SubscriptionGroupConfig>>,
31    pub forbidden_table: DashMap<CheetahString, std::collections::HashMap<CheetahString, i32>>,
32    pub data_version: DataVersion,
33}
34
35// Custom serialization to handle Arc inside DashMap
36impl Serialize for SubscriptionGroupWrapper {
37    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
38    where
39        S: Serializer,
40    {
41        use serde::ser::SerializeStruct;
42
43        let mut state = serializer.serialize_struct("SubscriptionGroupWrapper", 3)?;
44
45        // Serialize DashMap by converting Arc values to direct values
46        let table: std::collections::HashMap<CheetahString, SubscriptionGroupConfig> = self
47            .subscription_group_table
48            .iter()
49            .map(|entry| (entry.key().clone(), (**entry.value()).clone()))
50            .collect();
51        let forbidden_table: std::collections::HashMap<CheetahString, std::collections::HashMap<CheetahString, i32>> =
52            self.forbidden_table
53                .iter()
54                .map(|entry| (entry.key().clone(), entry.value().clone()))
55                .collect();
56        state.serialize_field("subscriptionGroupTable", &table)?;
57        state.serialize_field("forbiddenTable", &forbidden_table)?;
58        state.serialize_field("dataVersion", &self.data_version)?;
59        state.end()
60    }
61}
62
63// Custom deserialization to wrap values in Arc
64impl<'de> Deserialize<'de> for SubscriptionGroupWrapper {
65    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
66    where
67        D: Deserializer<'de>,
68    {
69        use std::fmt;
70
71        use serde::de::MapAccess;
72        use serde::de::Visitor;
73
74        #[derive(Deserialize)]
75        #[serde(field_identifier, rename_all = "camelCase")]
76        enum Field {
77            SubscriptionGroupTable,
78            ForbiddenTable,
79            DataVersion,
80            #[serde(other)]
81            Ignore,
82        }
83
84        struct SubscriptionGroupWrapperVisitor;
85
86        impl<'de> Visitor<'de> for SubscriptionGroupWrapperVisitor {
87            type Value = SubscriptionGroupWrapper;
88
89            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
90                formatter.write_str("struct SubscriptionGroupWrapper")
91            }
92
93            fn visit_map<V>(self, mut map: V) -> Result<SubscriptionGroupWrapper, V::Error>
94            where
95                V: MapAccess<'de>,
96            {
97                let mut subscription_group_table: Option<
98                    std::collections::HashMap<CheetahString, SubscriptionGroupConfig>,
99                > = None;
100                let mut forbidden_table: Option<
101                    std::collections::HashMap<CheetahString, std::collections::HashMap<CheetahString, i32>>,
102                > = None;
103                let mut data_version: Option<DataVersion> = None;
104
105                while let Some(key) = map.next_key()? {
106                    match key {
107                        Field::SubscriptionGroupTable => {
108                            if subscription_group_table.is_some() {
109                                return Err(de::Error::duplicate_field("subscriptionGroupTable"));
110                            }
111                            subscription_group_table = Some(map.next_value()?);
112                        }
113                        Field::ForbiddenTable => {
114                            if forbidden_table.is_some() {
115                                return Err(de::Error::duplicate_field("forbiddenTable"));
116                            }
117                            forbidden_table = Some(map.next_value()?);
118                        }
119                        Field::DataVersion => {
120                            if data_version.is_some() {
121                                return Err(de::Error::duplicate_field("dataVersion"));
122                            }
123                            data_version = Some(map.next_value()?);
124                        }
125                        Field::Ignore => {
126                            let _: de::IgnoredAny = map.next_value()?;
127                        }
128                    }
129                }
130
131                let subscription_group_table =
132                    subscription_group_table.ok_or_else(|| de::Error::missing_field("subscriptionGroupTable"))?;
133                let forbidden_table = forbidden_table.unwrap_or_default();
134                let data_version = data_version.ok_or_else(|| de::Error::missing_field("dataVersion"))?;
135
136                // Convert HashMap to DashMap with Arc-wrapped values
137                let dash_map = DashMap::new();
138                for (key, value) in subscription_group_table {
139                    dash_map.insert(key, Arc::new(value));
140                }
141                let forbidden_dash_map = DashMap::new();
142                for (key, value) in forbidden_table {
143                    forbidden_dash_map.insert(key, value);
144                }
145
146                Ok(SubscriptionGroupWrapper {
147                    subscription_group_table: dash_map,
148                    forbidden_table: forbidden_dash_map,
149                    data_version,
150                })
151            }
152        }
153
154        const FIELDS: &[&str] = &["subscriptionGroupTable", "forbiddenTable", "dataVersion"];
155        deserializer.deserialize_struct("SubscriptionGroupWrapper", FIELDS, SubscriptionGroupWrapperVisitor)
156    }
157}
158
159impl Default for SubscriptionGroupWrapper {
160    fn default() -> Self {
161        Self::new()
162    }
163}
164
165impl SubscriptionGroupWrapper {
166    pub fn new() -> Self {
167        SubscriptionGroupWrapper {
168            subscription_group_table: DashMap::with_capacity(1024),
169            forbidden_table: DashMap::with_capacity(1024),
170            data_version: DataVersion::default(),
171        }
172    }
173
174    pub fn get_subscription_group_table(&self) -> &DashMap<CheetahString, Arc<SubscriptionGroupConfig>> {
175        &self.subscription_group_table
176    }
177
178    pub fn set_subscription_group_table(&mut self, table: DashMap<CheetahString, Arc<SubscriptionGroupConfig>>) {
179        self.subscription_group_table = table;
180    }
181
182    pub fn forbidden_table(&self) -> &DashMap<CheetahString, std::collections::HashMap<CheetahString, i32>> {
183        &self.forbidden_table
184    }
185
186    pub fn set_forbidden_table(
187        &mut self,
188        table: DashMap<CheetahString, std::collections::HashMap<CheetahString, i32>>,
189    ) {
190        self.forbidden_table = table;
191    }
192
193    pub fn data_version(&self) -> &DataVersion {
194        &self.data_version
195    }
196
197    pub fn set_data_version(&mut self, version: DataVersion) {
198        self.data_version = version;
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use crate::protocol::subscription::subscription_group_config::SubscriptionGroupConfig;
206
207    #[test]
208    fn new_creates_wrapper_with_default_values() {
209        let wrapper = SubscriptionGroupWrapper::new();
210
211        assert_eq!(wrapper.subscription_group_table.len(), 0);
212        assert_eq!(wrapper.forbidden_table.len(), 0);
213        assert!(wrapper.data_version.timestamp <= DataVersion::default().timestamp);
214    }
215
216    #[test]
217    fn get_subscription_group_table_returns_reference() {
218        let wrapper = SubscriptionGroupWrapper::new();
219        wrapper
220            .subscription_group_table
221            .insert("test_group".into(), Arc::new(SubscriptionGroupConfig::default()));
222
223        let table = wrapper.get_subscription_group_table();
224        assert_eq!(table.len(), 1);
225        assert!(table.contains_key("test_group"));
226    }
227
228    #[test]
229    fn deserialize_wrapper_accepts_forbidden_table() {
230        let json = r#"{
231            "subscriptionGroupTable": {
232                "group-a": {
233                    "groupName": "group-a"
234                }
235            },
236            "forbiddenTable": {
237                "group-a": {
238                    "topic-a": 1
239                }
240            },
241            "dataVersion": {
242                "timestamp": 1,
243                "counter": 1,
244                "stateVersion": 0
245            }
246        }"#;
247
248        let wrapper: SubscriptionGroupWrapper =
249            serde_json::from_str(json).expect("subscription group wrapper should deserialize");
250
251        assert!(wrapper.get_subscription_group_table().contains_key("group-a"));
252        assert!(wrapper.forbidden_table().contains_key("group-a"));
253    }
254}