stry_common/
uri.rs

1//! A simple parser for database connection URIs.
2
3use std::{
4    collections::BTreeMap,
5    convert::TryFrom,
6    error::Error,
7    fmt::{self, Display, Formatter},
8    num::ParseIntError,
9    str::{FromStr, Utf8Error},
10};
11
12/// A URI connection string.
13///
14/// # Format
15///
16/// ```not_rust
17/// scheme://[username:password@]host[:port1][,...hostN[:portN]][/[database][?options]]
18/// ```
19///
20/// # Warning
21///
22/// The parser does **not** support IPv6 or Unix pipes, manually build
23/// the config instead.
24#[derive(Debug, PartialEq)]
25pub struct Uri {
26    pub scheme: String,
27
28    pub username: Option<String>,
29    pub password: Option<String>,
30
31    pub hosts: Vec<String>,
32    pub ports: Vec<u16>,
33
34    pub database: Option<String>,
35
36    pub options: Option<BTreeMap<String, String>>,
37}
38
39impl Uri {
40    pub fn new<S, H>(scheme: S, host: H, port: u16) -> Uri
41    where
42        S: Into<String>,
43        H: Into<String>,
44    {
45        Uri {
46            scheme: scheme.into(),
47            username: None,
48            password: None,
49            hosts: vec![host.into()],
50            ports: vec![port],
51            database: None,
52            options: None,
53        }
54    }
55
56    pub fn parse<S>(text: S) -> Result<Self, UriError>
57    where
58        S: AsRef<str>,
59    {
60        let text = text.as_ref();
61
62        let config = Parser::parse(text)?;
63
64        Ok(config)
65    }
66
67    pub fn username<U>(mut self, username: U) -> Self
68    where
69        U: Into<String>,
70    {
71        self.username = Some(username.into());
72
73        self
74    }
75
76    pub fn password<P>(mut self, password: P) -> Self
77    where
78        P: Into<String>,
79    {
80        self.password = Some(password.into());
81
82        self
83    }
84
85    pub fn database<D>(mut self, database: D) -> Self
86    where
87        D: Into<String>,
88    {
89        self.database = Some(database.into());
90
91        self
92    }
93
94    pub fn option<K, V>(mut self, key: K, value: V) -> Self
95    where
96        K: Into<String>,
97        V: Into<String>,
98    {
99        let tree = self.options.get_or_insert_with(BTreeMap::new);
100
101        tree.insert(key.into(), value.into());
102
103        self
104    }
105}
106
107impl FromStr for Uri {
108    type Err = UriError;
109
110    fn from_str(s: &str) -> Result<Self, Self::Err> {
111        let config = Uri::parse(s)?;
112
113        Ok(config)
114    }
115}
116
117impl<'s> TryFrom<&'s str> for Uri {
118    type Error = UriError;
119
120    fn try_from(value: &'s str) -> Result<Self, Self::Error> {
121        let config = Uri::parse(value)?;
122
123        Ok(config)
124    }
125}
126
127impl TryFrom<String> for Uri {
128    type Error = UriError;
129
130    fn try_from(value: String) -> Result<Self, Self::Error> {
131        let config = Uri::parse(value)?;
132
133        Ok(config)
134    }
135}
136
137// A macro version of `take_until` as the pattern api isn't stable yet
138macro_rules! take_until {
139    ($text:expr, $patt:expr) => {{
140        match $text.find($patt) {
141            Some(index) => {
142                let (head, tail) = $text.split_at(index);
143
144                $text = tail;
145
146                Some(head)
147            }
148            None => None,
149        }
150    }};
151}
152
153// This parser is based off the one from tokio-postgres.
154// Just needed one that didn't have too many dependencies.
155struct Parser<'s> {
156    text: &'s str,
157}
158
159impl<'s> Parser<'s> {
160    fn parse(text: &'s str) -> Result<Uri, UriError> {
161        let mut parser = Parser { text };
162
163        let scheme = take_until!(parser.text, ':').ok_or(UriError::MissingScheme)?;
164
165        parser.eat(':')?;
166        parser.eat('/')?;
167        parser.eat('/')?;
168
169        let (username, password) = if parser.text.contains('@') {
170            parser.parse_credentials()?
171        } else {
172            (None, None)
173        };
174        let (hosts, ports) = parser.parse_hosts()?;
175        let database = parser.parse_path();
176        let options = parser.parse_params()?;
177
178        Ok(Uri {
179            scheme: scheme.to_string(),
180
181            username,
182            password,
183
184            hosts,
185            ports,
186
187            database,
188
189            options,
190        })
191    }
192
193    fn eat(&mut self, target: char) -> Result<(), UriError> {
194        if self.text.starts_with(target) {
195            let (_, tail) = self.text.split_at(1);
196
197            self.text = tail;
198
199            Ok(())
200        } else {
201            Err(UriError::UnexpectedCharacter {
202                expected: target,
203                got: self.text.chars().next().unwrap(),
204            })
205        }
206    }
207
208    fn parse_credentials(&mut self) -> Result<(Option<String>, Option<String>), UriError> {
209        match take_until!(self.text, '@') {
210            Some(taken) => {
211                let mut it = taken.splitn(2, ':');
212
213                let username = it.next().ok_or(UriError::MissingUsername)?;
214                let password = percent_encoding::percent_decode(
215                    it.next().ok_or(UriError::MissingPassword)?.as_bytes(),
216                );
217
218                self.eat('@')?;
219
220                Ok((
221                    Some(username.to_string()),
222                    Some(password.decode_utf8()?.to_string()),
223                ))
224            }
225            None => Ok((None, None)),
226        }
227    }
228
229    fn parse_hosts(&mut self) -> Result<(Vec<String>, Vec<u16>), UriError> {
230        match take_until!(self.text, &['/', '?'] as &[char]) {
231            Some(taken) => {
232                let pairs = taken.split(',');
233
234                let mut hosts = Vec::new();
235                let mut ports = Vec::new();
236
237                for pair in pairs {
238                    if let Some(index) = pair.find(':') {
239                        let (head, tail) = pair.split_at(index);
240
241                        hosts.push(head.to_string());
242                        ports.push(
243                            (tail[1..])
244                                .parse()
245                                .map_err(|err| (tail[1..].to_string(), err))?,
246                        );
247                    }
248                }
249
250                Ok((hosts, ports))
251            }
252            None => {
253                if self.text.is_empty() {
254                    Err(UriError::MissingHostPort)
255                } else {
256                    let mut hosts = Vec::new();
257                    let mut ports = Vec::new();
258
259                    if let Some(index) = self.text.find(':') {
260                        let (head, tail) = self.text.split_at(index);
261
262                        hosts.push(head.to_string());
263                        ports.push(
264                            (tail[1..])
265                                .parse()
266                                .map_err(|err| (tail[1..].to_string(), err))?,
267                        );
268                    }
269
270                    Ok((hosts, ports))
271                }
272            }
273        }
274    }
275
276    fn parse_path(&mut self) -> Option<String> {
277        if self.text.starts_with('/') {
278            self.text = &self.text[1..];
279
280            if self.text.is_empty() {
281                None
282            } else if let Some(index) = self.text.find('?') {
283                let (head, tail) = self.text.split_at(index);
284
285                self.text = tail;
286
287                Some(String::from(head))
288            } else {
289                Some(String::from(self.text))
290            }
291        } else {
292            None
293        }
294    }
295
296    fn parse_params(&mut self) -> Result<Option<BTreeMap<String, String>>, UriError> {
297        if self.text.starts_with('?') {
298            self.text = &self.text[1..];
299
300            let mut tree = BTreeMap::new();
301
302            for pair in self.text.split('&') {
303                let mut splitter = pair.split('=');
304
305                if let (Some(key), Some(value)) = (splitter.next(), splitter.next()) {
306                    let key = percent_encoding::percent_decode(key.as_bytes()).decode_utf8()?;
307                    let value = percent_encoding::percent_decode(value.as_bytes()).decode_utf8()?;
308
309                    tree.insert(key.to_string(), value.to_string());
310                }
311            }
312
313            Ok(if tree.is_empty() { None } else { Some(tree) })
314        } else {
315            Ok(None)
316        }
317    }
318}
319
320#[derive(Debug, PartialEq)]
321pub enum UriError {
322    InvalidHostPort { port: String, err: ParseIntError },
323    InvalidEncoding { err: Utf8Error },
324    MissingScheme,
325    MissingUsername,
326    MissingPassword,
327    MissingHostPort,
328    UnexpectedEof,
329    UnexpectedCharacter { expected: char, got: char },
330}
331
332impl Display for UriError {
333    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
334        match self {
335            UriError::InvalidHostPort { port, .. } => write!(f, "invalid host port: `{}`", port)?,
336            UriError::InvalidEncoding { .. } => write!(f, "invalid param encoding")?,
337            UriError::MissingScheme => write!(f, "missing scheme")?,
338            UriError::MissingUsername => write!(f, "missing username from credentials")?,
339            UriError::MissingPassword => write!(f, "missing password from credentials")?,
340            UriError::MissingHostPort => write!(f, "missing host and or port")?,
341            UriError::UnexpectedEof => write!(f, "unexpected EOF")?,
342            UriError::UnexpectedCharacter { expected, got } => write!(
343                f,
344                "unexpected character: expected `{}` but got `{}`",
345                expected, got,
346            )?,
347        }
348
349        Ok(())
350    }
351}
352
353impl Error for UriError {}
354
355impl From<(String, ParseIntError)> for UriError {
356    fn from((port, err): (String, ParseIntError)) -> Self {
357        Self::InvalidHostPort { port, err }
358    }
359}
360
361impl From<Utf8Error> for UriError {
362    fn from(err: Utf8Error) -> Self {
363        Self::InvalidEncoding { err }
364    }
365}
366
367#[cfg(test)]
368mod test {
369    use super::*;
370
371    #[test]
372    fn test_minimal_no_encoding() {
373        let expected = Uri::new("postgres", "localhost", 54123);
374        let actual = Uri::parse("postgres://localhost:54123");
375
376        assert_eq!(Ok(expected), actual);
377    }
378
379    #[test]
380    fn test_options_encoding() {
381        let expected = Uri::new("postgres", "localhost", 54123).option("with a space", "for sure");
382        let actual = Uri::parse("postgres://localhost:54123?with%20a%20space=for%20sure");
383
384        assert_eq!(Ok(expected), actual);
385    }
386
387    #[test]
388    fn test_all_no_encoding() {
389        let expected = Uri::new("postgres", "localhost", 54123)
390            .username("username")
391            .password("password")
392            .database("database")
393            .option("tls", "true");
394        let actual = Uri::parse("postgres://username:password@localhost:54123/database?tls=true");
395
396        assert_eq!(Ok(expected), actual);
397    }
398}