1use percent_encoding::{utf8_percent_encode, AsciiSet, CONTROLS};
2use secure_string::SecureString;
3use serde::de::{DeserializeSeed, Error, MapAccess, Visitor};
4use serde::{Deserialize, Deserializer};
5use std::any::type_name;
6use std::borrow::Cow;
7use std::collections::HashMap;
8use std::fmt::{Debug, Display, Formatter};
9use std::sync::Arc;
10use strut_deserialize::{Slug, SlugMap};
11use strut_factory::impl_deserialize_field;
12use strut_util::BackoffConfig;
13
14const VHOST_ENCODE_SET: &AsciiSet = &CONTROLS
15 .add(b'/') .add(b'?') .add(b'#') .add(b'%'); #[derive(Debug, Default, Clone, PartialEq)]
22pub struct HandleCollection {
23 handles: SlugMap<Handle>,
24}
25
26#[derive(Clone, PartialEq)]
32pub struct Handle {
33 name: Arc<str>,
34 identifier: Arc<str>,
35 dsn: SecureString,
36 backoff: BackoffConfig,
37}
38
39pub struct DsnChunks<H, U, P, VH>
42where
43 H: AsRef<str>,
44 U: AsRef<str>,
45 P: Into<SecureString>,
46 VH: AsRef<str>,
47{
48 pub host: H,
50 pub port: u16,
52 pub user: U,
54 pub password: P,
59 pub vhost: VH,
65}
66
67impl Handle {
68 pub fn new<H, U, P, VH>(name: impl AsRef<str>, chunks: DsnChunks<H, U, P, VH>) -> Self
75 where
76 H: AsRef<str>,
77 U: AsRef<str>,
78 P: Into<SecureString>,
79 VH: AsRef<str>,
80 {
81 let name = Arc::from(name.as_ref());
82
83 let vhost = Self::ensure_encoded_vhost(chunks.vhost.as_ref());
84 let identifier = Self::compose_identifier(
85 chunks.host.as_ref(),
86 chunks.port,
87 chunks.user.as_ref(),
88 vhost.as_ref(),
89 );
90
91 let password = chunks.password.into();
92 let dsn = Self::compose_dsn(
93 chunks.host.as_ref(),
94 chunks.port,
95 chunks.user.as_ref(),
96 &password,
97 vhost.as_ref(),
98 );
99
100 let backoff = BackoffConfig::default();
101
102 Self {
103 name,
104 identifier,
105 dsn,
106 backoff,
107 }
108 }
109
110 pub fn with_backoff(self, backoff: BackoffConfig) -> Self {
112 Self { backoff, ..self }
113 }
114
115 fn ensure_encoded_vhost(vhost: &str) -> Cow<'_, str> {
118 utf8_percent_encode(vhost, VHOST_ENCODE_SET).into()
119 }
120
121 fn compose_identifier(host: &str, port: u16, user: &str, vhost: &str) -> Arc<str> {
123 Arc::from(format!("{}@{}:{}/{}", user, host, port, vhost))
124 }
125
126 fn compose_dsn(
128 host: &str,
129 port: u16,
130 user: &str,
131 password: &SecureString,
132 vhost: &str,
133 ) -> SecureString {
134 SecureString::from(format!(
135 "amqp://{}:{}@{}:{}/{}",
136 user,
137 password.unsecure(),
138 host,
139 port,
140 vhost,
141 ))
142 }
143}
144
145impl HandleCollection {
146 pub fn contains(&self, name: &str) -> bool {
149 self.handles.contains_key(name)
150 }
151
152 pub fn get(&self, name: &str) -> Option<&Handle> {
156 self.handles.get(name)
157 }
158
159 pub fn expect(&self, name: &str) -> &Handle {
162 self.get(name)
163 .unwrap_or_else(|| panic!("requested an undefined RabbitMQ handle '{}'", name))
164 }
165}
166
167impl Handle {
168 pub fn name(&self) -> &str {
170 &self.name
171 }
172
173 pub fn identifier(&self) -> &str {
177 &self.identifier
178 }
179
180 pub fn dsn(&self) -> &SecureString {
182 &self.dsn
183 }
184
185 pub fn backoff(&self) -> &BackoffConfig {
188 &self.backoff
189 }
190}
191
192impl Default for DsnChunks<&str, &str, &str, &str> {
194 fn default() -> Self {
195 Self {
196 host: Handle::default_host(),
197 port: Handle::default_port(),
198 user: Handle::default_user(),
199 password: Handle::default_password(),
200 vhost: Handle::default_vhost(),
201 }
202 }
203}
204
205impl Handle {
206 fn default_name() -> &'static str {
207 "default"
208 }
209
210 fn default_host() -> &'static str {
211 "localhost"
212 }
213
214 fn default_port() -> u16 {
215 5672
216 }
217
218 fn default_user() -> &'static str {
219 "guest"
220 }
221
222 fn default_password() -> &'static str {
223 "guest"
224 }
225
226 fn default_vhost() -> &'static str {
227 "/"
228 }
229}
230
231impl Default for Handle {
232 fn default() -> Self {
233 Self::new(Self::default_name(), DsnChunks::default())
234 }
235}
236
237impl Debug for Handle {
240 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
241 f.debug_struct(type_name::<Self>())
242 .field("name", &self.name)
243 .field("identifier", &self.identifier)
244 .field("backoff", &self.backoff)
245 .finish()
246 }
247}
248
249impl Display for Handle {
250 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
251 f.write_str(&self.identifier)
252 }
253}
254
255impl AsRef<Handle> for Handle {
256 fn as_ref(&self) -> &Handle {
257 self
258 }
259}
260
261impl AsRef<HandleCollection> for HandleCollection {
262 fn as_ref(&self) -> &HandleCollection {
263 self
264 }
265}
266
267const _: () = {
268 impl<S> FromIterator<(S, Handle)> for HandleCollection
269 where
270 S: Into<Slug>,
271 {
272 fn from_iter<T: IntoIterator<Item = (S, Handle)>>(iter: T) -> Self {
273 let handles = iter.into_iter().collect();
274
275 Self { handles }
276 }
277 }
278
279 impl<const N: usize, S> From<[(S, Handle); N]> for HandleCollection
280 where
281 S: Into<Slug>,
282 {
283 fn from(value: [(S, Handle); N]) -> Self {
284 value.into_iter().collect()
285 }
286 }
287};
288
289const _: () = {
290 impl<'de> Deserialize<'de> for HandleCollection {
291 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
292 where
293 D: Deserializer<'de>,
294 {
295 deserializer.deserialize_map(HandleCollectionVisitor)
296 }
297 }
298
299 struct HandleCollectionVisitor;
300
301 impl<'de> Visitor<'de> for HandleCollectionVisitor {
302 type Value = HandleCollection;
303
304 fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
305 formatter.write_str("a map of RabbitMQ handles")
306 }
307
308 fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
309 where
310 A: MapAccess<'de>,
311 {
312 let grouped = Slug::group_map(map)?;
313 let mut handles = HashMap::with_capacity(grouped.len());
314
315 for (key, value) in grouped {
316 let seed = HandleSeed {
317 name: key.original(),
318 };
319 let handle = seed.deserialize(value).map_err(Error::custom)?;
320 handles.insert(key, handle);
321 }
322
323 Ok(HandleCollection {
324 handles: SlugMap::new(handles),
325 })
326 }
327 }
328
329 struct HandleSeed<'a> {
330 name: &'a str,
331 }
332
333 impl<'de> DeserializeSeed<'de> for HandleSeed<'_> {
334 type Value = Handle;
335
336 fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
337 where
338 D: Deserializer<'de>,
339 {
340 deserializer.deserialize_map(HandleSeedVisitor { name: self.name })
341 }
342 }
343
344 struct HandleSeedVisitor<'a> {
345 name: &'a str,
346 }
347
348 impl<'de> Visitor<'de> for HandleSeedVisitor<'_> {
349 type Value = Handle;
350
351 fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
352 formatter.write_str("a map of RabbitMQ handle")
353 }
354
355 fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
356 where
357 A: MapAccess<'de>,
358 {
359 visit_handle(map, Some(self.name))
360 }
361 }
362
363 impl<'de> Deserialize<'de> for Handle {
364 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
365 where
366 D: Deserializer<'de>,
367 {
368 deserializer.deserialize_map(HandleVisitor)
369 }
370 }
371
372 struct HandleVisitor;
373
374 impl<'de> Visitor<'de> for HandleVisitor {
375 type Value = Handle;
376
377 fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
378 formatter.write_str("a map of RabbitMQ handle")
379 }
380
381 fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
382 where
383 A: MapAccess<'de>,
384 {
385 visit_handle(map, None)
386 }
387 }
388
389 fn visit_handle<'de, A>(mut map: A, known_name: Option<&str>) -> Result<Handle, A::Error>
390 where
391 A: MapAccess<'de>,
392 {
393 let mut name: Option<String> = None;
396 let mut host: Option<String> = None;
397 let mut port = None;
398 let mut user: Option<String> = None;
399 let mut password: Option<SecureString> = None;
400 let mut vhost: Option<String> = None;
401
402 while let Some(key) = map.next_key()? {
403 match key {
404 HandleField::name => key.poll(&mut map, &mut name)?,
405 HandleField::host => key.poll(&mut map, &mut host)?,
406 HandleField::port => key.poll(&mut map, &mut port)?,
407 HandleField::user => key.poll(&mut map, &mut user)?,
408 HandleField::password => key.poll(&mut map, &mut password)?,
409 HandleField::vhost => key.poll(&mut map, &mut vhost)?,
410 HandleField::__ignore => map.next_value()?,
411 };
412 }
413
414 let name = match known_name {
415 Some(known_name) => known_name,
416 None => name.as_deref().unwrap_or_else(|| Handle::default_name()),
417 };
418
419 let chunks = DsnChunks {
421 host: host.as_deref().unwrap_or_else(|| Handle::default_host()),
422 port: port.unwrap_or_else(Handle::default_port),
423 user: user.as_deref().unwrap_or_else(|| Handle::default_user()),
424 password: password.unwrap_or_else(|| Handle::default_password().into()),
425 vhost: vhost.as_deref().unwrap_or_else(|| Handle::default_vhost()),
426 };
427
428 Ok(Handle::new(name, chunks))
429 }
430
431 impl_deserialize_field!(
432 HandleField,
433 strut_deserialize::Slug::eq_as_slugs,
434 name,
435 host | hostname,
436 port,
437 user | username,
438 password,
439 vhost,
440 );
441};
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446 use pretty_assertions::assert_eq;
447
448 #[test]
449 fn deserialize_from_empty() {
450 let input = "";
452 let expected_output = Handle::default();
453
454 let actual_output = serde_yml::from_str::<Handle>(input).unwrap();
456
457 assert_eq!(expected_output, actual_output);
459 }
460
461 #[test]
462 fn deserialize_from_full() {
463 let input = r#"
465name: test_handle
466host: test_host
467port: 8080
468user: test_user
469password: test_password
470vhost: test_vhost
471"#;
472 let expected_output = Handle::new(
473 "test_handle",
474 DsnChunks {
475 host: "test_host",
476 port: 8080,
477 user: "test_user",
478 password: "test_password",
479 vhost: "test_vhost",
480 },
481 );
482
483 let actual_output = serde_yml::from_str::<Handle>(input).unwrap();
485
486 assert_eq!(expected_output, actual_output);
488 }
489
490 #[test]
491 fn deserialize_collection_from_empty() {
492 let input = "";
494 let expected_output = HandleCollection::default();
495
496 let actual_output = serde_yml::from_str::<HandleCollection>(input).unwrap();
498
499 assert_eq!(expected_output, actual_output);
501 }
502
503 #[test]
504 fn deserialize_collection_from_full() {
505 let input = r#"
507test_handle_a: {}
508test_handle_b:
509 host: test_host
510 port: 8080
511"#;
512 let expected_output = HandleCollection::from([
513 (
514 "test_handle_a",
515 Handle::new("test_handle_a", DsnChunks::default()),
516 ),
517 (
518 "test_handle_b",
519 Handle::new(
520 "test_handle_b",
521 DsnChunks {
522 host: "test_host",
523 port: 8080,
524 ..Default::default()
525 },
526 ),
527 ),
528 ]);
529
530 let actual_output = serde_yml::from_str::<HandleCollection>(input).unwrap();
532
533 assert_eq!(expected_output, actual_output);
535 }
536}