1use serde::{Deserialize, Serialize};
2
3use crate::chain::{Address, Chain};
4use crate::message::MessageType;
5
6#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
9pub struct Authorization {
10 pub address: Address,
11 #[serde(skip_serializing_if = "Option::is_none")]
12 pub chain: Option<Chain>,
13 #[serde(default, skip_serializing_if = "Vec::is_empty")]
14 pub channels: Vec<String>,
15 #[serde(default, skip_serializing_if = "Vec::is_empty")]
16 pub types: Vec<MessageType>,
17 #[serde(default, skip_serializing_if = "Vec::is_empty")]
18 pub post_types: Vec<String>,
19 #[serde(default, skip_serializing_if = "Vec::is_empty")]
20 pub aggregate_keys: Vec<String>,
21}
22
23impl Authorization {
24 pub fn try_merge(&self, other: &Self) -> Option<Self> {
35 if self.address != other.address || self.chain != other.chain {
36 return None;
37 }
38
39 let channels_eq = set_eq(&self.channels, &other.channels);
40 let types_eq = set_eq(&self.types, &other.types);
41 let post_types_eq = set_eq(&self.post_types, &other.post_types);
42 let aggregate_keys_eq = set_eq(&self.aggregate_keys, &other.aggregate_keys);
43
44 let differing = (!channels_eq) as u8
45 + (!types_eq) as u8
46 + (!post_types_eq) as u8
47 + (!aggregate_keys_eq) as u8;
48
49 match differing {
50 0 => Some(self.clone()),
51 1 => {
52 let mut merged = self.clone();
53 if !channels_eq {
54 if self.channels.is_empty() || other.channels.is_empty() {
55 return None;
56 }
57 merged.channels = sorted_union(&self.channels, &other.channels);
58 } else if !types_eq {
59 if self.types.is_empty() || other.types.is_empty() {
60 return None;
61 }
62 merged.types = sorted_union(&self.types, &other.types);
63 } else if !post_types_eq {
64 if self.post_types.is_empty() || other.post_types.is_empty() {
65 return None;
66 }
67 merged.post_types = sorted_union(&self.post_types, &other.post_types);
68 } else if !aggregate_keys_eq {
69 if self.aggregate_keys.is_empty() || other.aggregate_keys.is_empty() {
70 return None;
71 }
72 merged.aggregate_keys =
73 sorted_union(&self.aggregate_keys, &other.aggregate_keys);
74 }
75 Some(merged)
76 }
77 _ => None,
78 }
79 }
80}
81
82fn set_eq<T: Ord + Clone>(a: &[T], b: &[T]) -> bool {
83 let mut a_sorted: Vec<T> = a.to_vec();
84 a_sorted.sort();
85 a_sorted.dedup();
86 let mut b_sorted: Vec<T> = b.to_vec();
87 b_sorted.sort();
88 b_sorted.dedup();
89 a_sorted == b_sorted
90}
91
92fn sorted_union<T: Ord + Clone>(a: &[T], b: &[T]) -> Vec<T> {
93 let mut out: Vec<T> = a.iter().cloned().chain(b.iter().cloned()).collect();
94 out.sort();
95 out.dedup();
96 out
97}
98
99#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
101pub struct SecurityAggregateContent {
102 #[serde(default)]
103 pub authorizations: Vec<Authorization>,
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109
110 #[test]
111 fn test_minimal_authorization_round_trip() {
112 let auth = Authorization {
113 address: Address::from("0xabc123".to_string()),
114 chain: None,
115 channels: vec![],
116 types: vec![],
117 post_types: vec![],
118 aggregate_keys: vec![],
119 };
120 let json = serde_json::to_string(&auth).unwrap();
121 assert_eq!(json, r#"{"address":"0xabc123"}"#);
123 let deserialized: Authorization = serde_json::from_str(&json).unwrap();
124 assert_eq!(auth, deserialized);
125 }
126
127 #[test]
128 fn test_full_authorization_round_trip() {
129 let auth = Authorization {
130 address: Address::from("0xdelegate".to_string()),
131 chain: Some(Chain::Ethereum),
132 channels: vec!["my-channel".to_string()],
133 types: vec![MessageType::Post, MessageType::Aggregate],
134 post_types: vec!["blog".to_string()],
135 aggregate_keys: vec!["profile".to_string()],
136 };
137 let json = serde_json::to_string(&auth).unwrap();
138 let deserialized: Authorization = serde_json::from_str(&json).unwrap();
139 assert_eq!(auth, deserialized);
140 }
141
142 #[test]
143 fn test_python_sdk_wire_format_compatibility() {
144 let python_json = r#"{
146 "address": "0xdelegate",
147 "chain": "ETH",
148 "channels": ["aleph-test"],
149 "types": ["POST", "AGGREGATE"],
150 "post_types": ["blog"],
151 "aggregate_keys": ["profile"]
152 }"#;
153 let auth: Authorization = serde_json::from_str(python_json).unwrap();
154 assert_eq!(auth.address, Address::from("0xdelegate".to_string()));
155 assert_eq!(auth.chain, Some(Chain::Ethereum));
156 assert_eq!(auth.types, vec![MessageType::Post, MessageType::Aggregate]);
157 }
158
159 #[test]
160 fn test_security_aggregate_content_round_trip() {
161 let content = SecurityAggregateContent {
162 authorizations: vec![Authorization {
163 address: Address::from("0xabc".to_string()),
164 chain: None,
165 channels: vec![],
166 types: vec![MessageType::Post],
167 post_types: vec![],
168 aggregate_keys: vec![],
169 }],
170 };
171 let json = serde_json::to_string(&content).unwrap();
172 let deserialized: SecurityAggregateContent = serde_json::from_str(&json).unwrap();
173 assert_eq!(content, deserialized);
174 }
175
176 #[test]
177 fn test_empty_security_aggregate_deserialization() {
178 let json = r#"{"authorizations":[]}"#;
179 let content: SecurityAggregateContent = serde_json::from_str(json).unwrap();
180 assert!(content.authorizations.is_empty());
181 }
182
183 fn auth(
184 address: &str,
185 chain: Option<Chain>,
186 channels: Vec<&str>,
187 types: Vec<MessageType>,
188 post_types: Vec<&str>,
189 aggregate_keys: Vec<&str>,
190 ) -> Authorization {
191 Authorization {
192 address: Address::from(address.to_string()),
193 chain,
194 channels: channels.into_iter().map(String::from).collect(),
195 types,
196 post_types: post_types.into_iter().map(String::from).collect(),
197 aggregate_keys: aggregate_keys.into_iter().map(String::from).collect(),
198 }
199 }
200
201 #[test]
202 fn try_merge_identical_returns_clone() {
203 let a = auth(
204 "0xD",
205 None,
206 vec![],
207 vec![MessageType::Post],
208 vec![],
209 vec!["k1"],
210 );
211 let merged = a.try_merge(&a).unwrap();
212 assert_eq!(merged, a);
213 }
214
215 #[test]
216 fn try_merge_set_equal_different_order() {
217 let a = auth("0xD", None, vec![], vec![], vec![], vec!["A", "B"]);
218 let b = auth("0xD", None, vec![], vec![], vec![], vec!["B", "A"]);
219 let merged = a.try_merge(&b).expect("set-equal entries merge");
220 assert_eq!(
221 merged
222 .aggregate_keys
223 .iter()
224 .collect::<std::collections::HashSet<_>>(),
225 ["A".to_string(), "B".to_string()].iter().collect()
226 );
227 }
228
229 #[test]
230 fn try_merge_set_equal_with_duplicates() {
231 let a = auth("0xD", None, vec![], vec![], vec![], vec!["A", "A", "B"]);
232 let b = auth("0xD", None, vec![], vec![], vec![], vec!["A", "B"]);
233 assert!(a.try_merge(&b).is_some());
234 }
235
236 #[test]
237 fn try_merge_aggregate_keys_one_differs_unions() {
238 let a = auth("0xD", None, vec![], vec![], vec![], vec!["A", "B"]);
239 let b = auth("0xD", None, vec![], vec![], vec![], vec!["C"]);
240 let merged = a.try_merge(&b).expect("aggregate_keys merge");
241 let mut got = merged.aggregate_keys.clone();
242 got.sort();
243 assert_eq!(got, vec!["A".to_string(), "B".to_string(), "C".to_string()]);
244 assert!(merged.channels.is_empty());
245 assert!(merged.types.is_empty());
246 assert!(merged.post_types.is_empty());
247 }
248
249 #[test]
250 fn try_merge_post_types_one_differs_unions() {
251 let a = auth("0xD", None, vec![], vec![], vec!["blog"], vec![]);
252 let b = auth("0xD", None, vec![], vec![], vec!["comment"], vec![]);
253 let merged = a.try_merge(&b).expect("post_types merge");
254 let mut got = merged.post_types.clone();
255 got.sort();
256 assert_eq!(got, vec!["blog".to_string(), "comment".to_string()]);
257 }
258
259 #[test]
260 fn try_merge_channels_one_differs_unions() {
261 let a = auth("0xD", None, vec!["c1"], vec![], vec![], vec![]);
262 let b = auth("0xD", None, vec!["c2"], vec![], vec![], vec![]);
263 let merged = a.try_merge(&b).expect("channels merge");
264 let mut got = merged.channels.clone();
265 got.sort();
266 assert_eq!(got, vec!["c1".to_string(), "c2".to_string()]);
267 }
268
269 #[test]
270 fn try_merge_types_one_differs_unions() {
271 let a = auth("0xD", None, vec![], vec![MessageType::Post], vec![], vec![]);
272 let b = auth(
273 "0xD",
274 None,
275 vec![],
276 vec![MessageType::Aggregate],
277 vec![],
278 vec![],
279 );
280 let merged = a.try_merge(&b).expect("types merge");
281 assert!(merged.types.contains(&MessageType::Post));
282 assert!(merged.types.contains(&MessageType::Aggregate));
283 assert_eq!(merged.types.len(), 2);
284 }
285
286 #[test]
287 fn try_merge_one_side_empty_returns_none() {
288 let a = auth("0xD", None, vec![], vec![], vec![], vec!["A"]);
289 let b = auth("0xD", None, vec![], vec![], vec![], vec![]);
290 assert!(
291 a.try_merge(&b).is_none(),
292 "restricted vs wildcard must not merge"
293 );
294 assert!(
295 b.try_merge(&a).is_none(),
296 "wildcard vs restricted must not merge"
297 );
298 }
299
300 #[test]
301 fn try_merge_two_fields_differ_returns_none() {
302 let a = auth(
303 "0xD",
304 None,
305 vec!["c1"],
306 vec![MessageType::Post],
307 vec![],
308 vec![],
309 );
310 let b = auth(
311 "0xD",
312 None,
313 vec!["c2"],
314 vec![MessageType::Aggregate],
315 vec![],
316 vec![],
317 );
318 assert!(a.try_merge(&b).is_none());
319 }
320
321 #[test]
322 fn try_merge_different_address_returns_none() {
323 let a = auth("0xD", None, vec![], vec![], vec![], vec!["A"]);
324 let b = auth("0xE", None, vec![], vec![], vec![], vec!["A"]);
325 assert!(a.try_merge(&b).is_none());
326 }
327
328 #[test]
329 fn try_merge_different_chain_some_vs_none_returns_none() {
330 let a = auth(
331 "0xD",
332 Some(Chain::Ethereum),
333 vec![],
334 vec![],
335 vec![],
336 vec!["A"],
337 );
338 let b = auth("0xD", None, vec![], vec![], vec![], vec!["A"]);
339 assert!(a.try_merge(&b).is_none());
340 }
341
342 #[test]
343 fn try_merge_different_chain_some_vs_some_returns_none() {
344 let a = auth(
345 "0xD",
346 Some(Chain::Ethereum),
347 vec![],
348 vec![],
349 vec![],
350 vec!["A"],
351 );
352 let b = auth("0xD", Some(Chain::Sol), vec![], vec![], vec![], vec!["B"]);
353 assert!(a.try_merge(&b).is_none());
354 }
355}