Skip to main content

amq_protocol_uri/
lib.rs

1#![deny(missing_docs)]
2
3//! # AMQP URI manipulation library
4//!
5//! amq-protocol-uri is a library aiming at providing tools to help
6//! managing AMQP URIs
7
8use amq_protocol_types::{ChannelId, FrameSize, Heartbeat};
9use url::Url;
10
11use std::{fmt, num::ParseIntError, str::FromStr};
12
13/// An AMQP Uri
14#[derive(Clone, Debug, PartialEq, Eq)]
15pub struct AMQPUri {
16    /// The scheme used by the AMQP connection
17    pub scheme: AMQPScheme,
18    /// The connection information
19    pub authority: AMQPAuthority,
20    /// The target vhost
21    pub vhost: String,
22    /// The optional query string to pass parameters to the server
23    pub query: AMQPQueryString,
24}
25
26/// The scheme used by the AMQP connection
27#[derive(Clone, Debug, Default, PartialEq, Eq)]
28pub enum AMQPScheme {
29    /// Plain AMQP
30    #[default]
31    AMQP,
32    /// Encrypted AMQP over TLS
33    AMQPS,
34}
35
36impl FromStr for AMQPScheme {
37    type Err = String;
38
39    fn from_str(s: &str) -> Result<Self, Self::Err> {
40        match s {
41            "amqp" => Ok(AMQPScheme::AMQP),
42            "amqps" => Ok(AMQPScheme::AMQPS),
43            s => Err(format!("Invalid AMQP scheme: {s}")),
44        }
45    }
46}
47
48impl fmt::Display for AMQPScheme {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        f.write_str(match self {
51            AMQPScheme::AMQP => "amqp",
52            AMQPScheme::AMQPS => "amqps",
53        })
54    }
55}
56
57/// The connection information
58#[derive(Clone, Debug, PartialEq, Eq)]
59pub struct AMQPAuthority {
60    /// The credentials used to connect to the server
61    pub userinfo: AMQPUserInfo,
62    /// The server's host
63    pub host: String,
64    /// The port the server listens on
65    pub port: u16,
66}
67
68/// The credentials used to connect to the server
69#[derive(Clone, Debug, PartialEq, Eq)]
70pub struct AMQPUserInfo {
71    /// The username
72    pub username: String,
73    /// The password
74    pub password: String,
75}
76
77/// The optional query string to pass parameters to the server
78#[derive(Clone, Debug, Default, PartialEq, Eq)]
79pub struct AMQPQueryString {
80    /// The maximum size of an AMQP Frame
81    pub frame_max: Option<FrameSize>,
82    /// The maximum number of open channels
83    pub channel_max: Option<ChannelId>,
84    /// The maximum time between two heartbeats
85    pub heartbeat: Option<Heartbeat>,
86    /// The maximum time to wait (in milliseconds) for the connection to succeed
87    pub connection_timeout: Option<u64>,
88    /// The SASL mechanism used for authentication
89    pub auth_mechanism: Option<SASLMechanism>,
90    // Fields available in Erlang implementation for SSL settings:
91    // cacertfile, certfile, keyfile, verify, fail_if_no_peer_cert, password,
92    // server_name_indication, depth
93}
94
95/// The SASL mechanisms supported by RabbitMQ
96#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
97pub enum SASLMechanism {
98    /// This is a legacy mechanism kept for backward compatibility
99    AMQPlain,
100    /// Anonymous authentication if supported by the RabbitMQ server
101    Anonymous,
102    /// Delegate all authentication to the transport instead of the RabbitMQ server
103    External,
104    /// Default plain login, this should be supported everywhere
105    #[default]
106    Plain,
107    /// A demo of RabbitMQ SecureOk mechanism, offers the same level of security as Plain
108    RabbitCrDemo,
109}
110
111impl SASLMechanism {
112    /// Get the name of the SASL mechanism as str
113    pub fn name(&self) -> &'static str {
114        match self {
115            SASLMechanism::AMQPlain => "AMQPLAIN",
116            SASLMechanism::Anonymous => "ANONYMOUS",
117            SASLMechanism::External => "EXTERNAL",
118            SASLMechanism::Plain => "PLAIN",
119            SASLMechanism::RabbitCrDemo => "RABBIT-CR-DEMO",
120        }
121    }
122}
123
124impl fmt::Display for SASLMechanism {
125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126        f.write_str(self.name())
127    }
128}
129
130impl FromStr for SASLMechanism {
131    type Err = String;
132
133    fn from_str(s: &str) -> Result<Self, Self::Err> {
134        match s.to_lowercase().as_str() {
135            "amqplain" => Ok(SASLMechanism::AMQPlain),
136            "anonymous" => Ok(SASLMechanism::Anonymous),
137            "external" => Ok(SASLMechanism::External),
138            "plain" => Ok(SASLMechanism::Plain),
139            "rabbit-cr-demo" => Ok(SASLMechanism::RabbitCrDemo),
140            s => Err(format!("Invalid SASL mechanism: {s}")),
141        }
142    }
143}
144
145fn percent_decode(s: &str) -> Result<String, String> {
146    percent_encoding::percent_decode(s.as_bytes())
147        .decode_utf8()
148        .map(|s| s.to_string())
149        .map_err(|e| e.to_string())
150}
151
152fn percent_encode<'a>(s: &'a str) -> percent_encoding::PercentEncode<'a> {
153    percent_encoding::utf8_percent_encode(s, percent_encoding::NON_ALPHANUMERIC)
154}
155
156impl Default for AMQPUri {
157    fn default() -> Self {
158        AMQPUri {
159            scheme: Default::default(),
160            authority: Default::default(),
161            vhost: "/".to_string(),
162            query: Default::default(),
163        }
164    }
165}
166
167fn int_queryparam<T: FromStr<Err = ParseIntError>>(
168    url: &Url,
169    param: &str,
170) -> Result<Option<T>, String> {
171    url.query_pairs()
172        .find(|(key, _)| key == param)
173        .map_or(Ok(None), |(_, ref value)| value.parse::<T>().map(Some))
174        .map_err(|e: ParseIntError| e.to_string())
175}
176
177impl FromStr for AMQPUri {
178    type Err = String;
179
180    fn from_str(s: &str) -> Result<Self, Self::Err> {
181        let url = Url::parse(s).map_err(|e| e.to_string())?;
182        if url.cannot_be_a_base() {
183            return Err(format!("Invalid URL: '{s}'"));
184        }
185        let default = AMQPUri::default();
186        let scheme = url.scheme().parse::<AMQPScheme>()?;
187        let username = match url.username() {
188            "" => default.authority.userinfo.username,
189            username => percent_decode(username)?,
190        };
191        let password = url
192            .password()
193            .map_or(Ok(default.authority.userinfo.password), percent_decode)?;
194        let host = url
195            .domain()
196            .map_or(Ok(default.authority.host), percent_decode)?;
197        let port = url.port().unwrap_or_else(|| scheme.default_port());
198        let vhost = percent_decode(url.path().get(1..).unwrap_or("/"))?;
199        let frame_max = int_queryparam(&url, "frame_max")?;
200        let channel_max = int_queryparam(&url, "channel_max")?;
201        let heartbeat = int_queryparam(&url, "heartbeat")?;
202        let connection_timeout = int_queryparam(&url, "connection_timeout")?;
203        let auth_mechanism = url
204            .query_pairs()
205            .find(|(key, _)| key == "auth_mechanism")
206            .map_or(Ok(None), |(_, ref value)| value.parse().map(Some))?;
207
208        Ok(AMQPUri {
209            scheme,
210            authority: AMQPAuthority {
211                userinfo: AMQPUserInfo { username, password },
212                host,
213                port,
214            },
215            vhost,
216            query: AMQPQueryString {
217                frame_max,
218                channel_max,
219                heartbeat,
220                connection_timeout,
221                auth_mechanism,
222            },
223        })
224    }
225}
226
227impl fmt::Display for AMQPUri {
228    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229        write!(
230            f,
231            "{}://{}:{}@{}:{}/{}",
232            self.scheme,
233            percent_encode(&self.authority.userinfo.username),
234            percent_encode(&self.authority.userinfo.password),
235            self.authority.host,
236            self.authority.port,
237            percent_encode(&self.vhost),
238        )?;
239        let mut sep = '?';
240        if let Some(v) = self.query.frame_max {
241            write!(f, "{sep}frame_max={v}")?;
242            sep = '&';
243        }
244        if let Some(v) = self.query.channel_max {
245            write!(f, "{sep}channel_max={v}")?;
246            sep = '&';
247        }
248        if let Some(v) = self.query.heartbeat {
249            write!(f, "{sep}heartbeat={v}")?;
250            sep = '&';
251        }
252        if let Some(v) = self.query.connection_timeout {
253            write!(f, "{sep}connection_timeout={v}")?;
254            sep = '&';
255        }
256        if let Some(v) = self.query.auth_mechanism {
257            write!(f, "{sep}auth_mechanism={v}")?;
258        }
259        Ok(())
260    }
261}
262
263impl AMQPScheme {
264    /// The default port for this scheme
265    pub fn default_port(&self) -> u16 {
266        match *self {
267            AMQPScheme::AMQP => 5672,
268            AMQPScheme::AMQPS => 5671,
269        }
270    }
271}
272
273impl Default for AMQPAuthority {
274    fn default() -> Self {
275        AMQPAuthority {
276            userinfo: Default::default(),
277            host: "localhost".to_string(),
278            port: AMQPScheme::default().default_port(),
279        }
280    }
281}
282
283impl Default for AMQPUserInfo {
284    fn default() -> Self {
285        AMQPUserInfo {
286            username: "guest".to_string(),
287            password: "guest".to_string(),
288        }
289    }
290}
291
292#[cfg(test)]
293mod test {
294    use super::*;
295
296    #[test]
297    fn test_parse_amqp_no_path() {
298        let uri = "amqp://localhost".parse();
299        assert_eq!(uri, Ok(AMQPUri::default()));
300    }
301
302    #[test]
303    fn test_parse_amqp() {
304        let uri = "amqp://localhost/%2f".parse();
305        assert_eq!(uri, Ok(AMQPUri::default()));
306    }
307
308    #[test]
309    fn test_parse_amqps() {
310        let uri = "amqps://localhost/".parse();
311        assert_eq!(
312            uri,
313            Ok(AMQPUri {
314                scheme: AMQPScheme::AMQPS,
315                authority: AMQPAuthority {
316                    port: 5671,
317                    ..Default::default()
318                },
319                vhost: "".to_string(),
320                ..Default::default()
321            })
322        );
323    }
324
325    #[test]
326    fn test_parse_amqps_with_creds() {
327        let uri = "amqps://user:pass@hostname/v?foo=bar".parse();
328        assert_eq!(
329            uri,
330            Ok(AMQPUri {
331                scheme: AMQPScheme::AMQPS,
332                authority: AMQPAuthority {
333                    userinfo: AMQPUserInfo {
334                        username: "user".to_string(),
335                        password: "pass".to_string(),
336                    },
337                    host: "hostname".to_string(),
338                    port: 5671,
339                },
340                vhost: "v".to_string(),
341                ..Default::default()
342            })
343        );
344    }
345
346    #[test]
347    fn test_parse_amqps_with_creds_percent() {
348        let uri = "amqp://user%61:%61pass@ho%61st:10000/v%2fhost".parse();
349        assert_eq!(
350            uri,
351            Ok(AMQPUri {
352                scheme: AMQPScheme::AMQP,
353                authority: AMQPAuthority {
354                    userinfo: AMQPUserInfo {
355                        username: "usera".to_string(),
356                        password: "apass".to_string(),
357                    },
358                    host: "hoast".to_string(),
359                    port: 10000,
360                },
361                vhost: "v/host".to_string(),
362                ..Default::default()
363            })
364        );
365    }
366
367    #[test]
368    fn test_parse_with_heartbeat_frame_max() {
369        let uri = "amqp://localhost/%2f?heartbeat=42&frame_max=64&connection_timeout=30000".parse();
370        assert_eq!(
371            uri,
372            Ok(AMQPUri {
373                query: AMQPQueryString {
374                    frame_max: Some(64),
375                    heartbeat: Some(42),
376                    connection_timeout: Some(30000),
377                    ..Default::default()
378                },
379                ..Default::default()
380            })
381        );
382    }
383
384    #[test]
385    fn test_url_with_no_base() {
386        let uri: Result<AMQPUri, String> = "foo".parse();
387        assert_eq!(uri, Err("relative URL without a base".to_string()));
388    }
389
390    #[test]
391    fn test_invalid_url() {
392        let uri: Result<AMQPUri, String> = "foo:bar".parse();
393        assert_eq!(uri, Err("Invalid URL: 'foo:bar'".to_string()));
394    }
395
396    #[test]
397    fn test_invalid_scheme() {
398        let uri: Result<AMQPUri, String> = "http://localhost/".parse();
399        assert_eq!(uri, Err("Invalid AMQP scheme: http".to_string()));
400    }
401}